mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ae59a5e88a | |||
| 79e9b164f9 | |||
| 93888e56d1 | |||
| eff9bd7bd2 | |||
| bde1db1c3b |
@@ -0,0 +1,38 @@
|
||||
# Code Owners for traefik-oidc
|
||||
# These owners will be automatically requested for review when someone opens a PR
|
||||
|
||||
# Default owner for everything in the repo
|
||||
* @lukaszraczylo
|
||||
|
||||
# Core authentication and middleware
|
||||
/middleware/ @lukaszraczylo
|
||||
/auth/ @lukaszraczylo
|
||||
/handlers/ @lukaszraczylo
|
||||
|
||||
# OIDC providers
|
||||
/internal/providers/ @lukaszraczylo
|
||||
|
||||
# Session management and security
|
||||
/session/ @lukaszraczylo
|
||||
/internal/security/ @lukaszraczylo
|
||||
/security/ @lukaszraczylo
|
||||
|
||||
# Token management
|
||||
/internal/token/ @lukaszraczylo
|
||||
|
||||
# Configuration
|
||||
/config/ @lukaszraczylo
|
||||
/.traefik.yml @lukaszraczylo
|
||||
|
||||
# GitHub Actions and CI/CD
|
||||
/.github/ @lukaszraczylo
|
||||
/.github/workflows/ @lukaszraczylo
|
||||
/.golangci.yml @lukaszraczylo
|
||||
|
||||
# Documentation
|
||||
/docs/ @lukaszraczylo
|
||||
README.md @lukaszraczylo
|
||||
|
||||
# Dependencies
|
||||
go.mod @lukaszraczylo
|
||||
go.sum @lukaszraczylo
|
||||
@@ -0,0 +1,123 @@
|
||||
## Description
|
||||
|
||||
<!-- Provide a brief description of the changes in this PR -->
|
||||
|
||||
## Type of Change
|
||||
|
||||
<!-- Mark the relevant option with an "x" -->
|
||||
|
||||
- [ ] Bug fix (non-breaking change which fixes an issue)
|
||||
- [ ] New feature (non-breaking change which adds functionality)
|
||||
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
|
||||
- [ ] Documentation update
|
||||
- [ ] Performance improvement
|
||||
- [ ] Code refactoring
|
||||
- [ ] Security fix
|
||||
- [ ] Provider-specific fix/enhancement
|
||||
|
||||
## Related Issues
|
||||
|
||||
<!-- Link to related issues using #issue_number -->
|
||||
|
||||
Fixes #
|
||||
Related to #
|
||||
|
||||
## Changes Made
|
||||
|
||||
<!-- List the main changes made in this PR -->
|
||||
|
||||
-
|
||||
-
|
||||
-
|
||||
|
||||
## Provider Impact
|
||||
|
||||
<!-- If this affects specific OIDC providers, list them here -->
|
||||
|
||||
- [ ] Google
|
||||
- [ ] Azure AD
|
||||
- [ ] Auth0
|
||||
- [ ] Okta
|
||||
- [ ] Keycloak
|
||||
- [ ] AWS Cognito
|
||||
- [ ] GitLab
|
||||
- [ ] GitHub
|
||||
- [ ] Generic OIDC
|
||||
- [ ] All providers
|
||||
|
||||
## Testing Performed
|
||||
|
||||
<!-- Describe the tests you ran to verify your changes -->
|
||||
|
||||
- [ ] Unit tests pass locally
|
||||
- [ ] Integration tests pass locally
|
||||
- [ ] Race detector shows no issues
|
||||
- [ ] Memory leak tests pass
|
||||
- [ ] Manual testing performed
|
||||
|
||||
### Test Configuration
|
||||
|
||||
<!-- Provide details about your test configuration if applicable -->
|
||||
|
||||
**Provider tested:**
|
||||
**Go version:**
|
||||
**Traefik version:**
|
||||
|
||||
## Security Considerations
|
||||
|
||||
<!-- Describe any security implications of these changes -->
|
||||
|
||||
- [ ] This PR does not introduce security vulnerabilities
|
||||
- [ ] Security scanning has been performed
|
||||
- [ ] Credentials/secrets are properly handled
|
||||
- [ ] Input validation is implemented
|
||||
|
||||
## Performance Impact
|
||||
|
||||
<!-- Describe any performance implications -->
|
||||
|
||||
- [ ] No performance impact expected
|
||||
- [ ] Performance improved (describe how)
|
||||
- [ ] Performance may be affected (describe why and mitigation)
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
<!-- If this is a breaking change, describe the impact and migration path -->
|
||||
|
||||
**Breaking changes:**
|
||||
|
||||
|
||||
**Migration guide:**
|
||||
|
||||
|
||||
## Checklist
|
||||
|
||||
<!-- Ensure all items are checked before requesting review -->
|
||||
|
||||
- [ ] My code follows the project's code style
|
||||
- [ ] I have performed a self-review of my code
|
||||
- [ ] I have commented my code, particularly in hard-to-understand areas
|
||||
- [ ] I have made corresponding changes to the documentation
|
||||
- [ ] My changes generate no new warnings
|
||||
- [ ] I have added tests that prove my fix is effective or that my feature works
|
||||
- [ ] New and existing unit tests pass locally with my changes
|
||||
- [ ] Any dependent changes have been merged and published
|
||||
|
||||
## Additional Context
|
||||
|
||||
<!-- Add any other context, screenshots, or information about the PR here -->
|
||||
|
||||
## Screenshots (if applicable)
|
||||
|
||||
<!-- Add screenshots to help explain your changes -->
|
||||
|
||||
---
|
||||
|
||||
**For Reviewers:**
|
||||
|
||||
Please verify:
|
||||
- [ ] Code quality and style
|
||||
- [ ] Test coverage is adequate
|
||||
- [ ] Security implications reviewed
|
||||
- [ ] Documentation is updated
|
||||
- [ ] No performance regressions
|
||||
@@ -0,0 +1,52 @@
|
||||
version: 2
|
||||
updates:
|
||||
# Maintain dependencies for GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 5
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "github-actions"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
|
||||
# Maintain Go module dependencies
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "monday"
|
||||
time: "09:00"
|
||||
open-pull-requests-limit: 10
|
||||
commit-message:
|
||||
prefix: "chore(deps)"
|
||||
include: "scope"
|
||||
labels:
|
||||
- "dependencies"
|
||||
- "go"
|
||||
reviewers:
|
||||
- "lukaszraczylo"
|
||||
# Group patch updates together
|
||||
groups:
|
||||
patch-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "patch"
|
||||
minor-updates:
|
||||
patterns:
|
||||
- "*"
|
||||
update-types:
|
||||
- "minor"
|
||||
# Ignore certain dependencies if needed
|
||||
ignore:
|
||||
# Example: ignore specific versions
|
||||
# - dependency-name: "github.com/example/package"
|
||||
# versions: ["1.x", "2.x"]
|
||||
@@ -0,0 +1,9 @@
|
||||
# Ensure consistent line endings
|
||||
* text=auto eol=lf
|
||||
|
||||
# GitHub Actions files should use LF
|
||||
*.yml text eol=lf
|
||||
*.yaml text eol=lf
|
||||
|
||||
# Shell scripts should use LF
|
||||
*.sh text eol=lf
|
||||
@@ -0,0 +1,225 @@
|
||||
# GitHub Actions Workflows
|
||||
|
||||
This directory contains CI/CD workflows for the Traefik OIDC middleware.
|
||||
|
||||
## Workflows
|
||||
|
||||
### PR Validation (`pr-validation.yml`)
|
||||
|
||||
A comprehensive validation workflow that runs **all checks in parallel** for maximum speed and thorough testing.
|
||||
|
||||
**Triggered on:**
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
**Parallel Jobs (20+ concurrent checks):**
|
||||
|
||||
#### Code Quality
|
||||
- **Quick Checks** - Format, go vet, go mod verify
|
||||
- **golangci-lint** - Comprehensive linting
|
||||
- **Staticcheck** - Static analysis
|
||||
|
||||
#### Security
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database check
|
||||
- **CodeQL** - GitHub's code analysis
|
||||
|
||||
#### Testing
|
||||
- **Race Detector** - Concurrent access bug detection
|
||||
- **Coverage** - Test coverage with 75% threshold
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration test suite
|
||||
- **Regression Tests** - Prevent previously fixed bugs
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management validation
|
||||
- **Token Tests** - Token validation scenarios
|
||||
- **CSRF Tests** - CSRF protection validation
|
||||
|
||||
#### Provider Testing (Matrix)
|
||||
Tests run in parallel for each OIDC provider:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Compatibility
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Build Matrix** - linux/darwin × amd64/arm64
|
||||
- **Go Versions** - Go 1.23 and 1.24 compatibility
|
||||
|
||||
#### Final Validation
|
||||
- **All Checks Passed** - Ensures all jobs succeeded
|
||||
|
||||
## Workflow Features
|
||||
|
||||
### 🚀 Parallel Execution
|
||||
All independent checks run simultaneously for fastest feedback (~5-10 minutes for full suite).
|
||||
|
||||
### 📊 Coverage Reporting
|
||||
- Automatic PR comments with coverage statistics
|
||||
- Per-package coverage breakdown
|
||||
- 75% coverage threshold enforcement
|
||||
|
||||
### 🔒 Security First
|
||||
- Multiple security scanners (gosec, govulncheck, CodeQL)
|
||||
- SARIF report uploads for GitHub Security tab
|
||||
- Security edge case testing
|
||||
|
||||
### 🎯 Comprehensive Testing
|
||||
- Race condition detection
|
||||
- Memory leak detection
|
||||
- Provider-specific testing
|
||||
- Integration and regression tests
|
||||
|
||||
### 📈 Performance Tracking
|
||||
- Benchmark results stored as artifacts
|
||||
- Performance regression detection
|
||||
|
||||
### ✅ Quality Gates
|
||||
All checks must pass before PR can be merged:
|
||||
- Code formatting and style
|
||||
- Security vulnerabilities
|
||||
- Test coverage threshold
|
||||
- Race conditions
|
||||
- Memory leaks
|
||||
- Build success on all platforms
|
||||
|
||||
## Local Development
|
||||
|
||||
### Run checks locally before pushing:
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Run tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Check coverage
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# Run specific test suites
|
||||
go test -v -run='.*Leak.*' ./... # Memory leak tests
|
||||
go test -v -run='.*Integration.*' ./... # Integration tests
|
||||
go test -v -run='.*Regression.*' ./... # Regression tests
|
||||
|
||||
# Run benchmarks
|
||||
go test -bench=. -benchmem ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
govulncheck ./...
|
||||
```
|
||||
|
||||
### Required Tools
|
||||
|
||||
Install these tools for local development:
|
||||
|
||||
```bash
|
||||
# golangci-lint
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Workflow Fails
|
||||
|
||||
1. **Check job status** - Click on failed job for details
|
||||
2. **Review logs** - Expand failed steps to see error messages
|
||||
3. **Run locally** - Reproduce issue with local commands above
|
||||
4. **Check coverage** - Ensure test coverage meets 75% threshold
|
||||
|
||||
### Coverage Below Threshold
|
||||
|
||||
Add tests to increase coverage:
|
||||
```bash
|
||||
# See which lines aren't covered
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Detected
|
||||
|
||||
Run with race detector locally:
|
||||
```bash
|
||||
go test -race -v ./...
|
||||
```
|
||||
|
||||
### Provider Test Failure
|
||||
|
||||
Test specific provider:
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./internal/providers/...
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
The workflow is optimized for speed:
|
||||
|
||||
- **Parallel execution** - All independent jobs run simultaneously
|
||||
- **Go caching** - Dependencies cached between runs
|
||||
- **Strategic ordering** - Quick checks run first for fast feedback
|
||||
- **Fail-fast disabled** - Continue running all tests even if some fail
|
||||
|
||||
## Workflow Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
Monitor workflow runs at: `https://github.com/{owner}/{repo}/actions`
|
||||
|
||||
### Status Badges
|
||||
Add to README.md:
|
||||
```markdown
|
||||

|
||||
```
|
||||
|
||||
### Notifications
|
||||
Configure in repository settings:
|
||||
- Settings → Notifications
|
||||
- Choose email or Slack notifications for workflow failures
|
||||
|
||||
## Maintenance
|
||||
|
||||
### Update Go Version
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
go-version: '1.24' # Update this
|
||||
```
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit in workflow file:
|
||||
```yaml
|
||||
THRESHOLD=75 # Adjust this value
|
||||
```
|
||||
|
||||
### Add New Provider
|
||||
Add to provider matrix:
|
||||
```yaml
|
||||
matrix:
|
||||
provider:
|
||||
- new_provider # Add here
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
- [golangci-lint Configuration](../.golangci.yml)
|
||||
- [Dependabot Configuration](../dependabot.yml)
|
||||
- [PR Template](../PULL_REQUEST_TEMPLATE.md)
|
||||
@@ -0,0 +1,629 @@
|
||||
name: PR Validation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
checks: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
# Fast feedback - format and basic checks
|
||||
quick-checks:
|
||||
name: Quick Checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
# Exclude vendor directory from format checks
|
||||
UNFORMATTED=$(gofmt -s -l . | grep -v "^vendor/" || true)
|
||||
if [ -n "$UNFORMATTED" ]; then
|
||||
echo "Code is not formatted. Run: gofmt -s -w ."
|
||||
echo "Unformatted files:"
|
||||
echo "$UNFORMATTED"
|
||||
gofmt -s -d $(echo "$UNFORMATTED")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Go mod verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Go mod tidy check
|
||||
run: |
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum
|
||||
|
||||
# Static analysis with golangci-lint (advisory - will not fail the build)
|
||||
golangci-lint:
|
||||
name: golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=10m
|
||||
continue-on-error: true # Allow pipeline to continue even with linting warnings
|
||||
|
||||
# Staticcheck analysis
|
||||
staticcheck:
|
||||
name: Staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install staticcheck
|
||||
run: go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run staticcheck
|
||||
run: staticcheck ./...
|
||||
|
||||
# Security scanning with gosec
|
||||
gosec:
|
||||
name: Gosec Security Scanner
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run Gosec Security Scanner
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -no-fail -fmt sarif -out results.sarif ./... || echo "Gosec completed with warnings"
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: always() && hashFiles('results.sarif') != ''
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
continue-on-error: true
|
||||
|
||||
# Vulnerability scanning
|
||||
govulncheck:
|
||||
name: Vulnerability Scan
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Run govulncheck
|
||||
run: govulncheck ./...
|
||||
|
||||
# CodeQL analysis
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: go
|
||||
continue-on-error: true
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
continue-on-error: true
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
continue-on-error: true
|
||||
|
||||
# Unit tests with race detection
|
||||
test-race:
|
||||
name: Unit Tests (Race Detector)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with race detector
|
||||
run: go test -race -timeout=15m -count=1 -v ./...
|
||||
env:
|
||||
GOMAXPROCS: 4
|
||||
|
||||
# Coverage analysis with threshold check
|
||||
test-coverage:
|
||||
name: Test Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test -coverprofile=coverage.out -covermode=atomic -timeout=15m ./...
|
||||
go tool cover -func=coverage.out -o=coverage.txt
|
||||
|
||||
- name: Calculate coverage
|
||||
id: coverage
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//')
|
||||
echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT
|
||||
echo "Total Coverage: $COVERAGE%"
|
||||
|
||||
# Get per-package coverage
|
||||
echo "## Coverage by Package" >> coverage_report.md
|
||||
echo "" >> coverage_report.md
|
||||
go tool cover -func=coverage.out | grep -v "total:" | awk '{print "- " $1 ": " $3}' >> coverage_report.md || true
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
continue-on-error: true
|
||||
|
||||
- name: Comment coverage on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const coverage = '${{ steps.coverage.outputs.coverage }}';
|
||||
let coverageReport = '';
|
||||
|
||||
try {
|
||||
coverageReport = fs.readFileSync('coverage_report.md', 'utf8');
|
||||
} catch (e) {
|
||||
coverageReport = 'Coverage details not available';
|
||||
}
|
||||
|
||||
const threshold = 70;
|
||||
const coverageNum = parseFloat(coverage);
|
||||
const emoji = coverageNum >= threshold ? '✅' : '⚠️';
|
||||
|
||||
const body = `## ${emoji} Test Coverage Report\n\n**Total Coverage:** ${coverage}%\n**Threshold:** ${threshold}%\n\n${coverageReport}`;
|
||||
|
||||
// Find existing comment
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
|
||||
const botComment = comments.find(comment =>
|
||||
comment.user.type === 'Bot' &&
|
||||
comment.body.includes('Test Coverage Report')
|
||||
);
|
||||
|
||||
if (botComment) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: botComment.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
}
|
||||
|
||||
- name: Check coverage threshold
|
||||
run: |
|
||||
COVERAGE=${{ steps.coverage.outputs.coverage }}
|
||||
THRESHOLD=70
|
||||
echo "Coverage: $COVERAGE%"
|
||||
echo "Threshold: $THRESHOLD%"
|
||||
if (( $(echo "$COVERAGE < $THRESHOLD" | bc -l) )); then
|
||||
echo "⚠️ Coverage $COVERAGE% is below threshold $THRESHOLD%"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Coverage $COVERAGE% meets threshold $THRESHOLD%"
|
||||
|
||||
# Memory leak detection
|
||||
test-memory-leaks:
|
||||
name: Memory Leak Detection
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run goroutine leak tests
|
||||
run: |
|
||||
echo "Running goroutine leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Gg]oroutine.*[Ll]eak.*' ./... || echo "No goroutine leak tests found"
|
||||
|
||||
- name: Run memory leak tests
|
||||
run: |
|
||||
echo "Running memory leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Mm]emory.*[Ll]eak.*' ./... || echo "No memory leak tests found"
|
||||
|
||||
- name: Run cleanup tests
|
||||
run: |
|
||||
echo "Running cleanup and resource management tests..."
|
||||
go test -v -timeout=20m -run='.*[Cc]leanup.*' ./... || echo "No cleanup tests found"
|
||||
|
||||
# Integration tests
|
||||
test-integration:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
if [ -d "./integration" ]; then
|
||||
go test -v -timeout=20m ./integration/...
|
||||
else
|
||||
echo "Running integration tests from all packages..."
|
||||
go test -v -timeout=20m -run='.*[Ii]ntegration.*' ./...
|
||||
fi
|
||||
|
||||
# Regression tests
|
||||
test-regression:
|
||||
name: Regression Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
echo "Running regression tests..."
|
||||
go test -v -timeout=20m -run='.*[Rr]egression.*' ./...
|
||||
|
||||
# Provider-specific tests (parallel matrix)
|
||||
test-providers:
|
||||
name: Provider Tests (${{ matrix.provider }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
provider:
|
||||
- google
|
||||
- azure
|
||||
- auth0
|
||||
- okta
|
||||
- keycloak
|
||||
- cognito
|
||||
- gitlab
|
||||
- github
|
||||
- generic
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run ${{ matrix.provider }} provider tests
|
||||
run: |
|
||||
PROVIDER_CAP=$(echo "${{ matrix.provider }}" | sed 's/.*/\u&/')
|
||||
echo "Testing $PROVIDER_CAP provider..."
|
||||
go test -v -timeout=10m -run=".*$PROVIDER_CAP.*" ./internal/providers/... || true
|
||||
go test -v -timeout=10m -run=".*${{ matrix.provider }}.*" ./... || true
|
||||
|
||||
# Benchmark tests with performance tracking
|
||||
benchmark:
|
||||
name: Benchmark Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run benchmarks
|
||||
run: |
|
||||
echo "Running benchmark tests..."
|
||||
go test -bench=. -benchmem -benchtime=1s -run=^$ ./... | tee benchmark.txt
|
||||
|
||||
- name: Upload benchmark results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results
|
||||
path: benchmark.txt
|
||||
retention-days: 30
|
||||
|
||||
- name: Compare benchmarks
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
echo "Benchmark results available in artifacts"
|
||||
echo "To compare with main branch, download previous benchmark results"
|
||||
|
||||
# Build validation across platforms
|
||||
build:
|
||||
name: Build (${{ matrix.os }}/${{ matrix.arch }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [linux, darwin]
|
||||
arch: [amd64, arm64]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Build for ${{ matrix.os }}/${{ matrix.arch }}
|
||||
env:
|
||||
GOOS: ${{ matrix.os }}
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
run: |
|
||||
echo "Building for $GOOS/$GOARCH..."
|
||||
go build -v -ldflags="-s -w" ./...
|
||||
|
||||
# Security-specific edge case tests
|
||||
test-security-edge-cases:
|
||||
name: Security Edge Cases
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run security edge case tests
|
||||
run: |
|
||||
echo "Running security edge case tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ecurity.*' ./...
|
||||
|
||||
# Session management tests
|
||||
test-session:
|
||||
name: Session Management Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run session tests
|
||||
run: |
|
||||
echo "Running session management tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ession.*' ./...
|
||||
|
||||
# Token validation tests
|
||||
test-token:
|
||||
name: Token Validation Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run token validation tests
|
||||
run: |
|
||||
echo "Running token validation tests..."
|
||||
go test -v -timeout=15m -run='.*[Tt]oken.*' ./...
|
||||
|
||||
# CSRF and security tests
|
||||
test-csrf:
|
||||
name: CSRF and Security Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run CSRF tests
|
||||
run: |
|
||||
echo "Running CSRF and security tests..."
|
||||
go test -v -timeout=15m -run='.*[Cc][Ss][Rr][Ff].*' ./...
|
||||
|
||||
# Multi-Go version compatibility
|
||||
test-go-versions:
|
||||
name: Go ${{ matrix.go-version }} Compatibility
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ['1.24']
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Run tests on Go ${{ matrix.go-version }}
|
||||
run: go test -short -timeout=10m ./...
|
||||
|
||||
# Final validation - all checks must pass (golangci-lint is advisory)
|
||||
all-checks-passed:
|
||||
name: ✅ All Checks Passed
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- quick-checks
|
||||
- golangci-lint
|
||||
- staticcheck
|
||||
- gosec
|
||||
- govulncheck
|
||||
- codeql
|
||||
- test-race
|
||||
- test-coverage
|
||||
- test-memory-leaks
|
||||
- test-integration
|
||||
- test-regression
|
||||
- test-providers
|
||||
- benchmark
|
||||
- build
|
||||
- test-security-edge-cases
|
||||
- test-session
|
||||
- test-token
|
||||
- test-csrf
|
||||
- test-go-versions
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
echo "Checking status of all jobs..."
|
||||
|
||||
# Check critical jobs (excluding golangci-lint which is advisory)
|
||||
CRITICAL_FAILURES=false
|
||||
|
||||
if [ "${{ needs.quick-checks.result }}" == "failure" ] || \
|
||||
[ "${{ needs.staticcheck.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-race.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-coverage.result }}" == "failure" ] || \
|
||||
[ "${{ needs.build.result }}" == "failure" ]; then
|
||||
CRITICAL_FAILURES=true
|
||||
fi
|
||||
|
||||
if [ "$CRITICAL_FAILURES" == "true" ]; then
|
||||
echo "❌ Critical checks failed"
|
||||
exit 1
|
||||
elif [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
|
||||
echo "⚠️ Some checks were cancelled"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All critical checks passed successfully!"
|
||||
if [ "${{ needs.golangci-lint.result }}" != "success" ]; then
|
||||
echo "ℹ️ Note: golangci-lint reported issues (advisory only)"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Post summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "# PR Validation Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "## Job Status" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Quick Checks: ${{ needs.quick-checks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Linting (advisory): ${{ needs.golangci-lint.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Static Analysis: ${{ needs.staticcheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Scan (gosec): ${{ needs.gosec.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Vulnerability Scan: ${{ needs.govulncheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CodeQL: ${{ needs.codeql.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Race Detection: ${{ needs.test-race.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Coverage: ${{ needs.test-coverage.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Memory Leaks: ${{ needs.test-memory-leaks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Integration Tests: ${{ needs.test-integration.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Regression Tests: ${{ needs.test-regression.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Provider Tests: ${{ needs.test-providers.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Benchmarks: ${{ needs.benchmark.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Build: ${{ needs.build.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Edge Cases: ${{ needs.test-security-edge-cases.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Session Tests: ${{ needs.test-session.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Token Tests: ${{ needs.test-token.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CSRF Tests: ${{ needs.test-csrf.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Go Version Compatibility: ${{ needs.test-go-versions.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
+192
@@ -0,0 +1,192 @@
|
||||
version: "2"
|
||||
run:
|
||||
go: "1.24"
|
||||
modules-download-mode: readonly
|
||||
tests: true
|
||||
linters:
|
||||
enable:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- goconst
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- goprintffuncname
|
||||
- gosec
|
||||
- misspell
|
||||
- noctx
|
||||
- nolintlint
|
||||
- prealloc
|
||||
- revive
|
||||
- rowserrcheck
|
||||
- sqlclosecheck
|
||||
- unconvert
|
||||
- unparam
|
||||
- whitespace
|
||||
disable:
|
||||
- exhaustive
|
||||
- funlen
|
||||
- gocognit
|
||||
- lll
|
||||
- mnd
|
||||
- testpackage
|
||||
- wsl
|
||||
settings:
|
||||
dupl:
|
||||
threshold: 200 # Allow intentional duplication in provider patterns and token management
|
||||
errcheck:
|
||||
check-type-assertions: true
|
||||
check-blank: false # Allow explicit blank assignments (_ = ...) to ignore errors
|
||||
exclude-functions:
|
||||
- (io.Closer).Close
|
||||
- (*database/sql.Rows).Close
|
||||
- (*database/sql.Stmt).Close
|
||||
- (io.Writer).Write
|
||||
- (*net/http.ResponseWriter).Write
|
||||
- fmt.Fprintf
|
||||
- fmt.Fprint
|
||||
- fmt.Fprintln
|
||||
goconst:
|
||||
min-len: 3
|
||||
min-occurrences: 10 # Increased to reduce noise for standard OAuth2/OIDC strings
|
||||
ignore-tests: true
|
||||
gocritic:
|
||||
# Using default enabled checks in v2
|
||||
enabled-checks:
|
||||
- appendCombine
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- commentedOutCode
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- hexLiteral
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nestingReduce
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- stringXbytes
|
||||
- typeAssertChain
|
||||
- typeUnparen
|
||||
- unlabelStmt
|
||||
- yodaStyleExpr
|
||||
gocyclo:
|
||||
min-complexity: 30 # OAuth/OIDC flows are inherently complex; set higher for Yaegi compatibility
|
||||
gosec:
|
||||
excludes:
|
||||
- G104
|
||||
- G404
|
||||
severity: medium
|
||||
confidence: medium
|
||||
govet:
|
||||
disable:
|
||||
- fieldalignment
|
||||
- shadow
|
||||
enable-all: true
|
||||
misspell:
|
||||
locale: US
|
||||
ignore-rules:
|
||||
- traefik
|
||||
- oidc
|
||||
- keycloak
|
||||
nolintlint:
|
||||
require-explanation: true
|
||||
require-specific: true
|
||||
allow-unused: false
|
||||
prealloc:
|
||||
simple: true
|
||||
range-loops: true
|
||||
for-loops: false
|
||||
revive:
|
||||
rules:
|
||||
- name: blank-imports
|
||||
- name: context-as-argument
|
||||
- name: context-keys-type
|
||||
- name: dot-imports
|
||||
- name: error-return
|
||||
- name: error-strings
|
||||
- name: error-naming
|
||||
- name: exported
|
||||
- name: if-return
|
||||
- name: increment-decrement
|
||||
- name: var-naming
|
||||
- name: var-declaration
|
||||
- name: package-comments
|
||||
- name: range
|
||||
- name: receiver-naming
|
||||
- name: time-naming
|
||||
- name: unexported-return
|
||||
- name: indent-error-flow
|
||||
- name: errorf
|
||||
- name: empty-block
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
- name: unreachable-code
|
||||
- name: redefines-builtin-id
|
||||
unparam:
|
||||
check-exported: false
|
||||
staticcheck:
|
||||
checks:
|
||||
- all
|
||||
- -QF1001 # De Morgan's law - style preference, may affect Yaegi
|
||||
- -QF1003 # Tagged switch - style preference, may affect Yaegi
|
||||
- -QF1007 # Merge conditional assignment - style preference
|
||||
- -QF1008 # Remove embedded field - may break Yaegi compatibility
|
||||
- -QF1012 # Use fmt.Fprintf - style preference
|
||||
- -ST1003 # Package name format - allowed for test packages
|
||||
exclusions:
|
||||
generated: lax
|
||||
rules:
|
||||
- linters:
|
||||
- bodyclose
|
||||
- dupl
|
||||
- errcheck
|
||||
- goconst
|
||||
- gocyclo
|
||||
- gosec
|
||||
- noctx
|
||||
- prealloc
|
||||
- unparam
|
||||
path: _test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
- gocyclo
|
||||
path: test.*\.go
|
||||
- linters:
|
||||
- gocritic
|
||||
- unused
|
||||
path: mocks.*\.go
|
||||
- linters:
|
||||
- gosec
|
||||
text: 'G404:'
|
||||
- linters:
|
||||
- all
|
||||
path: vendor/
|
||||
- linters:
|
||||
- goconst
|
||||
path: (.+)_test\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: internal/providers/(auth0|keycloak|okta|google|azure|github|gitlab|cognito|generic)\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session\.go
|
||||
- linters:
|
||||
- dupl
|
||||
path: session_chunk_manager\.go
|
||||
text: "(extractJWTExpiration|extractJWTIssuedAt)"
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
issues:
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
uniq-by-line: true
|
||||
formatters:
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
+254
-9
@@ -73,7 +73,11 @@ testData:
|
||||
- admin
|
||||
- developer
|
||||
|
||||
forceHTTPS: false # Forces the use of HTTPS for all URLs (default: true for security)
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
# When running behind load balancer that terminates TLS: MUST set to TRUE
|
||||
# See: https://github.com/lukaszraczylo/traefikoidc/issues/82
|
||||
forceHTTPS: true # Forces HTTPS scheme for redirect URIs (default when not specified: false)
|
||||
logLevel: debug # Sets logging verbosity: debug, info, error (default: info)
|
||||
rateLimit: 100 # Maximum number of requests per second (default: 100, minimum: 10)
|
||||
|
||||
@@ -102,7 +106,14 @@ testData:
|
||||
cookieDomain: "" # Explicit domain for session cookies (e.g., ".example.com" for multi-subdomain setups)
|
||||
overrideScopes: false # When true, replaces default scopes instead of appending (default: false)
|
||||
refreshGracePeriodSeconds: 60 # Seconds before token expiry to attempt proactive refresh (default: 60)
|
||||
|
||||
|
||||
# Auth0 / Custom API Audience Configuration
|
||||
audience: "" # Custom audience for access token validation (default: clientID)
|
||||
strictAudienceValidation: false # Reject sessions with audience mismatch (prevents token confusion attacks)
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
@@ -115,6 +126,12 @@ testData:
|
||||
- "https://*.example.com"
|
||||
corsAllowCredentials: true
|
||||
|
||||
# Cross-origin policies
|
||||
permissionsPolicy: "geolocation=(), camera=(), microphone=()"
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "production"
|
||||
@@ -312,6 +329,20 @@ testData:
|
||||
# clientSecret: your-auth0-client-secret # Store securely
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-auth0"
|
||||
#
|
||||
# # Auth0 Audience Configuration (for custom APIs)
|
||||
# # Scenario 1 (Recommended): Custom API with JWT access tokens
|
||||
# audience: "https://my-api.example.com" # Your API identifier from Auth0 dashboard
|
||||
# strictAudienceValidation: true # Enforce proper audience validation for security
|
||||
#
|
||||
# # Scenario 2 (Backward Compatible): Default audience (uses client_id)
|
||||
# # audience: "" # Leave empty or omit to use client_id as audience
|
||||
# # strictAudienceValidation: false # Allows fallback to ID token validation (logs warnings)
|
||||
#
|
||||
# # Scenario 3: Opaque (non-JWT) access tokens
|
||||
# # allowOpaqueTokens: true # Enable opaque token support
|
||||
# # requireTokenIntrospection: true # Require RFC 7662 token introspection
|
||||
#
|
||||
# scopes: # Defaults ["openid", "profile", "email"]. Add custom scopes if your Auth0 Rules/Actions require them.
|
||||
# - read:custom_data # Example custom scope
|
||||
# allowedRolesAndGroups: # Based on claims added via Auth0 Rules or Actions (e.g. namespaced claims)
|
||||
@@ -319,7 +350,7 @@ testData:
|
||||
# - editor
|
||||
# # Use Auth0 Rules or Actions to add custom claims (roles, permissions) to the ID Token.
|
||||
# # Ensure postLogoutRedirectURI is in Auth0 app's "Allowed Logout URLs".
|
||||
# # See README.md "Provider Configuration Recommendations" for Auth0.
|
||||
# # For detailed Auth0 audience configuration, see AUTH0_AUDIENCE_GUIDE.md
|
||||
|
||||
# --- Generic OIDC Provider Example ---
|
||||
# testDataGenericOIDC:
|
||||
@@ -448,9 +479,24 @@ configuration:
|
||||
forceHTTPS:
|
||||
type: boolean
|
||||
description: |
|
||||
Forces the use of HTTPS for all URLs.
|
||||
This is recommended for security in production environments.
|
||||
Default: true
|
||||
Forces HTTPS scheme for redirect URIs regardless of request headers or TLS state.
|
||||
|
||||
⚠️ CRITICAL CONFIGURATION for TLS Termination Scenarios:
|
||||
|
||||
When running Traefik behind a load balancer that terminates TLS (AWS ALB,
|
||||
Google Cloud Load Balancer, Azure Application Gateway, etc.), you MUST set
|
||||
this to true. Without it, redirect URIs will use http:// instead of https://,
|
||||
causing OAuth callback failures.
|
||||
|
||||
How it works:
|
||||
- When true: Always uses https:// for redirect URIs (highest priority)
|
||||
- When false: Detects scheme from X-Forwarded-Proto header or TLS state
|
||||
- When NOT specified: Defaults to false (Go zero value for bool)
|
||||
|
||||
Default: false (when not specified in configuration)
|
||||
Recommended: true (for production environments and TLS termination scenarios)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/82
|
||||
required: false
|
||||
|
||||
rateLimit:
|
||||
@@ -588,16 +634,159 @@ configuration:
|
||||
type: integer
|
||||
description: |
|
||||
The number of seconds before a token expires to attempt proactive refresh.
|
||||
|
||||
|
||||
When a request is made and the access token will expire within this grace period,
|
||||
the middleware will attempt to refresh the token proactively. This helps prevent
|
||||
authentication interruptions for active users.
|
||||
|
||||
|
||||
Setting this to 0 disables proactive refresh (tokens are only refreshed after expiry).
|
||||
|
||||
|
||||
Default: 60 (1 minute before expiry)
|
||||
required: false
|
||||
|
||||
audience:
|
||||
type: string
|
||||
description: |
|
||||
Custom audience value for access token validation.
|
||||
|
||||
This configures the expected audience claim in access tokens. Per OAuth 2.0 and OIDC
|
||||
specifications:
|
||||
- ID tokens always have aud=client_id (per OIDC Core 1.0)
|
||||
- Access tokens can have custom audiences (e.g., API identifiers)
|
||||
|
||||
Auth0 Scenarios:
|
||||
1. Custom API audience (recommended): Set to your API identifier from Auth0
|
||||
Example: "https://my-api.example.com"
|
||||
Result: Access tokens will contain this audience
|
||||
|
||||
2. Default audience: Leave empty or omit (uses client_id)
|
||||
Result: Access tokens may not contain client_id, triggering warnings
|
||||
|
||||
3. Opaque tokens: Use with allowOpaqueTokens=true for non-JWT tokens
|
||||
|
||||
When configured and different from client_id, the middleware automatically adds
|
||||
the audience parameter to the authorize endpoint request.
|
||||
|
||||
Default: "" (uses client_id as audience)
|
||||
See: AUTH0_AUDIENCE_GUIDE.md for detailed configuration
|
||||
required: false
|
||||
|
||||
strictAudienceValidation:
|
||||
type: boolean
|
||||
description: |
|
||||
Enforce strict audience validation for access tokens.
|
||||
|
||||
When enabled, sessions are rejected if access token validation fails due to
|
||||
audience mismatch. This prevents falling back to ID token validation, addressing
|
||||
potential token confusion attacks where tokens intended for different APIs could
|
||||
be used to grant access.
|
||||
|
||||
Auth0 Scenario 2 Protection:
|
||||
- When true: Rejects sessions with mismatched access token audience
|
||||
- When false: Logs security warnings but allows fallback to ID token (backward compatible)
|
||||
|
||||
Security Recommendation:
|
||||
- Production environments: Set to true for maximum security
|
||||
- Development/testing: Can use false with monitoring of security warnings
|
||||
|
||||
This setting addresses security concerns where access tokens without proper
|
||||
audience claims could bypass API-specific authorization checks.
|
||||
|
||||
Default: false (backward compatible)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/74
|
||||
required: false
|
||||
|
||||
allowOpaqueTokens:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable acceptance of opaque (non-JWT) access tokens.
|
||||
|
||||
When enabled, the middleware accepts access tokens that are not in JWT format
|
||||
(3-part base64 structure). Opaque tokens are validated using OAuth 2.0 Token
|
||||
Introspection (RFC 7662) if the provider exposes an introspection endpoint.
|
||||
|
||||
Auth0 Scenario 3:
|
||||
Some Auth0 configurations issue opaque access tokens when no default API is
|
||||
configured. This setting allows those tokens to be validated.
|
||||
|
||||
Requirements:
|
||||
- Provider must support introspection_endpoint in OIDC discovery
|
||||
- Client must have appropriate introspection permissions
|
||||
|
||||
Validation Process:
|
||||
1. Detects opaque token (not 3-part JWT structure)
|
||||
2. Calls provider's introspection endpoint with client credentials
|
||||
3. Validates response (active status, expiration, audience if present)
|
||||
4. Caches result for 5 minutes or token expiry (whichever is shorter)
|
||||
5. Falls back to ID token validation if introspection unavailable
|
||||
(unless requireTokenIntrospection=true)
|
||||
|
||||
Default: false (only JWT access tokens accepted)
|
||||
See: AUTH0_AUDIENCE_GUIDE.md for configuration examples
|
||||
required: false
|
||||
|
||||
requireTokenIntrospection:
|
||||
type: boolean
|
||||
description: |
|
||||
Require token introspection for all opaque access tokens.
|
||||
|
||||
When enabled with allowOpaqueTokens=true, opaque tokens are rejected if:
|
||||
- Introspection endpoint is not available from provider metadata
|
||||
- Introspection request fails
|
||||
- Introspection response indicates token is not active
|
||||
|
||||
Security Levels:
|
||||
- requireTokenIntrospection=true + allowOpaqueTokens=true:
|
||||
Maximum security - rejects opaque tokens without successful introspection
|
||||
|
||||
- requireTokenIntrospection=false + allowOpaqueTokens=true:
|
||||
Backward compatible - falls back to ID token validation if introspection fails
|
||||
|
||||
- requireTokenIntrospection=true + allowOpaqueTokens=false:
|
||||
No effect - opaque tokens are already rejected
|
||||
|
||||
Recommended Configuration:
|
||||
When accepting opaque tokens, always set this to true for maximum security:
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
Default: false (allows fallback to ID token)
|
||||
See: RFC 7662 OAuth 2.0 Token Introspection specification
|
||||
required: false
|
||||
|
||||
disableReplayDetection:
|
||||
type: boolean
|
||||
description: |
|
||||
Disable JTI-based replay attack detection for multi-replica deployments.
|
||||
|
||||
When running multiple Traefik replicas, each instance maintains its own in-memory
|
||||
JTI (JWT Token ID) cache. This causes false positives when the same valid token
|
||||
hits different replicas:
|
||||
- Request → Replica A → JTI added to cache → OK
|
||||
- Request → Replica B → JTI not in Replica B's cache → OK
|
||||
- Request → Replica A again → JTI found → FALSE POSITIVE "replay detected"
|
||||
|
||||
Security Impact:
|
||||
When disabled, the following validations remain active:
|
||||
- RSA/ECDSA signature verification
|
||||
- Token expiration (exp claim)
|
||||
- Issuer validation (iss claim)
|
||||
- Audience validation (aud claim)
|
||||
- Not-before validation (nbf claim)
|
||||
- Issued-at validation (iat claim)
|
||||
|
||||
Only the JTI replay check is skipped.
|
||||
|
||||
Recommendations:
|
||||
- Single-instance deployment: false (default, enables replay protection)
|
||||
- Multi-replica deployment: true (prevents false positives)
|
||||
- Production with shared cache: false (use Redis/Memcached for shared JTI cache)
|
||||
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
@@ -899,3 +1088,59 @@ configuration:
|
||||
Remove the X-Powered-By header to hide technology stack information.
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
permissionsPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Permissions-Policy header to control browser feature permissions.
|
||||
This header allows you to control which features and APIs can be used.
|
||||
|
||||
Examples:
|
||||
- "geolocation=(), camera=(), microphone=()" (deny all)
|
||||
- "geolocation=(self), camera=()" (allow geolocation for same origin only)
|
||||
|
||||
Common directives: accelerometer, camera, geolocation, gyroscope,
|
||||
magnetometer, microphone, payment, usb
|
||||
required: false
|
||||
|
||||
crossOriginEmbedderPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Embedder-Policy (COEP) header to prevent untrusted
|
||||
resources from being loaded.
|
||||
|
||||
Options:
|
||||
- "require-corp": Resources must explicitly grant permission
|
||||
- "credentialless": Load without credentials for cross-origin resources
|
||||
- "unsafe-none": No restrictions (default)
|
||||
|
||||
Required for certain browser features like SharedArrayBuffer.
|
||||
required: false
|
||||
|
||||
crossOriginOpenerPolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Opener-Policy (COOP) header to isolate browsing context
|
||||
from cross-origin windows.
|
||||
|
||||
Options:
|
||||
- "same-origin": Isolate from cross-origin documents
|
||||
- "same-origin-allow-popups": Allow popups that don't set COOP
|
||||
- "unsafe-none": No isolation (default)
|
||||
|
||||
Helps prevent cross-origin attacks and Spectre-like vulnerabilities.
|
||||
required: false
|
||||
|
||||
crossOriginResourcePolicy:
|
||||
type: string
|
||||
description: |
|
||||
Cross-Origin-Resource-Policy (CORP) header to control which origins
|
||||
can load this resource.
|
||||
|
||||
Options:
|
||||
- "same-origin": Only same-origin requests can load the resource
|
||||
- "same-site": Only same-site requests can load the resource
|
||||
- "cross-origin": Any origin can load the resource (default)
|
||||
|
||||
Prevents your resources from being embedded on other sites.
|
||||
required: false
|
||||
|
||||
+286
@@ -0,0 +1,286 @@
|
||||
# CI/CD Setup Guide
|
||||
|
||||
## 📋 Overview
|
||||
|
||||
This repository now has a comprehensive CI/CD pipeline that runs **20+ parallel checks** on every pull request to ensure code quality, security, and reliability.
|
||||
|
||||
## 🎯 What Was Added
|
||||
|
||||
### GitHub Actions Workflow
|
||||
- **`.github/workflows/pr-validation.yml`** - Main CI/CD pipeline (single file, all parallel)
|
||||
|
||||
### Configuration Files
|
||||
- **`.golangci.yml`** - Linter configuration with 30+ enabled checks
|
||||
- **`.github/dependabot.yml`** - Automated dependency updates
|
||||
- **`.github/CODEOWNERS`** - Automatic PR reviewer assignment
|
||||
- **`.github/PULL_REQUEST_TEMPLATE.md`** - Standardized PR descriptions
|
||||
- **`.github/workflows/README.md`** - Detailed workflow documentation
|
||||
- **`.github/workflows/.gitattributes`** - Consistent line endings
|
||||
|
||||
## ✅ What Gets Tested (All in Parallel)
|
||||
|
||||
### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters including style, complexity, bugs
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning with SARIF reports
|
||||
- **Govulncheck** - Go vulnerability database scanning
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
### Testing (9 test suites)
|
||||
- **Race Detector** - Concurrent access bugs
|
||||
- **Coverage** - 75% threshold with PR comments
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration suite
|
||||
- **Regression Tests** - Prevent old bugs from returning
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management
|
||||
- **Token Tests** - Token validation
|
||||
- **CSRF Tests** - CSRF protection
|
||||
|
||||
### Provider Testing (9 providers in parallel)
|
||||
- Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, GitHub, Generic
|
||||
|
||||
### Performance & Build (3 checks)
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Multi-platform Build** - 4 combinations (linux/darwin × amd64/arm64)
|
||||
- **Go Version Compatibility** - Go 1.23 & 1.24
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Push to GitHub
|
||||
```bash
|
||||
git add .github .golangci.yml CI_SETUP.md
|
||||
git commit -m "Add comprehensive CI/CD pipeline"
|
||||
git push origin main
|
||||
```
|
||||
|
||||
### 2. Create a Test PR
|
||||
```bash
|
||||
# Create a feature branch
|
||||
git checkout -b feature/test-ci
|
||||
echo "# Test" >> test.md
|
||||
git add test.md
|
||||
git commit -m "Test CI pipeline"
|
||||
git push origin feature/test-ci
|
||||
|
||||
# Create PR on GitHub
|
||||
# Watch all 20+ checks run in parallel! ⚡
|
||||
```
|
||||
|
||||
### 3. Monitor Results
|
||||
- Go to Actions tab: `https://github.com/{owner}/{repo}/actions`
|
||||
- Click on latest workflow run
|
||||
- See all parallel checks in action
|
||||
- Review coverage comment on PR
|
||||
|
||||
## 📊 Key Features
|
||||
|
||||
### ⚡ Maximum Speed
|
||||
- **Parallel execution** - All checks run simultaneously
|
||||
- **Smart caching** - Go modules and build cache
|
||||
- **Optimized order** - Quick checks first for fast feedback
|
||||
- **Expected runtime**: 5-10 minutes for full suite
|
||||
|
||||
### 🔒 Security First
|
||||
- **3 security scanners** - gosec, govulncheck, CodeQL
|
||||
- **SARIF integration** - Results in GitHub Security tab
|
||||
- **Dependency scanning** - Automated with Dependabot
|
||||
- **Security edge case tests**
|
||||
|
||||
### 📈 Coverage Tracking
|
||||
- **Automatic PR comments** with coverage stats
|
||||
- **Per-package breakdown** included
|
||||
- **75% threshold** enforced (configurable)
|
||||
- **Codecov integration** ready (optional)
|
||||
|
||||
### 🎨 Developer Experience
|
||||
- **Clear PR template** guides contributors
|
||||
- **Auto code owners** assignment
|
||||
- **Detailed error messages** for failures
|
||||
- **Benchmark tracking** for performance
|
||||
|
||||
## 🛠️ Local Development
|
||||
|
||||
### Install Required Tools
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
### Run Checks Locally
|
||||
```bash
|
||||
# Quick validation (before committing)
|
||||
gofmt -s -w . # Format code
|
||||
go vet ./... # Basic checks
|
||||
go mod tidy # Clean dependencies
|
||||
|
||||
# Linting
|
||||
golangci-lint run # Full lint suite
|
||||
staticcheck ./... # Static analysis
|
||||
|
||||
# Testing
|
||||
go test -race -timeout=15m ./... # Tests with race detector
|
||||
go test -coverprofile=coverage.out ./... # Coverage
|
||||
go tool cover -func=coverage.out # View coverage
|
||||
|
||||
# Security
|
||||
gosec ./... # Security scan
|
||||
govulncheck ./... # Vulnerability check
|
||||
|
||||
# Benchmarks
|
||||
go test -bench=. -benchmem ./... # Performance tests
|
||||
```
|
||||
|
||||
### Pre-commit Checklist
|
||||
```bash
|
||||
# Run this before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "✅ Ready to commit!"
|
||||
```
|
||||
|
||||
## 📝 Configuration
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
THRESHOLD=75 # Change to desired percentage
|
||||
```
|
||||
|
||||
### Modify Linter Rules
|
||||
Edit `.golangci.yml`:
|
||||
```yaml
|
||||
linters:
|
||||
enable:
|
||||
- newlinter # Add new linters here
|
||||
```
|
||||
|
||||
### Update Go Version
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
go-version: '1.24' # Update version
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Coverage Below Threshold
|
||||
```bash
|
||||
# See uncovered lines in browser
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Found
|
||||
```bash
|
||||
# Run specific test with race detector
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
### Linter Errors
|
||||
```bash
|
||||
# See detailed lint errors
|
||||
golangci-lint run -v
|
||||
|
||||
# Auto-fix some issues
|
||||
golangci-lint run --fix
|
||||
```
|
||||
|
||||
### Provider Test Fails
|
||||
```bash
|
||||
# Test specific provider
|
||||
go test -v -run='.*Azure.*' ./internal/providers/
|
||||
```
|
||||
|
||||
## 📈 Metrics & Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
- View all runs: `Actions` tab
|
||||
- Filter by workflow, branch, status
|
||||
- Download logs and artifacts
|
||||
|
||||
### Status Badge
|
||||
Add to README.md:
|
||||
```markdown
|
||||
[](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml)
|
||||
```
|
||||
|
||||
### Notifications
|
||||
- Configure in: Settings → Notifications
|
||||
- Email alerts for workflow failures
|
||||
- Slack/Discord webhooks supported
|
||||
|
||||
## 🔄 Continuous Improvement
|
||||
|
||||
### Dependabot Updates
|
||||
- Automatic weekly dependency checks (Mondays 9 AM)
|
||||
- Security updates prioritized
|
||||
- Groups patch updates together
|
||||
|
||||
### Code Owners
|
||||
- Auto-assigns reviewers based on file paths
|
||||
- Ensures expertise reviews changes
|
||||
- Speeds up PR review process
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Dependabot Config](.github/dependabot.yml)
|
||||
|
||||
## 🎉 Benefits
|
||||
|
||||
### For Contributors
|
||||
- Clear expectations via PR template
|
||||
- Fast feedback (5-10 min)
|
||||
- Comprehensive local tooling
|
||||
- Detailed error messages
|
||||
|
||||
### For Maintainers
|
||||
- Automated code review
|
||||
- Security scanning
|
||||
- Performance tracking
|
||||
- Quality gates enforcement
|
||||
|
||||
### For Users
|
||||
- Higher code quality
|
||||
- Fewer bugs in production
|
||||
- Better security
|
||||
- Consistent performance
|
||||
|
||||
## 🚦 Success Criteria
|
||||
|
||||
All PRs must pass:
|
||||
- ✅ All 20+ parallel checks
|
||||
- ✅ 75% test coverage minimum
|
||||
- ✅ Zero security vulnerabilities
|
||||
- ✅ No race conditions
|
||||
- ✅ No memory leaks
|
||||
- ✅ All providers tested
|
||||
- ✅ Builds on all platforms
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
1. **Run checks locally** before pushing to save CI time
|
||||
2. **Watch for PR comments** - coverage stats posted automatically
|
||||
3. **Check Security tab** for gosec/CodeQL findings
|
||||
4. **Review benchmark results** in artifacts
|
||||
5. **Use draft PRs** for work-in-progress to skip some checks
|
||||
|
||||
---
|
||||
|
||||
**Ready to go!** 🚀 Push your changes and create a PR to see it in action.
|
||||
@@ -8,6 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
|
||||
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
|
||||
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
|
||||
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
|
||||
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
|
||||
- **Domain restrictions**: Limit access to specific email domains or individual users
|
||||
- **Role-based access control**: Restrict access based on roles and groups from OIDC claims
|
||||
@@ -75,7 +76,7 @@ experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.2.1 # Use the latest version
|
||||
version: v0.7.8 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
@@ -114,8 +115,22 @@ The middleware supports the following configuration options:
|
||||
| `scopes` | OAuth 2.0 scopes to use for authentication | `["openid", "profile", "email"]` (always included by default) | `["roles", "custom_scope"]` (appended to defaults) |
|
||||
| `overrideScopes` | When true, replaces default scopes with provided scopes instead of appending | `false` | `true` (use only the scopes explicitly provided) |
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces the use of HTTPS for all URLs | `true` | `true`, `false` |
|
||||
| `forceHTTPS` | Forces HTTPS scheme for redirect URIs (**REQUIRED** for TLS termination at load balancer like AWS ALB) | `false` (when not specified) | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
>
|
||||
> If you're running Traefik behind a load balancer (AWS ALB, Google Cloud Load Balancer, Azure Application Gateway, etc.) that terminates TLS:
|
||||
> - **You MUST set `forceHTTPS: true`** in your configuration
|
||||
> - Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures
|
||||
> - This is especially critical for AWS ALB which may overwrite the `X-Forwarded-Proto` header
|
||||
>
|
||||
> **Default behavior:**
|
||||
> - When `forceHTTPS` is **not specified** in your config → defaults to `false` (Go zero value)
|
||||
> - When `forceHTTPS: true` is explicitly set → always uses `https://` for redirect URIs
|
||||
> - When `forceHTTPS: false` is explicitly set → scheme detection based on headers/TLS
|
||||
>
|
||||
> See [GitHub Issue #82](https://github.com/lukaszraczylo/traefikoidc/issues/82) for details.
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
|
||||
@@ -125,8 +140,13 @@ The middleware supports the following configuration options:
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `cookieDomain` | Explicit domain for session cookies (important for multi-subdomain setups) | auto-detected | `.example.com`, `app.example.com` |
|
||||
| `audience` | Custom audience for access token validation (for Auth0 custom APIs, etc.) | `clientID` | `https://my-api.example.com` |
|
||||
| `strictAudienceValidation` | Reject sessions with access token audience mismatch (prevents token confusion attacks) | `false` | `true` |
|
||||
| `allowOpaqueTokens` | Enable opaque (non-JWT) access token support via RFC 7662 introspection | `false` | `true` |
|
||||
| `requireTokenIntrospection` | Require introspection for opaque tokens (force validation, no fallback) | `false` | `true` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
@@ -201,6 +221,103 @@ scopes: []
|
||||
|
||||
The default append behavior ensures essential OIDC scopes are always present, while the override mode gives you complete control over the exact scopes requested from the provider.
|
||||
|
||||
## Auth0 Audience Validation & Security
|
||||
|
||||
The middleware provides comprehensive support for Auth0 audience validation to prevent token confusion attacks. Auth0 can issue tokens in three different scenarios, each requiring specific configuration.
|
||||
|
||||
### Understanding Token Audiences
|
||||
|
||||
Per OAuth 2.0 and OIDC specifications:
|
||||
- **ID Tokens**: MUST have `aud = client_id` (OIDC Core 1.0 spec)
|
||||
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
|
||||
### Auth0 Scenarios
|
||||
|
||||
#### Scenario 1: Custom API Audience ✅ (RECOMMENDED)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
strictAudienceValidation: true # Enforce strict validation
|
||||
```
|
||||
|
||||
**Result**: Fully secure, OIDC compliant with proper access token audience validation.
|
||||
|
||||
#### Scenario 2: Default Audience ⚠️ (USE WITH CAUTION)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# audience not specified (defaults to client_id)
|
||||
strictAudienceValidation: true # Recommended: reject mismatched tokens
|
||||
```
|
||||
|
||||
**Behavior**: Access tokens may not contain client_id in audience, triggering security warnings. Set `strictAudienceValidation: true` to reject such sessions.
|
||||
|
||||
#### Scenario 3: Opaque Access Tokens ✅ (SUPPORTED)
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true # Enable opaque token support
|
||||
requireTokenIntrospection: true # Require introspection (recommended)
|
||||
```
|
||||
|
||||
**Result**: Secure with OAuth 2.0 Token Introspection (RFC 7662).
|
||||
|
||||
### Security Configuration Options
|
||||
|
||||
| Option | Purpose | Recommended Value |
|
||||
|--------|---------|-------------------|
|
||||
| `audience` | Expected audience for access tokens | Your API identifier or leave empty |
|
||||
| `strictAudienceValidation` | Reject sessions with audience mismatch | `true` for production |
|
||||
| `allowOpaqueTokens` | Accept non-JWT access tokens | `true` if provider issues opaque tokens |
|
||||
| `requireTokenIntrospection` | Force introspection for opaque tokens | `true` when `allowOpaqueTokens=true` |
|
||||
|
||||
### Complete Auth0 Configuration Examples
|
||||
|
||||
**Production Configuration (Scenario 1):**
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-secure
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-auth0-domain.auth0.com
|
||||
clientID: your-auth0-client-id
|
||||
clientSecret: your-auth0-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
allowedRolesAndGroups:
|
||||
- "https://your-app.com/roles:admin"
|
||||
- editor
|
||||
```
|
||||
|
||||
**Opaque Token Configuration (Scenario 3):**
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-opaque
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-auth0-domain.auth0.com
|
||||
clientID: your-auth0-client-id
|
||||
clientSecret: your-auth0-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
For detailed Auth0 configuration including all three scenarios, troubleshooting, and security best practices, see **[AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md)**.
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The middleware includes comprehensive security headers support to protect your applications against common web vulnerabilities. Security headers are applied to all authenticated responses.
|
||||
@@ -319,6 +436,10 @@ securityHeaders:
|
||||
| `customHeaders` | Additional custom headers | `{}` | `{"X-Custom": "value"}` |
|
||||
| `disableServerHeader` | Remove Server header | `true` | `true`, `false` |
|
||||
| `disablePoweredByHeader` | Remove X-Powered-By header | `true` | `true`, `false` |
|
||||
| `permissionsPolicy` | Permissions-Policy header | `` | `"geolocation=(), camera=(), microphone=()"` |
|
||||
| `crossOriginEmbedderPolicy` | Cross-Origin-Embedder-Policy header | `` | `"require-corp"`, `"credentialless"`, `"unsafe-none"` |
|
||||
| `crossOriginOpenerPolicy` | Cross-Origin-Opener-Policy header | `` | `"same-origin"`, `"same-origin-allow-popups"`, `"unsafe-none"` |
|
||||
| `crossOriginResourcePolicy` | Cross-Origin-Resource-Policy header | `` | `"same-origin"`, `"same-site"`, `"cross-origin"` |
|
||||
|
||||
### CORS Wildcard Support
|
||||
|
||||
@@ -390,6 +511,47 @@ securityHeaders:
|
||||
corsAllowedOrigins: ["http://localhost:*"]
|
||||
```
|
||||
|
||||
### Multi-Replica Deployment Configuration
|
||||
|
||||
When running multiple Traefik replicas with the OIDC plugin, you may encounter false positive replay detection errors. Each replica maintains its own in-memory JTI (JWT Token ID) cache, causing legitimate token reuse to be flagged as replay attacks.
|
||||
|
||||
**Problem**: When the same valid token hits different replicas:
|
||||
- Request → Replica A → JTI added to Replica A's cache ✓
|
||||
- Request → Replica B → JTI NOT in Replica B's cache ✓
|
||||
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
|
||||
|
||||
**Solution**: Disable replay detection for distributed deployments:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
|
||||
```
|
||||
|
||||
**Security Note**: When `disableReplayDetection: true`:
|
||||
- ✅ Token signatures still validated
|
||||
- ✅ Expiration still checked
|
||||
- ✅ All other claims still verified
|
||||
- ❌ JTI replay check **skipped**
|
||||
|
||||
**Example Configuration**:
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-multi-replica
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
disableReplayDetection: true # Required for multi-replica deployments
|
||||
```
|
||||
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, set to `true` and consider implementing a shared cache backend (Redis/Memcached) if replay detection is required.
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Configuration
|
||||
@@ -740,6 +902,11 @@ spec:
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Audience configuration for custom APIs
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
strictAudienceValidation: true # Enforce proper audience validation
|
||||
|
||||
scopes:
|
||||
- read:custom_data # Custom scopes as needed
|
||||
allowedRolesAndGroups:
|
||||
@@ -748,6 +915,8 @@ spec:
|
||||
postLogoutRedirectURI: /logged-out-page # Must be in Auth0 Allowed Logout URLs
|
||||
```
|
||||
|
||||
**Note**: For detailed Auth0 audience configuration including opaque tokens and all security scenarios, see [AUTH0_AUDIENCE_GUIDE.md](docs/AUTH0_AUDIENCE_GUIDE.md).
|
||||
|
||||
### Okta Configuration
|
||||
|
||||
```yaml
|
||||
@@ -920,7 +1089,7 @@ services:
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.2.1"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.7.8"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./traefik-config/traefik.yml:/etc/traefik/traefik.yml
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAudienceConfiguration tests the custom audience configuration feature
|
||||
func TestAudienceConfiguration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
clientID string
|
||||
expectedAudience string
|
||||
}{
|
||||
{
|
||||
name: "no custom audience - uses clientID",
|
||||
configAudience: "",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "test-client-id",
|
||||
},
|
||||
{
|
||||
name: "custom audience specified",
|
||||
configAudience: "api://custom-audience",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "api://custom-audience",
|
||||
},
|
||||
{
|
||||
name: "auth0 style custom audience",
|
||||
configAudience: "https://api.example.com",
|
||||
clientID: "test-client-id",
|
||||
expectedAudience: "https://api.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create config with custom audience
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = tt.clientID
|
||||
config.ClientSecret = "test-secret"
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
config.CallbackURL = "/callback"
|
||||
config.Audience = tt.configAudience
|
||||
|
||||
// Create middleware instance
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
traefikOidc, err := NewWithContext(context.Background(), config, next, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create middleware: %v", err)
|
||||
}
|
||||
|
||||
// Verify audience is set correctly
|
||||
if traefikOidc.audience != tt.expectedAudience {
|
||||
t.Errorf("Expected audience %s, got %s", tt.expectedAudience, traefikOidc.audience)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
_ = traefikOidc.Close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceValidation tests the audience validation in Config.Validate()
|
||||
func TestAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
expectError bool
|
||||
errorContains string
|
||||
}{
|
||||
{
|
||||
name: "valid custom audience URL",
|
||||
audience: "https://api.example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "valid azure style audience",
|
||||
audience: "api://12345678-1234-1234-1234-123456789012",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty audience is valid (uses clientID)",
|
||||
audience: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "http URL not allowed",
|
||||
audience: "http://api.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience URL must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "wildcard not allowed",
|
||||
audience: "https://*.example.com",
|
||||
expectError: true,
|
||||
errorContains: "audience must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "too long audience",
|
||||
audience: "https://" + string(make([]byte, 250)) + ".com",
|
||||
expectError: true,
|
||||
errorContains: "audience must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "invalid characters",
|
||||
audience: "api://test\ninjection",
|
||||
expectError: true,
|
||||
errorContains: "audience contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client"
|
||||
config.ClientSecret = "test-secret"
|
||||
config.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||
config.CallbackURL = "/callback"
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got none")
|
||||
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("Expected error containing '%s', got: %v", tt.errorContains, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,931 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
||||
func TestConfigAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty audience is valid for backward compatibility",
|
||||
audience: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTPS URL audience Auth0 format",
|
||||
audience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid identifier audience",
|
||||
audience: "my-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Azure AD Application ID URI format",
|
||||
audience: "api://12345-guid-67890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Auth0 API identifier",
|
||||
audience: "https://my-company.auth0.com/api/v2/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL audience should fail",
|
||||
audience: "http://api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "Audience with wildcard should fail",
|
||||
audience: "https://api.*.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience with single asterisk should fail",
|
||||
audience: "*",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience over 256 characters should fail",
|
||||
audience: strings.Repeat("a", 257),
|
||||
wantErr: true,
|
||||
errContains: "must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with newline should fail",
|
||||
audience: "my-api\ninjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with carriage return should fail",
|
||||
audience: "my-api\rinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with tab should fail",
|
||||
audience: "my-api\tinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Valid audience exactly 256 characters",
|
||||
audience: strings.Repeat("a", 256),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid simple identifier",
|
||||
audience: "my-service-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid URN format",
|
||||
audience: "urn:myservice:api:v1",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
||||
func TestJWTAudienceVerification(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate RSA key for signing JWTs
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
// Create JWK
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
tokenAudience interface{}
|
||||
wantErr bool
|
||||
errContains string
|
||||
skipReplayCheck bool
|
||||
}{
|
||||
{
|
||||
name: "JWT with string aud matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with string aud NOT matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://wrong-api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud NOT containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with clientID as aud when no custom audience configured",
|
||||
configAudience: "",
|
||||
tokenAudience: "test-client-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with empty string aud",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Azure AD Application ID URI format",
|
||||
configAudience: "api://12345-app-id",
|
||||
tokenAudience: "api://12345-app-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Auth0 custom API audience",
|
||||
configAudience: "https://mycompany.com/api",
|
||||
tokenAudience: "https://mycompany.com/api",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Token confusion attack - audience for different service",
|
||||
configAudience: "https://service-a.example.com",
|
||||
tokenAudience: "https://service-b.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Set up the token verifier and JWT verifier
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Determine the expected audience for validation
|
||||
expectedAudience := tt.configAudience
|
||||
if expectedAudience == "" {
|
||||
expectedAudience = tOidc.clientID
|
||||
}
|
||||
|
||||
// Set the audience field on the tOidc instance
|
||||
tOidc.audience = expectedAudience
|
||||
|
||||
// Create JWT with specified audience
|
||||
jti := generateRandomString(16)
|
||||
if tt.skipReplayCheck {
|
||||
// Use a unique JTI for each test to avoid replay detection
|
||||
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
||||
}
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": tt.tokenAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
||||
// when the Audience field is not set
|
||||
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Test with no custom audience configured - should use clientID
|
||||
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // Should match clientID
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
err = ts.tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
||||
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Auth0 scenario: custom audience for API access
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://mycompany.auth0.com"
|
||||
config.ClientID = "auth0-client-id"
|
||||
config.ClientSecret = "auth0-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "https://api.mycompany.com" // Custom API audience
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Auth0 config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "auth0-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches configured audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Auth0 token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.ClientID, // Using clientID instead of API audience
|
||||
"scope": "openid profile email", // Mark as access token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err == nil {
|
||||
t.Error("Auth0 access token with wrong audience should have been rejected")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
||||
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Azure AD scenario: Application ID URI format
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
||||
config.ClientID = "azure-client-id"
|
||||
config.ClientSecret = "azure-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Azure AD config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "azure-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches Application ID URI
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
||||
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
// Service A configuration
|
||||
serviceA := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "service-a-client-id",
|
||||
clientSecret: "service-a-secret",
|
||||
audience: "service-a-client-id", // Service A uses its clientID as audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
serviceA.jwtVerifier = serviceA
|
||||
serviceA.tokenVerifier = serviceA
|
||||
|
||||
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
||||
// Create a token intended for service B
|
||||
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": "https://service-b.example.com", // For service B
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "attacker@example.com",
|
||||
"email": "attacker@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service B token: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify the service B token on service A
|
||||
err = serviceA.VerifyToken(serviceBToken)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
||||
case !strings.Contains(err.Error(), "invalid audience"):
|
||||
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
||||
default:
|
||||
t.Logf("Token confusion attack correctly prevented: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
||||
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
}{
|
||||
{
|
||||
name: "Single asterisk",
|
||||
audience: "*",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in URL",
|
||||
audience: "https://*.example.com",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in path",
|
||||
audience: "https://api.example.com/*",
|
||||
},
|
||||
{
|
||||
name: "Multiple wildcards",
|
||||
audience: "https://*.*.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
||||
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
||||
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
||||
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Newline injection",
|
||||
audience: "api.example.com\nmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Carriage return injection",
|
||||
audience: "api.example.com\rmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Tab injection",
|
||||
audience: "api.example.com\tmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection",
|
||||
audience: "api.example.com\x00malicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
||||
func TestAudienceWithReplayProtection(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a token with custom audience and fixed JTI
|
||||
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
||||
customAudience := "https://api.example.com"
|
||||
|
||||
// Set the audience field to match what we expect
|
||||
tOidc.audience = customAudience
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "user@example.com",
|
||||
"jti": fixedJTI,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Fatalf("First verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the JTI was blacklisted
|
||||
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
||||
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
||||
} else {
|
||||
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
||||
func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Create a test next handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Authenticated with custom audience"))
|
||||
})
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
close(tOidc.initComplete)
|
||||
|
||||
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
||||
// Create a valid token with the custom audience
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.company.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "user-123",
|
||||
"email": "user@company.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a request with authenticated session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
|
||||
// Create session with token
|
||||
resp := httptest.NewRecorder()
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
if err := session.Save(req, resp); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies and add them to a new request
|
||||
cookies := resp.Result().Cookies()
|
||||
req = httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
+120
-71
@@ -11,17 +11,24 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// AuthHandler provides core authentication functionality for OIDC flows
|
||||
type AuthHandler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
|
||||
type ScopeFilter interface {
|
||||
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
|
||||
}
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -30,29 +37,31 @@ type Logger interface {
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler instance
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
func (h *Handler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
session SessionData, redirectURL string,
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
||||
|
||||
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
const maxRedirects = 5
|
||||
@@ -128,7 +137,7 @@ func (h *AuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.R
|
||||
// BuildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
func (h *Handler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.clientID)
|
||||
params.Set("response_type", "code")
|
||||
@@ -144,49 +153,19 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
scopes := make([]string, len(h.scopes))
|
||||
copy(scopes, h.scopes)
|
||||
|
||||
if h.isGoogleProv() {
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
// Apply discovery-based scope filtering if available
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
|
||||
}
|
||||
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else if h.isAzureProv() {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
// Apply provider-specific modifications
|
||||
scopes, params = h.applyProviderSpecificConfig(scopes, params)
|
||||
|
||||
hasOfflineAccess := false
|
||||
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
if !h.overrideScopes || (h.overrideScopes && len(h.scopes) == 0) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)", h.overrideScopes, len(h.scopes))
|
||||
}
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.", len(h.scopes))
|
||||
}
|
||||
// Final filtering pass to remove anything the provider doesn't support
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
@@ -198,10 +177,80 @@ func (h *AuthHandler) BuildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
return h.buildURLWithParams(h.authURL, params)
|
||||
}
|
||||
|
||||
// applyProviderSpecificConfig applies provider-specific scope and parameter modifications
|
||||
func (h *Handler) applyProviderSpecificConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
switch {
|
||||
case h.isGoogleProv():
|
||||
return h.applyGoogleConfig(scopes, params)
|
||||
case h.isAzureProv():
|
||||
return h.applyAzureConfig(scopes, params)
|
||||
default:
|
||||
return h.applyStandardProviderConfig(scopes, params)
|
||||
}
|
||||
}
|
||||
|
||||
// applyGoogleConfig applies Google-specific configuration
|
||||
func (h *Handler) applyGoogleConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
// Google: Remove offline_access if present, add access_type=offline
|
||||
filteredScopes := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
return filteredScopes, params
|
||||
}
|
||||
|
||||
// applyAzureConfig applies Azure AD-specific configuration
|
||||
func (h *Handler) applyAzureConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// applyStandardProviderConfig applies configuration for standard OIDC providers
|
||||
func (h *Handler) applyStandardProviderConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// shouldAddOfflineAccess determines if offline_access scope should be added
|
||||
func (h *Handler) shouldAddOfflineAccess(scopes []string) bool {
|
||||
if h.overrideScopes && len(h.scopes) > 0 {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
// It handles both relative and absolute URLs, validates URL security,
|
||||
// and properly encodes query parameters.
|
||||
func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
func (h *Handler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if baseURL != "" {
|
||||
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
||||
if err := h.validateURL(baseURL); err != nil {
|
||||
@@ -252,7 +301,7 @@ func (h *AuthHandler) buildURLWithParams(baseURL string, params url.Values) stri
|
||||
|
||||
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
||||
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
||||
func (h *AuthHandler) validateURL(urlStr string) error {
|
||||
func (h *Handler) validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("empty URL")
|
||||
}
|
||||
@@ -267,7 +316,7 @@ func (h *AuthHandler) validateURL(urlStr string) error {
|
||||
|
||||
// validateParsedURL validates a parsed URL structure for security.
|
||||
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
||||
func (h *AuthHandler) validateParsedURL(u *url.URL) error {
|
||||
func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
allowedSchemes := map[string]bool{
|
||||
"https": true,
|
||||
"http": true,
|
||||
@@ -298,7 +347,7 @@ func (h *AuthHandler) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
func (h *AuthHandler) validateHost(host string) error {
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
+581
-12
@@ -22,6 +22,28 @@ func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
// mockScopeFilter is a mock implementation of the ScopeFilter interface for testing
|
||||
type mockScopeFilter struct{}
|
||||
|
||||
func (m *mockScopeFilter) FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string {
|
||||
// For testing, just return requested scopes if no supported scopes provided
|
||||
if len(supportedScopes) == 0 {
|
||||
return requestedScopes
|
||||
}
|
||||
// Simple filter logic for tests
|
||||
filtered := make([]string, 0, len(requestedScopes))
|
||||
supportedMap := make(map[string]bool)
|
||||
for _, s := range supportedScopes {
|
||||
supportedMap[s] = true
|
||||
}
|
||||
for _, s := range requestedScopes {
|
||||
if supportedMap[s] {
|
||||
filtered = append(filtered, s)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
@@ -64,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
||||
"test-client-id", "https://example.com/auth", "https://example.com",
|
||||
scopes, false)
|
||||
scopes, false, nil, nil)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -103,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{redirectCount: 5} // At the limit
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -138,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -169,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -200,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -231,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{saveError: &testError{"save failed"}}
|
||||
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
||||
@@ -275,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
||||
@@ -378,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
||||
[]string{"openid", "profile", "email"}, false)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -418,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
[]string{"openid", "profile", "email"}, false)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -446,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false)
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -471,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false)
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -543,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
tt.scopes, tt.overrideScopes)
|
||||
tt.scopes, tt.overrideScopes, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -597,3 +619,550 @@ type testError struct {
|
||||
func (e *testError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// SCOPE FILTERING INTEGRATION TESTS
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_WithScopeFiltering tests scope filtering when enabled
|
||||
func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
// Requested scopes include offline_access
|
||||
scopes := []string{"openid", "profile", "email", "offline_access"}
|
||||
// Provider only supports these
|
||||
scopesSupported := []string{"openid", "profile", "email"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
// offline_access should have been filtered out in the first pass
|
||||
// The standard provider logic then tries to add it back
|
||||
// But the final filtering pass removes it again
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered out when not in scopesSupported")
|
||||
}
|
||||
}
|
||||
|
||||
// Should contain the supported scopes
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in final scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in final scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in final scope string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_WithoutScopeFiltering tests backward compatibility
|
||||
func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
// No scopeFilter or scopesSupported (backward compatibility)
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
|
||||
// All scopes should be present, plus offline_access added by standard provider logic
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "offline_access") {
|
||||
t.Error("Expected offline_access added by standard provider logic")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess tests GitLab scenario
|
||||
func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email", "offline_access"}
|
||||
// GitLab discovery doc doesn't include offline_access
|
||||
scopesSupported := []string{"openid", "profile", "email", "read_user", "read_api"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
|
||||
"https://gitlab.example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
// offline_access should be filtered out
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("GitLab scenario: offline_access should have been filtered out")
|
||||
}
|
||||
}
|
||||
|
||||
// Should contain standard scopes
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in final scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in final scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in final scope string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess tests Google provider
|
||||
func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email", "offline_access"}
|
||||
scopesSupported := []string{"openid", "profile", "email"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://accounts.google.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
actualScope := query.Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
// Google removes offline_access and uses access_type=offline instead
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("Google scenario: offline_access should have been removed by Google-specific logic")
|
||||
}
|
||||
}
|
||||
|
||||
// Google-specific parameters should be present
|
||||
if query.Get("access_type") != "offline" {
|
||||
t.Error("Expected access_type=offline for Google")
|
||||
}
|
||||
if query.Get("prompt") != "consent" {
|
||||
t.Error("Expected prompt=consent for Google")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess tests Azure provider
|
||||
func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
// Azure supports offline_access
|
||||
scopesSupported := []string{"openid", "profile", "email", "offline_access"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
actualScope := query.Get("scope")
|
||||
|
||||
// Azure should add offline_access automatically and it should pass filtering
|
||||
if !strings.Contains(actualScope, "offline_access") {
|
||||
t.Error("Azure scenario: offline_access should be present")
|
||||
}
|
||||
|
||||
// Azure-specific parameter
|
||||
if query.Get("response_mode") != "query" {
|
||||
t.Error("Expected response_mode=query for Azure")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_GenericWithFiltering tests generic provider with discovery filtering
|
||||
func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email", "custom_scope", "offline_access"}
|
||||
scopesSupported := []string{"openid", "profile", "email", "custom_scope"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"generic-client", "https://auth.provider.com/authorize",
|
||||
"https://auth.provider.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
|
||||
// Should contain supported scopes including custom_scope
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "custom_scope") {
|
||||
t.Error("Expected custom_scope in scope string")
|
||||
}
|
||||
|
||||
// offline_access should be filtered out (not in scopesSupported)
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered out when not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering tests override scopes + filtering
|
||||
func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
// User explicitly overrides scopes
|
||||
scopes := []string{"openid", "custom:read", "custom:write"}
|
||||
scopesSupported := []string{"openid", "custom:read"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, true, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
// Should contain only supported scopes from override
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "custom:read") {
|
||||
t.Error("Expected custom:read in scope string")
|
||||
}
|
||||
|
||||
// custom:write should be filtered out
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "custom:write" {
|
||||
t.Error("custom:write should have been filtered out (not supported)")
|
||||
}
|
||||
}
|
||||
|
||||
// offline_access should NOT be auto-added when overrideScopes=true
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should not be auto-added when user overrides scopes")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_DoubleFiltering tests initial + final filtering passes
|
||||
func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
// Provider supports offline_access
|
||||
scopesSupported := []string{"openid", "profile", "email", "offline_access"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
|
||||
// Initial filtering: All requested scopes pass (all in scopesSupported)
|
||||
// Provider-specific logic: Adds offline_access (standard provider)
|
||||
// Final filtering: offline_access should still be present (it's in scopesSupported)
|
||||
if !strings.Contains(actualScope, "offline_access") {
|
||||
t.Error("offline_access should be present (supported by provider and added by logic)")
|
||||
}
|
||||
|
||||
// Original scopes should be present
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in scope string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_NoScopeFilterProvided tests when scopeFilter is nil
|
||||
func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
scopesSupported := []string{"openid", "profile"} // Even with scopesSupported, no filter
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, scopesSupported) // scopeFilter is nil
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
|
||||
// Without scopeFilter, all scopes should be present (no filtering)
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in scope string (no filtering without scopeFilter)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_EmptyScopesSupported tests empty scopesSupported list
|
||||
func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
scopesSupported := []string{} // Empty - backward compatibility mode
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
actualScope := parsedURL.Query().Get("scope")
|
||||
|
||||
// With empty scopesSupported, mockScopeFilter returns requested scopes unchanged
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in scope string")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in scope string")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_FilteringWithPKCE tests scope filtering with PKCE enabled
|
||||
func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "offline_access"}
|
||||
scopesSupported := []string{"openid", "profile"}
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// PKCE parameters should be present
|
||||
if query.Get("code_challenge") != "test-challenge" {
|
||||
t.Error("Expected code_challenge parameter with PKCE enabled")
|
||||
}
|
||||
if query.Get("code_challenge_method") != "S256" {
|
||||
t.Error("Expected code_challenge_method=S256 with PKCE enabled")
|
||||
}
|
||||
|
||||
// Scope filtering should still work
|
||||
actualScope := query.Get("scope")
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered out even with PKCE")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_ComplexScenario tests realistic complex scenario
|
||||
func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
// User configures: openid, profile, email, custom:read, offline_access
|
||||
scopes := []string{"openid", "profile", "email", "custom:read", "offline_access"}
|
||||
|
||||
// Provider discovery returns: openid, profile, email, custom:read, custom:write, admin:all
|
||||
scopesSupported := []string{"openid", "profile", "email", "custom:read", "custom:write", "admin:all"}
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
|
||||
|
||||
parsedURL, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse auth URL: %v", err)
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
|
||||
// Verify basic OAuth parameters
|
||||
if query.Get("client_id") != "complex-client" {
|
||||
t.Error("Expected correct client_id")
|
||||
}
|
||||
if query.Get("response_type") != "code" {
|
||||
t.Error("Expected response_type=code")
|
||||
}
|
||||
if query.Get("state") != "state-123" {
|
||||
t.Error("Expected correct state")
|
||||
}
|
||||
if query.Get("nonce") != "nonce-456" {
|
||||
t.Error("Expected correct nonce")
|
||||
}
|
||||
|
||||
// Verify PKCE parameters
|
||||
if query.Get("code_challenge") != "challenge-789" {
|
||||
t.Error("Expected correct code_challenge")
|
||||
}
|
||||
|
||||
// Verify scope filtering
|
||||
actualScope := query.Get("scope")
|
||||
|
||||
// Should contain: openid, profile, email, custom:read
|
||||
if !strings.Contains(actualScope, "openid") {
|
||||
t.Error("Expected openid in scope")
|
||||
}
|
||||
if !strings.Contains(actualScope, "profile") {
|
||||
t.Error("Expected profile in scope")
|
||||
}
|
||||
if !strings.Contains(actualScope, "email") {
|
||||
t.Error("Expected email in scope")
|
||||
}
|
||||
if !strings.Contains(actualScope, "custom:read") {
|
||||
t.Error("Expected custom:read in scope")
|
||||
}
|
||||
|
||||
// offline_access should be filtered (not in scopesSupported)
|
||||
actualScopes := strings.Split(actualScope, " ")
|
||||
for _, scope := range actualScopes {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered (not in scopesSupported)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_BuildAuthURL_LoggingVerification tests that logging occurs correctly
|
||||
func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
scopeFilter := &mockScopeFilter{}
|
||||
|
||||
scopes := []string{"openid", "profile", "offline_access"}
|
||||
scopesSupported := []string{"openid", "profile"}
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
|
||||
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
// Should have logged debug messages about filtering
|
||||
if len(logger.debugMessages) == 0 {
|
||||
t.Error("Expected debug messages to be logged during scope filtering")
|
||||
}
|
||||
|
||||
// Verify specific log messages were generated
|
||||
hasDiscoveryFilterLog := false
|
||||
hasFinalFilterLog := false
|
||||
hasFinalScopeLog := false
|
||||
|
||||
for _, msg := range logger.debugMessages {
|
||||
if strings.Contains(msg, "After discovery filtering") {
|
||||
hasDiscoveryFilterLog = true
|
||||
}
|
||||
if strings.Contains(msg, "After final filtering") {
|
||||
hasFinalFilterLog = true
|
||||
}
|
||||
if strings.Contains(msg, "Final scope string being sent") {
|
||||
hasFinalScopeLog = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasDiscoveryFilterLog {
|
||||
t.Error("Expected log message about discovery filtering")
|
||||
}
|
||||
if !hasFinalFilterLog {
|
||||
t.Error("Expected log message about final filtering")
|
||||
}
|
||||
if !hasFinalScopeLog {
|
||||
t.Error("Expected log message about final scope string")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -0,0 +1,428 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file contains tests for Auth0-specific audience validation scenarios.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuth0Scenario1WithCustomAudience tests Auth0 scenario 1:
|
||||
// - Custom audience configured in plugin
|
||||
// - Authorize endpoint called WITH audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, custom_audience]
|
||||
// Expected: Both tokens validate correctly
|
||||
func TestAuth0Scenario1WithCustomAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (OIDC standard)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token always has client_id
|
||||
"nonce": "test-nonce-scenario1", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, custom_audience]
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience, // Custom API audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data", // Access tokens have scope
|
||||
"jti": "access-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed (should validate against client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify access token validates against custom audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token validation failed (should validate against custom audience): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL includes audience parameter (URL-encoded)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if !strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should contain audience parameter when custom audience is configured, got: %s", authURL)
|
||||
}
|
||||
// Verify the audience is properly URL-encoded (contains %3A for :, %2F for /)
|
||||
if !strings.Contains(authURL, "audience=https%3A%2F%2Fmy-api.example.com") {
|
||||
t.Errorf("Auth URL should contain URL-encoded custom audience, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2DefaultAudience tests Auth0 scenario 2:
|
||||
// - No custom audience configured (defaults to client_id)
|
||||
// - Authorize endpoint called WITHOUT audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, default_audience] (no client_id)
|
||||
// Expected: ID token validates, access token falls back to ID token validation
|
||||
func TestAuth0Scenario2DefaultAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// No custom audience - defaults to client_id
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token with aud = client_id
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario2", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, some_default_audience]
|
||||
// This represents Auth0's default audience behavior
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://test-issuer.com/api/v2/", // Default Auth0 Management API
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Access token won't have client_id in aud, so it will fail validation
|
||||
// This is expected for scenario 2 - the session validation relies on ID token
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Logf("Access token validation passed (unexpected but OK if client_id is in aud array)")
|
||||
} else {
|
||||
// Expected failure - access token doesn't have client_id in aud
|
||||
t.Logf("Access token validation failed as expected (aud doesn't contain client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL does NOT include audience parameter (since audience == client_id)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should NOT contain audience parameter when audience equals client_id, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario3OpaqueAccessToken tests Auth0 scenario 3:
|
||||
// - No custom audience configured
|
||||
// - No default audience in Auth0
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: opaque (not JWT)
|
||||
// Expected: ID token validates, opaque access token is accepted
|
||||
func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable opaque tokens for this scenario (Option C requirement)
|
||||
ts.tOidc.allowOpaqueTokens = true
|
||||
|
||||
// No custom audience
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario3", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token (not a JWT - just a random string)
|
||||
opaqueAccessToken := "opaque_access_token_random_string_12345"
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token should fail JWT validation (expected)
|
||||
err = ts.tOidc.VerifyToken(opaqueAccessToken)
|
||||
if err == nil {
|
||||
t.Error("Opaque access token should fail JWT validation")
|
||||
} else {
|
||||
t.Logf("Opaque access token correctly rejected by JWT validator: %v", err)
|
||||
}
|
||||
|
||||
// Test that validateStandardTokens handles opaque tokens correctly
|
||||
// by falling back to ID token validation
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0AudienceArrayValidation tests that audience validation
|
||||
// correctly handles array audiences (common in Auth0)
|
||||
func TestAuth0AudienceArrayValidation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with audience as array containing our custom audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience,
|
||||
"https://another-api.example.com",
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data write:data",
|
||||
"jti": "array-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - custom audience is in the array
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token with audience array should validate when custom audience is present: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0MismatchedAudience tests that tokens with wrong audience fail validation
|
||||
func TestAuth0MismatchedAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with WRONG audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://different-api.example.com", // Wrong audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "wrong-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail validation - audience doesn't match
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Error("Access token with wrong audience should fail validation")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2StrictMode tests strict audience validation mode:
|
||||
// - Scenario 2 (access token with wrong audience) should be REJECTED
|
||||
// - strictAudienceValidation=true prevents fallback to ID token
|
||||
// - This addresses Allan's security concerns about audience bypass
|
||||
func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable strict mode to prevent Scenario 2 bypass (Option C)
|
||||
ts.tOidc.strictAudienceValidation = true
|
||||
|
||||
// Configure custom audience
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (valid)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-strict",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with WRONG audience (doesn't include custom audience)
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://wrong-api.example.com", // Wrong audience - not our custom audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Test session validation with wrong access token and valid ID token
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetIDToken(idToken)
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
if !needsRefresh {
|
||||
t.Errorf("Strict mode: Should signal refresh needed after rejection, got needsRefresh=%v", needsRefresh)
|
||||
}
|
||||
if expired {
|
||||
t.Errorf("Strict mode: Should not mark as expired (should try refresh first), got expired=%v", expired)
|
||||
}
|
||||
|
||||
t.Logf("✓ Strict mode correctly rejected Scenario 2 (access token audience mismatch)")
|
||||
}
|
||||
|
||||
// TestIDTokenAlwaysValidatesAgainstClientID verifies that ID tokens
|
||||
// are ALWAYS validated against client_id, regardless of configured audience
|
||||
func TestIDTokenAlwaysValidatesAgainstClientID(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure a custom audience different from client_id
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (per OIDC spec)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token MUST have client_id
|
||||
"nonce": "test-nonce-123", // ID tokens have nonce for replay protection
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-client-id-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - ID tokens are checked against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token should validate against client_id even when custom audience is configured: %v", err)
|
||||
}
|
||||
|
||||
// Create ID token with WRONG audience (should fail)
|
||||
wrongIDToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": customAudience, // WRONG - should be client_id
|
||||
"nonce": "test-nonce-wrong-456", // ID token has nonce, so it will be detected as ID token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "wrong-id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create wrong ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail - ID tokens must have client_id as audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(wrongIDToken)
|
||||
if err == nil {
|
||||
t.Error("ID token with custom audience (not client_id) should fail validation")
|
||||
}
|
||||
}
|
||||
+2
-2
@@ -47,7 +47,7 @@ func (t *TraefikOidc) generatePKCEParameters() (string, string, error) {
|
||||
// prepareSessionForAuthentication clears existing session data and sets new authentication state
|
||||
func (t *TraefikOidc) prepareSessionForAuthentication(session *SessionData, csrfToken, nonce, codeVerifier, incomingPath string) {
|
||||
// Clear all existing session data
|
||||
session.SetAuthenticated(false)
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication state on new flow
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
@@ -276,7 +276,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
// - redirectURL: The callback URL to be used in the new authentication flow.
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.logger.Debug("Handling expired token: Clearing session and initiating re-authentication.")
|
||||
session.SetAuthenticated(false)
|
||||
_ = session.SetAuthenticated(false) // Safe to ignore: clearing authentication on expired token
|
||||
session.SetIDToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGeneratePKCEParameters tests the generatePKCEParameters method
|
||||
func TestGeneratePKCEParameters(t *testing.T) {
|
||||
t.Run("PKCE enabled - successful generation", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE enabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier, "code verifier should not be empty when PKCE is enabled")
|
||||
assert.NotEmpty(t, challenge, "code challenge should not be empty when PKCE is enabled")
|
||||
|
||||
// Verify the challenge is derived from the verifier
|
||||
expectedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, expectedChallenge, challenge, "challenge should match derived challenge from verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE disabled - returns empty strings", func(t *testing.T) {
|
||||
// Create a TraefikOidc instance with PKCE disabled
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: false,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, verifier, "code verifier should be empty when PKCE is disabled")
|
||||
assert.Empty(t, challenge, "code challenge should be empty when PKCE is disabled")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - generates different values each time", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier1, challenge1, err1 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err1)
|
||||
|
||||
verifier2, challenge2, err2 := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.NotEqual(t, verifier1, verifier2, "verifiers should be different")
|
||||
assert.NotEqual(t, challenge1, challenge2, "challenges should be different")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier and challenge relationship", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// The challenge should always be derivable from the verifier
|
||||
recalculatedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, challenge, recalculatedChallenge,
|
||||
"challenge should always match the SHA256 hash of verifier")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - verifier meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
verifier, _, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// RFC 7636 requires verifier to be 43-128 characters
|
||||
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
|
||||
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
|
||||
})
|
||||
|
||||
t.Run("PKCE enabled - challenge meets RFC 7636 requirements", func(t *testing.T) {
|
||||
plugin := &TraefikOidc{
|
||||
enablePKCE: true,
|
||||
logger: NewLogger("debug"),
|
||||
}
|
||||
|
||||
_, challenge, err := plugin.generatePKCEParameters()
|
||||
require.NoError(t, err)
|
||||
|
||||
// SHA256 hash base64 encoded should be 43 characters
|
||||
assert.Equal(t, 43, len(challenge), "S256 challenge should be exactly 43 characters")
|
||||
})
|
||||
}
|
||||
+1
-1
@@ -538,7 +538,7 @@ func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
|
||||
|
||||
// Start the task if not already running
|
||||
if !rm.IsTaskRunning(name) {
|
||||
rm.StartBackgroundTask(name)
|
||||
_ = rm.StartBackgroundTask(name) // Safe to ignore: task registration succeeded, start is best-effort
|
||||
}
|
||||
|
||||
// Get the task from resource manager's internal registry
|
||||
|
||||
@@ -58,6 +58,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tokenURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/token",
|
||||
jwksURL: "https://login.microsoftonline.com/tenant-id/discovery/v2.0/keys",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
refreshGracePeriod: 60 * time.Second,
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestMemoryMonitorComprehensive tests memory monitor edge cases
|
||||
func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
t.Run("TriggerGC calls runtime GC", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.TriggerGC()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("GetMemoryPressure returns pressure level", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Initially should return None (no stats yet)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Equal(t, MemoryPressureNone, pressure)
|
||||
|
||||
// Collect stats to populate lastStats
|
||||
monitor.GetCurrentStats()
|
||||
|
||||
// Now should return a valid pressure level
|
||||
pressure = monitor.GetMemoryPressure()
|
||||
assert.NotNil(t, pressure)
|
||||
})
|
||||
|
||||
t.Run("StartMonitoring can be called", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Start monitoring should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 100*time.Millisecond)
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
})
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("StopMonitoring can be called safely", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// StopMonitoring should not panic even if not started
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
// Can be called multiple times safely
|
||||
assert.NotPanics(t, func() {
|
||||
monitor.StopMonitoring()
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalMemoryMonitor resets singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
// Get initial instance
|
||||
GetGlobalMemoryMonitor()
|
||||
|
||||
// Reset
|
||||
ResetGlobalMemoryMonitor()
|
||||
|
||||
// Should be able to get a new instance
|
||||
monitor := GetGlobalMemoryMonitor()
|
||||
assert.NotNil(t, monitor)
|
||||
|
||||
// Clean up
|
||||
monitor.StopMonitoring()
|
||||
ResetGlobalMemoryMonitor()
|
||||
})
|
||||
|
||||
t.Run("String method returns pressure name", func(t *testing.T) {
|
||||
pressures := []struct {
|
||||
level MemoryPressureLevel
|
||||
name string
|
||||
}{
|
||||
{MemoryPressureNone, "None"},
|
||||
{MemoryPressureLow, "Low"},
|
||||
{MemoryPressureModerate, "Moderate"},
|
||||
{MemoryPressureHigh, "High"},
|
||||
{MemoryPressureCritical, "Critical"},
|
||||
{MemoryPressureLevel(999), "Unknown"},
|
||||
}
|
||||
|
||||
for _, p := range pressures {
|
||||
assert.Equal(t, p.name, p.level.String(), "pressure level %d should return %s", p.level, p.name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetCurrentStats collects statistics", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.NotZero(t, stats.Timestamp)
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskRegistry tests background task registry edge cases
|
||||
func TestBackgroundTaskRegistry(t *testing.T) {
|
||||
t.Run("GetGlobalTaskRegistry returns singleton", func(t *testing.T) {
|
||||
registry1 := GetGlobalTaskRegistry()
|
||||
registry2 := GetGlobalTaskRegistry()
|
||||
|
||||
assert.Equal(t, registry1, registry2, "should return same instance")
|
||||
})
|
||||
|
||||
t.Run("RegisterTask adds task to registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-register-task"
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask(taskName, task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify task was registered
|
||||
_, exists := registry.GetTask(taskName)
|
||||
assert.True(t, exists, "task should be registered")
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("CreateSingletonTask is idempotent", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
taskName := "test-singleton-idempotent"
|
||||
callCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
taskFunc := func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
// First creation should succeed
|
||||
task1, err1 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NotNil(t, task1)
|
||||
|
||||
// Second creation should also succeed (idempotent)
|
||||
// Returns same task without error
|
||||
task2, err2 := registry.CreateSingletonTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
taskFunc,
|
||||
newNoOpLogger(),
|
||||
nil,
|
||||
)
|
||||
|
||||
assert.NoError(t, err2, "CreateSingletonTask should be idempotent")
|
||||
assert.NotNil(t, task2)
|
||||
|
||||
// Clean up
|
||||
if task1 != nil {
|
||||
task1.Stop()
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTaskCount returns active task count", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Initially should be 0 or small number
|
||||
initialCount := registry.GetTaskCount()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"count-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
err := registry.RegisterTask("count-test-task", task)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Count should increase
|
||||
newCount := registry.GetTaskCount()
|
||||
assert.Equal(t, initialCount+1, newCount)
|
||||
|
||||
// Clean up
|
||||
task.Stop()
|
||||
})
|
||||
|
||||
t.Run("StopAllTasks stops all tasks", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
// Create multiple tasks
|
||||
for i := 0; i < 3; i++ {
|
||||
taskName := "multi-task-" + string(rune(i+'0'))
|
||||
task := NewBackgroundTask(
|
||||
taskName,
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask(taskName, task)
|
||||
}
|
||||
|
||||
// Verify tasks were created
|
||||
assert.GreaterOrEqual(t, registry.GetTaskCount(), 3)
|
||||
|
||||
// Stop all tasks
|
||||
registry.StopAllTasks()
|
||||
|
||||
// Verify all tasks are removed
|
||||
taskCount := registry.GetTaskCount()
|
||||
assert.Equal(t, 0, taskCount, "all tasks should be stopped")
|
||||
})
|
||||
|
||||
t.Run("ResetGlobalTaskRegistry clears registry", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
registry := GetGlobalTaskRegistry()
|
||||
|
||||
// Create a task
|
||||
task := NewBackgroundTask(
|
||||
"reset-test-task",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
registry.RegisterTask("reset-test-task", task)
|
||||
|
||||
// Reset
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
// Get new registry
|
||||
newRegistry := GetGlobalTaskRegistry()
|
||||
assert.Equal(t, 0, newRegistry.GetTaskCount(), "new registry should be empty")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBackgroundTaskLifecycle tests background task lifecycle
|
||||
func TestBackgroundTaskLifecycle(t *testing.T) {
|
||||
t.Run("Start begins task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
executed := false
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"lifecycle-test",
|
||||
50*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
executed = true
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Wait for execution
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Verify it executed
|
||||
mu.Lock()
|
||||
wasExecuted := executed
|
||||
mu.Unlock()
|
||||
|
||||
assert.True(t, wasExecuted, "task should have executed")
|
||||
})
|
||||
|
||||
t.Run("Stop halts task execution", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"stop-test",
|
||||
30*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
// Let it run a few times
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Record count
|
||||
mu.Lock()
|
||||
countAfterStop := execCount
|
||||
mu.Unlock()
|
||||
|
||||
// Wait more
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
// Count should not increase
|
||||
mu.Lock()
|
||||
finalCount := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.Equal(t, countAfterStop, finalCount, "task should not execute after stop")
|
||||
})
|
||||
|
||||
t.Run("Multiple Start calls are safe", func(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping background task test in short mode")
|
||||
}
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
execCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-start-test",
|
||||
100*time.Millisecond,
|
||||
func() {
|
||||
mu.Lock()
|
||||
execCount++
|
||||
mu.Unlock()
|
||||
},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Multiple starts should be safe
|
||||
task.Start()
|
||||
task.Start()
|
||||
task.Start()
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(50 * time.Millisecond))
|
||||
|
||||
// Stop task
|
||||
task.Stop()
|
||||
|
||||
// Should have executed, but only one goroutine
|
||||
mu.Lock()
|
||||
count := execCount
|
||||
mu.Unlock()
|
||||
|
||||
assert.GreaterOrEqual(t, count, 0, "task should have executed at least once")
|
||||
})
|
||||
|
||||
t.Run("Multiple Stop calls are safe", func(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
task := NewBackgroundTask(
|
||||
"multi-stop-test",
|
||||
100*time.Millisecond,
|
||||
func() {},
|
||||
newNoOpLogger(),
|
||||
)
|
||||
|
||||
// Start and stop
|
||||
task.Start()
|
||||
time.Sleep(GetTestDuration(20 * time.Millisecond))
|
||||
|
||||
// Multiple stops should be safe
|
||||
assert.NotPanics(t, func() {
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
task.Stop()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryMonitorIntegration tests memory monitor integration
|
||||
func TestMemoryMonitorIntegration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory monitor integration test in short mode")
|
||||
}
|
||||
|
||||
t.Run("monitoring updates stats", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
ResetGlobalTaskRegistry()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalTaskRegistry()
|
||||
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
defer monitor.StopMonitoring()
|
||||
|
||||
// Start monitoring
|
||||
ctx := context.Background()
|
||||
monitor.StartMonitoring(ctx, 50*time.Millisecond)
|
||||
|
||||
// Wait for at least one check
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
|
||||
// Get pressure (should be a valid pressure level)
|
||||
pressure := monitor.GetMemoryPressure()
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, pressure, "pressure should be a valid level")
|
||||
|
||||
// Stop monitoring
|
||||
monitor.StopMonitoring()
|
||||
})
|
||||
|
||||
t.Run("global memory monitor singleton", func(t *testing.T) {
|
||||
ResetGlobalMemoryMonitor()
|
||||
defer ResetGlobalMemoryMonitor()
|
||||
|
||||
monitor1 := GetGlobalMemoryMonitor()
|
||||
monitor2 := GetGlobalMemoryMonitor()
|
||||
|
||||
assert.Equal(t, monitor1, monitor2, "should return same instance")
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryStatsCollection tests memory statistics collection
|
||||
func TestMemoryStatsCollection(t *testing.T) {
|
||||
t.Run("GetCurrentStats returns valid data", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
assert.NotNil(t, stats)
|
||||
assert.Greater(t, stats.HeapAllocBytes, uint64(0))
|
||||
assert.Greater(t, stats.HeapSysBytes, uint64(0))
|
||||
assert.Greater(t, stats.NumGoroutines, 0)
|
||||
assert.False(t, stats.Timestamp.IsZero())
|
||||
})
|
||||
|
||||
t.Run("Stats include memory pressure", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
stats := monitor.GetCurrentStats()
|
||||
|
||||
// Should calculate and include pressure level
|
||||
assert.NotNil(t, stats.MemoryPressure)
|
||||
assert.Contains(t, []MemoryPressureLevel{
|
||||
MemoryPressureNone,
|
||||
MemoryPressureLow,
|
||||
MemoryPressureModerate,
|
||||
MemoryPressureHigh,
|
||||
MemoryPressureCritical,
|
||||
}, stats.MemoryPressure)
|
||||
})
|
||||
|
||||
t.Run("TriggerGC reduces memory", func(t *testing.T) {
|
||||
thresholds := DefaultMemoryAlertThresholds()
|
||||
monitor := NewMemoryMonitor(newNoOpLogger(), thresholds)
|
||||
|
||||
// Allocate some memory
|
||||
_ = make([]byte, 1024*1024) // 1MB
|
||||
|
||||
// Get stats before GC
|
||||
beforeStats := monitor.GetCurrentStats()
|
||||
|
||||
// Trigger GC
|
||||
monitor.TriggerGC()
|
||||
|
||||
// Get stats after GC
|
||||
afterStats := monitor.GetCurrentStats()
|
||||
|
||||
// After GC should have different stats
|
||||
assert.NotEqual(t, beforeStats.LastGCTime, afterStats.LastGCTime)
|
||||
})
|
||||
}
|
||||
+18
-2
@@ -61,6 +61,22 @@ func (cm *CacheManager) GetSharedJWKCache() JWKCacheInterface {
|
||||
return &JWKCache{cache: cm.manager.GetJWKCache()}
|
||||
}
|
||||
|
||||
// GetSharedIntrospectionCache returns the shared token introspection cache
|
||||
// for caching OAuth 2.0 Token Introspection (RFC 7662) results
|
||||
func (cm *CacheManager) GetSharedIntrospectionCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetIntrospectionCache()}
|
||||
}
|
||||
|
||||
// GetSharedTokenTypeCache returns the shared token type cache
|
||||
// for caching token type detection results to improve performance
|
||||
func (cm *CacheManager) GetSharedTokenTypeCache() CacheInterface {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return &CacheInterfaceWrapper{cache: cm.manager.GetTokenTypeCache()}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down all cache components
|
||||
func (cm *CacheManager) Close() error {
|
||||
cm.mu.Lock()
|
||||
@@ -83,7 +99,7 @@ type CacheInterfaceWrapper struct {
|
||||
|
||||
// Set stores a value
|
||||
func (c *CacheInterfaceWrapper) Set(key string, value interface{}, ttl time.Duration) {
|
||||
c.cache.Set(key, value, ttl)
|
||||
_ = c.cache.Set(key, value, ttl) // Safe to ignore: cache set failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves a value
|
||||
@@ -110,7 +126,7 @@ func (c *CacheInterfaceWrapper) Cleanup() {
|
||||
func (c *CacheInterfaceWrapper) Close() {
|
||||
// Close the underlying cache to stop goroutines
|
||||
if c.cache != nil {
|
||||
c.cache.Close()
|
||||
_ = c.cache.Close() // Safe to ignore: closing cache is best-effort during shutdown
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,424 @@
|
||||
# Auth0 Audience Validation Guide
|
||||
|
||||
## Overview
|
||||
|
||||
This guide explains how to configure audience validation for Auth0 and other OIDC providers that support custom API audiences. It covers three common Auth0 scenarios and how to configure the middleware for maximum security.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Understanding Audiences](#understanding-audiences)
|
||||
2. [The Three Auth0 Scenarios](#the-three-auth0-scenarios)
|
||||
3. [Configuration Options](#configuration-options)
|
||||
4. [Security Recommendations](#security-recommendations)
|
||||
5. [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## Understanding Audiences
|
||||
|
||||
### What is an Audience?
|
||||
|
||||
The **audience** (`aud`) claim in a JWT identifies the intended recipient of the token. Per OAuth 2.0 and OIDC specifications:
|
||||
|
||||
- **ID Tokens**: MUST have `aud = client_id` (per OIDC Core 1.0 spec)
|
||||
- **Access Tokens**: Can have custom audiences (e.g., API identifiers)
|
||||
|
||||
### Why Does This Matter?
|
||||
|
||||
Proper audience validation prevents **token confusion attacks** where a token intended for one API is used to access another API.
|
||||
|
||||
---
|
||||
|
||||
## The Three Auth0 Scenarios
|
||||
|
||||
### Scenario 1: Custom API Audience ✅ **RECOMMENDED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com" # Your API identifier from Auth0
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request includes `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "https://my-api.example.com"]`
|
||||
3. Middleware validates:
|
||||
- ID tokens against `client_id`
|
||||
- Access tokens against custom audience
|
||||
|
||||
**Result:** ✅ Fully secure, OIDC compliant
|
||||
|
||||
---
|
||||
|
||||
### Scenario 2: Default Audience (No Custom API) ⚠️ **USE WITH CAUTION**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# audience not specified (defaults to client_id)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Authorization request WITHOUT `audience` parameter
|
||||
2. Auth0 issues:
|
||||
- **ID Token**: `aud = client_id`
|
||||
- **Access Token**: `aud = ["https://issuer/userinfo", "default_api"]` (no `client_id`)
|
||||
3. Access token validation fails (audience mismatch)
|
||||
4. Middleware falls back to ID token validation
|
||||
|
||||
**Security Warning:**
|
||||
```
|
||||
⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!
|
||||
⚠️ This could allow tokens intended for different APIs to grant access
|
||||
⚠️ Set strictAudienceValidation=true to enforce proper audience validation
|
||||
⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74
|
||||
```
|
||||
|
||||
**Recommended Fix:**
|
||||
```yaml
|
||||
strictAudienceValidation: true # Reject sessions with audience mismatch
|
||||
```
|
||||
|
||||
**Result:**
|
||||
- Default: ⚠️ Works but logs security warnings
|
||||
- With strict mode: ✅ Secure (rejects mismatched tokens)
|
||||
|
||||
---
|
||||
|
||||
### Scenario 3: Opaque Access Tokens ✅ **SUPPORTED**
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true # Enable opaque token support
|
||||
requireTokenIntrospection: true # Require introspection (recommended)
|
||||
```
|
||||
|
||||
**What Happens:**
|
||||
1. Auth0 issues opaque (non-JWT) access token
|
||||
2. Middleware detects opaque token (not 3 parts separated by dots)
|
||||
3. Uses OAuth 2.0 Token Introspection (RFC 7662) to validate
|
||||
4. Falls back to ID token if introspection unavailable (unless `requireTokenIntrospection=true`)
|
||||
|
||||
**Requirements:**
|
||||
- Provider must support `introspection_endpoint` in OIDC discovery
|
||||
- Client must have introspection permissions
|
||||
|
||||
**Result:** ✅ Secure with introspection, ⚠️ risky without
|
||||
|
||||
---
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Audience Settings
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `audience` | string | `client_id` | Expected audience for access tokens |
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
# .traefik.yml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
audience: "https://my-api.example.com"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Security Mode Settings
|
||||
|
||||
#### `strictAudienceValidation`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` for production
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects sessions if access token audience doesn't match (prevents Scenario 2)
|
||||
- When `false`: Logs warnings but allows fallback to ID token (backward compatible)
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use in production environments
|
||||
- ✅ When you have custom API audiences configured in Auth0
|
||||
- ⚠️ May break existing deployments relying on Scenario 2 behavior
|
||||
|
||||
---
|
||||
|
||||
#### `allowOpaqueTokens`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Accepts opaque (non-JWT) access tokens
|
||||
- When `false`: Only accepts JWT access tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ When Auth0 issues opaque tokens (no default API configured)
|
||||
- ✅ When using Auth0 Management API tokens
|
||||
- ⚠️ Requires introspection endpoint for security
|
||||
|
||||
---
|
||||
|
||||
#### `requireTokenIntrospection`
|
||||
|
||||
**Type:** boolean
|
||||
**Default:** `false`
|
||||
**Recommended:** `true` when `allowOpaqueTokens=true`
|
||||
|
||||
**What it does:**
|
||||
- When `true`: Rejects opaque tokens if introspection fails or endpoint unavailable
|
||||
- When `false`: Falls back to ID token validation for opaque tokens
|
||||
|
||||
**Example:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- ✅ Always use when `allowOpaqueTokens=true` for maximum security
|
||||
- ⚠️ Requires provider to expose introspection endpoint
|
||||
|
||||
---
|
||||
|
||||
## Security Recommendations
|
||||
|
||||
### Recommended Configuration for Auth0
|
||||
|
||||
**For APIs with custom audiences (Scenario 1):**
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
allowOpaqueTokens: false
|
||||
```
|
||||
|
||||
**For default Auth0 setup (Scenario 2):**
|
||||
```yaml
|
||||
# Don't set audience (defaults to client_id)
|
||||
strictAudienceValidation: true # Enforce proper configuration
|
||||
```
|
||||
|
||||
**For opaque tokens (Scenario 3):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. ✅ **Always set `strictAudienceValidation: true` in production**
|
||||
2. ✅ **Configure custom API audiences in Auth0 dashboard**
|
||||
3. ✅ **Use `requireTokenIntrospection: true` if accepting opaque tokens**
|
||||
4. ✅ **Monitor logs for security warnings**
|
||||
5. ❌ **Don't rely on Scenario 2 fallback behavior**
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Access token validation failed due to audience mismatch"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch
|
||||
```
|
||||
|
||||
**Cause:** Access token audience doesn't match configured audience
|
||||
|
||||
**Solutions:**
|
||||
1. **Configure correct audience:**
|
||||
```yaml
|
||||
audience: "https://your-api-identifier" # From Auth0 API settings
|
||||
```
|
||||
|
||||
2. **Update Auth0 authorization request:**
|
||||
- Ensure `audience` parameter is included in authorize URL
|
||||
- Middleware automatically adds this when `audience != client_id`
|
||||
|
||||
3. **Accept the behavior (not recommended):**
|
||||
```yaml
|
||||
strictAudienceValidation: false # Logs warnings but allows
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### "Opaque token detected but allowOpaqueTokens=false"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque access token detected but allowOpaqueTokens=false
|
||||
```
|
||||
|
||||
**Cause:** Auth0 issued non-JWT access token but middleware not configured to accept them
|
||||
|
||||
**Solutions:**
|
||||
1. **Enable opaque tokens:**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT access tokens:**
|
||||
- Create an API in Auth0 dashboard
|
||||
- Set API identifier as `audience` in configuration
|
||||
|
||||
---
|
||||
|
||||
### "Introspection endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
⚠️ Opaque tokens enabled but no introspection endpoint available from provider
|
||||
```
|
||||
|
||||
**Cause:** Auth0 provider metadata doesn't include `introspection_endpoint`
|
||||
|
||||
**Solutions:**
|
||||
1. **Check provider discovery:**
|
||||
```bash
|
||||
curl https://YOUR_DOMAIN/.well-known/openid-configuration
|
||||
```
|
||||
Look for `introspection_endpoint`
|
||||
|
||||
2. **Disable required introspection (less secure):**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: false # Falls back to ID token
|
||||
```
|
||||
|
||||
3. **Use JWT access tokens instead** (recommended)
|
||||
|
||||
---
|
||||
|
||||
### "Token introspection required but endpoint not available"
|
||||
|
||||
**Symptom:**
|
||||
```
|
||||
❌ SECURITY: Opaque token rejected (introspection required but failed)
|
||||
```
|
||||
|
||||
**Cause:** `requireTokenIntrospection=true` but provider doesn't support it
|
||||
|
||||
**Solutions:**
|
||||
1. **Disable required introspection:**
|
||||
```yaml
|
||||
requireTokenIntrospection: false
|
||||
```
|
||||
|
||||
2. **Configure Auth0 to issue JWT tokens** (better solution)
|
||||
|
||||
---
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Token Type Detection
|
||||
|
||||
The middleware uses a sophisticated 6-step detection algorithm:
|
||||
|
||||
1. **RFC 9068 `typ` header**: `at+jwt` → Access Token
|
||||
2. **Explicit type claims**: `token_use`, `token_type`
|
||||
3. **`scope` claim**: Present → Access Token
|
||||
4. **`nonce` claim**: Present → ID Token (OIDC spec)
|
||||
5. **Audience check**: `aud == client_id` only → ID Token
|
||||
6. **Default**: Access Token
|
||||
|
||||
### OAuth 2.0 Token Introspection (RFC 7662)
|
||||
|
||||
When opaque tokens are detected:
|
||||
|
||||
1. Middleware calls provider's `introspection_endpoint`
|
||||
2. Authenticates using client credentials
|
||||
3. Receives response with `active` status and claims
|
||||
4. Caches result for 5 minutes (configurable via TTL)
|
||||
5. Validates expiration, not-before, and audience if present
|
||||
|
||||
**Cache behavior:**
|
||||
- Cache key: Token hash
|
||||
- TTL: 5 minutes or token expiry (whichever is shorter)
|
||||
- Reduces introspection requests for frequently used tokens
|
||||
|
||||
---
|
||||
|
||||
## Reference Links
|
||||
|
||||
- [GitHub Issue #74](https://github.com/lukaszraczylo/traefikoidc/issues/74) - Original Auth0 audience discussion
|
||||
- [OIDC Core 1.0 Spec](https://openid.net/specs/openid-connect-core-1_0.html) - ID Token requirements
|
||||
- [OAuth 2.0 RFC 6749](https://datatracker.ietf.org/doc/html/rfc6749) - OAuth 2.0 specification
|
||||
- [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662) - OAuth 2.0 Token Introspection
|
||||
- [RFC 9068](https://datatracker.ietf.org/doc/html/rfc9068) - JWT Access Token Profile
|
||||
- [Auth0 API Authorization](https://auth0.com/docs/secure/tokens/access-tokens) - Auth0 audience documentation
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Previous Versions
|
||||
|
||||
**If you're upgrading from a version without these features:**
|
||||
|
||||
1. **No action required for default behavior** - backward compatible
|
||||
2. **Recommended: Enable strict mode gradually**
|
||||
```yaml
|
||||
# Step 1: Enable and monitor logs
|
||||
strictAudienceValidation: false # Default
|
||||
|
||||
# Step 2: After confirming no warnings, enable
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
3. **For opaque tokens: Enable explicitly**
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
```
|
||||
|
||||
### Testing Your Configuration
|
||||
|
||||
1. **Check logs for warnings:**
|
||||
```bash
|
||||
# Look for Scenario 2 warnings
|
||||
grep "SCENARIO 2 DETECTED" /var/log/traefik.log
|
||||
|
||||
# Look for opaque token warnings
|
||||
grep "Opaque" /var/log/traefik.log
|
||||
```
|
||||
|
||||
2. **Test with curl:**
|
||||
```bash
|
||||
# Get token from Auth0
|
||||
ACCESS_TOKEN="your_access_token"
|
||||
|
||||
# Test request
|
||||
curl -H "Authorization: Bearer $ACCESS_TOKEN" \
|
||||
https://your-app.example.com/api
|
||||
```
|
||||
|
||||
3. **Monitor for security warnings in production logs**
|
||||
|
||||
---
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- GitHub Issues: https://github.com/lukaszraczylo/traefikoidc/issues
|
||||
- Security issues: See SECURITY.md for responsible disclosure
|
||||
|
||||
---
|
||||
|
||||
**Last Updated:** 2025-01-09
|
||||
**Version:** 0.7.8+
|
||||
@@ -89,8 +89,9 @@ scopes: ["openid", "profile", "email", "offline_access"]
|
||||
- **Offline access**: Requires `offline_access` scope for refresh tokens
|
||||
- **Access token validation**: Supports both JWT and opaque access tokens
|
||||
- **Tenant isolation**: Can restrict to specific Azure AD tenants
|
||||
- **Application ID URI**: Supports custom audience for protected APIs
|
||||
|
||||
### Example Configuration
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
@@ -108,6 +109,33 @@ http:
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Azure AD API Configuration (Application ID URI)
|
||||
|
||||
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
# Specify the Application ID URI as audience
|
||||
audience: "api://12345678-1234-1234-1234-123456789abc"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
|
||||
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected
|
||||
- For ID token validation only (no API access), you can omit the `audience` parameter
|
||||
|
||||
### Azure App Registration Setup
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to "Azure Active Directory" > "App registrations"
|
||||
@@ -116,6 +144,12 @@ http:
|
||||
5. Create client secret in "Certificates & secrets"
|
||||
6. Configure API permissions for required scopes
|
||||
|
||||
### Azure AD API Exposure Setup (for custom audiences)
|
||||
1. In your App Registration, go to "Expose an API"
|
||||
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
|
||||
3. Add any custom scopes your API exposes
|
||||
4. Update the middleware configuration to include the `audience` parameter with this URI
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
@@ -138,8 +172,9 @@ scopes: ["openid", "profile", "email", "offline_access"]
|
||||
- **Rules and hooks**: Leverages Auth0's extensibility
|
||||
- **Social connections**: Works with Auth0's social identity providers
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
- **API audiences**: Supports custom audience for API access tokens
|
||||
|
||||
### Example Configuration
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
@@ -158,6 +193,34 @@ http:
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Auth0 API Configuration (Custom Audience)
|
||||
|
||||
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
# Specify the Auth0 API identifier as audience
|
||||
audience: "https://api.company.com"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
|
||||
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
|
||||
- For ID token validation only (no APIs), you can omit the `audience` parameter
|
||||
|
||||
### Auth0 Application Setup
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create new application (Regular Web Application)
|
||||
@@ -165,6 +228,14 @@ http:
|
||||
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
|
||||
### Auth0 API Setup (for custom audiences)
|
||||
1. Go to Auth0 Dashboard → APIs
|
||||
2. Create a new API or select existing API
|
||||
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
|
||||
4. In API Settings → Machine to Machine Applications, authorize your application
|
||||
5. Configure API permissions/scopes as needed
|
||||
6. Use the API identifier as the `audience` parameter in your configuration
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
@@ -236,7 +307,7 @@ scopes: ["openid", "profile", "email"]
|
||||
- **Self-hosted support**: Works with self-hosted GitLab instances
|
||||
- **Group membership**: Can restrict by GitLab groups
|
||||
- **Project access**: Can validate project permissions
|
||||
- **Offline access**: Supports refresh tokens with `offline_access`
|
||||
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
@@ -250,7 +321,9 @@ http:
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
scopes: ["openid", "profile", "email"]
|
||||
# Note: GitLab doesn't support the offline_access scope.
|
||||
# Refresh tokens are issued automatically for the openid scope.
|
||||
allowedRolesAndGroups: ["developers", "maintainers"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
@@ -459,8 +532,120 @@ http:
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
### Overview
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
|
||||
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
|
||||
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
|
||||
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
|
||||
|
||||
### Example Scenarios
|
||||
|
||||
#### Self-Hosted GitLab
|
||||
|
||||
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
|
||||
```
|
||||
The requested scope is invalid, unknown, or malformed.
|
||||
```
|
||||
|
||||
**Solution**: The middleware automatically detects this by:
|
||||
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
|
||||
2. Observing that `offline_access` is NOT in the `scopes_supported` list
|
||||
3. Filtering out `offline_access` from the request
|
||||
4. Authentication succeeds
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.example.com"
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
# Even though offline_access is listed, it will be automatically
|
||||
# filtered out if GitLab doesn't support it
|
||||
```
|
||||
|
||||
#### Auth0 or Keycloak
|
||||
|
||||
These providers typically support `offline_access` and it will be included:
|
||||
|
||||
```yaml
|
||||
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
|
||||
# Result: All requested scopes are sent
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
|
||||
2. **No Manual Configuration**: No need to know which scopes each provider supports
|
||||
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
|
||||
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
|
||||
5. **Backward Compatible**: Existing configurations continue to work
|
||||
|
||||
### Logging
|
||||
|
||||
The middleware provides detailed logging for scope filtering:
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
|
||||
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Issue**: Provider rejects scope even after filtering
|
||||
|
||||
**Possible Causes**:
|
||||
1. Provider's discovery document is outdated
|
||||
2. Provider doesn't properly implement `scopes_supported`
|
||||
3. Custom authorization server with non-standard behavior
|
||||
|
||||
**Solutions**:
|
||||
1. Use `overrideScopes: true` and explicitly list only supported scopes
|
||||
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
|
||||
---
|
||||
|
||||
## Common Configuration Options
|
||||
|
||||
### Audience Configuration
|
||||
|
||||
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
|
||||
|
||||
```yaml
|
||||
# Optional: Custom audience for JWT validation
|
||||
# If not set, defaults to clientID for backward compatibility
|
||||
audience: "https://api.example.com" # Auth0 API identifier
|
||||
# OR
|
||||
audience: "api://12345-guid" # Azure AD Application ID URI
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- **Auth0**: When using Auth0 APIs with custom audience parameters
|
||||
- **Azure AD**: When exposing your app as an API with Application ID URI
|
||||
- **Keycloak**: When using audience-restricted tokens
|
||||
- **Okta**: When using custom authorization servers with API audiences
|
||||
|
||||
**When to omit**:
|
||||
- For standard ID token validation (default behavior)
|
||||
- When the provider sets `aud` claim to your `clientID`
|
||||
- For backward compatibility with existing configurations
|
||||
|
||||
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
|
||||
|
||||
### Security Settings
|
||||
```yaml
|
||||
# Force HTTPS (recommended for production)
|
||||
|
||||
+4
-2
@@ -123,8 +123,10 @@ func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
metrics["seconds_since_last_success"] = time.Since(b.lastSuccessTime).Seconds()
|
||||
}
|
||||
|
||||
if metrics["total_requests"].(int64) > 0 {
|
||||
successRate := float64(metrics["total_successes"].(int64)) / float64(metrics["total_requests"].(int64))
|
||||
totalReq, _ := metrics["total_requests"].(int64) // Safe to ignore: type assertion with fallback
|
||||
totalSucc, _ := metrics["total_successes"].(int64) // Safe to ignore: type assertion with fallback
|
||||
if totalReq > 0 {
|
||||
successRate := float64(totalSucc) / float64(totalReq)
|
||||
metrics["success_rate"] = successRate
|
||||
} else {
|
||||
metrics["success_rate"] = 1.0
|
||||
|
||||
@@ -0,0 +1,560 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRetryExecutorReset tests the Reset method
|
||||
func TestRetryExecutorReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
require.NotNil(t, executor)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
executor.Reset()
|
||||
})
|
||||
|
||||
// Multiple resets should be safe
|
||||
executor.Reset()
|
||||
executor.Reset()
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsAvailable tests the IsAvailable method
|
||||
func TestRetryExecutorIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
// Retry executor should always be available
|
||||
assert.True(t, executor.IsAvailable())
|
||||
|
||||
// Should remain available after operations
|
||||
ctx := context.Background()
|
||||
executor.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.True(t, executor.IsAvailable())
|
||||
}
|
||||
|
||||
// TestSessionErrorUnwrap tests SessionError.Unwrap
|
||||
func TestSessionErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("root cause")
|
||||
sessionErr := NewSessionError("save", "failed to save session", rootErr)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
sessionErr := NewSessionError("load", "failed to load session", nil)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("database error")
|
||||
sessionErr := NewSessionError("delete", "failed to delete session", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(sessionErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestTokenErrorUnwrap tests TokenError.Unwrap
|
||||
func TestTokenErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature verification failed")
|
||||
tokenErr := NewTokenError("id_token", "invalid", "token is invalid", rootErr)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
tokenErr := NewTokenError("access_token", "expired", "token has expired", nil)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("crypto error")
|
||||
tokenErr := NewTokenError("refresh_token", "malformed", "token is malformed", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(tokenErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterFallback tests fallback registration
|
||||
func TestGracefulDegradationRegisterFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register single fallback", func(t *testing.T) {
|
||||
fallback := func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
}
|
||||
|
||||
gd.RegisterFallback("service1", fallback)
|
||||
|
||||
// Verify fallback was registered (indirectly)
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return nil, errors.New("service failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("register multiple fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback2", nil
|
||||
})
|
||||
gd.RegisterFallback("service3", func() (interface{}, error) {
|
||||
return "fallback3", nil
|
||||
})
|
||||
|
||||
result2, _ := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
result3, _ := gd.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "fallback2", result2)
|
||||
assert.Equal(t, "fallback3", result3)
|
||||
})
|
||||
|
||||
t.Run("override existing fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "old fallback", nil
|
||||
})
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "new fallback", nil
|
||||
})
|
||||
|
||||
result, _ := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "new fallback", result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterHealthCheck tests health check registration
|
||||
func TestGracefulDegradationRegisterHealthCheck(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register health check", func(t *testing.T) {
|
||||
healthy := true
|
||||
healthCheck := func() bool {
|
||||
return healthy
|
||||
}
|
||||
|
||||
gd.RegisterHealthCheck("service1", healthCheck)
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
|
||||
// Set healthy and wait for health check to run
|
||||
healthy = true
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (may still be degraded due to timing, but health check was registered)
|
||||
})
|
||||
|
||||
t.Run("multiple health checks", func(t *testing.T) {
|
||||
gd.RegisterHealthCheck("service2", func() bool { return true })
|
||||
gd.RegisterHealthCheck("service3", func() bool { return false })
|
||||
|
||||
// Health checks are registered and will be called periodically
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithContext tests ExecuteWithContext
|
||||
func TestGracefulDegradationExecuteWithContext(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("successful execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testErr := errors.New("operation failed")
|
||||
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return testErr
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("uses fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("default", func() (interface{}, error) {
|
||||
return nil, nil // Success fallback
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return errors.New("primary failed")
|
||||
})
|
||||
|
||||
// With fallback succeeding, overall operation succeeds
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithFallback tests ExecuteWithFallback
|
||||
func TestGracefulDegradationExecuteWithFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("primary succeeds", func(t *testing.T) {
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return "primary result", nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "primary result", result)
|
||||
})
|
||||
|
||||
t.Run("fallback succeeds when primary fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("error when no fallback available", func(t *testing.T) {
|
||||
config.EnableFallbacks = false
|
||||
gdNoFallback := NewGracefulDegradation(config, logger)
|
||||
defer gdNoFallback.Close()
|
||||
|
||||
result, err := gdNoFallback.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("fallback also fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback also failed")
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback also failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsServiceDegraded tests service degradation status
|
||||
func TestGracefulDegradationIsServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 100 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("service not degraded initially", func(t *testing.T) {
|
||||
assert.False(t, gd.isServiceDegraded("new-service"))
|
||||
})
|
||||
|
||||
t.Run("service degraded after marking", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
})
|
||||
|
||||
t.Run("service recovers after timeout", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
assert.True(t, gd.isServiceDegraded("service2"))
|
||||
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("service2"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationMarkServiceDegraded tests marking services as degraded
|
||||
func TestGracefulDegradationMarkServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("mark single service", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service1")
|
||||
})
|
||||
|
||||
t.Run("mark multiple services", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service2")
|
||||
assert.Contains(t, degraded, "service3")
|
||||
})
|
||||
|
||||
t.Run("marking same service multiple times updates timestamp", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service4")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
// Service should still be marked as degraded
|
||||
assert.True(t, gd.isServiceDegraded("service4"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteFallback tests fallback execution
|
||||
func TestGracefulDegradationExecuteFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("execute registered fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) {
|
||||
return "fallback value", nil
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback value", result)
|
||||
})
|
||||
|
||||
t.Run("error when fallback not registered", func(t *testing.T) {
|
||||
result, err := gd.executeFallback("non-existent-service")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "no fallback available")
|
||||
})
|
||||
|
||||
t.Run("propagate fallback errors", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback error")
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service2")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback error")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationReset tests Reset method
|
||||
func TestGracefulDegradationReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("reset clears degraded services", func(t *testing.T) {
|
||||
// Mark several services as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 3)
|
||||
|
||||
// Reset
|
||||
gd.Reset()
|
||||
|
||||
// All should be cleared
|
||||
assert.Len(t, gd.GetDegradedServices(), 0)
|
||||
})
|
||||
|
||||
t.Run("can mark services degraded after reset", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 1)
|
||||
assert.Contains(t, gd.GetDegradedServices(), "service4")
|
||||
})
|
||||
|
||||
t.Run("multiple resets are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsAvailable tests IsAvailable method
|
||||
func TestGracefulDegradationIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Should always return true
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even with degraded services
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even after reset
|
||||
gd.Reset()
|
||||
assert.True(t, gd.IsAvailable())
|
||||
}
|
||||
|
||||
// TestGracefulDegradationGetMetrics tests GetMetrics method
|
||||
func TestGracefulDegradationGetMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("basic metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
require.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "degraded_services_count")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
assert.Contains(t, metrics, "registered_fallbacks_count")
|
||||
assert.Contains(t, metrics, "registered_health_checks_count")
|
||||
assert.Contains(t, metrics, "health_check_interval_seconds")
|
||||
assert.Contains(t, metrics, "recovery_timeout_seconds")
|
||||
assert.Contains(t, metrics, "fallbacks_enabled")
|
||||
})
|
||||
|
||||
t.Run("metrics reflect degraded services", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.Equal(t, 2, metrics["degraded_services_count"])
|
||||
degradedList := metrics["degraded_services"].([]string)
|
||||
assert.Len(t, degradedList, 2)
|
||||
})
|
||||
|
||||
t.Run("metrics reflect registered fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) { return nil, nil })
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) { return nil, nil })
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.GreaterOrEqual(t, metrics["registered_fallbacks_count"], 2)
|
||||
})
|
||||
|
||||
t.Run("metrics include base metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
// Should include base recovery mechanism metrics
|
||||
assert.Contains(t, metrics, "name")
|
||||
assert.Contains(t, metrics, "uptime_seconds")
|
||||
assert.Contains(t, metrics, "total_requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationFullScenario tests a complete degradation scenario
|
||||
func TestGracefulDegradationFullScenario(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping full scenario test in short mode")
|
||||
}
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 200 * time.Millisecond
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register fallback
|
||||
gd.RegisterFallback("critical-service", func() (interface{}, error) {
|
||||
return "fallback data", nil
|
||||
})
|
||||
|
||||
// Register health check
|
||||
serviceHealthy := false
|
||||
gd.RegisterHealthCheck("critical-service", func() bool {
|
||||
return serviceHealthy
|
||||
})
|
||||
|
||||
// First call - primary succeeds
|
||||
result1, err1 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "primary data", nil
|
||||
})
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, "primary data", result1)
|
||||
|
||||
// Second call - primary fails, fallback succeeds
|
||||
result2, err2 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service down")
|
||||
})
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, "fallback data", result2)
|
||||
|
||||
// Service is now degraded
|
||||
assert.True(t, gd.isServiceDegraded("critical-service"))
|
||||
|
||||
// Third call - should use fallback immediately
|
||||
result3, err3 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
assert.NoError(t, err3)
|
||||
assert.Equal(t, "fallback data", result3)
|
||||
|
||||
// Mark service as healthy and wait for health check
|
||||
serviceHealthy = true
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (timing-dependent, so we don't assert)
|
||||
|
||||
// Get metrics
|
||||
metrics := gd.GetMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreakerAllowRequestEdgeCases tests edge cases in circuit breaker request allowing
|
||||
func TestCircuitBreakerAllowRequestEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("invalid state returns false", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Force invalid state
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerState(999) // Invalid state
|
||||
cb.mutex.Unlock()
|
||||
|
||||
// Should return false for invalid state
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "invalid state should not allow requests")
|
||||
})
|
||||
|
||||
t.Run("open to half-open transition on timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: baseTimeout,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Verify circuit is open
|
||||
assert.Equal(t, CircuitBreakerOpen, cb.GetState())
|
||||
assert.False(t, cb.allowRequest())
|
||||
|
||||
// Wait for timeout (longer than timeout to ensure transition)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should transition to half-open
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "should allow request after timeout")
|
||||
assert.Equal(t, CircuitBreakerHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("half-open allows requests", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Manually set to half-open
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
cb.mutex.Unlock()
|
||||
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "half-open should allow requests")
|
||||
})
|
||||
|
||||
t.Run("open blocks requests before timeout", func(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 1 * time.Hour, // Long timeout
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should be blocked
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "open circuit should block requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsRetryableErrorEdgeCases tests edge cases for retry decision
|
||||
func TestRetryExecutorIsRetryableErrorEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRetryConfig()
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("nil error is not retryable", func(t *testing.T) {
|
||||
retryable := re.isRetryableError(nil)
|
||||
assert.False(t, retryable)
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 429 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 429,
|
||||
Message: "Too Many Requests",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "429 Too Many Requests should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 500 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 500,
|
||||
Message: "Internal Server Error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "500 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 503 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 503,
|
||||
Message: "Service Unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "503 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 400 is not retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 400,
|
||||
Message: "Bad Request",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.False(t, retryable, "400 errors should not be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with timeout is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: true,
|
||||
temporary: false,
|
||||
msg: "timeout error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "timeout errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection refused is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection refused",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection refused should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection reset is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection reset by peer",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection reset should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with network unreachable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "network is unreachable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "network unreachable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with no route to host is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "no route to host",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "no route to host should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with temporary failure is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "temporary failure in name resolution",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "temporary failure should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with try again is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "try again later",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "try again should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with resource temporarily unavailable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "resource temporarily unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "resource temporarily unavailable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("configured retryable error patterns", func(t *testing.T) {
|
||||
err := errors.New("connection refused by server")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.True(t, retryable, "configured pattern should be retryable")
|
||||
})
|
||||
|
||||
t.Run("non-retryable error", func(t *testing.T) {
|
||||
err := errors.New("invalid input data")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.False(t, retryable, "non-configured error should not be retryable")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorCalculateDelayEdgeCases tests delay calculation edge cases
|
||||
func TestRetryExecutorCalculateDelayEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("delay calculation without jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false, // Jitter disabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 1: 100ms * 2^0 = 100ms
|
||||
delay1 := re.calculateDelay(1)
|
||||
assert.Equal(t, 100*time.Millisecond, delay1)
|
||||
|
||||
// Attempt 2: 100ms * 2^1 = 200ms
|
||||
delay2 := re.calculateDelay(2)
|
||||
assert.Equal(t, 200*time.Millisecond, delay2)
|
||||
|
||||
// Attempt 3: 100ms * 2^2 = 400ms
|
||||
delay3 := re.calculateDelay(3)
|
||||
assert.Equal(t, 400*time.Millisecond, delay3)
|
||||
})
|
||||
|
||||
t.Run("delay calculation with jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true, // Jitter enabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// With jitter, delay should be within 10% of expected
|
||||
delay := re.calculateDelay(2)
|
||||
expectedBase := 200 * time.Millisecond
|
||||
minDelay := time.Duration(float64(expectedBase) * 0.9)
|
||||
maxDelay := time.Duration(float64(expectedBase) * 1.1)
|
||||
|
||||
assert.GreaterOrEqual(t, delay, minDelay, "delay should be >= 90% of base")
|
||||
assert.LessOrEqual(t, delay, maxDelay, "delay should be <= 110% of base")
|
||||
})
|
||||
|
||||
t.Run("delay capped at max delay", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 10,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 500 * time.Millisecond, // Low max delay
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 10: would be 100ms * 2^9 = 51200ms, but capped at 500ms
|
||||
delay := re.calculateDelay(10)
|
||||
assert.Equal(t, 500*time.Millisecond, delay, "delay should be capped at max")
|
||||
})
|
||||
|
||||
t.Run("delay with large backoff factor", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 3.0, // Larger backoff
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 3: 50ms * 3^2 = 450ms
|
||||
delay := re.calculateDelay(3)
|
||||
assert.Equal(t, 450*time.Millisecond, delay)
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorTypesErrorMethodsWithoutCause tests error type Error() methods without cause
|
||||
func TestErrorTypesErrorMethodsWithoutCause(t *testing.T) {
|
||||
t.Run("HTTPError.Error without cause", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 404,
|
||||
Message: "Not Found",
|
||||
}
|
||||
|
||||
errStr := httpErr.Error()
|
||||
assert.Equal(t, "HTTP 404: Not Found", errStr)
|
||||
})
|
||||
|
||||
t.Run("HTTPError.Error with different status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{200, "OK", "HTTP 200: OK"},
|
||||
{301, "Moved", "HTTP 301: Moved"},
|
||||
{401, "Unauthorized", "HTTP 401: Unauthorized"},
|
||||
{500, "Server Error", "HTTP 500: Server Error"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: tc.code,
|
||||
Message: tc.message,
|
||||
}
|
||||
assert.Equal(t, tc.expected, httpErr.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error without cause", func(t *testing.T) {
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_token",
|
||||
Message: "Token validation failed",
|
||||
Context: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Equal(t, "OIDC error [invalid_token]: Token validation failed", errStr)
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature mismatch")
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_signature",
|
||||
Message: "JWT signature invalid",
|
||||
Context: make(map[string]interface{}),
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Contains(t, errStr, "OIDC error [invalid_signature]: JWT signature invalid")
|
||||
assert.Contains(t, errStr, "caused by: signature mismatch")
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error without cause", func(t *testing.T) {
|
||||
sessErr := &SessionError{
|
||||
Operation: "load",
|
||||
Message: "Session not found",
|
||||
SessionID: "sess123",
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Equal(t, "Session error in load: Session not found", errStr)
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("database connection failed")
|
||||
sessErr := &SessionError{
|
||||
Operation: "save",
|
||||
Message: "Failed to persist session",
|
||||
SessionID: "sess456",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Contains(t, errStr, "Session error in save: Failed to persist session")
|
||||
assert.Contains(t, errStr, "caused by: database connection failed")
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error without cause", func(t *testing.T) {
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "access_token",
|
||||
Reason: "expired",
|
||||
Message: "Token has expired",
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Equal(t, "Token error (access_token) - expired: Token has expired", errStr)
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("time check failed")
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "id_token",
|
||||
Reason: "expired",
|
||||
Message: "Token validity period exceeded",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Contains(t, errStr, "Token error (id_token) - expired: Token validity period exceeded")
|
||||
assert.Contains(t, errStr, "caused by: time check failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationHealthChecks tests health check functionality
|
||||
func TestGracefulDegradationHealthChecks(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("performHealthChecks recovers degraded service", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns true
|
||||
healthCheckCalled := false
|
||||
gd.RegisterHealthCheck("test-service", func() bool {
|
||||
healthCheckCalled = true
|
||||
return true // Service is healthy
|
||||
})
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("test-service")
|
||||
|
||||
// Verify service is degraded
|
||||
assert.True(t, gd.isServiceDegraded("test-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Health check should have been called
|
||||
assert.True(t, healthCheckCalled, "health check should be called")
|
||||
|
||||
// Service should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("test-service"), "service should be recovered")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks marks service degraded on failure", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns false
|
||||
gd.RegisterHealthCheck("failing-service", func() bool {
|
||||
return false // Service is unhealthy
|
||||
})
|
||||
|
||||
// Initially not degraded
|
||||
assert.False(t, gd.isServiceDegraded("failing-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Service should be marked degraded
|
||||
assert.True(t, gd.isServiceDegraded("failing-service"), "service should be degraded")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks runs multiple health checks independently", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
service1Checked := false
|
||||
service2Checked := false
|
||||
|
||||
gd.RegisterHealthCheck("service1", func() bool {
|
||||
service1Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
gd.RegisterHealthCheck("service2", func() bool {
|
||||
service2Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
// Manually trigger health checks
|
||||
gd.performHealthChecks()
|
||||
|
||||
assert.True(t, service1Checked, "service1 health check should run")
|
||||
assert.True(t, service2Checked, "service2 health check should run")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks handles empty health checks", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Call performHealthChecks with no registered health checks
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
gd.performHealthChecks()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationServiceRecoveryTimeout tests recovery timeout behavior
|
||||
func TestGracefulDegradationServiceRecoveryTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("service auto-recovers after timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour, // Long interval, won't run during test
|
||||
RecoveryTimeout: baseTimeout,
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("auto-recover-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("auto-recover-service"))
|
||||
|
||||
// Wait for recovery timeout (longer than timeout to ensure recovery)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should auto-recover
|
||||
assert.False(t, gd.isServiceDegraded("auto-recover-service"), "service should auto-recover after timeout")
|
||||
})
|
||||
|
||||
t.Run("service remains degraded before timeout", func(t *testing.T) {
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour,
|
||||
RecoveryTimeout: 1 * time.Hour, // Very long timeout
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("long-timeout-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"))
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(10 * time.Millisecond))
|
||||
|
||||
// Should still be degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"), "service should remain degraded before timeout")
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorRecoveryManagerIntegration tests full integration of error recovery mechanisms
|
||||
func TestErrorRecoveryManagerIntegration(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
erm := NewErrorRecoveryManager(logger)
|
||||
|
||||
t.Run("circuit breaker and retry integration", func(t *testing.T) {
|
||||
// Create a circuit breaker with higher max failures to allow retries
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10, // High threshold
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}, logger)
|
||||
|
||||
erm.mutex.Lock()
|
||||
erm.circuitBreakers["test-service-integration"] = cb
|
||||
erm.mutex.Unlock()
|
||||
|
||||
attempts := 0
|
||||
fn := func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service-integration", fn)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, attempts, 3, "should retry until success")
|
||||
})
|
||||
|
||||
t.Run("circuit breaker opens on repeated failures", func(t *testing.T) {
|
||||
fn := func() error {
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
// First call - should fail after retries
|
||||
err1 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err1)
|
||||
|
||||
// Second call - should fail after retries
|
||||
err2 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err2)
|
||||
|
||||
// Check circuit breaker state
|
||||
cb := erm.GetCircuitBreaker("failing-service")
|
||||
state := cb.GetState()
|
||||
assert.Equal(t, CircuitBreakerOpen, state, "circuit should be open after repeated failures")
|
||||
})
|
||||
|
||||
t.Run("recovery metrics include all mechanisms", func(t *testing.T) {
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "circuit_breakers")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
})
|
||||
}
|
||||
|
||||
// TestContainsHelperFunction tests the contains helper function edge cases
|
||||
func TestContainsHelperFunction(t *testing.T) {
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("prefix match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout error occurred", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("suffix match", func(t *testing.T) {
|
||||
assert.True(t, contains("connection timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("middle match", func(t *testing.T) {
|
||||
assert.True(t, contains("a connection timeout error", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
assert.False(t, contains("connection refused", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("substring longer than string", func(t *testing.T) {
|
||||
assert.False(t, contains("abc", "abcdef"))
|
||||
})
|
||||
|
||||
t.Run("empty substring", func(t *testing.T) {
|
||||
assert.True(t, contains("test", ""))
|
||||
})
|
||||
|
||||
t.Run("empty string", func(t *testing.T) {
|
||||
assert.False(t, contains("", "test"))
|
||||
})
|
||||
|
||||
t.Run("both empty", func(t *testing.T) {
|
||||
assert.True(t, contains("", ""))
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,848 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test Circuit Breaker State Transitions
|
||||
|
||||
func TestCircuitBreakerStateTransitions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
failures int
|
||||
maxFailures int
|
||||
expectedStateBefore string
|
||||
expectedStateAfter string
|
||||
}{
|
||||
{
|
||||
name: "stays closed below threshold",
|
||||
failures: 1,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "closed",
|
||||
},
|
||||
{
|
||||
name: "opens at threshold",
|
||||
failures: 3,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "open",
|
||||
},
|
||||
{
|
||||
name: "opens above threshold",
|
||||
failures: 5,
|
||||
maxFailures: 3,
|
||||
expectedStateBefore: "closed",
|
||||
expectedStateAfter: "open",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: tt.maxFailures,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Verify initial state
|
||||
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateBefore {
|
||||
t.Errorf("Expected initial state %s, got %s", tt.expectedStateBefore, state)
|
||||
}
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < tt.failures; i++ {
|
||||
_ = cb.Execute(func() error {
|
||||
return errors.New("test failure")
|
||||
})
|
||||
}
|
||||
|
||||
// Verify final state
|
||||
if state := circuitBreakerStateToString(cb.GetState()); state != tt.expectedStateAfter {
|
||||
t.Errorf("Expected final state %s, got %s", tt.expectedStateAfter, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerHalfOpenTransition(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open after failures")
|
||||
}
|
||||
|
||||
// Wait for timeout to trigger half-open
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Next request should be allowed (half-open)
|
||||
allowed := false
|
||||
_ = cb.Execute(func() error {
|
||||
allowed = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if !allowed {
|
||||
t.Error("Request should be allowed in half-open state")
|
||||
}
|
||||
|
||||
// Successful request should close the circuit
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Circuit should be closed after successful half-open request, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerHalfOpenFailure(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Wait for half-open
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Fail in half-open state
|
||||
_ = cb.Execute(func() error {
|
||||
return errors.New("fail again")
|
||||
})
|
||||
|
||||
// Should return to open state
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Circuit should be open after half-open failure, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerConcurrency(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
failureCount := int64(0)
|
||||
|
||||
// Concurrent successful requests
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&failureCount, 1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if successCount != 100 {
|
||||
t.Errorf("Expected 100 successful requests, got %d", successCount)
|
||||
}
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics["total_requests"].(int64) != 100 {
|
||||
t.Errorf("Expected 100 total requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerReset(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Circuit should be open")
|
||||
}
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Circuit should be closed after reset")
|
||||
}
|
||||
|
||||
// Should allow requests after reset
|
||||
err := cb.Execute(func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Should allow requests after reset, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerMetrics(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}, nil)
|
||||
|
||||
// Execute some requests
|
||||
_ = cb.Execute(func() error { return nil })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return nil })
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != 3 {
|
||||
t.Errorf("Expected 3 requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["total_successes"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 successes, got %d", metrics["total_successes"])
|
||||
}
|
||||
|
||||
if metrics["total_failures"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
||||
}
|
||||
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state 'closed', got %v", metrics["state"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreakerIsAvailable(t *testing.T) {
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}, nil)
|
||||
|
||||
// Should be available initially
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Circuit should be available initially")
|
||||
}
|
||||
|
||||
// Open the circuit
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
_ = cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should not be available when open
|
||||
if cb.IsAvailable() {
|
||||
t.Error("Circuit should not be available when open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be available in half-open
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Circuit should be available in half-open state")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Retry Executor
|
||||
|
||||
func TestRetryExecutorSuccess(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt for immediate success, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorEventualSuccess(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected success after retries, got %v", err)
|
||||
}
|
||||
|
||||
if attempts != 3 {
|
||||
t.Errorf("Expected 3 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorMaxAttemptsExceeded(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error after max attempts")
|
||||
}
|
||||
|
||||
if attempts != 3 {
|
||||
t.Errorf("Expected 3 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorNonRetryableError(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
err := re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("permanent failure")
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-retryable failure")
|
||||
}
|
||||
|
||||
if attempts != 1 {
|
||||
t.Errorf("Expected 1 attempt for non-retryable error, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorContextCancellation(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
attempts := 0
|
||||
done := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
done <- re.ExecuteWithContext(ctx, func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
}()
|
||||
|
||||
// Cancel after short delay
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
err := <-done
|
||||
|
||||
if err != context.Canceled {
|
||||
t.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
if attempts == 0 {
|
||||
t.Error("Should have attempted at least once")
|
||||
}
|
||||
|
||||
if attempts >= 5 {
|
||||
t.Error("Should not have completed all attempts after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorExponentialBackoff(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 4,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
attempts := 0
|
||||
startTime := time.Now()
|
||||
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
|
||||
// Should have delays: 100ms, 200ms, 400ms = 700ms total (approx)
|
||||
if elapsed < 650*time.Millisecond || elapsed > 850*time.Millisecond {
|
||||
t.Errorf("Expected ~700ms elapsed with exponential backoff, got %v", elapsed)
|
||||
}
|
||||
|
||||
if attempts != 4 {
|
||||
t.Errorf("Expected 4 attempts, got %d", attempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorWithJitter(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
RetryableErrors: []string{"temporary failure"},
|
||||
}, nil)
|
||||
|
||||
// Run multiple times to verify jitter adds variability
|
||||
durations := make([]time.Duration, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
startTime := time.Now()
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
return errors.New("temporary failure")
|
||||
})
|
||||
durations[i] = time.Since(startTime)
|
||||
}
|
||||
|
||||
// Check that not all durations are identical (jitter should add variance)
|
||||
allSame := true
|
||||
for i := 1; i < len(durations); i++ {
|
||||
if durations[i] != durations[0] {
|
||||
allSame = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if allSame {
|
||||
t.Error("Expected jitter to add variability to retry delays")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorNetworkErrors(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "timeout error",
|
||||
err: &mockNetError{timeout: true, temporary: true},
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "temporary network error",
|
||||
err: &mockNetError{timeout: false, temporary: true, msg: "temporary failure"},
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
err: &mockNetError{timeout: false, temporary: false, msg: "connection refused"},
|
||||
shouldRetry: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attempts := 0
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return tt.err
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorHTTPErrors(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
shouldRetry bool
|
||||
}{
|
||||
{"500 Internal Server Error", 500, true},
|
||||
{"502 Bad Gateway", 502, true},
|
||||
{"503 Service Unavailable", 503, true},
|
||||
{"429 Too Many Requests", 429, true},
|
||||
{"400 Bad Request", 400, false},
|
||||
{"404 Not Found", 404, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
attempts := 0
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return &HTTPError{StatusCode: tt.statusCode, Message: "test"}
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts, got %d", expectedAttempts, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryExecutorMetrics(t *testing.T) {
|
||||
re := NewRetryExecutor(RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true,
|
||||
}, nil)
|
||||
|
||||
_ = re.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
metrics := re.GetMetrics()
|
||||
|
||||
if metrics["max_attempts"] != 3 {
|
||||
t.Errorf("Expected max_attempts 3, got %v", metrics["max_attempts"])
|
||||
}
|
||||
|
||||
if metrics["backoff_factor"] != 2.0 {
|
||||
t.Errorf("Expected backoff_factor 2.0, got %v", metrics["backoff_factor"])
|
||||
}
|
||||
|
||||
if metrics["enable_jitter"] != true {
|
||||
t.Errorf("Expected enable_jitter true, got %v", metrics["enable_jitter"])
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Types
|
||||
|
||||
func TestOIDCErrorCreation(t *testing.T) {
|
||||
err := NewOIDCError("invalid_token", "Token is expired", nil)
|
||||
|
||||
if err.Code != "invalid_token" {
|
||||
t.Errorf("Expected code 'invalid_token', got %s", err.Code)
|
||||
}
|
||||
|
||||
if err.Message != "Token is expired" {
|
||||
t.Errorf("Expected message 'Token is expired', got %s", err.Message)
|
||||
}
|
||||
|
||||
expectedMsg := "OIDC error [invalid_token]: Token is expired"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCErrorWithCause(t *testing.T) {
|
||||
cause := errors.New("underlying error")
|
||||
err := NewOIDCError("token_error", "Failed to validate", cause)
|
||||
|
||||
if err.Unwrap() != cause {
|
||||
t.Error("Expected unwrap to return underlying cause")
|
||||
}
|
||||
|
||||
if err.Error() == "" {
|
||||
t.Error("Error string should include cause")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCErrorWithContext(t *testing.T) {
|
||||
err := NewOIDCError("auth_failed", "Authentication failed", nil).
|
||||
WithContext("provider", "google").
|
||||
WithContext("user_id", "12345")
|
||||
|
||||
if err.Context["provider"] != "google" {
|
||||
t.Errorf("Expected provider 'google', got %v", err.Context["provider"])
|
||||
}
|
||||
|
||||
if err.Context["user_id"] != "12345" {
|
||||
t.Errorf("Expected user_id '12345', got %v", err.Context["user_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionErrorCreation(t *testing.T) {
|
||||
err := NewSessionError("save", "Failed to save session", nil)
|
||||
|
||||
if err.Operation != "save" {
|
||||
t.Errorf("Expected operation 'save', got %s", err.Operation)
|
||||
}
|
||||
|
||||
expectedMsg := "Session error in save: Failed to save session"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionErrorWithSessionID(t *testing.T) {
|
||||
err := NewSessionError("load", "Session not found", nil).
|
||||
WithSessionID("sess_12345")
|
||||
|
||||
if err.SessionID != "sess_12345" {
|
||||
t.Errorf("Expected session ID 'sess_12345', got %s", err.SessionID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenErrorCreation(t *testing.T) {
|
||||
err := NewTokenError("id_token", "expired", "Token has expired", nil)
|
||||
|
||||
if err.TokenType != "id_token" {
|
||||
t.Errorf("Expected token type 'id_token', got %s", err.TokenType)
|
||||
}
|
||||
|
||||
if err.Reason != "expired" {
|
||||
t.Errorf("Expected reason 'expired', got %s", err.Reason)
|
||||
}
|
||||
|
||||
expectedMsg := "Token error (id_token) - expired: Token has expired"
|
||||
if err.Error() != expectedMsg {
|
||||
t.Errorf("Expected error string '%s', got '%s'", expectedMsg, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Base Recovery Mechanism
|
||||
|
||||
func TestBaseRecoveryMechanismMetrics(t *testing.T) {
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", nil)
|
||||
|
||||
base.RecordRequest()
|
||||
base.RecordSuccess()
|
||||
base.RecordRequest()
|
||||
base.RecordFailure()
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 requests, got %d", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["total_successes"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 success, got %d", metrics["total_successes"])
|
||||
}
|
||||
|
||||
if metrics["total_failures"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 failure, got %d", metrics["total_failures"])
|
||||
}
|
||||
|
||||
if metrics["success_rate"].(float64) != 0.5 {
|
||||
t.Errorf("Expected success rate 0.5, got %v", metrics["success_rate"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseRecoveryMechanismConcurrentUpdates(t *testing.T) {
|
||||
base := NewBaseRecoveryMechanism("concurrent-test", nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
iterations := 1000
|
||||
|
||||
// Concurrent requests
|
||||
for i := 0; i < iterations; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
base.RecordRequest()
|
||||
if i%2 == 0 {
|
||||
base.RecordSuccess()
|
||||
} else {
|
||||
base.RecordFailure()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics["total_requests"].(int64) != int64(iterations) {
|
||||
t.Errorf("Expected %d requests, got %d", iterations, metrics["total_requests"])
|
||||
}
|
||||
|
||||
totalSuccessesAndFailures := metrics["total_successes"].(int64) + metrics["total_failures"].(int64)
|
||||
if totalSuccessesAndFailures != int64(iterations) {
|
||||
t.Errorf("Expected %d total successes+failures, got %d", iterations, totalSuccessesAndFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Recovery Manager
|
||||
|
||||
func TestErrorRecoveryManagerCreation(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
if erm == nil {
|
||||
t.Fatal("Expected non-nil error recovery manager")
|
||||
}
|
||||
|
||||
if erm.retryExecutor == nil {
|
||||
t.Error("Expected retry executor to be initialized")
|
||||
}
|
||||
|
||||
if erm.gracefulDegradation == nil {
|
||||
t.Error("Expected graceful degradation to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerGetCircuitBreaker(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
cb1 := erm.GetCircuitBreaker("service1")
|
||||
cb2 := erm.GetCircuitBreaker("service1")
|
||||
cb3 := erm.GetCircuitBreaker("service2")
|
||||
|
||||
if cb1 == nil || cb2 == nil || cb3 == nil {
|
||||
t.Fatal("Expected non-nil circuit breakers")
|
||||
}
|
||||
|
||||
// Should return same instance for same service
|
||||
if cb1 != cb2 {
|
||||
t.Error("Expected same circuit breaker instance for same service")
|
||||
}
|
||||
|
||||
// Should return different instances for different services
|
||||
if cb1 == cb3 {
|
||||
t.Error("Expected different circuit breaker instances for different services")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerExecuteWithRecovery(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
success := false
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service", func() error {
|
||||
success = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
t.Error("Expected function to execute")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorRecoveryManagerMetrics(t *testing.T) {
|
||||
erm := NewErrorRecoveryManager(nil)
|
||||
|
||||
// Create some circuit breakers
|
||||
_ = erm.GetCircuitBreaker("service1")
|
||||
_ = erm.GetCircuitBreaker("service2")
|
||||
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
cbMetrics, ok := metrics["circuit_breakers"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected circuit_breakers in metrics")
|
||||
}
|
||||
|
||||
if len(cbMetrics) != 2 {
|
||||
t.Errorf("Expected 2 circuit breakers in metrics, got %d", len(cbMetrics))
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions and types
|
||||
|
||||
func circuitBreakerStateToString(state CircuitBreakerState) string {
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Mock network error for testing
|
||||
type mockNetError struct {
|
||||
timeout bool
|
||||
temporary bool
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockNetError) Error() string { return e.msg }
|
||||
func (e *mockNetError) Timeout() bool { return e.timeout }
|
||||
func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// Ensure mockNetError implements net.Error
|
||||
var _ net.Error = (*mockNetError)(nil)
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.13.0
|
||||
golang.org/x/time v0.14.0
|
||||
)
|
||||
|
||||
require (
|
||||
|
||||
@@ -12,8 +12,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
|
||||
golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
|
||||
@@ -86,7 +86,7 @@ func (m *GoroutineManager) StartPeriodicTask(name string, interval time.Duration
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.logger.Debugf("Periodic task %s cancelled", name)
|
||||
m.logger.Debugf("Periodic task %s canceled", name)
|
||||
return
|
||||
case <-ticker.C:
|
||||
task()
|
||||
|
||||
@@ -0,0 +1,625 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test GoroutineManager Creation
|
||||
|
||||
func TestNewGoroutineManager(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
if gm == nil {
|
||||
t.Fatal("Expected non-nil goroutine manager")
|
||||
}
|
||||
|
||||
if gm.ctx == nil {
|
||||
t.Error("Expected context to be initialized")
|
||||
}
|
||||
|
||||
if gm.cancel == nil {
|
||||
t.Error("Expected cancel function to be initialized")
|
||||
}
|
||||
|
||||
if gm.goroutines == nil {
|
||||
t.Error("Expected goroutines map to be initialized")
|
||||
}
|
||||
|
||||
if gm.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Starting Goroutines
|
||||
|
||||
func TestStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("test-goroutine", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
})
|
||||
|
||||
// Give goroutine time to execute
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if len(status) != 1 {
|
||||
t.Errorf("Expected 1 goroutine in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if _, exists := status["test-goroutine"]; !exists {
|
||||
t.Error("Expected goroutine 'test-goroutine' in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineDuplicate(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start a long-running goroutine
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
// Give first goroutine time to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to start another with same name (should be skipped)
|
||||
gm.StartGoroutine("duplicate-test", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Should only have executed once
|
||||
if counter.Load() != 1 {
|
||||
t.Errorf("Expected counter to be 1 (duplicate should be skipped), got %d", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineContextCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
started := atomic.Bool{}
|
||||
canceled := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("cancel-test", func(ctx context.Context) {
|
||||
started.Store(true)
|
||||
<-ctx.Done()
|
||||
canceled.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !started.Load() {
|
||||
t.Error("Expected goroutine to start")
|
||||
}
|
||||
|
||||
// Stop the goroutine
|
||||
gm.StopGoroutine("cancel-test")
|
||||
|
||||
// Wait for cancellation
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !canceled.Load() {
|
||||
t.Error("Expected goroutine to be canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartGoroutineWithPanic(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("panic-test", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// Give goroutine time to panic and recover
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute before panic")
|
||||
}
|
||||
|
||||
// Check that goroutine is marked as not running after panic
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running after panic")
|
||||
}
|
||||
}
|
||||
|
||||
// Manager should still be functional
|
||||
counter := atomic.Int32{}
|
||||
gm.StartGoroutine("after-panic", func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 1 {
|
||||
t.Error("Expected manager to still be functional after panic recovery")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Periodic Tasks
|
||||
|
||||
func TestStartPeriodicTask(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("periodic-test", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for multiple executions
|
||||
time.Sleep(160 * time.Millisecond)
|
||||
|
||||
// Should have executed at least 2-3 times
|
||||
count := counter.Load()
|
||||
if count < 2 {
|
||||
t.Errorf("Expected periodic task to execute at least 2 times, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartPeriodicTaskCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("cancel-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
})
|
||||
|
||||
// Wait for some executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
// Stop the task
|
||||
gm.StopGoroutine("cancel-periodic")
|
||||
|
||||
countBeforeStop := counter.Load()
|
||||
|
||||
// Wait and verify no more executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
countAfterStop := counter.Load()
|
||||
|
||||
// Allow 1 additional execution (could be in progress when stopped)
|
||||
if countAfterStop > countBeforeStop+1 {
|
||||
t.Errorf("Expected periodic task to stop executing, before: %d, after: %d",
|
||||
countBeforeStop, countAfterStop)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Stopping Goroutines
|
||||
|
||||
func TestStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
stopped := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("stop-test", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
stopped.Store(true)
|
||||
})
|
||||
|
||||
// Wait for goroutine to start
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
gm.StopGoroutine("stop-test")
|
||||
|
||||
// Wait for goroutine to stop
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !stopped.Load() {
|
||||
t.Error("Expected goroutine to be stopped")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["stop-test"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStopGoroutineNonExistent(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Should not panic or error when stopping non-existent goroutine
|
||||
gm.StopGoroutine("non-existent")
|
||||
}
|
||||
|
||||
func TestStopGoroutineAlreadyStopped(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
gm.StartGoroutine("already-stopped", func(ctx context.Context) {
|
||||
// Exit immediately
|
||||
})
|
||||
|
||||
// Wait for goroutine to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Try to stop already-stopped goroutine (should be safe)
|
||||
gm.StopGoroutine("already-stopped")
|
||||
}
|
||||
|
||||
// Test Shutdown
|
||||
|
||||
func TestShutdownGraceful(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
// Start multiple goroutines
|
||||
for i := 0; i < 5; i++ {
|
||||
name := "goroutine-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
<-ctx.Done()
|
||||
counter.Add(-1)
|
||||
})
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if counter.Load() != 5 {
|
||||
t.Errorf("Expected 5 goroutines running, got %d", counter.Load())
|
||||
}
|
||||
|
||||
// Shutdown with generous timeout
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected graceful shutdown, got error: %v", err)
|
||||
}
|
||||
|
||||
if counter.Load() != 0 {
|
||||
t.Errorf("Expected all goroutines to complete cleanup, got %d still running", counter.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownWithTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Start a goroutine that ignores cancellation (bad behavior, but testing timeout)
|
||||
gm.StartGoroutine("stubborn", func(ctx context.Context) {
|
||||
// Simulate a goroutine that takes too long to stop
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Shutdown with very short timeout
|
||||
err := gm.Shutdown(10 * time.Millisecond)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error")
|
||||
}
|
||||
|
||||
if err != ErrShutdownTimeout {
|
||||
t.Errorf("Expected ErrShutdownTimeout, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdownEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown with no goroutines should succeed immediately
|
||||
err := gm.Shutdown(time.Second)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for empty shutdown, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Status
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start multiple goroutines with different states
|
||||
gm.StartGoroutine("running", func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
gm.StartGoroutine("quick", func(ctx context.Context) {
|
||||
// Exits immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if len(status) != 2 {
|
||||
t.Errorf("Expected 2 goroutines in status, got %d", len(status))
|
||||
}
|
||||
|
||||
if runningStatus, exists := status["running"]; exists {
|
||||
if !runningStatus.Running {
|
||||
t.Error("Expected 'running' goroutine to be marked as running")
|
||||
}
|
||||
|
||||
if runningStatus.Name != "running" {
|
||||
t.Errorf("Expected name 'running', got %s", runningStatus.Name)
|
||||
}
|
||||
|
||||
if runningStatus.StartTime.IsZero() {
|
||||
t.Error("Expected non-zero start time")
|
||||
}
|
||||
|
||||
if runningStatus.Runtime <= 0 {
|
||||
t.Error("Expected positive runtime")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'running' goroutine in status")
|
||||
}
|
||||
|
||||
if quickStatus, exists := status["quick"]; exists {
|
||||
if quickStatus.Running {
|
||||
t.Error("Expected 'quick' goroutine to be marked as not running")
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected 'quick' goroutine in status")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatusEmpty(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
status := gm.GetStatus()
|
||||
|
||||
if status == nil {
|
||||
t.Fatal("Expected non-nil status map")
|
||||
}
|
||||
|
||||
if len(status) != 0 {
|
||||
t.Errorf("Expected empty status, got %d entries", len(status))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Concurrent Operations
|
||||
|
||||
func TestConcurrentStartGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(2 * time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
const numGoroutines = 50
|
||||
|
||||
// Start many goroutines concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "concurrent-" + string(rune('0'+id%10)) + string(rune('0'+id/10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
counter.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
counter.Add(-1)
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all to start
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Verify goroutines are tracked
|
||||
status := gm.GetStatus()
|
||||
if len(status) < numGoroutines/2 {
|
||||
t.Errorf("Expected at least %d goroutines, got %d", numGoroutines/2, len(status))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentStopGoroutine(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
const numGoroutines = 20
|
||||
|
||||
// Start goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
name := "stop-concurrent-" + string(rune('0'+i%10))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Stop all concurrently
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
name := "stop-concurrent-" + string(rune('0'+id%10))
|
||||
gm.StopGoroutine(name)
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Verify all stopped
|
||||
status := gm.GetStatus()
|
||||
for _, s := range status {
|
||||
if s.Running {
|
||||
t.Errorf("Expected goroutine %s to be stopped", s.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentGetStatus(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
// Start some goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
name := "status-test-" + string(rune('0'+i))
|
||||
gm.StartGoroutine(name, func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
})
|
||||
}
|
||||
|
||||
// Concurrently read status many times (should not race)
|
||||
done := make(chan struct{})
|
||||
for i := 0; i < 20; i++ {
|
||||
go func() {
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = gm.GetStatus()
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all concurrent reads
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// Test Error Cases
|
||||
|
||||
func TestShutdownTimeoutError(t *testing.T) {
|
||||
err := ErrShutdownTimeout
|
||||
|
||||
if err.Error() != "shutdown timeout: some goroutines did not stop in time" {
|
||||
t.Errorf("Unexpected error message: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// Test Edge Cases
|
||||
|
||||
func TestStartGoroutineAfterShutdown(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// Shutdown immediately
|
||||
_ = gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
// Try to start goroutine after shutdown
|
||||
gm.StartGoroutine("after-shutdown", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
<-ctx.Done()
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Goroutine should have started but context already canceled
|
||||
// It may or may not execute depending on timing, but shouldn't panic
|
||||
status := gm.GetStatus()
|
||||
if _, exists := status["after-shutdown"]; exists {
|
||||
// If it's in status, it was tracked (acceptable)
|
||||
t.Log("Goroutine was tracked even after shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleShutdowns(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
|
||||
// First shutdown
|
||||
err1 := gm.Shutdown(time.Second)
|
||||
if err1 != nil {
|
||||
t.Errorf("Expected first shutdown to succeed, got: %v", err1)
|
||||
}
|
||||
|
||||
// Second shutdown (should not panic or error)
|
||||
err2 := gm.Shutdown(time.Second)
|
||||
if err2 != nil {
|
||||
t.Errorf("Expected second shutdown to succeed, got: %v", err2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoroutineWithImmediateReturn(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
executed := atomic.Bool{}
|
||||
|
||||
gm.StartGoroutine("immediate", func(ctx context.Context) {
|
||||
executed.Store(true)
|
||||
// Return immediately
|
||||
})
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if !executed.Load() {
|
||||
t.Error("Expected goroutine to execute")
|
||||
}
|
||||
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["immediate"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected immediately-returning goroutine to be marked as not running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeriodicTaskPanicRecovery(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
gm := NewGoroutineManager(logger)
|
||||
defer gm.Shutdown(time.Second)
|
||||
|
||||
counter := atomic.Int32{}
|
||||
|
||||
gm.StartPeriodicTask("panic-periodic", 50*time.Millisecond, func() {
|
||||
counter.Add(1)
|
||||
if counter.Load() == 2 {
|
||||
panic("periodic panic")
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for panic to occur
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// After panic, the goroutine should have stopped
|
||||
status := gm.GetStatus()
|
||||
if goroutineStatus, exists := status["panic-periodic"]; exists {
|
||||
if goroutineStatus.Running {
|
||||
t.Error("Expected panicked periodic task to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
+18
-8
@@ -109,7 +109,7 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
client := t.tokenHTTPClient
|
||||
if client == nil {
|
||||
// Use shared transport pool to prevent memory leaks
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
pooledClient := CreateTokenHTTPClient()
|
||||
client = &http.Client{
|
||||
Transport: pooledClient.Transport,
|
||||
@@ -124,7 +124,12 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
// Read tokenURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
tokenURL := t.tokenURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -135,13 +140,13 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType string, code
|
||||
return nil, fmt.Errorf("failed to exchange tokens: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining response body on defer
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on defer
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
limitReader := io.LimitReader(resp.Body, 1024*10)
|
||||
bodyBytes, _ := io.ReadAll(limitReader)
|
||||
bodyBytes, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
@@ -232,7 +237,7 @@ func NewTokenCache() *TokenCache {
|
||||
// - expiration: The duration for which the cache entry should be valid
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
_ = tc.cache.Set(token, claims, expiration) // Safe to ignore: cache failures are non-critical
|
||||
}
|
||||
|
||||
// Get retrieves cached claims for a token.
|
||||
@@ -355,8 +360,13 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
// Read endSessionURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
endSessionURL := t.endSessionURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
|
||||
@@ -245,7 +245,7 @@ func (f *HTTPClientFactory) CreateHTTPClient(config HTTPClientConfig) *http.Clie
|
||||
|
||||
// Add cookie jar if requested
|
||||
if config.UseCookieJar {
|
||||
jar, _ := cookiejar.New(nil)
|
||||
jar, _ := cookiejar.New(nil) // Safe to ignore: cookiejar creation with nil options rarely fails
|
||||
client.Jar = jar
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestOIDCProviderHTTPClientConfigUnit tests OIDCProviderHTTPClientConfig function
|
||||
func TestOIDCProviderHTTPClientConfigUnit(t *testing.T) {
|
||||
config := OIDCProviderHTTPClientConfig()
|
||||
|
||||
// Verify OIDC-specific settings
|
||||
assert.Equal(t, 15*time.Second, config.Timeout, "OIDC provider should have 15s timeout")
|
||||
assert.Equal(t, 100, config.MaxIdleConns, "OIDC provider should have 100 max idle conns")
|
||||
assert.Equal(t, 25, config.MaxIdleConnsPerHost, "OIDC provider should have 25 max idle conns per host")
|
||||
assert.Equal(t, 50, config.MaxConnsPerHost, "OIDC provider should have 50 max conns per host")
|
||||
assert.Equal(t, 90*time.Second, config.IdleConnTimeout, "OIDC provider should have 90s idle conn timeout")
|
||||
assert.True(t, config.UseCookieJar, "OIDC provider should have cookie jar enabled")
|
||||
}
|
||||
|
||||
// TestCreateDefaultClientUnit tests CreateDefaultClient function
|
||||
func TestCreateDefaultClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateDefaultClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "default client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateTokenClientUnit tests CreateTokenClient function
|
||||
func TestCreateTokenClientUnit(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
client := factory.CreateTokenClient()
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport, "client should have transport")
|
||||
assert.NotNil(t, client.Jar, "token client should have cookie jar")
|
||||
assert.Equal(t, 10*time.Second, client.Timeout, "token client should have 10s timeout")
|
||||
}
|
||||
|
||||
// TestCreateHTTPClientWithConfigUnit tests CreateHTTPClientWithConfig function
|
||||
func TestCreateHTTPClientWithConfigUnit(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 5 * time.Second,
|
||||
MaxIdleConns: 20,
|
||||
MaxIdleConnsPerHost: 5,
|
||||
UseCookieJar: true,
|
||||
}
|
||||
|
||||
client := CreateHTTPClientWithConfig(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 5*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar, "client should have cookie jar when configured")
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryCreateHTTPClientValidation tests validation in CreateHTTPClient
|
||||
func TestHTTPClientFactoryCreateHTTPClientValidation(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
t.Run("zero values get defaults", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
// All zero values
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Verify defaults were applied
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("custom values preserved", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: 15 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxRedirects: 3,
|
||||
UseCookieJar: true,
|
||||
ForceHTTP2: true,
|
||||
DisableKeepAlives: true,
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.Equal(t, 15*time.Second, client.Timeout)
|
||||
assert.NotNil(t, client.Jar)
|
||||
})
|
||||
|
||||
t.Run("invalid timeout gets default", func(t *testing.T) {
|
||||
config := HTTPClientConfig{
|
||||
Timeout: -1 * time.Second, // Invalid
|
||||
}
|
||||
|
||||
client := factory.CreateHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
// Should get default due to validation failure
|
||||
assert.Equal(t, 30*time.Second, client.Timeout)
|
||||
})
|
||||
}
|
||||
|
||||
// TestHTTPClientFactoryValidateHTTPClientConfig tests ValidateHTTPClientConfig
|
||||
func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
|
||||
factory := NewHTTPClientFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config HTTPClientConfig
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 50,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 20,
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConns",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConns too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConns: 1500,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConns too high",
|
||||
},
|
||||
{
|
||||
name: "negative MaxIdleConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: -1,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost cannot be negative",
|
||||
},
|
||||
{
|
||||
name: "timeout too high",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Minute,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout too high",
|
||||
},
|
||||
{
|
||||
name: "negative timeout",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: -1 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "timeout must be positive",
|
||||
},
|
||||
{
|
||||
name: "MaxIdleConnsPerHost exceeds MaxConnsPerHost",
|
||||
config: HTTPClientConfig{
|
||||
Timeout: 10 * time.Second,
|
||||
DialTimeout: 5 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
MaxIdleConnsPerHost: 50,
|
||||
MaxConnsPerHost: 10,
|
||||
},
|
||||
wantError: true,
|
||||
errorMsg: "MaxIdleConnsPerHost (50) cannot exceed MaxConnsPerHost (10)",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := factory.ValidateHTTPClientConfig(&tt.config)
|
||||
|
||||
if tt.wantError {
|
||||
assert.Error(t, err)
|
||||
if tt.errorMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errorMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,691 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestSharedTransportPoolGetOrCreateTransport tests transport creation and reuse
|
||||
func TestSharedTransportPoolGetOrCreateTransport(t *testing.T) {
|
||||
t.Run("create new transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
require.NotNil(t, transport)
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount))
|
||||
assert.Len(t, pool.transports, 1)
|
||||
})
|
||||
|
||||
t.Run("reuse existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config)
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Equal(t, transport1, transport2, "should reuse same transport")
|
||||
assert.Equal(t, int32(1), atomic.LoadInt32(&pool.clientCount), "client count should not increase")
|
||||
|
||||
// Check ref count
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 2, shared.refCount, "ref count should be 2")
|
||||
})
|
||||
|
||||
t.Run("client limit enforcement", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 5, // Already at max
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
assert.Nil(t, transport, "should return nil when at client limit")
|
||||
})
|
||||
|
||||
t.Run("client limit with existing transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create first transport
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
transport1 := pool.GetOrCreateTransport(config1)
|
||||
require.NotNil(t, transport1)
|
||||
|
||||
// Set client count to max
|
||||
atomic.StoreInt32(&pool.clientCount, 5)
|
||||
|
||||
// Try to create with different config
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15 // Different config
|
||||
transport2 := pool.GetOrCreateTransport(config2)
|
||||
|
||||
// Should return existing transport since at limit
|
||||
assert.NotNil(t, transport2)
|
||||
assert.Equal(t, transport1, transport2)
|
||||
})
|
||||
|
||||
t.Run("updates last used time", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
firstTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Get again
|
||||
transport2 := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport2)
|
||||
|
||||
pool.mu.RLock()
|
||||
secondTime := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, secondTime.After(firstTime), "lastUsed should be updated")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolReleaseTransport tests transport release
|
||||
func TestSharedTransportPoolReleaseTransport(t *testing.T) {
|
||||
t.Run("decrement ref count", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Get again to increase ref count
|
||||
pool.GetOrCreateTransport(config)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 2, refCount)
|
||||
|
||||
// Release
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
newRefCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
assert.Equal(t, 1, newRefCount)
|
||||
})
|
||||
|
||||
t.Run("ref count reaches zero", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
pool.mu.RUnlock()
|
||||
|
||||
// Release to zero
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
shared := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 0, shared.refCount)
|
||||
assert.NotZero(t, shared.lastUsed, "lastUsed should be set")
|
||||
})
|
||||
|
||||
t.Run("release non-existent transport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
// Create a transport not in the pool
|
||||
fakeTransport := &http.Transport{}
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
pool.ReleaseTransport(fakeTransport)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("release updates last used", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
beforeRelease := time.Now()
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
lastUsed := pool.transports[key].lastUsed
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, lastUsed.After(beforeRelease) || lastUsed.Equal(beforeRelease))
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanup tests cleanup functionality
|
||||
func TestSharedTransportPoolCleanup(t *testing.T) {
|
||||
t.Run("cleanup all transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create multiple transports
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
pool.GetOrCreateTransport(config1)
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 15
|
||||
pool.GetOrCreateTransport(config2)
|
||||
|
||||
assert.Greater(t, len(pool.transports), 0)
|
||||
|
||||
// Cleanup
|
||||
pool.Cleanup()
|
||||
|
||||
assert.Len(t, pool.transports, 0, "all transports should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup cancels context", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
pool.Cleanup()
|
||||
|
||||
select {
|
||||
case <-pool.ctx.Done():
|
||||
// Context was canceled
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("context should be canceled")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cleanup with no transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
pool.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("cleanup closes idle connections", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Cleanup should call CloseIdleConnections on each transport
|
||||
pool.Cleanup()
|
||||
|
||||
// Verify transports map is cleared
|
||||
assert.Empty(t, pool.transports)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolCleanupIdleTransports tests periodic cleanup
|
||||
func TestSharedTransportPoolCleanupIdleTransports(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cleanup goroutine test in short mode")
|
||||
}
|
||||
|
||||
t.Run("cleanup removes idle transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport and release it
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
pool.ReleaseTransport(transport)
|
||||
|
||||
// Set lastUsed to old time
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Start cleanup in background (simulating what would happen)
|
||||
// Note: We're testing the cleanup logic manually here
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should be removed
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.False(t, exists, "old idle transport should be removed")
|
||||
})
|
||||
|
||||
t.Run("cleanup preserves active transports", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create transport with refs
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
// Keep ref count > 0, but set old lastUsed
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup logic
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Transport should still exist (has ref count)
|
||||
pool.mu.RLock()
|
||||
_, exists := pool.transports[key]
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.True(t, exists, "transport with references should be preserved")
|
||||
})
|
||||
|
||||
t.Run("cleanup respects context cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
pool.cleanupIdleTransports(ctx)
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Cancel context
|
||||
cancel()
|
||||
|
||||
// Should exit quickly
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("cleanup goroutine should exit on context cancellation")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestCreatePooledHTTPClient tests pooled client creation
|
||||
func TestCreatePooledHTTPClient(t *testing.T) {
|
||||
t.Run("create client with default config", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.Transport)
|
||||
assert.Equal(t, config.Timeout, client.Timeout)
|
||||
})
|
||||
|
||||
t.Run("create multiple clients reuse transport", func(t *testing.T) {
|
||||
// Reset global pool for clean test
|
||||
globalTransportPoolOnce = sync.Once{}
|
||||
globalTransportPool = nil
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
client1 := CreatePooledHTTPClient(config)
|
||||
client2 := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client1)
|
||||
require.NotNil(t, client2)
|
||||
|
||||
// Should use same transport
|
||||
assert.Equal(t, client1.Transport, client2.Transport)
|
||||
})
|
||||
|
||||
t.Run("redirect policy is set", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 3
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
assert.NotNil(t, client.CheckRedirect)
|
||||
|
||||
// Test redirect limit
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 3; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after max redirects")
|
||||
})
|
||||
|
||||
t.Run("default redirect limit", func(t *testing.T) {
|
||||
config := DefaultHTTPClientConfig()
|
||||
config.MaxRedirects = 0 // Should default to 10
|
||||
|
||||
client := CreatePooledHTTPClient(config)
|
||||
|
||||
require.NotNil(t, client)
|
||||
|
||||
// Test default redirect limit (10)
|
||||
var redirects []*http.Request
|
||||
for i := 0; i < 10; i++ {
|
||||
redirects = append(redirects, &http.Request{})
|
||||
}
|
||||
|
||||
err := client.CheckRedirect(nil, redirects)
|
||||
assert.Error(t, err, "should error after 10 redirects")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGetGlobalTransportPool tests singleton pattern
|
||||
func TestGetGlobalTransportPool(t *testing.T) {
|
||||
t.Run("returns same instance", func(t *testing.T) {
|
||||
pool1 := GetGlobalTransportPool()
|
||||
pool2 := GetGlobalTransportPool()
|
||||
|
||||
assert.Equal(t, pool1, pool2, "should return same singleton instance")
|
||||
})
|
||||
|
||||
t.Run("pool is initialized", func(t *testing.T) {
|
||||
pool := GetGlobalTransportPool()
|
||||
|
||||
require.NotNil(t, pool)
|
||||
assert.NotNil(t, pool.transports)
|
||||
assert.Equal(t, 20, pool.maxConns)
|
||||
assert.Equal(t, int32(5), pool.maxClients)
|
||||
assert.NotNil(t, pool.ctx)
|
||||
assert.NotNil(t, pool.cancel)
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolConcurrency tests thread safety
|
||||
func TestSharedTransportPoolConcurrency(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrency test in short mode")
|
||||
}
|
||||
|
||||
t.Run("concurrent GetOrCreateTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10, // Allow more for concurrency test
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
const numGoroutines = 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
transports := make([]*http.Transport, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
transports[idx] = pool.GetOrCreateTransport(config)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All should get same transport
|
||||
firstTransport := transports[0]
|
||||
for i := 1; i < numGoroutines; i++ {
|
||||
if transports[i] != nil {
|
||||
assert.Equal(t, firstTransport, transports[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent ReleaseTransport", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 10,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
|
||||
// Increase ref count
|
||||
for i := 0; i < 20; i++ {
|
||||
pool.GetOrCreateTransport(config)
|
||||
}
|
||||
|
||||
const numReleases = 20
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < numReleases; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pool.ReleaseTransport(transport)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic and ref count should be decremented
|
||||
pool.mu.RLock()
|
||||
key := pool.configKey(config)
|
||||
refCount := pool.transports[key].refCount
|
||||
pool.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, 1, refCount, "ref count should be 1 after 20 releases from initial 21")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSharedTransportPoolEdgeCases tests edge cases
|
||||
func TestSharedTransportPoolEdgeCases(t *testing.T) {
|
||||
t.Run("config key generation", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
config1.MaxIdleConnsPerHost = 5
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 10
|
||||
config2.MaxIdleConnsPerHost = 5
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.Equal(t, key1, key2, "same config should produce same key")
|
||||
})
|
||||
|
||||
t.Run("different configs produce different keys", func(t *testing.T) {
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
}
|
||||
|
||||
config1 := DefaultHTTPClientConfig()
|
||||
config1.MaxConnsPerHost = 10
|
||||
|
||||
config2 := DefaultHTTPClientConfig()
|
||||
config2.MaxConnsPerHost = 20
|
||||
|
||||
key1 := pool.configKey(config1)
|
||||
key2 := pool.configKey(config2)
|
||||
|
||||
assert.NotEqual(t, key1, key2, "different configs should produce different keys")
|
||||
})
|
||||
|
||||
t.Run("client count decrements on cleanup", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
pool := &SharedTransportPool{
|
||||
transports: make(map[string]*sharedTransport),
|
||||
maxConns: 20,
|
||||
clientCount: 0,
|
||||
maxClients: 5,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
config := DefaultHTTPClientConfig()
|
||||
transport := pool.GetOrCreateTransport(config)
|
||||
require.NotNil(t, transport)
|
||||
|
||||
initialCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(1), initialCount)
|
||||
|
||||
// Release and mark as old
|
||||
pool.ReleaseTransport(transport)
|
||||
pool.mu.Lock()
|
||||
key := pool.configKey(config)
|
||||
pool.transports[key].lastUsed = time.Now().Add(-3 * time.Minute)
|
||||
pool.mu.Unlock()
|
||||
|
||||
// Run cleanup
|
||||
pool.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range pool.transports {
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(pool.transports, transportKey)
|
||||
atomic.AddInt32(&pool.clientCount, -1)
|
||||
}
|
||||
}
|
||||
pool.mu.Unlock()
|
||||
|
||||
finalCount := atomic.LoadInt32(&pool.clientCount)
|
||||
assert.Equal(t, int32(0), finalCount, "client count should decrement on cleanup")
|
||||
})
|
||||
}
|
||||
Vendored
+1
-1
@@ -355,7 +355,7 @@ func (c *Cache) removeItem(key string, item *Item) {
|
||||
|
||||
func (c *Cache) evictLRU() {
|
||||
if elem := c.lruList.Back(); elem != nil {
|
||||
item := elem.Value.(*Item)
|
||||
item, _ := elem.Value.(*Item) // Safe to ignore: type assertion from known type
|
||||
c.removeItem(item.Key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
c.logger.Debugf("Cache: Evicted LRU item key=%s", item.Key)
|
||||
|
||||
Vendored
+2
@@ -1,3 +1,5 @@
|
||||
// Package cache provides high-performance caching implementations for OIDC tokens, metadata, and JWKs.
|
||||
// It includes compatibility wrappers for backward compatibility with existing cache interfaces.
|
||||
package cache
|
||||
|
||||
import (
|
||||
|
||||
@@ -91,7 +91,8 @@ func (e *OIDCError) ToJSON() map[string]any {
|
||||
}
|
||||
|
||||
if e.Details != "" {
|
||||
result["error"].(map[string]any)["details"] = e.Details
|
||||
errorMap, _ := result["error"].(map[string]any) // Safe to ignore: type assertion from known type
|
||||
errorMap["details"] = e.Details
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -130,7 +130,7 @@ func (h *AuthFlowHandler) waitForInitialization(req *http.Request) bool {
|
||||
}
|
||||
return true
|
||||
case <-req.Context().Done():
|
||||
h.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
||||
h.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
return false
|
||||
case <-time.After(30 * time.Second):
|
||||
h.logger.Error("Timeout waiting for OIDC initialization")
|
||||
|
||||
@@ -246,7 +246,7 @@ func TestAuthFlowHandler_waitForInitialization(t *testing.T) {
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "Request cancelled",
|
||||
name: "Request canceled",
|
||||
setupHandler: func() (*AuthFlowHandler, context.CancelFunc) {
|
||||
initComplete := make(chan struct{})
|
||||
handler := &AuthFlowHandler{
|
||||
|
||||
@@ -215,12 +215,12 @@ func (h *SessionHandler) SendErrorResponse(rw http.ResponseWriter, req *http.Req
|
||||
// For AJAX requests, send JSON response
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(statusCode)
|
||||
fmt.Fprintf(rw, `{"error": "%s"}`, message)
|
||||
_, _ = fmt.Fprintf(rw, `{"error": "%s"}`, message) // Safe to ignore: writing error response
|
||||
} else {
|
||||
// For browser requests, send HTML response
|
||||
rw.Header().Set("Content-Type", "text/html")
|
||||
rw.WriteHeader(statusCode)
|
||||
fmt.Fprintf(rw, `<html><body><h1>Error %d</h1><p>%s</p></body></html>`, statusCode, message)
|
||||
_, _ = fmt.Fprintf(rw, `<html><body><h1>Error %d</h1><p>%s</p></body></html>`, statusCode, message) // Safe to ignore: writing error response
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -81,8 +81,8 @@ func (rp *RequestProcessor) WaitForInitialization(req *http.Request, initComplet
|
||||
case <-initComplete:
|
||||
return nil
|
||||
case <-req.Context().Done():
|
||||
rp.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
||||
return fmt.Errorf("request cancelled")
|
||||
rp.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
return fmt.Errorf("request canceled")
|
||||
case <-time.After(30 * time.Second):
|
||||
rp.logger.Error("Timeout waiting for OIDC initialization")
|
||||
return fmt.Errorf("timeout waiting for OIDC provider initialization")
|
||||
|
||||
@@ -383,7 +383,7 @@ func TestWaitForInitialization(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Request context cancelled", func(t *testing.T) {
|
||||
t.Run("Request context canceled", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
req = req.WithContext(ctx)
|
||||
@@ -396,15 +396,15 @@ func TestWaitForInitialization(t *testing.T) {
|
||||
|
||||
err := processor.WaitForInitialization(req, initComplete)
|
||||
if err == nil {
|
||||
t.Error("Expected error when request context is cancelled")
|
||||
t.Error("Expected error when request context is canceled")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "request cancelled") {
|
||||
t.Errorf("Expected 'request cancelled' error, got: %v", err)
|
||||
if !strings.Contains(err.Error(), "request canceled") {
|
||||
t.Errorf("Expected 'request canceled' error, got: %v", err)
|
||||
}
|
||||
|
||||
if len(logger.DebugCalls) == 0 {
|
||||
t.Error("Expected debug log when request is cancelled")
|
||||
t.Error("Expected debug log when request is canceled")
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
+20
-12
@@ -119,7 +119,7 @@ func newManager() *Manager {
|
||||
// Initialize compression pools
|
||||
m.gzipWriterPool = &sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
|
||||
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed) // Safe to ignore: factory function
|
||||
return w
|
||||
},
|
||||
}
|
||||
@@ -178,13 +178,17 @@ func (m *Manager) GetBuffer(sizeHint int) *bytes.Buffer {
|
||||
|
||||
switch {
|
||||
case sizeHint <= 1024:
|
||||
return m.smallBufferPool.Get().(*bytes.Buffer)
|
||||
buf, _ := m.smallBufferPool.Get().(*bytes.Buffer) // Safe to ignore: pool return is best-effort
|
||||
return buf
|
||||
case sizeHint <= 4096:
|
||||
return m.mediumBufferPool.Get().(*bytes.Buffer)
|
||||
buf, _ := m.mediumBufferPool.Get().(*bytes.Buffer) // Safe to ignore: pool return is best-effort
|
||||
return buf
|
||||
case sizeHint <= 8192:
|
||||
return m.largeBufferPool.Get().(*bytes.Buffer)
|
||||
buf, _ := m.largeBufferPool.Get().(*bytes.Buffer) // Safe to ignore: pool return is best-effort
|
||||
return buf
|
||||
case sizeHint <= 16384:
|
||||
return m.xlBufferPool.Get().(*bytes.Buffer)
|
||||
buf, _ := m.xlBufferPool.Get().(*bytes.Buffer) // Safe to ignore: pool return is best-effort
|
||||
return buf
|
||||
default:
|
||||
// For very large buffers, create new ones
|
||||
return bytes.NewBuffer(make([]byte, 0, sizeHint))
|
||||
@@ -225,7 +229,8 @@ func (m *Manager) PutBuffer(buf *bytes.Buffer) {
|
||||
// GetGzipWriter returns a gzip writer from the pool
|
||||
func (m *Manager) GetGzipWriter() *gzip.Writer {
|
||||
atomic.AddUint64(&m.stats.GzipGets, 1)
|
||||
return m.gzipWriterPool.Get().(*gzip.Writer)
|
||||
w, _ := m.gzipWriterPool.Get().(*gzip.Writer) // Safe to ignore: pool return is best-effort
|
||||
return w
|
||||
}
|
||||
|
||||
// PutGzipWriter returns a gzip writer to the pool
|
||||
@@ -245,7 +250,8 @@ func (m *Manager) GetGzipReader() *gzip.Reader {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.(*gzip.Reader)
|
||||
reader, _ := r.(*gzip.Reader) // Safe to ignore: pool return is best-effort
|
||||
return reader
|
||||
}
|
||||
|
||||
// PutGzipReader returns a gzip reader to the pool
|
||||
@@ -254,14 +260,14 @@ func (m *Manager) PutGzipReader(r *gzip.Reader) {
|
||||
return
|
||||
}
|
||||
atomic.AddUint64(&m.stats.GzipPuts, 1)
|
||||
r.Reset(nil)
|
||||
_ = r.Reset(nil) // Safe to ignore: resetting to nil reader for pool reuse
|
||||
m.gzipReaderPool.Put(r)
|
||||
}
|
||||
|
||||
// GetStringBuilder returns a string builder from the pool
|
||||
func (m *Manager) GetStringBuilder() *strings.Builder {
|
||||
atomic.AddUint64(&m.stats.StringGets, 1)
|
||||
sb := m.stringBuilderPool.Get().(*strings.Builder)
|
||||
sb, _ := m.stringBuilderPool.Get().(*strings.Builder) // Safe to ignore: pool return is best-effort
|
||||
sb.Reset()
|
||||
return sb
|
||||
}
|
||||
@@ -287,7 +293,8 @@ func (m *Manager) PutStringBuilder(sb *strings.Builder) {
|
||||
// GetJWTBuffer returns JWT parsing buffers from the pool
|
||||
func (m *Manager) GetJWTBuffer() *JWTBuffer {
|
||||
atomic.AddUint64(&m.stats.JWTGets, 1)
|
||||
return m.jwtBufferPool.Get().(*JWTBuffer)
|
||||
buf, _ := m.jwtBufferPool.Get().(*JWTBuffer) // Safe to ignore: pool return is best-effort
|
||||
return buf
|
||||
}
|
||||
|
||||
// PutJWTBuffer returns JWT parsing buffers to the pool
|
||||
@@ -314,7 +321,8 @@ func (m *Manager) PutJWTBuffer(buf *JWTBuffer) {
|
||||
// GetHTTPResponseBuffer returns an HTTP response buffer from the pool
|
||||
func (m *Manager) GetHTTPResponseBuffer() []byte {
|
||||
atomic.AddUint64(&m.stats.HTTPGets, 1)
|
||||
return *m.httpResponsePool.Get().(*[]byte)
|
||||
buf, _ := m.httpResponsePool.Get().(*[]byte) // Safe to ignore: pool return is best-effort
|
||||
return *buf
|
||||
}
|
||||
|
||||
// PutHTTPResponseBuffer returns an HTTP response buffer to the pool
|
||||
@@ -363,7 +371,7 @@ func (m *Manager) GetByteSlice(size int) []byte {
|
||||
m.poolMu.Unlock()
|
||||
}
|
||||
|
||||
b := pool.Get().(*[]byte)
|
||||
b, _ := pool.Get().(*[]byte) // Safe to ignore: pool return is best-effort
|
||||
return (*b)[:size]
|
||||
}
|
||||
|
||||
|
||||
@@ -155,7 +155,9 @@ func (r *ProviderRegistry) detectProviderUnsafe(issuerURL string) OIDCProvider {
|
||||
return p
|
||||
}
|
||||
case ProviderTypeGitLab:
|
||||
if strings.Contains(host, "gitlab.com") {
|
||||
// Match gitlab.com, self-hosted (gitlab.*), and instances with gitlab in subdomain
|
||||
if strings.Contains(host, "gitlab.com") ||
|
||||
strings.Contains(host, "gitlab") {
|
||||
return p
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
@@ -238,6 +239,26 @@ func TestProviderRegistry_DetectProvider(t *testing.T) {
|
||||
issuerURL: "https://gitlab.com/oauth",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "GitLab self-hosted detection - gitlab subdomain",
|
||||
issuerURL: "https://gitlab.example.com",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "GitLab self-hosted detection - gitlab in domain",
|
||||
issuerURL: "https://my-gitlab.company.io",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "GitLab self-hosted detection - gitlab prefix",
|
||||
issuerURL: "https://gitlab-prod.internal.net",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "GitLab self-hosted detection - gitlab suffix",
|
||||
issuerURL: "https://company-gitlab.net",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "Generic provider fallback",
|
||||
issuerURL: "https://auth.example.com",
|
||||
@@ -482,6 +503,206 @@ func TestProviderRegistry_DoubleCheckedLocking(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderRegistry_DetectGitLabSelfHosted tests improved GitLab detection for issue #61
|
||||
func TestProviderRegistry_DetectGitLabSelfHosted(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
genericProvider := NewGenericProvider()
|
||||
gitlabProvider := NewGitLabProvider()
|
||||
githubProvider := NewGitHubProvider()
|
||||
|
||||
registry.RegisterProvider(genericProvider)
|
||||
registry.RegisterProvider(gitlabProvider)
|
||||
registry.RegisterProvider(githubProvider)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
expected OIDCProvider
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "GitLab.com official",
|
||||
issuerURL: "https://gitlab.com",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect official GitLab.com",
|
||||
},
|
||||
{
|
||||
name: "GitLab.com with path",
|
||||
issuerURL: "https://gitlab.com/oauth/authorize",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect GitLab.com with path",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted gitlab.example.com",
|
||||
issuerURL: "https://gitlab.example.com",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab as subdomain",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted my.gitlab.io",
|
||||
issuerURL: "https://my.gitlab.io",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab in domain",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted example-gitlab.com",
|
||||
issuerURL: "https://example-gitlab.com",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab as suffix",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted gitlab-prod.company.net",
|
||||
issuerURL: "https://gitlab-prod.company.net",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab as prefix",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted my-gitlab.internal",
|
||||
issuerURL: "https://my-gitlab.internal",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab in middle of host",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted company.gitlab.services",
|
||||
issuerURL: "https://company.gitlab.services",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect gitlab in middle of domain",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted with port",
|
||||
issuerURL: "https://gitlab.example.com:8443",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect GitLab with custom port",
|
||||
},
|
||||
{
|
||||
name: "Self-hosted with path and query",
|
||||
issuerURL: "https://gitlab.example.com/oauth?param=value",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect GitLab with complex URL",
|
||||
},
|
||||
{
|
||||
name: "Case insensitive - GITLAB",
|
||||
issuerURL: "https://GITLAB.example.com",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect GitLab case-insensitively",
|
||||
},
|
||||
{
|
||||
name: "Case insensitive - GitLab",
|
||||
issuerURL: "https://GitLab.example.com",
|
||||
expected: gitlabProvider,
|
||||
description: "Should detect GitLab with mixed case",
|
||||
},
|
||||
{
|
||||
name: "Not GitLab - git prefix only",
|
||||
issuerURL: "https://github.com",
|
||||
expected: githubProvider, // Should match GitHub provider, not GitLab
|
||||
description: "Should not match github.com as GitLab",
|
||||
},
|
||||
{
|
||||
name: "Not GitLab - lab suffix only",
|
||||
issuerURL: "https://mylab.example.com",
|
||||
expected: genericProvider,
|
||||
description: "Should not match partial gitlab string",
|
||||
},
|
||||
{
|
||||
name: "Not GitLab - git and lab separate",
|
||||
issuerURL: "https://git.mylab.example.com",
|
||||
expected: genericProvider,
|
||||
description: "Should not match git and lab when not together",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Clear cache to ensure fresh detection
|
||||
registry.ClearCache()
|
||||
|
||||
result := registry.DetectProvider(tt.issuerURL)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("%s: Expected %v, got %v", tt.description, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderRegistry_GitLabDetection_RealWorldURLs tests real-world GitLab URLs
|
||||
func TestProviderRegistry_GitLabDetection_RealWorldURLs(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
genericProvider := NewGenericProvider()
|
||||
gitlabProvider := NewGitLabProvider()
|
||||
githubProvider := NewGitHubProvider()
|
||||
|
||||
registry.RegisterProvider(genericProvider)
|
||||
registry.RegisterProvider(gitlabProvider)
|
||||
registry.RegisterProvider(githubProvider)
|
||||
|
||||
realWorldTests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
expected OIDCProvider
|
||||
}{
|
||||
// Actual self-hosted GitLab examples from issue #61
|
||||
{
|
||||
name: "Company self-hosted GitLab",
|
||||
issuerURL: "https://gitlab.company.com",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "Organization GitLab instance with gitlab in subdomain",
|
||||
issuerURL: "https://gitlab.organization.org",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "Internal GitLab server",
|
||||
issuerURL: "https://gitlab.internal.corp",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
{
|
||||
name: "GitLab with custom subdomain",
|
||||
issuerURL: "https://code.gitlab.mycompany.com",
|
||||
expected: gitlabProvider,
|
||||
},
|
||||
// Negative cases to ensure we don't over-match
|
||||
{
|
||||
name: "GitHub should not match GitLab",
|
||||
issuerURL: "https://github.com",
|
||||
expected: githubProvider,
|
||||
},
|
||||
{
|
||||
name: "Generic git server",
|
||||
issuerURL: "https://git.example.com",
|
||||
expected: genericProvider,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range realWorldTests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
registry.ClearCache()
|
||||
result := registry.DetectProvider(tt.issuerURL)
|
||||
|
||||
if result != tt.expected {
|
||||
var expectedType, resultType string
|
||||
if tt.expected != nil {
|
||||
expectedType = fmt.Sprintf("%v", tt.expected.GetType())
|
||||
} else {
|
||||
expectedType = "nil"
|
||||
}
|
||||
if result != nil {
|
||||
resultType = fmt.Sprintf("%v", result.GetType())
|
||||
} else {
|
||||
resultType = "nil"
|
||||
}
|
||||
|
||||
t.Errorf("Expected provider type %s, got %s for URL %s",
|
||||
expectedType, resultType, tt.issuerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkProviderRegistry_DetectProvider_Cached(b *testing.B) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
@@ -381,7 +381,7 @@ func NewTestSuite() *TestSuite {
|
||||
func (ts *TestSuite) Setup() {
|
||||
// Common test setup
|
||||
ts.Logger.Clear()
|
||||
ts.Session.Clear(nil, nil)
|
||||
_ = ts.Session.Clear(nil, nil) // Safe to ignore: test helper function
|
||||
ts.TokenCache.Clear()
|
||||
ts.TokenVerifier.ShouldFail = false
|
||||
ts.TokenVerifier.Error = nil
|
||||
|
||||
@@ -586,6 +586,7 @@ func TestIssue67_TokenResilienceRecursionBug(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenResilienceManager: resilienceManager,
|
||||
tokenHTTPClient: &http.Client{
|
||||
@@ -671,6 +672,7 @@ func TestIssue67_TokenResilienceManager_NoRecursion(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenResilienceManager: resilienceManager,
|
||||
tokenHTTPClient: &http.Client{
|
||||
@@ -738,6 +740,7 @@ func TestIssue67_DirectRecursionDetection(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test",
|
||||
audience: "test",
|
||||
clientSecret: "test",
|
||||
tokenResilienceManager: NewTokenResilienceManager(config, logger),
|
||||
tokenHTTPClient: &http.Client{Timeout: 2 * time.Second},
|
||||
|
||||
@@ -100,7 +100,7 @@ func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
c.cache.Set(jwksURL, jwks, 1*time.Hour)
|
||||
_ = c.cache.Set(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
return jwks, nil
|
||||
}
|
||||
@@ -126,10 +126,10 @@ func fetchJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*J
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching JWKS: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
body, _ := io.ReadAll(resp.Body) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("JWKS fetch failed with status %d: %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewJWKCache tests JWK cache creation
|
||||
func TestNewJWKCache(t *testing.T) {
|
||||
cache := NewJWKCache()
|
||||
|
||||
require.NotNil(t, cache)
|
||||
assert.NotNil(t, cache.cache, "cache should have underlying universal cache")
|
||||
}
|
||||
|
||||
// TestJWKCacheGetJWKS tests JWKS fetching and caching
|
||||
func TestJWKCacheGetJWKS(t *testing.T) {
|
||||
t.Run("fetch from remote on cache miss", func(t *testing.T) {
|
||||
// Create mock server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := JWKSet{
|
||||
Keys: []JWK{
|
||||
{
|
||||
Kid: "key1",
|
||||
Kty: "RSA",
|
||||
Use: "sig",
|
||||
Alg: "RS256",
|
||||
N: "test-n-value",
|
||||
E: "AQAB",
|
||||
},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, jwks)
|
||||
assert.Len(t, jwks.Keys, 1)
|
||||
assert.Equal(t, "key1", jwks.Keys[0].Kid)
|
||||
assert.Equal(t, "RSA", jwks.Keys[0].Kty)
|
||||
})
|
||||
|
||||
t.Run("return cached value on cache hit", func(t *testing.T) {
|
||||
fetchCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fetchCount++
|
||||
jwks := JWKSet{
|
||||
Keys: []JWK{
|
||||
{Kid: "key1", Kty: "RSA"},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
// First fetch - should hit server
|
||||
jwks1, err1 := cache.GetJWKS(ctx, server.URL, client)
|
||||
require.NoError(t, err1)
|
||||
assert.Equal(t, 1, fetchCount, "should fetch from server on first call")
|
||||
|
||||
// Second fetch - should use cache
|
||||
jwks2, err2 := cache.GetJWKS(ctx, server.URL, client)
|
||||
require.NoError(t, err2)
|
||||
assert.Equal(t, 1, fetchCount, "should not fetch from server on second call")
|
||||
|
||||
// Both should return same data
|
||||
assert.Equal(t, jwks1.Keys[0].Kid, jwks2.Keys[0].Kid)
|
||||
})
|
||||
|
||||
t.Run("handle server error", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("server error"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, jwks)
|
||||
assert.Contains(t, err.Error(), "500")
|
||||
})
|
||||
|
||||
t.Run("handle empty JWKS", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := JWKSet{Keys: []JWK{}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, jwks)
|
||||
assert.Contains(t, err.Error(), "no keys")
|
||||
})
|
||||
|
||||
t.Run("handle invalid JSON", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("invalid json"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, jwks)
|
||||
assert.Contains(t, err.Error(), "parsing")
|
||||
})
|
||||
|
||||
t.Run("handle multiple keys", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := JWKSet{
|
||||
Keys: []JWK{
|
||||
{Kid: "key1", Kty: "RSA", Alg: "RS256"},
|
||||
{Kid: "key2", Kty: "RSA", Alg: "RS256"},
|
||||
{Kid: "key3", Kty: "EC", Alg: "ES256"},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx := context.Background()
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, jwks.Keys, 3)
|
||||
assert.Equal(t, "key1", jwks.Keys[0].Kid)
|
||||
assert.Equal(t, "key2", jwks.Keys[1].Kid)
|
||||
assert.Equal(t, "key3", jwks.Keys[2].Kid)
|
||||
})
|
||||
|
||||
t.Run("context cancellation", func(t *testing.T) {
|
||||
// Create server that delays response
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
jwks := JWKSet{Keys: []JWK{{Kid: "key1"}}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
client := http.DefaultClient
|
||||
|
||||
jwks, err := cache.GetJWKS(ctx, server.URL, client)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, jwks)
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWKSetGetKey tests the GetKey method
|
||||
func TestJWKSetGetKey(t *testing.T) {
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{
|
||||
{Kid: "key1", Kty: "RSA", Alg: "RS256"},
|
||||
{Kid: "key2", Kty: "RSA", Alg: "RS384"},
|
||||
{Kid: "key3", Kty: "EC", Alg: "ES256"},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("find existing key", func(t *testing.T) {
|
||||
key := jwks.GetKey("key2")
|
||||
|
||||
require.NotNil(t, key)
|
||||
assert.Equal(t, "key2", key.Kid)
|
||||
assert.Equal(t, "RS384", key.Alg)
|
||||
})
|
||||
|
||||
t.Run("return nil for non-existent key", func(t *testing.T) {
|
||||
key := jwks.GetKey("non-existent")
|
||||
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
|
||||
t.Run("find first key", func(t *testing.T) {
|
||||
key := jwks.GetKey("key1")
|
||||
|
||||
require.NotNil(t, key)
|
||||
assert.Equal(t, "key1", key.Kid)
|
||||
})
|
||||
|
||||
t.Run("find last key", func(t *testing.T) {
|
||||
key := jwks.GetKey("key3")
|
||||
|
||||
require.NotNil(t, key)
|
||||
assert.Equal(t, "key3", key.Kid)
|
||||
assert.Equal(t, "EC", key.Kty)
|
||||
})
|
||||
|
||||
t.Run("empty key set returns nil", func(t *testing.T) {
|
||||
emptyJWKS := &JWKSet{Keys: []JWK{}}
|
||||
key := emptyJWKS.GetKey("any-key")
|
||||
|
||||
assert.Nil(t, key)
|
||||
})
|
||||
|
||||
t.Run("case sensitive key ID", func(t *testing.T) {
|
||||
key1 := jwks.GetKey("key1")
|
||||
key2 := jwks.GetKey("KEY1")
|
||||
|
||||
assert.NotNil(t, key1)
|
||||
assert.Nil(t, key2, "key ID lookup should be case sensitive")
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWKCacheCleanupAndClose tests the no-op Cleanup and Close methods
|
||||
func TestJWKCacheCleanupAndClose(t *testing.T) {
|
||||
cache := NewJWKCache()
|
||||
require.NotNil(t, cache)
|
||||
|
||||
t.Run("cleanup is safe to call", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("close is safe to call", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("multiple cleanup calls are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Cleanup()
|
||||
cache.Cleanup()
|
||||
cache.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("multiple close calls are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Close()
|
||||
cache.Close()
|
||||
cache.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("operations work after cleanup", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := JWKSet{Keys: []JWK{{Kid: "key1", Kty: "RSA"}}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache.Cleanup()
|
||||
|
||||
// Should still work
|
||||
jwks, err := cache.GetJWKS(context.Background(), server.URL, http.DefaultClient)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jwks)
|
||||
})
|
||||
|
||||
t.Run("operations work after close", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
jwks := JWKSet{Keys: []JWK{{Kid: "key2", Kty: "RSA"}}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache.Close()
|
||||
|
||||
// Should still work (close is a no-op)
|
||||
jwks, err := cache.GetJWKS(context.Background(), server.URL, http.DefaultClient)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jwks)
|
||||
})
|
||||
}
|
||||
|
||||
// TestFetchJWKS tests the fetchJWKS helper function indirectly through GetJWKS
|
||||
func TestFetchJWKSEdgeCases(t *testing.T) {
|
||||
t.Run("handles various HTTP status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
status int
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{200, false, ""},
|
||||
{400, true, "400"},
|
||||
{401, true, "401"},
|
||||
{403, true, "403"},
|
||||
{404, true, "404"},
|
||||
{500, true, "500"},
|
||||
{502, true, "502"},
|
||||
{503, true, "503"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("status_%d", tc.status), func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(tc.status)
|
||||
if tc.status == 200 {
|
||||
jwks := JWKSet{Keys: []JWK{{Kid: "key1"}}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
} else {
|
||||
w.Write([]byte("error"))
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
jwks, err := cache.GetJWKS(context.Background(), server.URL, http.DefaultClient)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tc.errContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errContains)
|
||||
}
|
||||
assert.Nil(t, jwks)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jwks)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("handles response body reading", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Write valid JSON
|
||||
jwks := JWKSet{
|
||||
Keys: []JWK{
|
||||
{Kid: "test-key", Kty: "RSA", Alg: "RS256"},
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
jwks, err := cache.GetJWKS(context.Background(), server.URL, http.DefaultClient)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, jwks.Keys, 1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWKCacheConcurrency tests concurrent access to JWK cache
|
||||
func TestJWKCacheConcurrency(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrency test in short mode")
|
||||
}
|
||||
|
||||
fetchCount := 0
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fetchCount++
|
||||
time.Sleep(10 * time.Millisecond) // Simulate some processing
|
||||
jwks := JWKSet{Keys: []JWK{{Kid: "key1", Kty: "RSA"}}}
|
||||
json.NewEncoder(w).Encode(jwks)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cache := NewJWKCache()
|
||||
const numGoroutines = 10
|
||||
|
||||
// Launch multiple concurrent requests
|
||||
done := make(chan bool, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
jwks, err := cache.GetJWKS(context.Background(), server.URL, http.DefaultClient)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, jwks)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all to complete
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// With caching and mutex protection, server should only be hit once or very few times
|
||||
// (may be hit more than once due to race between first requests)
|
||||
assert.LessOrEqual(t, fetchCount, 3, "should use cache for most requests")
|
||||
}
|
||||
@@ -257,12 +257,12 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
// not-before time (if present), and prevents replay attacks using JTI claims.
|
||||
// Parameters:
|
||||
// - issuerURL: Expected issuer URL to validate against
|
||||
// - clientID: Expected audience (client ID) to validate against
|
||||
// - expectedAudience: Expected audience to validate against (can be clientID or custom audience)
|
||||
// - skipReplayCheck: Optional parameter to skip replay attack protection
|
||||
//
|
||||
// Returns:
|
||||
// - An error describing the first validation failure encountered
|
||||
func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error {
|
||||
func (j *JWT) Verify(issuerURL, expectedAudience string, skipReplayCheck ...bool) error {
|
||||
alg, ok := j.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'alg' header")
|
||||
@@ -290,7 +290,7 @@ func (j *JWT) Verify(issuerURL, clientID string, skipReplayCheck ...bool) error
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'aud' claim")
|
||||
}
|
||||
if err := verifyAudience(aud, clientID); err != nil {
|
||||
if err := verifyAudience(aud, expectedAudience); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -11,12 +11,30 @@ var (
|
||||
singletonNoOpLogger *Logger
|
||||
// noOpLoggerOnce ensures the singleton is created only once
|
||||
noOpLoggerOnce sync.Once
|
||||
// noOpLoggerMu protects access to the singleton logger during reset
|
||||
noOpLoggerMu sync.RWMutex
|
||||
)
|
||||
|
||||
// GetSingletonNoOpLogger returns the singleton no-op logger instance.
|
||||
// This reduces memory allocation by reusing the same no-op logger
|
||||
// instance across the entire application.
|
||||
func GetSingletonNoOpLogger() *Logger {
|
||||
noOpLoggerMu.RLock()
|
||||
if singletonNoOpLogger != nil {
|
||||
logger := singletonNoOpLogger
|
||||
noOpLoggerMu.RUnlock()
|
||||
return logger
|
||||
}
|
||||
noOpLoggerMu.RUnlock()
|
||||
|
||||
noOpLoggerMu.Lock()
|
||||
defer noOpLoggerMu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if singletonNoOpLogger != nil {
|
||||
return singletonNoOpLogger
|
||||
}
|
||||
|
||||
noOpLoggerOnce.Do(func() {
|
||||
singletonNoOpLogger = &Logger{
|
||||
logError: log.New(io.Discard, "", 0),
|
||||
@@ -29,6 +47,9 @@ func GetSingletonNoOpLogger() *Logger {
|
||||
|
||||
// ResetSingletonNoOpLogger resets the singleton instance (mainly for testing)
|
||||
func ResetSingletonNoOpLogger() {
|
||||
noOpLoggerMu.Lock()
|
||||
defer noOpLoggerMu.Unlock()
|
||||
|
||||
noOpLoggerOnce = sync.Once{}
|
||||
singletonNoOpLogger = nil
|
||||
}
|
||||
|
||||
@@ -152,14 +152,26 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return config.PostLogoutRedirectURI
|
||||
}(),
|
||||
tokenBlacklist: cacheManager.GetSharedTokenBlacklist(),
|
||||
jwkCache: cacheManager.GetSharedJWKCache(),
|
||||
metadataCache: cacheManager.GetSharedMetadataCache(),
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
tokenBlacklist: cacheManager.GetSharedTokenBlacklist(),
|
||||
tokenTypeCache: cacheManager.GetSharedTokenTypeCache(), // Cache for token type detection
|
||||
jwkCache: cacheManager.GetSharedJWKCache(),
|
||||
metadataCache: cacheManager.GetSharedMetadataCache(),
|
||||
introspectionCache: cacheManager.GetSharedIntrospectionCache(), // Cache for introspection results
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: func() string {
|
||||
if config.Audience != "" {
|
||||
return config.Audience
|
||||
}
|
||||
return config.ClientID
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
strictAudienceValidation: config.StrictAudienceValidation,
|
||||
allowOpaqueTokens: config.AllowOpaqueTokens,
|
||||
requireTokenIntrospection: config.RequireTokenIntrospection,
|
||||
disableReplayDetection: config.DisableReplayDetection,
|
||||
scopes: func() []string {
|
||||
userProvidedScopes := deduplicateScopes(config.Scopes)
|
||||
|
||||
@@ -192,9 +204,17 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
cancelFunc: cancelFunc,
|
||||
suppressDiagnosticLogs: isTestMode(),
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
}
|
||||
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, t.logger)
|
||||
// Log audience configuration
|
||||
if config.Audience != "" && config.Audience != config.ClientID {
|
||||
t.logger.Infof("Custom audience configured: %s", config.Audience)
|
||||
} else {
|
||||
t.logger.Debugf("No custom audience specified, using clientID as audience: %s", t.clientID)
|
||||
}
|
||||
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, config.CookieDomain, t.logger) // Safe to ignore: session manager creation with fallback to defaults
|
||||
t.errorRecoveryManager = NewErrorRecoveryManager(t.logger)
|
||||
|
||||
// Initialize token resilience manager with default configuration
|
||||
@@ -284,11 +304,11 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
t.initializeMetadata(config.ProviderURL)
|
||||
}()
|
||||
|
||||
// Setup cleanup hook for when context is cancelled
|
||||
// Setup cleanup hook for when context is canceled
|
||||
if pluginCtx != nil {
|
||||
go func() {
|
||||
<-pluginCtx.Done()
|
||||
t.Close()
|
||||
_ = t.Close() // Safe to ignore: cleanup on context cancellation
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -341,16 +361,31 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
|
||||
// updateMetadataEndpoints updates internal endpoint URLs with discovered metadata.
|
||||
// It sets the authorization URL, token URL, JWKS URL, issuer URL, revocation URL,
|
||||
// and end session URL based on the provider's metadata.
|
||||
// end session URL, and introspection URL based on the provider's metadata.
|
||||
// Parameters:
|
||||
// - metadata: A pointer to the ProviderMetadata struct containing the discovered endpoints.
|
||||
func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
t.metadataMu.Lock()
|
||||
defer t.metadataMu.Unlock()
|
||||
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.scopesSupported = metadata.ScopesSupported // Store supported scopes from discovery
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
t.endSessionURL = metadata.EndSessionURL
|
||||
t.introspectionURL = metadata.IntrospectionURL // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
|
||||
// Log introspection endpoint availability for opaque token support
|
||||
if t.introspectionURL != "" {
|
||||
t.logger.Debugf("Token introspection endpoint discovered: %s", t.introspectionURL)
|
||||
if t.allowOpaqueTokens {
|
||||
t.logger.Debugf("Opaque token support enabled with introspection endpoint")
|
||||
}
|
||||
} else if t.allowOpaqueTokens || t.requireTokenIntrospection {
|
||||
t.logger.Infof("⚠️ Opaque tokens enabled but no introspection endpoint available from provider")
|
||||
}
|
||||
}
|
||||
|
||||
// startMetadataRefresh starts a background goroutine that periodically refreshes provider metadata.
|
||||
@@ -390,7 +425,7 @@ func (t *TraefikOidc) startMetadataRefresh(providerURL string) {
|
||||
|
||||
// Start the task if not already running
|
||||
if !rm.IsTaskRunning(taskName) {
|
||||
rm.StartBackgroundTask(taskName)
|
||||
_ = rm.StartBackgroundTask(taskName) // Safe to ignore: task registration succeeded, start is best-effort
|
||||
t.logger.Debug("Started singleton metadata refresh task")
|
||||
} else {
|
||||
t.logger.Debug("Metadata refresh task already running, skipping duplicate")
|
||||
|
||||
+18
-1
@@ -10,6 +10,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -37,6 +38,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -82,6 +84,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
enablePKCE: true,
|
||||
tokenHTTPClient: &http.Client{
|
||||
@@ -116,6 +119,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/invalid",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -146,6 +150,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/expired",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -176,6 +181,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/timeout",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 100 * time.Millisecond,
|
||||
@@ -206,6 +212,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/error",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -236,6 +243,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/malformed",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -266,6 +274,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/incomplete",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -299,6 +308,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/slow",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -329,6 +339,7 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/ratelimit",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -482,13 +493,17 @@ func TestExchangeCodeForToken_Comprehensive(t *testing.T) {
|
||||
// TestExchangeCodeForToken_Integration tests integration scenarios
|
||||
func TestExchangeCodeForToken_Integration(t *testing.T) {
|
||||
t.Run("multiple concurrent exchanges", func(t *testing.T) {
|
||||
// Use atomic counter for unique token generation to handle race detector slowdown
|
||||
var tokenCounter int64
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Add small delay to test concurrency
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Generate unique token using atomic counter
|
||||
tokenID := atomic.AddInt64(&tokenCounter, 1)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(TokenResponse{
|
||||
AccessToken: fmt.Sprintf("token_%d", time.Now().UnixNano()),
|
||||
AccessToken: fmt.Sprintf("token_%d", tokenID),
|
||||
IDToken: "test_id_token",
|
||||
RefreshToken: "test_refresh_token",
|
||||
TokenType: "Bearer",
|
||||
@@ -500,6 +515,7 @@ func TestExchangeCodeForToken_Integration(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -586,6 +602,7 @@ func TestExchangeCodeForToken_Integration(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// TestGoroutineLeakPrevention_ContextCancellation tests that goroutines are properly cleaned up
|
||||
// when the context is cancelled during middleware initialization and operation
|
||||
// when the context is canceled during middleware initialization and operation
|
||||
func TestGoroutineLeakPrevention_ContextCancellation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -21,19 +21,19 @@ func TestGoroutineLeakPrevention_ContextCancellation(t *testing.T) {
|
||||
name: "immediate_cancellation",
|
||||
cancelAfter: 1 * time.Millisecond,
|
||||
expectedLeaks: 10, // Allow for background tasks (replay-cache-cleanup, health-check, etc.)
|
||||
description: "Context cancelled immediately during initialization",
|
||||
description: "Context canceled immediately during initialization",
|
||||
},
|
||||
{
|
||||
name: "quick_cancellation",
|
||||
cancelAfter: 50 * time.Millisecond,
|
||||
expectedLeaks: 5, // Allow for some background task leaks during cancellation
|
||||
description: "Context cancelled during metadata initialization",
|
||||
description: "Context canceled during metadata initialization",
|
||||
},
|
||||
{
|
||||
name: "delayed_cancellation",
|
||||
cancelAfter: 200 * time.Millisecond,
|
||||
expectedLeaks: 5, // Allow for some background task leaks during cancellation
|
||||
description: "Context cancelled after partial initialization",
|
||||
description: "Context canceled after partial initialization",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func TestGoroutineLeakPrevention_ContextCancellation(t *testing.T) {
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Initialization completed (or was cancelled)
|
||||
// Initialization completed (or was canceled)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Plugin initialization did not complete within timeout")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -71,6 +72,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/expired",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -97,6 +99,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/invalid",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -123,6 +126,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/revoked",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -149,6 +153,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/timeout",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 100 * time.Millisecond,
|
||||
@@ -175,6 +180,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/error",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -201,6 +207,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/malformed",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -228,6 +235,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/partial",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -259,6 +267,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/ratelimit",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -285,6 +294,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -315,6 +325,7 @@ func TestGetNewTokenWithRefreshToken(t *testing.T) {
|
||||
return &TraefikOidc{
|
||||
tokenURL: server.URL + "/token/rotating",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -519,6 +530,7 @@ func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -588,6 +600,7 @@ func TestGetNewTokenWithRefreshToken_Concurrency(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -642,6 +655,7 @@ func TestGetNewTokenWithRefreshToken_ErrorRecovery(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
tokenURL: server.URL + "/token",
|
||||
clientID: "test_client",
|
||||
audience: "test_client",
|
||||
clientSecret: "test_secret",
|
||||
tokenHTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
|
||||
@@ -135,7 +135,7 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
go func() {
|
||||
time.Sleep(shortTimeout)
|
||||
if time.Since(start) >= shortTimeout {
|
||||
// Simulate timeout by cancelling
|
||||
// Simulate timeout by canceling
|
||||
close(done)
|
||||
}
|
||||
}()
|
||||
@@ -192,6 +192,7 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logoutURLPath: "/logout",
|
||||
tokenURL: "https://provider.example.com/token",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
tokenHTTPClient: http.DefaultClient,
|
||||
}
|
||||
@@ -297,6 +298,7 @@ func TestProcessAuthorizedRequest(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
authURL: "https://provider.example.com/auth",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
redirURLPath: "/callback",
|
||||
}
|
||||
},
|
||||
|
||||
@@ -124,6 +124,7 @@ func (ts *TestSuite) Setup() {
|
||||
ts.tOidc = &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -1304,6 +1305,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
|
||||
// Add potentially missing fields based on New() comparison
|
||||
clientID: ts.tOidc.clientID,
|
||||
audience: ts.tOidc.clientID,
|
||||
issuerURL: ts.tOidc.issuerURL,
|
||||
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
|
||||
httpClient: ts.tOidc.httpClient,
|
||||
@@ -1668,6 +1670,7 @@ func TestHandleLogout(t *testing.T) {
|
||||
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
||||
httpClient: &http.Client{},
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
tokenCache: NewTokenCache(),
|
||||
forceHTTPS: false,
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -1035,6 +1036,305 @@ func TestGoroutineLeakPrevention(t *testing.T) {
|
||||
suite.runner.RunMemoryLeakTests(t, tests)
|
||||
}
|
||||
|
||||
// TestLazyBackgroundTask tests LazyBackgroundTask specific functionality
|
||||
func TestLazyBackgroundTask(t *testing.T) {
|
||||
config := GetTestConfig()
|
||||
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
|
||||
return
|
||||
}
|
||||
|
||||
suite := NewMemoryLeakFixesTestSuite()
|
||||
|
||||
tests := []MemoryLeakTestCase{
|
||||
{
|
||||
Name: "LazyBackgroundTask delayed start",
|
||||
Description: "Test that lazy background task doesn't start until StartIfNeeded is called",
|
||||
Operation: func() error {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
callCount := 0
|
||||
taskFunc := func() {
|
||||
callCount++
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("lazy-test", 50*time.Millisecond, taskFunc, logger)
|
||||
|
||||
// Wait - should not execute yet
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
if callCount != 0 {
|
||||
return fmt.Errorf("task should not have executed before StartIfNeeded")
|
||||
}
|
||||
|
||||
// Now start it
|
||||
task.StartIfNeeded()
|
||||
time.Sleep(GetTestDuration(150 * time.Millisecond))
|
||||
|
||||
if callCount < 2 {
|
||||
return fmt.Errorf("task should have executed at least twice after starting")
|
||||
}
|
||||
|
||||
task.Stop()
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
return nil
|
||||
},
|
||||
Iterations: 5,
|
||||
MaxGoroutineGrowth: 2,
|
||||
MaxMemoryGrowthMB: 1.0,
|
||||
GCBetweenRuns: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
Name: "LazyBackgroundTask multiple StartIfNeeded calls",
|
||||
Description: "Test that multiple StartIfNeeded calls only start task once",
|
||||
Operation: func() error {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
execCount := 0
|
||||
|
||||
taskFunc := func() {
|
||||
execCount++
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("lazy-multiple", 50*time.Millisecond, taskFunc, logger)
|
||||
|
||||
// Call multiple times - should be idempotent
|
||||
task.StartIfNeeded()
|
||||
task.StartIfNeeded()
|
||||
task.StartIfNeeded()
|
||||
|
||||
// Verify it started (should execute)
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
if execCount < 1 {
|
||||
return fmt.Errorf("task should have executed at least once")
|
||||
}
|
||||
|
||||
// Verify started flag is set
|
||||
if !task.started {
|
||||
return fmt.Errorf("task should be marked as started")
|
||||
}
|
||||
|
||||
task.Stop()
|
||||
|
||||
return nil
|
||||
},
|
||||
Iterations: 5,
|
||||
MaxGoroutineGrowth: 2,
|
||||
MaxMemoryGrowthMB: 1.0,
|
||||
GCBetweenRuns: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
Name: "LazyBackgroundTask stop and restart",
|
||||
Description: "Test that task can be stopped and restarted",
|
||||
Operation: func() error {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
execCount := 0
|
||||
taskFunc := func() {
|
||||
execCount++
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("lazy-restart", 50*time.Millisecond, taskFunc, logger)
|
||||
|
||||
// Start
|
||||
task.StartIfNeeded()
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
countAfterFirst := execCount
|
||||
|
||||
// Stop
|
||||
task.Stop()
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
countAfterStop := execCount
|
||||
|
||||
// Should not have executed much more after stop (allow 1 in-flight)
|
||||
if countAfterStop > countAfterFirst+1 {
|
||||
return fmt.Errorf("task executed after stop: %d > %d", countAfterStop, countAfterFirst+1)
|
||||
}
|
||||
|
||||
// Restart
|
||||
task.StartIfNeeded()
|
||||
time.Sleep(GetTestDuration(100 * time.Millisecond))
|
||||
|
||||
if execCount <= countAfterStop {
|
||||
return fmt.Errorf("task should execute after restart")
|
||||
}
|
||||
|
||||
task.Stop()
|
||||
return nil
|
||||
},
|
||||
Iterations: 3,
|
||||
MaxGoroutineGrowth: 2,
|
||||
MaxMemoryGrowthMB: 1.0,
|
||||
GCBetweenRuns: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
suite.runner.RunMemoryLeakTests(t, tests)
|
||||
}
|
||||
|
||||
// TestLazyCache tests NewLazyCache and NewLazyCacheWithLogger
|
||||
func TestLazyCache(t *testing.T) {
|
||||
config := GetTestConfig()
|
||||
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
|
||||
return
|
||||
}
|
||||
|
||||
suite := NewMemoryLeakFixesTestSuite()
|
||||
|
||||
tests := []MemoryLeakTestCase{
|
||||
{
|
||||
Name: "LazyCache basic operations",
|
||||
Description: "Test NewLazyCache with basic cache operations",
|
||||
Operation: func() error {
|
||||
cache := NewLazyCache()
|
||||
if cache == nil {
|
||||
return fmt.Errorf("NewLazyCache returned nil")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("key1", "value1", time.Minute)
|
||||
val, found := cache.Get("key1")
|
||||
if !found || val != "value1" {
|
||||
return fmt.Errorf("cache operation failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Iterations: 10,
|
||||
MaxGoroutineGrowth: 2,
|
||||
MaxMemoryGrowthMB: 2.0,
|
||||
GCBetweenRuns: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
Name: "LazyCacheWithLogger operations",
|
||||
Description: "Test NewLazyCacheWithLogger with custom logger",
|
||||
Operation: func() error {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cache := NewLazyCacheWithLogger(logger)
|
||||
if cache == nil {
|
||||
return fmt.Errorf("NewLazyCacheWithLogger returned nil")
|
||||
}
|
||||
|
||||
// Test with multiple entries
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("lazy-key-%d", i)
|
||||
cache.Set(key, i, time.Minute)
|
||||
}
|
||||
|
||||
// Verify
|
||||
for i := 0; i < 50; i++ {
|
||||
key := fmt.Sprintf("lazy-key-%d", i)
|
||||
val, found := cache.Get(key)
|
||||
if !found || val != i {
|
||||
return fmt.Errorf("cache value mismatch for %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Iterations: 5,
|
||||
MaxGoroutineGrowth: 2,
|
||||
MaxMemoryGrowthMB: 3.0,
|
||||
GCBetweenRuns: true,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
suite.runner.RunMemoryLeakTests(t, tests)
|
||||
}
|
||||
|
||||
// TestOptimizedMiddlewareConfig tests DefaultOptimizedConfig
|
||||
func TestOptimizedMiddlewareConfig(t *testing.T) {
|
||||
t.Run("DefaultOptimizedConfig", func(t *testing.T) {
|
||||
config := DefaultOptimizedConfig()
|
||||
|
||||
assert.NotNil(t, config)
|
||||
assert.True(t, config.DelayBackgroundTasks)
|
||||
assert.True(t, config.ReducedCleanupIntervals)
|
||||
assert.True(t, config.AggressiveConnectionCleanup)
|
||||
assert.True(t, config.MinimalCacheSize)
|
||||
})
|
||||
|
||||
t.Run("CustomOptimizedConfig", func(t *testing.T) {
|
||||
config := &OptimizedMiddlewareConfig{
|
||||
DelayBackgroundTasks: false,
|
||||
ReducedCleanupIntervals: true,
|
||||
AggressiveConnectionCleanup: false,
|
||||
MinimalCacheSize: true,
|
||||
}
|
||||
|
||||
assert.False(t, config.DelayBackgroundTasks)
|
||||
assert.True(t, config.ReducedCleanupIntervals)
|
||||
assert.False(t, config.AggressiveConnectionCleanup)
|
||||
assert.True(t, config.MinimalCacheSize)
|
||||
})
|
||||
}
|
||||
|
||||
// TestCleanupIdleConnections tests the HTTP connection cleanup function
|
||||
func TestCleanupIdleConnections(t *testing.T) {
|
||||
config := GetTestConfig()
|
||||
if config.ShouldSkipTest(t, TestTypeLeakDetection) {
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("CleanupIdleConnections basic", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DisableCompression: true,
|
||||
},
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start cleanup in background
|
||||
go CleanupIdleConnections(client, 50*time.Millisecond, stopChan)
|
||||
|
||||
// Let it run a couple of cycles
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Stop cleanup
|
||||
close(stopChan)
|
||||
|
||||
// Wait for cleanup to finish
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("CleanupIdleConnections stop immediately", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start and immediately stop
|
||||
go CleanupIdleConnections(client, 100*time.Millisecond, stopChan)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(stopChan)
|
||||
|
||||
// Wait for cleanup
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("CleanupIdleConnections with nil transport", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: nil,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Should handle gracefully
|
||||
go CleanupIdleConnections(client, 50*time.Millisecond, stopChan)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
close(stopChan)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkMemoryLeakFixes provides performance benchmarks for memory leak fixes
|
||||
func BenchmarkMemoryLeakFixes(b *testing.B) {
|
||||
suite := NewMemoryLeakFixesTestSuite()
|
||||
@@ -1060,6 +1360,26 @@ func BenchmarkMemoryLeakFixes(b *testing.B) {
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LazyBackgroundTaskLifecycle", func(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
taskFunc := func() {}
|
||||
task := NewLazyBackgroundTask("bench-lazy-task", 100*time.Millisecond, taskFunc, logger)
|
||||
task.StartIfNeeded()
|
||||
task.Stop()
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("LazyCacheLifecycle", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := NewLazyCache()
|
||||
cache.Set("bench-key", "bench-value", time.Minute)
|
||||
_, _ = cache.Get("bench-key")
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("MetadataCacheLifecycle", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
||||
@@ -0,0 +1,225 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestNewLazyBackgroundTaskUnit tests LazyBackgroundTask creation without leak detection
|
||||
func TestNewLazyBackgroundTaskUnit(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
callCount := 0
|
||||
taskFunc := func() {
|
||||
callCount++
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("test-task", 50*time.Millisecond, taskFunc, logger)
|
||||
|
||||
require.NotNil(t, task)
|
||||
assert.NotNil(t, task.BackgroundTask)
|
||||
assert.False(t, task.started)
|
||||
|
||||
// Should not execute before StartIfNeeded
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
assert.Equal(t, 0, callCount, "task should not execute before StartIfNeeded")
|
||||
|
||||
// Cleanup
|
||||
if task.started {
|
||||
task.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// TestLazyBackgroundTaskStartIfNeededUnit tests the StartIfNeeded method
|
||||
func TestLazyBackgroundTaskStartIfNeededUnit(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
callCount := 0
|
||||
var mu sync.Mutex
|
||||
taskFunc := func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("test-start", 30*time.Millisecond, taskFunc, logger)
|
||||
require.NotNil(t, task)
|
||||
|
||||
// Start the task
|
||||
task.StartIfNeeded()
|
||||
assert.True(t, task.started)
|
||||
|
||||
// Wait for execution
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
firstCount := callCount
|
||||
mu.Unlock()
|
||||
assert.Greater(t, firstCount, 0, "task should execute after StartIfNeeded")
|
||||
|
||||
// Multiple calls should be idempotent
|
||||
task.StartIfNeeded()
|
||||
task.StartIfNeeded()
|
||||
|
||||
// Cleanup
|
||||
task.Stop()
|
||||
}
|
||||
|
||||
// TestLazyBackgroundTaskStopUnit tests the Stop method
|
||||
func TestLazyBackgroundTaskStopUnit(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
callCount := 0
|
||||
var mu sync.Mutex
|
||||
taskFunc := func() {
|
||||
mu.Lock()
|
||||
callCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
task := NewLazyBackgroundTask("test-stop", 30*time.Millisecond, taskFunc, logger)
|
||||
require.NotNil(t, task)
|
||||
|
||||
// Start and let it run
|
||||
task.StartIfNeeded()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
countAfterStart := callCount
|
||||
mu.Unlock()
|
||||
assert.Greater(t, countAfterStart, 0)
|
||||
|
||||
// Stop the task
|
||||
task.Stop()
|
||||
assert.False(t, task.started)
|
||||
|
||||
// Wait and verify it stopped
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
mu.Lock()
|
||||
countAfterStop := callCount
|
||||
mu.Unlock()
|
||||
|
||||
// Allow 1 in-flight execution
|
||||
assert.LessOrEqual(t, countAfterStop, countAfterStart+1, "task should stop executing")
|
||||
}
|
||||
|
||||
// TestNewLazyCacheUnit tests NewLazyCache creation
|
||||
func TestNewLazyCacheUnit(t *testing.T) {
|
||||
cache := NewLazyCache()
|
||||
|
||||
require.NotNil(t, cache)
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test-key", "test-value", time.Minute)
|
||||
val, found := cache.Get("test-key")
|
||||
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "test-value", val)
|
||||
}
|
||||
|
||||
// TestNewLazyCacheWithLoggerUnit tests NewLazyCacheWithLogger creation
|
||||
func TestNewLazyCacheWithLoggerUnit(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cache := NewLazyCacheWithLogger(logger)
|
||||
|
||||
require.NotNil(t, cache)
|
||||
|
||||
// Test with multiple entries
|
||||
for i := 0; i < 10; i++ {
|
||||
key := "key-" + string(rune('0'+i))
|
||||
cache.Set(key, i, time.Minute)
|
||||
}
|
||||
|
||||
// Verify entries
|
||||
for i := 0; i < 10; i++ {
|
||||
key := "key-" + string(rune('0'+i))
|
||||
val, found := cache.Get(key)
|
||||
assert.True(t, found, "should find key %s", key)
|
||||
assert.Equal(t, i, val, "should get correct value for key %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLazyCacheWithLoggerNilUnit tests NewLazyCacheWithLogger with nil logger
|
||||
func TestNewLazyCacheWithLoggerNilUnit(t *testing.T) {
|
||||
cache := NewLazyCacheWithLogger(nil)
|
||||
|
||||
require.NotNil(t, cache)
|
||||
|
||||
// Should work with nil logger (uses no-op logger)
|
||||
cache.Set("nil-test", "value", time.Minute)
|
||||
val, found := cache.Get("nil-test")
|
||||
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "value", val)
|
||||
}
|
||||
|
||||
// TestCleanupIdleConnectionsUnit tests CleanupIdleConnections function
|
||||
func TestCleanupIdleConnectionsUnit(t *testing.T) {
|
||||
t.Run("basic cleanup cycle", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
DisableCompression: true,
|
||||
},
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start cleanup in background
|
||||
go CleanupIdleConnections(client, 40*time.Millisecond, stopChan)
|
||||
|
||||
// Let it run a couple of cycles
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Stop cleanup
|
||||
close(stopChan)
|
||||
|
||||
// Wait for cleanup to finish
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("immediate stop", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Start and immediately stop
|
||||
go CleanupIdleConnections(client, 100*time.Millisecond, stopChan)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(stopChan)
|
||||
|
||||
// Wait for cleanup
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
|
||||
t.Run("nil transport", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: nil,
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
|
||||
// Should handle gracefully
|
||||
go CleanupIdleConnections(client, 40*time.Millisecond, stopChan)
|
||||
time.Sleep(80 * time.Millisecond)
|
||||
close(stopChan)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
|
||||
// TestDefaultOptimizedConfigUnit tests DefaultOptimizedConfig function (already has 100% coverage)
|
||||
func TestDefaultOptimizedConfigUnit(t *testing.T) {
|
||||
config := DefaultOptimizedConfig()
|
||||
|
||||
require.NotNil(t, config)
|
||||
assert.True(t, config.DelayBackgroundTasks)
|
||||
assert.True(t, config.ReducedCleanupIntervals)
|
||||
assert.True(t, config.AggressiveConnectionCleanup)
|
||||
assert.True(t, config.MinimalCacheSize)
|
||||
}
|
||||
+10
-6
@@ -58,7 +58,7 @@ func NewBufferPool(maxSize int) *BufferPool {
|
||||
|
||||
// Get retrieves a buffer from the pool
|
||||
func (p *BufferPool) Get() *bytes.Buffer {
|
||||
buf := p.pool.Get().(*bytes.Buffer)
|
||||
buf, _ := p.pool.Get().(*bytes.Buffer) // Safe to ignore: pool return is best-effort
|
||||
buf.Reset()
|
||||
return buf
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func NewGzipWriterPool() *GzipWriterPool {
|
||||
return &GzipWriterPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed)
|
||||
w, _ := gzip.NewWriterLevel(nil, gzip.BestSpeed) // Safe to ignore: factory function
|
||||
return w
|
||||
},
|
||||
},
|
||||
@@ -94,7 +94,8 @@ func NewGzipWriterPool() *GzipWriterPool {
|
||||
|
||||
// Get retrieves a gzip writer from the pool
|
||||
func (p *GzipWriterPool) Get() *gzip.Writer {
|
||||
return p.pool.Get().(*gzip.Writer)
|
||||
w, _ := p.pool.Get().(*gzip.Writer) // Safe to ignore: pool return is best-effort
|
||||
return w
|
||||
}
|
||||
|
||||
// Put returns a gzip writer to the pool
|
||||
@@ -128,13 +129,14 @@ func (p *GzipReaderPool) Get() *gzip.Reader {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return r.(*gzip.Reader)
|
||||
reader, _ := r.(*gzip.Reader) // Safe to ignore: pool return is best-effort
|
||||
return reader
|
||||
}
|
||||
|
||||
// Put returns a gzip reader to the pool
|
||||
func (p *GzipReaderPool) Put(r *gzip.Reader) {
|
||||
if r != nil {
|
||||
r.Reset(nil)
|
||||
_ = r.Reset(nil) // Safe to ignore: resetting to nil reader for pool reuse
|
||||
p.pool.Put(r)
|
||||
}
|
||||
}
|
||||
@@ -187,7 +189,9 @@ func DecompressTokenOptimized(compressed string) (string, error) {
|
||||
if err != nil {
|
||||
return compressed, err
|
||||
}
|
||||
defer gzipReader.Close()
|
||||
defer func() {
|
||||
_ = gzipReader.Close() // Safe to ignore: closing resource in defer
|
||||
}()
|
||||
|
||||
outputBuf := opts.bufferPool.Get()
|
||||
defer opts.bufferPool.Put(outputBuf)
|
||||
|
||||
+1
-1
@@ -109,7 +109,7 @@ func (mc *MetadataCache) GetProviderMetadata(ctx context.Context, providerURL st
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch metadata: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { _ = resp.Body.Close() }() // Safe to ignore: closing body on defer
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("metadata fetch returned status %d", resp.StatusCode)
|
||||
|
||||
+9
-4
@@ -46,14 +46,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
if t.issuerURL == "" {
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if issuerURL == "" {
|
||||
t.logger.Error("OIDC provider metadata initialization failed or incomplete")
|
||||
t.sendErrorResponse(rw, req, "OIDC provider metadata initialization failed - please check provider availability and configuration", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Request cancelled", http.StatusRequestTimeout)
|
||||
t.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
t.sendErrorResponse(rw, req, "Request canceled", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
t.logger.Error("Timeout waiting for OIDC initialization")
|
||||
@@ -79,7 +84,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v. Initiating authentication.", err)
|
||||
cleanReq := req.Clone(req.Context())
|
||||
session, _ = t.sessionManager.GetSession(cleanReq)
|
||||
session, _ = t.sessionManager.GetSession(cleanReq) // Safe to ignore: error already logged, proceeding with new session
|
||||
if session != nil {
|
||||
defer session.returnToPoolSafely()
|
||||
if clearErr := session.Clear(cleanReq, rw); clearErr != nil {
|
||||
|
||||
@@ -179,8 +179,8 @@ func (m *AuthMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
m.logger.Debug("Request cancelled while waiting for OIDC initialization")
|
||||
m.sendErrorResponseFunc(rw, req, "Request cancelled", http.StatusRequestTimeout)
|
||||
m.logger.Debug("Request canceled while waiting for OIDC initialization")
|
||||
m.sendErrorResponseFunc(rw, req, "Request canceled", http.StatusRequestTimeout)
|
||||
return
|
||||
case <-time.After(30 * time.Second):
|
||||
m.logger.Error("Timeout waiting for OIDC initialization")
|
||||
|
||||
@@ -301,7 +301,7 @@ func TestServeHTTP_ComprehensiveCoverage(t *testing.T) {
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
// This should timeout or be cancelled
|
||||
// This should timeout or be canceled
|
||||
m.ServeHTTP(rw, req)
|
||||
|
||||
if !errorResponseSent {
|
||||
|
||||
@@ -0,0 +1,370 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestMiddlewareContextCancellation tests request context cancellation
|
||||
func TestMiddlewareContextCancellation(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate waiting
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
}
|
||||
|
||||
// Create request with canceled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil).WithContext(ctx)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Should return timeout/cancel error
|
||||
if rw.Code != http.StatusRequestTimeout && rw.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("Expected timeout status for canceled context, got %d", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareSessionErrorRecovery tests session error recovery
|
||||
func TestMiddlewareSessionErrorRecovery(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
authURL: "https://provider.example.com/auth",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create request with corrupted session cookie
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "_oidc_session",
|
||||
Value: "corrupted!!!invalid!!!",
|
||||
})
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Should handle gracefully and initiate auth
|
||||
if rw.Code != http.StatusFound && rw.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected redirect for corrupted session, got %d", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareAJAXRequestHandling tests AJAX-specific request handling
|
||||
func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// AJAX request without auth should get 401, not redirect
|
||||
if rw.Code != http.StatusUnauthorized {
|
||||
t.Errorf("Expected 401 for unauthenticated AJAX request, got %d", rw.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMiddlewareDomainRestrictions tests domain-based access control
|
||||
// NOTE: Currently commented out due to complex session setup requirements
|
||||
// These scenarios are tested indirectly through integration tests
|
||||
/*
|
||||
func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
t.Run("allowed_domain_passes", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
allowedUserDomains: map[string]struct{}{
|
||||
"example.com": {},
|
||||
},
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create authenticated session
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
session.SetIDToken("dummy-token")
|
||||
session.Save(req, httptest.NewRecorder())
|
||||
|
||||
// Add session cookies to request
|
||||
rw := httptest.NewRecorder()
|
||||
session.Save(req, rw)
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200 for allowed domain, got %d", rw.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("forbidden_domain_blocked", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
allowedUserDomains: map[string]struct{}{
|
||||
"example.com": {},
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create session with forbidden domain
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@forbidden.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
rw := httptest.NewRecorder()
|
||||
session.Save(req, rw)
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != http.StatusForbidden {
|
||||
t.Errorf("Expected 403 for forbidden domain, got %d", rw.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
*/
|
||||
|
||||
// TestMiddlewareOpaqueTokenHandling tests opaque (non-JWT) token handling
|
||||
// NOTE: Currently commented out due to complex session setup requirements
|
||||
/*
|
||||
func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "user@example.com"}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create session with opaque token
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken("sk_live_abcdefghijklmnopqrstuvwxyz") // Opaque token (no dots)
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Save and inject cookies
|
||||
rw := httptest.NewRecorder()
|
||||
session.Save(req, rw)
|
||||
for _, cookie := range rw.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
oidc.ServeHTTP(rw, req)
|
||||
|
||||
// Should process successfully without JWT verification
|
||||
if rw.Code != http.StatusOK {
|
||||
t.Errorf("Expected 200 for opaque token, got %d", rw.Code)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// TestMiddlewareProcessAuthorizedRequestEdgeCases tests processAuthorizedRequest edge cases
|
||||
func TestMiddlewareProcessAuthorizedRequestEdgeCases(t *testing.T) {
|
||||
sessionManager := createTestSessionManager(t)
|
||||
|
||||
t.Run("missing_email_initiates_reauth", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
logger: NewLogger("debug"),
|
||||
sessionManager: sessionManager,
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
authURL: "https://provider.example.com/auth",
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("") // No email
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
redirectURL := "https://example.com/callback"
|
||||
oidc.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
|
||||
// Should initiate re-auth
|
||||
if rw.Code != http.StatusFound && rw.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected redirect when email is missing, got %d", rw.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing_token_with_role_checks", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
|
||||
logger: NewLogger("debug"),
|
||||
sessionManager: sessionManager,
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
clientID: "test-client",
|
||||
audience: "test-client",
|
||||
authURL: "https://provider.example.com/auth",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetIDToken("") // No ID token
|
||||
session.SetAccessToken("") // No access token
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
redirectURL := "https://example.com/callback"
|
||||
oidc.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
|
||||
// Should initiate re-auth when token is missing but role checks required
|
||||
if rw.Code != http.StatusFound && rw.Code != http.StatusSeeOther {
|
||||
t.Errorf("Expected redirect when token is missing with role checks, got %d", rw.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("security_headers_applied", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
logger: NewLogger("debug"),
|
||||
sessionManager: sessionManager,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetIDToken("dummy-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
redirectURL := "https://example.com/callback"
|
||||
oidc.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
|
||||
// Verify security headers are set
|
||||
if rw.Header().Get("X-Frame-Options") == "" {
|
||||
t.Error("Expected X-Frame-Options header to be set")
|
||||
}
|
||||
if rw.Header().Get("X-Content-Type-Options") == "" {
|
||||
t.Error("Expected X-Content-Type-Options header to be set")
|
||||
}
|
||||
if rw.Header().Get("X-XSS-Protection") == "" {
|
||||
t.Error("Expected X-XSS-Protection header to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authentication_headers_set", func(t *testing.T) {
|
||||
oidc := &TraefikOidc{
|
||||
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}),
|
||||
logger: NewLogger("debug"),
|
||||
sessionManager: sessionManager,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
session, _ := sessionManager.GetSession(req)
|
||||
testEmail := "user@example.com"
|
||||
session.SetEmail(testEmail)
|
||||
session.SetIDToken("dummy-id-token")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
redirectURL := "https://example.com/callback"
|
||||
oidc.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
|
||||
// Verify authentication headers
|
||||
if req.Header.Get("X-Forwarded-User") != testEmail {
|
||||
t.Errorf("Expected X-Forwarded-User=%s, got %s", testEmail, req.Header.Get("X-Forwarded-User"))
|
||||
}
|
||||
if req.Header.Get("X-Auth-Request-User") != testEmail {
|
||||
t.Errorf("Expected X-Auth-Request-User=%s, got %s", testEmail, req.Header.Get("X-Auth-Request-User"))
|
||||
}
|
||||
// Token header may not be set in all scenarios, just verify it's not causing errors
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestGenerateNonce tests the nonce generation for OIDC flows
|
||||
func TestGenerateNonce(t *testing.T) {
|
||||
t.Run("basic generation", func(t *testing.T) {
|
||||
nonce, err := generateNonce()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, nonce)
|
||||
|
||||
// 32 bytes base64 URL encoded should produce 44 characters (with potential padding)
|
||||
// but typically 43 characters without padding
|
||||
assert.GreaterOrEqual(t, len(nonce), 43, "nonce should be at least 43 characters")
|
||||
})
|
||||
|
||||
t.Run("nonce is base64 URL encoded", func(t *testing.T) {
|
||||
nonce, err := generateNonce()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid base64 URL encoding
|
||||
_, err = base64.URLEncoding.DecodeString(nonce)
|
||||
assert.NoError(t, err, "nonce should be valid base64 URL encoding")
|
||||
})
|
||||
|
||||
t.Run("multiple generations produce different values", func(t *testing.T) {
|
||||
nonce1, err1 := generateNonce()
|
||||
nonce2, err2 := generateNonce()
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.NotEqual(t, nonce1, nonce2, "consecutive generations should produce different nonces")
|
||||
})
|
||||
|
||||
t.Run("nonce has sufficient entropy", func(t *testing.T) {
|
||||
// Generate multiple nonces and verify they're all unique
|
||||
nonces := make(map[string]bool)
|
||||
iterations := 100
|
||||
|
||||
for i := 0; i < iterations; i++ {
|
||||
nonce, err := generateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check for duplicates
|
||||
assert.False(t, nonces[nonce], "nonce should be unique across multiple generations")
|
||||
nonces[nonce] = true
|
||||
}
|
||||
|
||||
assert.Len(t, nonces, iterations, "all nonces should be unique")
|
||||
})
|
||||
|
||||
t.Run("nonce length is consistent", func(t *testing.T) {
|
||||
nonce1, err1 := generateNonce()
|
||||
nonce2, err2 := generateNonce()
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.Equal(t, len(nonce1), len(nonce2), "nonce length should be consistent")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGenerateCodeVerifier tests the PKCE code verifier generation
|
||||
func TestGenerateCodeVerifier(t *testing.T) {
|
||||
t.Run("basic generation", func(t *testing.T) {
|
||||
verifier, err := generateCodeVerifier()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier)
|
||||
|
||||
// RFC 7636 requires 43-128 characters for code verifier
|
||||
// With 32 bytes base64 raw URL encoded, we get 43 characters
|
||||
assert.Len(t, verifier, 43, "code verifier should be 43 characters (32 bytes base64 encoded)")
|
||||
})
|
||||
|
||||
t.Run("verifier is base64 URL encoded", func(t *testing.T) {
|
||||
verifier, err := generateCodeVerifier()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be valid base64 URL encoding
|
||||
_, err = base64.RawURLEncoding.DecodeString(verifier)
|
||||
assert.NoError(t, err, "verifier should be valid base64 URL encoding")
|
||||
})
|
||||
|
||||
t.Run("multiple generations produce different values", func(t *testing.T) {
|
||||
verifier1, err1 := generateCodeVerifier()
|
||||
verifier2, err2 := generateCodeVerifier()
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
|
||||
assert.NotEqual(t, verifier1, verifier2, "consecutive generations should produce different verifiers")
|
||||
})
|
||||
|
||||
t.Run("verifier contains only URL-safe characters", func(t *testing.T) {
|
||||
verifier, err := generateCodeVerifier()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Base64 URL encoding should only contain A-Z, a-z, 0-9, -, _
|
||||
for _, char := range verifier {
|
||||
validChar := (char >= 'A' && char <= 'Z') ||
|
||||
(char >= 'a' && char <= 'z') ||
|
||||
(char >= '0' && char <= '9') ||
|
||||
char == '-' || char == '_'
|
||||
assert.True(t, validChar, "verifier should only contain URL-safe characters")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("no padding characters", func(t *testing.T) {
|
||||
verifier, err := generateCodeVerifier()
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Raw URL encoding should not have padding
|
||||
assert.False(t, strings.Contains(verifier, "="), "verifier should not contain padding")
|
||||
})
|
||||
}
|
||||
|
||||
// TestDeriveCodeChallenge tests the PKCE code challenge derivation
|
||||
func TestDeriveCodeChallenge(t *testing.T) {
|
||||
t.Run("basic derivation", func(t *testing.T) {
|
||||
verifier := "test-verifier-value-1234567890abcdefghij"
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
assert.NotEmpty(t, challenge)
|
||||
assert.NotEqual(t, verifier, challenge, "challenge should be different from verifier")
|
||||
})
|
||||
|
||||
t.Run("challenge is SHA256 hash", func(t *testing.T) {
|
||||
verifier := "test-code-verifier"
|
||||
|
||||
// Manually compute expected challenge
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(verifier))
|
||||
expectedHash := hasher.Sum(nil)
|
||||
expectedChallenge := base64.RawURLEncoding.EncodeToString(expectedHash)
|
||||
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
assert.Equal(t, expectedChallenge, challenge, "challenge should match SHA256 hash")
|
||||
})
|
||||
|
||||
t.Run("same verifier produces same challenge", func(t *testing.T) {
|
||||
verifier := "consistent-verifier-12345"
|
||||
|
||||
challenge1 := deriveCodeChallenge(verifier)
|
||||
challenge2 := deriveCodeChallenge(verifier)
|
||||
|
||||
assert.Equal(t, challenge1, challenge2, "same verifier should always produce same challenge")
|
||||
})
|
||||
|
||||
t.Run("different verifiers produce different challenges", func(t *testing.T) {
|
||||
verifier1 := "verifier-one"
|
||||
verifier2 := "verifier-two"
|
||||
|
||||
challenge1 := deriveCodeChallenge(verifier1)
|
||||
challenge2 := deriveCodeChallenge(verifier2)
|
||||
|
||||
assert.NotEqual(t, challenge1, challenge2, "different verifiers should produce different challenges")
|
||||
})
|
||||
|
||||
t.Run("challenge is base64 URL encoded", func(t *testing.T) {
|
||||
verifier := "test-verifier"
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
// Should be valid base64 URL encoding
|
||||
_, err := base64.RawURLEncoding.DecodeString(challenge)
|
||||
assert.NoError(t, err, "challenge should be valid base64 URL encoding")
|
||||
})
|
||||
|
||||
t.Run("challenge length is correct", func(t *testing.T) {
|
||||
verifier := "some-random-verifier"
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
// SHA256 produces 32 bytes, which when base64 encoded becomes 43 characters
|
||||
assert.Len(t, challenge, 43, "SHA256 hash should produce 43-character base64 string")
|
||||
})
|
||||
|
||||
t.Run("no padding in challenge", func(t *testing.T) {
|
||||
verifier := "test-verifier-no-padding"
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
assert.False(t, strings.Contains(challenge, "="), "challenge should not contain padding")
|
||||
})
|
||||
|
||||
t.Run("empty verifier produces valid challenge", func(t *testing.T) {
|
||||
verifier := ""
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
assert.NotEmpty(t, challenge, "even empty verifier should produce a challenge")
|
||||
assert.Len(t, challenge, 43, "challenge should still be 43 characters")
|
||||
})
|
||||
}
|
||||
|
||||
// TestPKCEFlowIntegration tests the complete PKCE flow
|
||||
func TestPKCEFlowIntegration(t *testing.T) {
|
||||
t.Run("complete PKCE flow", func(t *testing.T) {
|
||||
// Step 1: Generate code verifier
|
||||
verifier, err := generateCodeVerifier()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 2: Derive code challenge
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
// Verify challenge was derived from verifier
|
||||
expectedChallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, expectedChallenge, challenge)
|
||||
|
||||
// Verify verifier can be used to recreate challenge
|
||||
rechallenge := deriveCodeChallenge(verifier)
|
||||
assert.Equal(t, challenge, rechallenge, "verifier should consistently produce same challenge")
|
||||
})
|
||||
|
||||
t.Run("multiple PKCE flows are independent", func(t *testing.T) {
|
||||
// Flow 1
|
||||
verifier1, err1 := generateCodeVerifier()
|
||||
require.NoError(t, err1)
|
||||
challenge1 := deriveCodeChallenge(verifier1)
|
||||
|
||||
// Flow 2
|
||||
verifier2, err2 := generateCodeVerifier()
|
||||
require.NoError(t, err2)
|
||||
challenge2 := deriveCodeChallenge(verifier2)
|
||||
|
||||
// Flows should be independent
|
||||
assert.NotEqual(t, verifier1, verifier2)
|
||||
assert.NotEqual(t, challenge1, challenge2)
|
||||
|
||||
// Each flow should be internally consistent
|
||||
assert.Equal(t, challenge1, deriveCodeChallenge(verifier1))
|
||||
assert.Equal(t, challenge2, deriveCodeChallenge(verifier2))
|
||||
})
|
||||
|
||||
t.Run("RFC 7636 compliance", func(t *testing.T) {
|
||||
verifier, err := generateCodeVerifier()
|
||||
require.NoError(t, err)
|
||||
|
||||
challenge := deriveCodeChallenge(verifier)
|
||||
|
||||
// RFC 7636 Section 4.2:
|
||||
// - code_verifier: high-entropy cryptographic random string
|
||||
// - Minimum length: 43 characters
|
||||
// - Maximum length: 128 characters
|
||||
// - Character set: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
|
||||
assert.GreaterOrEqual(t, len(verifier), 43, "verifier should be at least 43 characters")
|
||||
assert.LessOrEqual(t, len(verifier), 128, "verifier should be at most 128 characters")
|
||||
|
||||
// RFC 7636 Section 4.2:
|
||||
// - code_challenge = BASE64URL(SHA256(code_verifier))
|
||||
assert.NotEmpty(t, challenge)
|
||||
assert.Len(t, challenge, 43, "S256 challenge should be 43 characters")
|
||||
})
|
||||
}
|
||||
|
||||
// TestTokenCacheCleanupAndClose tests the no-op Cleanup and Close methods
|
||||
func TestTokenCacheCleanupAndClose(t *testing.T) {
|
||||
cache := NewTokenCache()
|
||||
require.NotNil(t, cache)
|
||||
|
||||
t.Run("cleanup is safe to call", func(t *testing.T) {
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("close is safe to call", func(t *testing.T) {
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("multiple cleanup calls are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Cleanup()
|
||||
cache.Cleanup()
|
||||
cache.Cleanup()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("multiple close calls are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
cache.Close()
|
||||
cache.Close()
|
||||
cache.Close()
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("operations work after cleanup", func(t *testing.T) {
|
||||
cache.Cleanup()
|
||||
|
||||
// Should still work
|
||||
testClaims := map[string]interface{}{"sub": "user123"}
|
||||
cache.Set("token1", testClaims, 1*time.Minute)
|
||||
|
||||
claims, found := cache.Get("token1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, testClaims, claims)
|
||||
})
|
||||
|
||||
t.Run("operations work after close", func(t *testing.T) {
|
||||
cache.Close()
|
||||
|
||||
// Should still work (close is a no-op)
|
||||
testClaims := map[string]interface{}{"sub": "user456"}
|
||||
cache.Set("token2", testClaims, 1*time.Minute)
|
||||
|
||||
claims, found := cache.Get("token2")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, testClaims, claims)
|
||||
})
|
||||
}
|
||||
|
||||
// TestCreateStringMap tests the createStringMap utility function
|
||||
func TestCreateStringMap(t *testing.T) {
|
||||
t.Run("empty slice", func(t *testing.T) {
|
||||
result := createStringMap([]string{})
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("single item", func(t *testing.T) {
|
||||
result := createStringMap([]string{"key1"})
|
||||
assert.Len(t, result, 1)
|
||||
_, exists := result["key1"]
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("multiple items", func(t *testing.T) {
|
||||
result := createStringMap([]string{"key1", "key2", "key3"})
|
||||
assert.Len(t, result, 3)
|
||||
|
||||
for _, key := range []string{"key1", "key2", "key3"} {
|
||||
_, exists := result[key]
|
||||
assert.True(t, exists, "key %s should exist", key)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("duplicate items", func(t *testing.T) {
|
||||
result := createStringMap([]string{"key1", "key2", "key1", "key3", "key2"})
|
||||
// Map should only contain unique keys
|
||||
assert.Len(t, result, 3)
|
||||
|
||||
for _, key := range []string{"key1", "key2", "key3"} {
|
||||
_, exists := result[key]
|
||||
assert.True(t, exists, "key %s should exist", key)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -557,7 +557,8 @@ func TestSessionWindowReset(t *testing.T) {
|
||||
config := DefaultRefreshCoordinatorConfig()
|
||||
config.MaxRefreshAttempts = 2
|
||||
config.RefreshAttemptWindow = 500 * time.Millisecond
|
||||
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
||||
config.RefreshCooldownPeriod = 2 * time.Second // Explicitly set cooldown > window
|
||||
config.DeduplicationCleanupDelay = 0 // Immediate cleanup for deterministic test behavior
|
||||
|
||||
coordinator := NewRefreshCoordinator(config, logger)
|
||||
defer coordinator.Shutdown()
|
||||
@@ -578,22 +579,25 @@ func TestSessionWindowReset(t *testing.T) {
|
||||
for i := 0; i < config.MaxRefreshAttempts; i++ {
|
||||
ctx := context.Background()
|
||||
_, _ = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
// Add small delay to ensure attempts are registered separately
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Next attempt should trigger cooldown
|
||||
ctx := context.Background()
|
||||
_, err := coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Expected cooldown after max attempts")
|
||||
t.Errorf("Expected cooldown after max attempts, got: %v", err)
|
||||
}
|
||||
|
||||
// Wait for window to expire (but not cooldown)
|
||||
time.Sleep(config.RefreshAttemptWindow + 100*time.Millisecond)
|
||||
// Use generous buffer for CI environments
|
||||
time.Sleep(config.RefreshAttemptWindow + 200*time.Millisecond)
|
||||
|
||||
// Should still be in cooldown (cooldown > window)
|
||||
// Should still be in cooldown (cooldown=2s > window=500ms)
|
||||
_, err = coordinator.CoordinateRefresh(ctx, sessionID, refreshToken, refreshFunc)
|
||||
if err == nil || err.Error() != "refresh attempts exceeded for session, in cooldown period" {
|
||||
t.Error("Should still be in cooldown period")
|
||||
t.Errorf("Should still be in cooldown period after window expiry, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ScopeFilterLogger interface for dependency injection
|
||||
type ScopeFilterLogger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// ScopeFilter handles OAuth scope validation and filtering based on provider capabilities.
|
||||
type ScopeFilter struct {
|
||||
logger ScopeFilterLogger
|
||||
}
|
||||
|
||||
// NewScopeFilter creates a new ScopeFilter instance.
|
||||
func NewScopeFilter(logger ScopeFilterLogger) *ScopeFilter {
|
||||
return &ScopeFilter{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// FilterSupportedScopes returns the intersection of requested and supported scopes.
|
||||
// It preserves the order of requested scopes and returns all requested scopes
|
||||
// if supportedScopes is empty (fallback for providers without scopes_supported).
|
||||
//
|
||||
// Parameters:
|
||||
// - requestedScopes: Scopes the application wants to request
|
||||
// - supportedScopes: Scopes advertised by the provider (from discovery doc)
|
||||
// - providerURL: Provider URL for logging purposes
|
||||
//
|
||||
// Returns:
|
||||
// - Filtered list of scopes safe to request from the provider
|
||||
func (sf *ScopeFilter) FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string {
|
||||
// If no supported scopes declared, return all requested (backward compatibility)
|
||||
if len(supportedScopes) == 0 {
|
||||
sf.logger.Debugf("ScopeFilter: Provider %s has no scopes_supported in discovery doc, using all requested scopes", providerURL)
|
||||
return requestedScopes
|
||||
}
|
||||
|
||||
// Build lookup map for efficient checking
|
||||
supportedMap := make(map[string]bool, len(supportedScopes))
|
||||
for _, scope := range supportedScopes {
|
||||
supportedMap[strings.TrimSpace(scope)] = true
|
||||
}
|
||||
|
||||
// Filter requested scopes
|
||||
filtered := make([]string, 0, len(requestedScopes))
|
||||
removed := make([]string, 0)
|
||||
|
||||
for _, scope := range requestedScopes {
|
||||
trimmed := strings.TrimSpace(scope)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if supportedMap[trimmed] {
|
||||
filtered = append(filtered, trimmed)
|
||||
} else {
|
||||
removed = append(removed, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
// Log filtering results
|
||||
if len(removed) > 0 {
|
||||
sf.logger.Infof("ScopeFilter: Filtered unsupported scopes for %s: %v (not in provider's scopes_supported)",
|
||||
providerURL, removed)
|
||||
sf.logger.Debugf("ScopeFilter: Provider %s supported scopes: %v", providerURL, supportedScopes)
|
||||
sf.logger.Debugf("ScopeFilter: Final filtered scopes: %v", filtered)
|
||||
} else {
|
||||
sf.logger.Debugf("ScopeFilter: All requested scopes are supported by %s", providerURL)
|
||||
}
|
||||
|
||||
// If all scopes were filtered out, return at least "openid"
|
||||
if len(filtered) == 0 {
|
||||
sf.logger.Infof("ScopeFilter: All scopes filtered out for %s, falling back to 'openid'", providerURL)
|
||||
return []string{"openid"}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// EnsureOpenIDScope ensures "openid" scope is present in the scope list.
|
||||
// This is required for OIDC compliance.
|
||||
func (sf *ScopeFilter) EnsureOpenIDScope(scopes []string) []string {
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
return scopes
|
||||
}
|
||||
}
|
||||
|
||||
sf.logger.Debugf("ScopeFilter: Adding required 'openid' scope")
|
||||
return append([]string{"openid"}, scopes...)
|
||||
}
|
||||
@@ -0,0 +1,724 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockLogger for testing
|
||||
type mockScopeFilterLogger struct {
|
||||
debugMessages []string
|
||||
infoMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockScopeFilterLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockScopeFilterLogger) Infof(format string, args ...interface{}) {
|
||||
l.infoMessages = append(l.infoMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockScopeFilterLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
// TestNewScopeFilter tests the ScopeFilter constructor
|
||||
func TestNewScopeFilter(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
if filter == nil {
|
||||
t.Fatal("Expected ScopeFilter to be created, got nil")
|
||||
}
|
||||
|
||||
// Logger is set correctly (we can't directly compare interface values)
|
||||
if filter.logger == nil {
|
||||
t.Error("Logger not set in ScopeFilter")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_AllSupported tests when all requested scopes are supported
|
||||
func TestFilterSupportedScopes_AllSupported(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email"}
|
||||
supported := []string{"openid", "profile", "email", "address", "phone"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log debug message that all scopes are supported
|
||||
if len(logger.debugMessages) == 0 {
|
||||
t.Error("Expected debug messages to be logged")
|
||||
}
|
||||
|
||||
// Should not log any info messages (no filtering occurred)
|
||||
if len(logger.infoMessages) > 0 {
|
||||
t.Error("Expected no info messages when all scopes supported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_SomeFiltered tests when some scopes need to be filtered
|
||||
func TestFilterSupportedScopes_SomeFiltered(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "offline_access", "custom_scope"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://gitlab.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Verify offline_access and custom_scope were filtered out
|
||||
for _, scope := range result {
|
||||
if scope == "offline_access" || scope == "custom_scope" {
|
||||
t.Errorf("Scope '%s' should have been filtered out", scope)
|
||||
}
|
||||
}
|
||||
|
||||
// Should log info message about filtered scopes
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about filtered scopes")
|
||||
}
|
||||
|
||||
// Should log debug messages about supported scopes and final result
|
||||
if len(logger.debugMessages) < 2 {
|
||||
t.Error("Expected debug messages about provider supported scopes and final result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_AllFiltered tests when all scopes are filtered (fallback to openid)
|
||||
func TestFilterSupportedScopes_AllFiltered(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"custom_scope1", "custom_scope2", "unsupported"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected fallback to %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log info message about all scopes being filtered (falling back to openid)
|
||||
if len(logger.infoMessages) < 2 { // One for filtered scopes, one for fallback
|
||||
t.Error("Expected info messages when all scopes filtered")
|
||||
}
|
||||
|
||||
// Should log info message about filtered scopes
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about filtered scopes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_NoSupportedScopes tests fallback behavior when no scopes_supported
|
||||
func TestFilterSupportedScopes_NoSupportedScopes(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "offline_access"}
|
||||
supported := []string{} // Empty supported list (backward compatibility)
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should return all requested scopes unchanged
|
||||
if !reflect.DeepEqual(result, requested) {
|
||||
t.Errorf("Expected all requested scopes %v, got %v", requested, result)
|
||||
}
|
||||
|
||||
// Should log debug message about no scopes_supported
|
||||
if len(logger.debugMessages) == 0 {
|
||||
t.Error("Expected debug message about no scopes_supported")
|
||||
}
|
||||
|
||||
// Should not log info messages (backward compatibility mode)
|
||||
if len(logger.infoMessages) > 0 {
|
||||
t.Error("Expected no info messages when no supported scopes provided")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_EmptyRequested tests when requested scopes are empty
|
||||
func TestFilterSupportedScopes_EmptyRequested(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should return openid as fallback
|
||||
expected := []string{"openid"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected fallback to %v when requested empty, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log info message about empty result (fallback to openid)
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message when no scopes requested")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_DuplicateScopes tests handling of duplicate scope names
|
||||
func TestFilterSupportedScopes_DuplicateScopes(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "openid", "email"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should preserve duplicates from requested
|
||||
expected := []string{"openid", "profile", "openid", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v (preserving duplicates), got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_WhitespaceHandling tests trimming of whitespace
|
||||
func TestFilterSupportedScopes_WhitespaceHandling(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{" openid ", "profile", " email"}
|
||||
supported := []string{"openid", "profile", "email", "phone"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should trim whitespace from scopes
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected trimmed scopes %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_EmptyStrings tests filtering out empty strings
|
||||
func TestFilterSupportedScopes_EmptyStrings(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "", "profile", " ", "email"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should filter out empty strings
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v (without empty strings), got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_CasePreservation tests that scope case is preserved
|
||||
func TestFilterSupportedScopes_CasePreservation(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"OpenID", "Profile", "Email"}
|
||||
supported := []string{"OpenID", "Profile", "Email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should preserve case exactly
|
||||
expected := []string{"OpenID", "Profile", "Email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected case-preserved %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_CaseSensitiveMatching tests case-sensitive matching
|
||||
func TestFilterSupportedScopes_CaseSensitiveMatching(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "Profile", "EMAIL"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Only "openid" should match (case-sensitive)
|
||||
// Profile and EMAIL won't match profile and email in supported list
|
||||
expected := []string{"openid"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected case-sensitive filtering %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log info about filtered scopes
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about filtered scopes due to case mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_OrderPreservation tests that order is preserved
|
||||
func TestFilterSupportedScopes_OrderPreservation(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"email", "profile", "openid", "phone"}
|
||||
supported := []string{"openid", "profile", "email", "phone", "address"}
|
||||
providerURL := "https://auth.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
// Should preserve order from requested
|
||||
expected := []string{"email", "profile", "openid", "phone"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected order-preserved %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_GitLabScenario simulates GitLab rejecting offline_access
|
||||
func TestFilterSupportedScopes_GitLabScenario(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
// User requests offline_access but GitLab doesn't support it
|
||||
requested := []string{"openid", "profile", "email", "offline_access"}
|
||||
supported := []string{"openid", "profile", "email", "read_user", "read_api"}
|
||||
providerURL := "https://gitlab.example.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v (without offline_access), got %v", expected, result)
|
||||
}
|
||||
|
||||
// Verify offline_access was filtered out
|
||||
for _, scope := range result {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered out for GitLab")
|
||||
}
|
||||
}
|
||||
|
||||
// Should log info about filtered scopes
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about offline_access being filtered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_GoogleScenario simulates Google's scope handling
|
||||
func TestFilterSupportedScopes_GoogleScenario(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
// Google supports these standard scopes
|
||||
requested := []string{"openid", "profile", "email"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://accounts.google.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// No scopes should be filtered
|
||||
if len(logger.infoMessages) > 0 {
|
||||
t.Error("Expected no filtering for standard Google scopes")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_AzureScenario simulates Azure's scope handling
|
||||
func TestFilterSupportedScopes_AzureScenario(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
// Azure supports offline_access and OIDC scopes
|
||||
requested := []string{"openid", "profile", "email", "offline_access"}
|
||||
supported := []string{"openid", "profile", "email", "offline_access"}
|
||||
providerURL := "https://login.microsoftonline.com/tenant"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email", "offline_access"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v (including offline_access), got %v", expected, result)
|
||||
}
|
||||
|
||||
// All scopes should be retained
|
||||
if len(logger.infoMessages) > 0 {
|
||||
t.Error("Expected no filtering for standard Azure scopes with offline_access")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_GenericWithFiltering simulates generic provider with filtering
|
||||
func TestFilterSupportedScopes_GenericWithFiltering(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "offline_access", "custom:scope"}
|
||||
supported := []string{"openid", "profile", "email", "custom:scope"}
|
||||
providerURL := "https://auth.custom-provider.com"
|
||||
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email", "custom:scope"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v (without offline_access), got %v", expected, result)
|
||||
}
|
||||
|
||||
// offline_access should be filtered
|
||||
for _, scope := range result {
|
||||
if scope == "offline_access" {
|
||||
t.Error("offline_access should have been filtered for this provider")
|
||||
}
|
||||
}
|
||||
|
||||
// Should log info about filtering
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about filtered offline_access")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_MultipleProviderURLs tests different provider URLs
|
||||
func TestFilterSupportedScopes_MultipleProviderURLs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
providerURL string
|
||||
requested []string
|
||||
supported []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "GitLab.com",
|
||||
providerURL: "https://gitlab.com",
|
||||
requested: []string{"openid", "offline_access"},
|
||||
supported: []string{"openid"},
|
||||
expected: []string{"openid"},
|
||||
},
|
||||
{
|
||||
name: "Self-hosted GitLab",
|
||||
providerURL: "https://gitlab.example.com",
|
||||
requested: []string{"openid", "profile", "offline_access"},
|
||||
supported: []string{"openid", "profile"},
|
||||
expected: []string{"openid", "profile"},
|
||||
},
|
||||
{
|
||||
name: "Keycloak",
|
||||
providerURL: "https://keycloak.example.com/realms/master",
|
||||
requested: []string{"openid", "profile", "email"},
|
||||
supported: []string{"openid", "profile", "email", "offline_access"},
|
||||
expected: []string{"openid", "profile", "email"},
|
||||
},
|
||||
{
|
||||
name: "Auth0",
|
||||
providerURL: "https://tenant.auth0.com",
|
||||
requested: []string{"openid", "profile", "offline_access"},
|
||||
supported: []string{"openid", "profile", "offline_access"},
|
||||
expected: []string{"openid", "profile", "offline_access"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
result := filter.FilterSupportedScopes(tt.requested, tt.supported, tt.providerURL)
|
||||
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureOpenIDScope_Present tests when openid is already present
|
||||
func TestEnsureOpenIDScope_Present(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
result := filter.EnsureOpenIDScope(scopes)
|
||||
|
||||
// Should return scopes unchanged
|
||||
if !reflect.DeepEqual(result, scopes) {
|
||||
t.Errorf("Expected scopes unchanged %v, got %v", scopes, result)
|
||||
}
|
||||
|
||||
// Should not log anything (openid already present)
|
||||
if len(logger.debugMessages) > 0 {
|
||||
t.Error("Expected no debug messages when openid already present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureOpenIDScope_Missing tests when openid needs to be added
|
||||
func TestEnsureOpenIDScope_Missing(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
scopes := []string{"profile", "email"}
|
||||
result := filter.EnsureOpenIDScope(scopes)
|
||||
|
||||
// Should prepend openid
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected openid prepended %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log debug message about adding openid
|
||||
if len(logger.debugMessages) == 0 {
|
||||
t.Error("Expected debug message about adding openid scope")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureOpenIDScope_Empty tests with empty scopes list
|
||||
func TestEnsureOpenIDScope_Empty(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
scopes := []string{}
|
||||
result := filter.EnsureOpenIDScope(scopes)
|
||||
|
||||
// Should return just openid
|
||||
expected := []string{"openid"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Should log debug message
|
||||
if len(logger.debugMessages) == 0 {
|
||||
t.Error("Expected debug message about adding openid scope")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureOpenIDScope_Nil tests with nil scopes list
|
||||
func TestEnsureOpenIDScope_Nil(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
var scopes []string // nil slice
|
||||
result := filter.EnsureOpenIDScope(scopes)
|
||||
|
||||
// Should return just openid
|
||||
expected := []string{"openid"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureOpenIDScope_CaseVariations tests that case matters for openid detection
|
||||
func TestEnsureOpenIDScope_CaseVariations(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Lowercase openid",
|
||||
scopes: []string{"openid", "profile"},
|
||||
expected: []string{"openid", "profile"},
|
||||
},
|
||||
{
|
||||
name: "Mixed case OpenID (should add lowercase)",
|
||||
scopes: []string{"OpenID", "profile"},
|
||||
expected: []string{"openid", "OpenID", "profile"},
|
||||
},
|
||||
{
|
||||
name: "OPENID uppercase (should add lowercase)",
|
||||
scopes: []string{"OPENID", "profile"},
|
||||
expected: []string{"openid", "OPENID", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
result := filter.EnsureOpenIDScope(tt.scopes)
|
||||
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_IntegrationScenario tests realistic end-to-end scenario
|
||||
func TestFilterSupportedScopes_IntegrationScenario(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
// Simulate: User configures plugin with these scopes
|
||||
requested := []string{"openid", "profile", "email", "offline_access", "custom_claim"}
|
||||
|
||||
// Provider discovery returns these supported scopes
|
||||
supported := []string{"openid", "profile", "email", "read_user"}
|
||||
|
||||
providerURL := "https://gitlab.company.com"
|
||||
|
||||
// Filter should remove offline_access and custom_claim
|
||||
result := filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
|
||||
expected := []string{"openid", "profile", "email"}
|
||||
if !reflect.DeepEqual(result, expected) {
|
||||
t.Errorf("Expected %v, got %v", expected, result)
|
||||
}
|
||||
|
||||
// Verify logging occurred
|
||||
if len(logger.infoMessages) == 0 {
|
||||
t.Error("Expected info message about filtered scopes")
|
||||
}
|
||||
|
||||
if len(logger.debugMessages) < 2 {
|
||||
t.Error("Expected debug messages about supported scopes and final result")
|
||||
}
|
||||
|
||||
// Verify specific scopes were filtered
|
||||
for _, scope := range result {
|
||||
if scope == "offline_access" || scope == "custom_claim" {
|
||||
t.Errorf("Scope '%s' should have been filtered out", scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilterSupportedScopes_LoggingBehavior tests comprehensive logging scenarios
|
||||
func TestFilterSupportedScopes_LoggingBehavior(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requested []string
|
||||
supported []string
|
||||
expectDebugOnly bool
|
||||
expectInfoLog bool
|
||||
}{
|
||||
{
|
||||
name: "All supported - debug only",
|
||||
requested: []string{"openid", "profile"},
|
||||
supported: []string{"openid", "profile", "email"},
|
||||
expectDebugOnly: true,
|
||||
},
|
||||
{
|
||||
name: "Some filtered - info + debug",
|
||||
requested: []string{"openid", "offline_access"},
|
||||
supported: []string{"openid"},
|
||||
expectInfoLog: true,
|
||||
},
|
||||
{
|
||||
name: "All filtered - info + debug",
|
||||
requested: []string{"custom1", "custom2"},
|
||||
supported: []string{"openid"},
|
||||
expectInfoLog: true,
|
||||
},
|
||||
{
|
||||
name: "No supported scopes - debug only",
|
||||
requested: []string{"openid"},
|
||||
supported: []string{},
|
||||
expectDebugOnly: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
filter.FilterSupportedScopes(tt.requested, tt.supported, "https://example.com")
|
||||
|
||||
hasDebug := len(logger.debugMessages) > 0
|
||||
hasInfo := len(logger.infoMessages) > 0
|
||||
|
||||
if tt.expectDebugOnly && (!hasDebug || hasInfo) {
|
||||
t.Errorf("Expected only debug logs, got debug=%v info=%v",
|
||||
hasDebug, hasInfo)
|
||||
}
|
||||
|
||||
if tt.expectInfoLog && !hasInfo {
|
||||
t.Error("Expected info log but didn't get one")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkFilterSupportedScopes_AllSupported(b *testing.B) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "phone"}
|
||||
supported := []string{"openid", "profile", "email", "phone", "address"}
|
||||
providerURL := "https://example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilterSupportedScopes_SomeFiltered(b *testing.B) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "offline_access", "custom"}
|
||||
supported := []string{"openid", "profile", "email"}
|
||||
providerURL := "https://example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFilterSupportedScopes_NoSupported(b *testing.B) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
requested := []string{"openid", "profile", "email", "offline_access"}
|
||||
supported := []string{}
|
||||
providerURL := "https://example.com"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.FilterSupportedScopes(requested, supported, providerURL)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEnsureOpenIDScope_Present(b *testing.B) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
scopes := []string{"openid", "profile", "email"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.EnsureOpenIDScope(scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkEnsureOpenIDScope_Missing(b *testing.B) {
|
||||
logger := &mockScopeFilterLogger{}
|
||||
filter := NewScopeFilter(logger)
|
||||
|
||||
scopes := []string{"profile", "email"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
filter.EnsureOpenIDScope(scopes)
|
||||
}
|
||||
}
|
||||
@@ -335,6 +335,7 @@ func TestJWTReplayAttack(t *testing.T) {
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -551,6 +552,7 @@ func TestSessionFixationAttack(t *testing.T) {
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -857,6 +859,7 @@ func TestTokenBlacklisting(t *testing.T) {
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -1278,6 +1281,7 @@ func TestRateLimiting(t *testing.T) {
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -1385,6 +1389,7 @@ func TestAuthorizationHeaderBypass(t *testing.T) {
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
@@ -1560,6 +1565,7 @@ func TestInvalidRedirectURI(t *testing.T) {
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
|
||||
+4
-4
@@ -444,9 +444,9 @@ func (sm *SessionManager) PeriodicChunkCleanup() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if context is cancelled or we're in test mode to prevent logging after test completion
|
||||
// Check if context is canceled or we're in test mode to prevent logging after test completion
|
||||
if sm.ctx == nil || sm.ctx.Err() != nil || isTestMode() {
|
||||
return // Skip logging if context is cancelled or in test mode
|
||||
return // Skip logging if context is canceled or in test mode
|
||||
}
|
||||
|
||||
sm.logger.Debug("Starting comprehensive session cleanup cycle")
|
||||
@@ -796,7 +796,7 @@ func (sm *SessionManager) CleanupOldCookies(w http.ResponseWriter, r *http.Reque
|
||||
// - The loaded SessionData instance.
|
||||
// - An error if session loading or validation fails.
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||
sessionData, _ := sm.sessionPool.Get().(*SessionData) // Safe to ignore: pool return is best-effort
|
||||
atomic.AddInt64(&sm.poolHits, 1)
|
||||
atomic.AddInt64(&sm.activeSessions, 1)
|
||||
|
||||
@@ -822,7 +822,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
|
||||
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
||||
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
||||
sessionData.Clear(r, nil)
|
||||
_ = sessionData.Clear(r, nil) // Safe to ignore: session is being invalidated
|
||||
return handleError(fmt.Errorf("session timeout"), "session expired")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,7 +122,7 @@ func (sm *SessionManager) initializeSession(sessionData SessionData, r *http.Req
|
||||
|
||||
// Extract and set session values
|
||||
if auth, ok := session.Values["authenticated"].(bool); ok {
|
||||
sessionData.SetAuthenticated(auth)
|
||||
_ = sessionData.SetAuthenticated(auth) // Safe to ignore: session initialization error
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -34,7 +34,7 @@ func (m *SessionChunkManager) CleanupChunks(chunks map[int]*sessions.Session, w
|
||||
if session != nil && session.Options != nil {
|
||||
// Set MaxAge to -1 to expire the cookie
|
||||
session.Options.MaxAge = -1
|
||||
session.Save(nil, w) // Save with nil request is safe for expiration
|
||||
_ = session.Save(nil, w) // Safe to ignore: best effort cleanup of expired chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,540 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// Helper function to create a mock HTTP request for session creation
|
||||
func createMockRequest() *http.Request {
|
||||
req := httptest.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
}
|
||||
|
||||
// Test NewSessionChunkManager
|
||||
|
||||
func TestNewSessionChunkManager(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
if manager == nil {
|
||||
t.Fatal("Expected non-nil session chunk manager")
|
||||
}
|
||||
|
||||
if manager.maxChunks != 10 {
|
||||
t.Errorf("Expected maxChunks 10, got %d", manager.maxChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSessionChunkManagerDefaultLimit(t *testing.T) {
|
||||
// Test with 0 maxChunks (should use default)
|
||||
manager := NewSessionChunkManager(0)
|
||||
|
||||
if manager.maxChunks != 20 {
|
||||
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSessionChunkManagerNegativeLimit(t *testing.T) {
|
||||
// Test with negative maxChunks (should use default)
|
||||
manager := NewSessionChunkManager(-5)
|
||||
|
||||
if manager.maxChunks != 20 {
|
||||
t.Errorf("Expected default maxChunks 20, got %d", manager.maxChunks)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CleanupChunks
|
||||
|
||||
func TestCleanupChunksWithoutWriter(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add some chunks
|
||||
for i := 0; i < 5; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
session.Values["token_chunk"] = "chunk-data"
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
// Cleanup without writer (should just clear map)
|
||||
manager.CleanupChunks(chunks, nil)
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupChunksWithWriter(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add some chunks
|
||||
for i := 0; i < 3; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
session.Values["token_chunk"] = "chunk-data"
|
||||
session.Options = &sessions.Options{MaxAge: 3600}
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
// Create response writer
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Note: We can't fully test the Save behavior without a proper HTTP request
|
||||
// but we can verify the cleanup clears the map
|
||||
// The actual Save(nil, w) in the real code has a comment saying it's safe for expiration
|
||||
manager.CleanupChunks(chunks, w)
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupChunksNilSession(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
chunks[0] = nil
|
||||
chunks[1] = nil
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Should handle nil sessions gracefully
|
||||
manager.CleanupChunks(chunks, w)
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Errorf("Expected chunks map to be empty, got %d entries", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupChunksEmptyMap(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
|
||||
// Should handle empty map gracefully
|
||||
manager.CleanupChunks(chunks, nil)
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Error("Expected chunks map to remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
// Test ValidateAndCleanChunks
|
||||
|
||||
func TestValidateAndCleanChunksWithinLimit(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add chunks within limit
|
||||
for i := 0; i < 5; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
result := manager.ValidateAndCleanChunks(chunks)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected validation to pass for chunks within limit")
|
||||
}
|
||||
|
||||
if len(chunks) != 5 {
|
||||
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAndCleanChunksExceedLimit(t *testing.T) {
|
||||
manager := NewSessionChunkManager(5)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add more chunks than limit
|
||||
for i := 0; i < 10; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
result := manager.ValidateAndCleanChunks(chunks)
|
||||
|
||||
if result {
|
||||
t.Error("Expected validation to fail for chunks exceeding limit")
|
||||
}
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Errorf("Expected chunks to be cleared, got %d", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAndCleanChunksAtLimit(t *testing.T) {
|
||||
manager := NewSessionChunkManager(5)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add chunks exactly at limit
|
||||
for i := 0; i < 5; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
result := manager.ValidateAndCleanChunks(chunks)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected validation to pass for chunks at limit")
|
||||
}
|
||||
|
||||
if len(chunks) != 5 {
|
||||
t.Errorf("Expected chunks to remain intact, got %d", len(chunks))
|
||||
}
|
||||
}
|
||||
|
||||
// Test SafeSetChunk
|
||||
|
||||
func TestSafeSetChunkValidIndex(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
|
||||
result := manager.SafeSetChunk(chunks, 5, session)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected SafeSetChunk to succeed for valid index")
|
||||
}
|
||||
|
||||
if chunks[5] != session {
|
||||
t.Error("Expected session to be set at index 5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSetChunkNegativeIndex(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
|
||||
result := manager.SafeSetChunk(chunks, -1, session)
|
||||
|
||||
if result {
|
||||
t.Error("Expected SafeSetChunk to fail for negative index")
|
||||
}
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Error("Expected chunks map to remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSetChunkIndexTooHigh(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
|
||||
result := manager.SafeSetChunk(chunks, 10, session)
|
||||
|
||||
if result {
|
||||
t.Error("Expected SafeSetChunk to fail for index >= maxChunks")
|
||||
}
|
||||
|
||||
if len(chunks) != 0 {
|
||||
t.Error("Expected chunks map to remain empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSetChunkExceedingLimit(t *testing.T) {
|
||||
manager := NewSessionChunkManager(5)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Fill up to limit
|
||||
for i := 0; i < 5; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
// Try to add a new chunk at new index (should fail)
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
result := manager.SafeSetChunk(chunks, 2, session)
|
||||
|
||||
// This should succeed because index 2 already exists
|
||||
if !result {
|
||||
t.Error("Expected SafeSetChunk to succeed for existing index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSafeSetChunkReplaceExisting(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
session1, _ := store.New(createMockRequest(), "chunk1")
|
||||
session2, _ := store.New(createMockRequest(), "chunk2")
|
||||
|
||||
// Set initial session
|
||||
manager.SafeSetChunk(chunks, 3, session1)
|
||||
|
||||
// Replace with new session
|
||||
result := manager.SafeSetChunk(chunks, 3, session2)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected SafeSetChunk to succeed for replacing existing chunk")
|
||||
}
|
||||
|
||||
if chunks[3] != session2 {
|
||||
t.Error("Expected session to be replaced at index 3")
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetChunkCount
|
||||
|
||||
func TestGetChunkCount(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add some chunks
|
||||
for i := 0; i < 7; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
count := manager.GetChunkCount(chunks)
|
||||
|
||||
if count != 7 {
|
||||
t.Errorf("Expected chunk count 7, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChunkCountEmpty(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
|
||||
count := manager.GetChunkCount(chunks)
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("Expected chunk count 0, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Test CompactChunks
|
||||
|
||||
func TestCompactChunksNoGaps(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add sequential chunks
|
||||
for i := 0; i < 5; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
session.Values["index"] = i
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
compacted := manager.CompactChunks(chunks)
|
||||
|
||||
if len(compacted) != 5 {
|
||||
t.Errorf("Expected 5 compacted chunks, got %d", len(compacted))
|
||||
}
|
||||
|
||||
// Verify order
|
||||
for i := 0; i < 5; i++ {
|
||||
if compacted[i] == nil {
|
||||
t.Errorf("Expected chunk at index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactChunksWithGaps(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add chunks with gaps
|
||||
indices := []int{0, 2, 5, 7}
|
||||
for _, idx := range indices {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
session.Values["original_index"] = idx
|
||||
chunks[idx] = session
|
||||
}
|
||||
|
||||
compacted := manager.CompactChunks(chunks)
|
||||
|
||||
if len(compacted) != 4 {
|
||||
t.Errorf("Expected 4 compacted chunks, got %d", len(compacted))
|
||||
}
|
||||
|
||||
// Verify chunks are reindexed sequentially
|
||||
for i := 0; i < 4; i++ {
|
||||
if compacted[i] == nil {
|
||||
t.Errorf("Expected chunk at compacted index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactChunksWithNilEntries(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add chunks and nil entries
|
||||
session1, _ := store.New(createMockRequest(), "chunk1")
|
||||
session2, _ := store.New(createMockRequest(), "chunk2")
|
||||
session3, _ := store.New(createMockRequest(), "chunk3")
|
||||
|
||||
chunks[0] = session1
|
||||
chunks[1] = nil
|
||||
chunks[2] = session2
|
||||
chunks[3] = nil
|
||||
chunks[4] = session3
|
||||
|
||||
compacted := manager.CompactChunks(chunks)
|
||||
|
||||
if len(compacted) != 3 {
|
||||
t.Errorf("Expected 3 compacted chunks (nil entries removed), got %d", len(compacted))
|
||||
}
|
||||
|
||||
// Verify non-nil chunks are compacted
|
||||
for i := 0; i < 3; i++ {
|
||||
if compacted[i] == nil {
|
||||
t.Errorf("Expected non-nil chunk at compacted index %d", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactChunksEmpty(t *testing.T) {
|
||||
manager := NewSessionChunkManager(10)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
|
||||
compacted := manager.CompactChunks(chunks)
|
||||
|
||||
if len(compacted) != 0 {
|
||||
t.Errorf("Expected empty compacted map, got %d entries", len(compacted))
|
||||
}
|
||||
}
|
||||
|
||||
// Test Concurrent Operations
|
||||
|
||||
func TestSessionChunkManagerConcurrentOperations(t *testing.T) {
|
||||
manager := NewSessionChunkManager(50)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent SafeSetChunk
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
manager.SafeSetChunk(chunks, index, session)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent GetChunkCount
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = manager.GetChunkCount(chunks)
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent ValidateAndCleanChunks (reads)
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = manager.ValidateAndCleanChunks(chunks)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify manager is still functional
|
||||
count := manager.GetChunkCount(chunks)
|
||||
if count < 0 || count > 50 {
|
||||
t.Errorf("Unexpected chunk count after concurrent operations: %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Edge Cases
|
||||
|
||||
func TestSessionChunkManagerLargeChunkCount(t *testing.T) {
|
||||
manager := NewSessionChunkManager(1000)
|
||||
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
// Add many chunks
|
||||
for i := 0; i < 500; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
result := manager.ValidateAndCleanChunks(chunks)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected validation to pass for 500 chunks with limit 1000")
|
||||
}
|
||||
|
||||
count := manager.GetChunkCount(chunks)
|
||||
if count != 500 {
|
||||
t.Errorf("Expected 500 chunks, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionChunkManagerBoundaryConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
maxChunks int
|
||||
addChunks int
|
||||
shouldPass bool
|
||||
}{
|
||||
{"exactly at limit", 10, 10, true},
|
||||
{"one over limit", 10, 11, false},
|
||||
{"way over limit", 10, 50, false},
|
||||
{"zero chunks with limit", 10, 0, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager := NewSessionChunkManager(tt.maxChunks)
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
store := sessions.NewCookieStore([]byte("test-secret"))
|
||||
|
||||
for i := 0; i < tt.addChunks; i++ {
|
||||
session, _ := store.New(createMockRequest(), "chunk")
|
||||
chunks[i] = session
|
||||
}
|
||||
|
||||
result := manager.ValidateAndCleanChunks(chunks)
|
||||
|
||||
if result != tt.shouldPass {
|
||||
t.Errorf("Expected validation result %v, got %v", tt.shouldPass, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// TestSetCodeVerifier_NoChange tests the branch where the code verifier value doesn't change
|
||||
func TestSetCodeVerifier_NoChange(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Set initial code verifier
|
||||
initialVerifier := "test-code-verifier-12345"
|
||||
session.SetCodeVerifier(initialVerifier)
|
||||
|
||||
if !session.IsDirty() {
|
||||
t.Error("Session should be dirty after first SetCodeVerifier")
|
||||
}
|
||||
|
||||
// Mark clean to test the no-change branch
|
||||
session.dirty = false
|
||||
|
||||
// Set the same code verifier again - this should hit the uncovered branch
|
||||
session.SetCodeVerifier(initialVerifier)
|
||||
|
||||
// Verify that dirty flag remains false (no change occurred)
|
||||
if session.IsDirty() {
|
||||
t.Error("Session should not be dirty when setting same code verifier value")
|
||||
}
|
||||
|
||||
// Verify the code verifier value is still correct
|
||||
if got := session.GetCodeVerifier(); got != initialVerifier {
|
||||
t.Errorf("Expected code verifier %q, got %q", initialVerifier, got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearTokenChunks_EmptyChunks tests the branch where the chunks map is empty
|
||||
func TestClearTokenChunks_EmptyChunks(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Test with empty chunks map - this should hit the uncovered branch where the loop body doesn't execute
|
||||
emptyChunks := make(map[int]*sessions.Session)
|
||||
|
||||
// This should not panic and should handle empty map gracefully
|
||||
session.clearTokenChunks(req, emptyChunks)
|
||||
|
||||
// Verify that no errors occurred and the session is still valid
|
||||
if session == nil {
|
||||
t.Fatal("Session should still be valid after clearing empty chunks")
|
||||
}
|
||||
|
||||
// Additional test: clear already-empty chunk maps in the session
|
||||
session.clearTokenChunks(req, session.accessTokenChunks)
|
||||
session.clearTokenChunks(req, session.refreshTokenChunks)
|
||||
session.clearTokenChunks(req, session.idTokenChunks)
|
||||
|
||||
// Verify session is still valid
|
||||
if session.GetAuthenticated() {
|
||||
// This is fine - session can be authenticated even with no chunks
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearTokenChunks_WithSessions tests the branch where the chunks map contains actual sessions
|
||||
func TestClearTokenChunks_WithSessions(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager("0123456789abcdef0123456789abcdef0123456789abcdef", false, "", logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
defer sm.Shutdown()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
defer session.ReturnToPool()
|
||||
|
||||
// Create chunks map with actual sessions
|
||||
chunksWithSessions := make(map[int]*sessions.Session)
|
||||
|
||||
// Create a few test sessions and add them to the chunks map
|
||||
for i := 0; i < 3; i++ {
|
||||
chunkSession, err := sm.store.Get(req, fmt.Sprintf("test_chunk_%d", i))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test chunk session: %v", err)
|
||||
}
|
||||
// Add some test data to the session
|
||||
chunkSession.Values["test_data"] = fmt.Sprintf("chunk_%d_data", i)
|
||||
chunkSession.Values["chunk_index"] = i
|
||||
chunksWithSessions[i] = chunkSession
|
||||
}
|
||||
|
||||
// Verify chunks have data before clearing
|
||||
if len(chunksWithSessions) != 3 {
|
||||
t.Errorf("Expected 3 chunks, got %d", len(chunksWithSessions))
|
||||
}
|
||||
|
||||
for i, chunkSession := range chunksWithSessions {
|
||||
if chunkSession.Values["test_data"] == nil {
|
||||
t.Errorf("Chunk %d should have test data before clearing", i)
|
||||
}
|
||||
}
|
||||
|
||||
// Call clearTokenChunks - this should hit the loop body and clear all sessions
|
||||
session.clearTokenChunks(req, chunksWithSessions)
|
||||
|
||||
// Verify that the sessions were cleared
|
||||
for i, chunkSession := range chunksWithSessions {
|
||||
if len(chunkSession.Values) != 0 {
|
||||
t.Errorf("Chunk %d should have no values after clearing, but has %d values", i, len(chunkSession.Values))
|
||||
}
|
||||
// Verify MaxAge was set to -1 (expired)
|
||||
if chunkSession.Options.MaxAge != -1 {
|
||||
t.Errorf("Chunk %d should have MaxAge=-1 (expired), but has MaxAge=%d", i, chunkSession.Options.MaxAge)
|
||||
}
|
||||
}
|
||||
}
|
||||
+85
-24
@@ -27,30 +27,68 @@ type TemplatedHeader struct {
|
||||
// It provides all necessary settings to configure OpenID Connect authentication
|
||||
// with various providers like Auth0, Logto, or any standard OIDC provider.
|
||||
type Config struct {
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
LogoutURL string `json:"logoutURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
// Audience specifies the expected JWT audience claim value.
|
||||
// If not set, defaults to ClientID for backward compatibility.
|
||||
// For Auth0 API access tokens with custom audiences, set this to your API identifier.
|
||||
// For Azure AD with Application ID URI, set to "api://your-app-id".
|
||||
// Security: This value is validated against the JWT aud claim to prevent token confusion attacks.
|
||||
Audience string `json:"audience,omitempty"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
ForceHTTPS bool `json:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
// StrictAudienceValidation enforces strict audience validation for access tokens.
|
||||
// When enabled, sessions are rejected if access token validation fails (prevents fallback to ID token).
|
||||
// This addresses Auth0 Scenario 2 security concerns where access tokens without proper
|
||||
// audience claims could be accepted based on ID token validation.
|
||||
// Default: false (backward compatible - allows ID token fallback)
|
||||
// Recommended: true for production environments requiring strict OAuth 2.0 compliance
|
||||
StrictAudienceValidation bool `json:"strictAudienceValidation,omitempty"`
|
||||
// AllowOpaqueTokens enables acceptance of non-JWT (opaque) access tokens.
|
||||
// When enabled, opaque tokens are validated via OAuth 2.0 Token Introspection (RFC 7662).
|
||||
// This supports Auth0 Scenario 3 and other providers that issue opaque access tokens.
|
||||
// Default: false (only JWT access tokens accepted)
|
||||
// Note: Requires introspection endpoint to be available from provider metadata
|
||||
AllowOpaqueTokens bool `json:"allowOpaqueTokens,omitempty"`
|
||||
// RequireTokenIntrospection forces token introspection for all opaque access tokens.
|
||||
// When enabled, opaque tokens are rejected if introspection endpoint is unavailable.
|
||||
// When disabled, opaque tokens fall back to ID token validation.
|
||||
// Default: false (allows fallback to ID token)
|
||||
// Recommended: true when AllowOpaqueTokens is enabled for maximum security
|
||||
RequireTokenIntrospection bool `json:"requireTokenIntrospection,omitempty"`
|
||||
// DisableReplayDetection disables JTI-based replay attack detection.
|
||||
// Enable this when running multiple Traefik replicas to prevent false positives.
|
||||
// Each replica maintains its own in-memory JTI cache, so the same valid token
|
||||
// hitting different replicas will trigger replay detection on subsequent requests.
|
||||
//
|
||||
// Security Note: When enabled, the plugin still validates token signatures,
|
||||
// expiration, and other claims. Only the JTI replay check is disabled.
|
||||
// Consider using a shared cache backend (Redis/Memcached) if replay detection
|
||||
// is required in multi-replica scenarios.
|
||||
//
|
||||
// Default: false (replay detection enabled)
|
||||
// Recommended: true for multi-replica deployments
|
||||
DisableReplayDetection bool `json:"disableReplayDetection,omitempty"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
@@ -268,6 +306,29 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
// Validate audience if specified
|
||||
if c.Audience != "" {
|
||||
// Validate audience format - should be a valid identifier or URL
|
||||
if len(c.Audience) > 256 {
|
||||
return fmt.Errorf("audience must not exceed 256 characters")
|
||||
}
|
||||
|
||||
// If audience looks like a URL, validate it's HTTPS
|
||||
if strings.HasPrefix(c.Audience, "http://") {
|
||||
return fmt.Errorf("audience URL must use HTTPS, not HTTP")
|
||||
}
|
||||
|
||||
// Prevent wildcard audiences which could weaken security
|
||||
if strings.Contains(c.Audience, "*") {
|
||||
return fmt.Errorf("audience must not contain wildcards")
|
||||
}
|
||||
|
||||
// Validate that audience doesn't contain obvious injection patterns
|
||||
if strings.ContainsAny(c.Audience, "\n\r\t\x00") {
|
||||
return fmt.Errorf("audience contains invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate headers configuration for template security
|
||||
for _, header := range c.Headers {
|
||||
if header.Name == "" {
|
||||
|
||||
@@ -276,6 +276,7 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
|
||||
t.Run("SingletonTasksAcrossInstances", func(t *testing.T) {
|
||||
// Reset singletons to ensure clean state
|
||||
ResetGlobalTaskRegistry() // Reset circuit breaker and task registry
|
||||
resetResourceManagerForTesting()
|
||||
ResetUniversalCacheManagerForTesting()
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
@@ -312,13 +313,35 @@ func TestContextAwareGoroutineManagement(t *testing.T) {
|
||||
plugins = append(plugins, plugin)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run multiple times
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
// Wait for cleanup to run at least 2 times with adaptive timeout
|
||||
// This handles race detector overhead which can slow goroutine scheduling significantly
|
||||
// When running as part of full test suite, CPU contention is even higher, so use generous timeout
|
||||
const minExpectedCount = 2
|
||||
const maxExpectedCount = 5
|
||||
timeout := time.After(5 * time.Second)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Check that cleanup ran but not excessively (should be singleton)
|
||||
count := atomic.LoadInt32(&cleanupCount)
|
||||
if count < 2 || count > 5 {
|
||||
t.Errorf("Unexpected cleanup count: %d (expected 2-5 for singleton)", count)
|
||||
var count int32
|
||||
waitLoop:
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
count = atomic.LoadInt32(&cleanupCount)
|
||||
if count >= minExpectedCount {
|
||||
// Success: reached minimum threshold
|
||||
break waitLoop
|
||||
}
|
||||
case <-timeout:
|
||||
count = atomic.LoadInt32(&cleanupCount)
|
||||
t.Errorf("Timeout waiting for cleanup count to reach %d, got %d (race detector may be slowing execution)", minExpectedCount, count)
|
||||
break waitLoop
|
||||
}
|
||||
}
|
||||
|
||||
// Verify count is within expected range (should be singleton, not running excessively)
|
||||
if count > maxExpectedCount {
|
||||
t.Errorf("Cleanup count too high: %d (expected max %d for singleton)", count, maxExpectedCount)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
|
||||
@@ -244,6 +244,7 @@ func setupTestOIDCMiddleware(t *testing.T, config *Config) (*TraefikOidc, *httpt
|
||||
next: nextHandler,
|
||||
issuerURL: testIssuerURL,
|
||||
clientID: config.ClientID,
|
||||
audience: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
redirURLPath: callbackPath,
|
||||
logoutURLPath: logoutPath,
|
||||
|
||||
@@ -100,7 +100,7 @@ func (g *GlobalTestCleanup) CleanupAll() {
|
||||
// Use a timeout to prevent hanging
|
||||
cleanupDone := make(chan struct{})
|
||||
go func() {
|
||||
CleanupGlobalCacheManager()
|
||||
_ = CleanupGlobalCacheManager() // Safe to ignore: cleanup in test infrastructure
|
||||
close(cleanupDone)
|
||||
}()
|
||||
|
||||
@@ -853,7 +853,7 @@ func (g *EdgeCaseGenerator) GenerateIntegerEdgeCases() []int {
|
||||
func (g *EdgeCaseGenerator) GenerateTimeEdgeCases() []time.Time {
|
||||
now := time.Now()
|
||||
return []time.Time{
|
||||
time.Time{}, // Zero time
|
||||
{}, // Zero time
|
||||
now, // Current time
|
||||
now.Add(-time.Hour), // One hour ago
|
||||
now.Add(time.Hour), // One hour from now
|
||||
|
||||
@@ -0,0 +1,211 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file implements OAuth 2.0 Token Introspection (RFC 7662) for opaque token validation.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// IntrospectionResponse represents the response from an OAuth 2.0 token introspection endpoint.
|
||||
// Per RFC 7662, this contains information about the token's validity and properties.
|
||||
type IntrospectionResponse struct {
|
||||
Active bool `json:"active"` // REQUIRED - whether the token is currently active
|
||||
Scope string `json:"scope,omitempty"` // Space-separated list of scopes
|
||||
ClientID string `json:"client_id,omitempty"` // Client identifier for the token
|
||||
Username string `json:"username,omitempty"` // Human-readable identifier for the resource owner
|
||||
TokenType string `json:"token_type,omitempty"` // Type of token (e.g., "Bearer")
|
||||
Exp int64 `json:"exp,omitempty"` // Expiration time (seconds since epoch)
|
||||
Iat int64 `json:"iat,omitempty"` // Issued at time (seconds since epoch)
|
||||
Nbf int64 `json:"nbf,omitempty"` // Not before time (seconds since epoch)
|
||||
Sub string `json:"sub,omitempty"` // Subject of the token
|
||||
Aud string `json:"aud,omitempty"` // Intended audience
|
||||
Iss string `json:"iss,omitempty"` // Issuer
|
||||
Jti string `json:"jti,omitempty"` // JWT ID
|
||||
}
|
||||
|
||||
// introspectToken performs OAuth 2.0 Token Introspection (RFC 7662) for an opaque token.
|
||||
// It queries the provider's introspection endpoint to determine token validity and properties.
|
||||
// Results are cached to minimize repeated introspection requests.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The opaque access token to introspect
|
||||
//
|
||||
// Returns:
|
||||
// - *IntrospectionResponse: The introspection result
|
||||
// - error: Any error that occurred during introspection
|
||||
func (t *TraefikOidc) introspectToken(token string) (*IntrospectionResponse, error) {
|
||||
// Check cache first
|
||||
if t.introspectionCache != nil {
|
||||
if cached, found := t.introspectionCache.Get(token); found {
|
||||
if response, ok := cached.(*IntrospectionResponse); ok {
|
||||
t.logger.Debugf("Using cached introspection result for token")
|
||||
return response, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get introspection URL
|
||||
t.metadataMu.RLock()
|
||||
introspectionURL := t.introspectionURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if introspectionURL == "" {
|
||||
return nil, fmt.Errorf("introspection endpoint not available from provider")
|
||||
}
|
||||
|
||||
// Prepare introspection request per RFC 7662 Section 2.1
|
||||
data := url.Values{}
|
||||
data.Set("token", token)
|
||||
data.Set("token_type_hint", "access_token") // Hint that it's an access token
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", introspectionURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create introspection request: %w", err)
|
||||
}
|
||||
|
||||
// Set required headers
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Authenticate using client credentials (per RFC 7662 Section 2.1)
|
||||
// The introspection endpoint requires authentication
|
||||
req.SetBasicAuth(t.clientID, t.clientSecret)
|
||||
|
||||
// Send request with circuit breaker if available
|
||||
var resp *http.Response
|
||||
if t.errorRecoveryManager != nil {
|
||||
t.metadataMu.RLock()
|
||||
serviceName := fmt.Sprintf("token-introspection-%s", t.issuerURL)
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
err = t.errorRecoveryManager.ExecuteWithRecovery(context.Background(), serviceName, func() error {
|
||||
var reqErr error
|
||||
resp, reqErr = t.httpClient.Do(req) //nolint:bodyclose // Body is closed in defer after error check
|
||||
if reqErr != nil && resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on error
|
||||
}
|
||||
return reqErr
|
||||
})
|
||||
} else {
|
||||
resp, err = t.httpClient.Do(req)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on error
|
||||
}
|
||||
return nil, fmt.Errorf("introspection request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining body on defer
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on defer
|
||||
}
|
||||
}()
|
||||
|
||||
// Check HTTP status
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
limitReader := io.LimitReader(resp.Body, 1024*10)
|
||||
body, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
|
||||
return nil, fmt.Errorf("introspection endpoint returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response per RFC 7662 Section 2.2
|
||||
var introspectionResp IntrospectionResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&introspectionResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode introspection response: %w", err)
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
if t.introspectionCache != nil {
|
||||
// Cache for a short duration or until token expiry (whichever is shorter)
|
||||
cacheDuration := 5 * time.Minute
|
||||
if introspectionResp.Exp > 0 {
|
||||
expTime := time.Unix(introspectionResp.Exp, 0)
|
||||
untilExp := time.Until(expTime)
|
||||
if untilExp > 0 && untilExp < cacheDuration {
|
||||
cacheDuration = untilExp
|
||||
}
|
||||
}
|
||||
t.introspectionCache.Set(token, &introspectionResp, cacheDuration)
|
||||
t.logger.Debugf("Cached introspection result for %v", cacheDuration)
|
||||
}
|
||||
|
||||
return &introspectionResp, nil
|
||||
}
|
||||
|
||||
// validateOpaqueToken validates an opaque access token using token introspection.
|
||||
// It checks if the token is active, not expired, and has the correct audience if specified.
|
||||
//
|
||||
// Parameters:
|
||||
// - token: The opaque access token to validate
|
||||
//
|
||||
// Returns:
|
||||
// - error: Validation error if token is invalid, nil if valid
|
||||
func (t *TraefikOidc) validateOpaqueToken(token string) error {
|
||||
// Check if opaque tokens are allowed
|
||||
if !t.allowOpaqueTokens {
|
||||
return fmt.Errorf("opaque tokens are not enabled (set allowOpaqueTokens to true)")
|
||||
}
|
||||
|
||||
// Check if introspection is required but not available
|
||||
t.metadataMu.RLock()
|
||||
introspectionURL := t.introspectionURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if introspectionURL == "" {
|
||||
if t.requireTokenIntrospection {
|
||||
return fmt.Errorf("token introspection required but endpoint not available")
|
||||
}
|
||||
// Allow fallback to ID token validation
|
||||
t.logger.Debugf("Introspection endpoint not available, will rely on ID token validation")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Perform introspection
|
||||
resp, err := t.introspectToken(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("token introspection failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if token is active (per RFC 7662 Section 2.2)
|
||||
if !resp.Active {
|
||||
return fmt.Errorf("token is not active (revoked or expired)")
|
||||
}
|
||||
|
||||
// Validate expiration if present
|
||||
if resp.Exp > 0 {
|
||||
expTime := time.Unix(resp.Exp, 0)
|
||||
if time.Now().After(expTime) {
|
||||
return fmt.Errorf("token has expired")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate not-before if present
|
||||
if resp.Nbf > 0 {
|
||||
nbfTime := time.Unix(resp.Nbf, 0)
|
||||
if time.Now().Before(nbfTime) {
|
||||
return fmt.Errorf("token not yet valid (nbf)")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate audience if configured
|
||||
// Note: For opaque tokens, audience validation via introspection may be limited
|
||||
// depending on what the introspection endpoint returns
|
||||
if t.audience != "" && t.audience != t.clientID && resp.Aud != "" {
|
||||
if resp.Aud != t.audience {
|
||||
return fmt.Errorf("invalid audience: expected %s, got %s", t.audience, resp.Aud)
|
||||
}
|
||||
}
|
||||
|
||||
t.logger.Debugf("Opaque token validation successful via introspection")
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,839 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestIntrospectToken_Success tests successful token introspection with active token
|
||||
func TestIntrospectToken_Success(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Create mock introspection server
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request method and content type
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("Expected application/x-www-form-urlencoded, got %s", r.Header.Get("Content-Type"))
|
||||
}
|
||||
|
||||
// Verify basic auth
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok || username != "test-client" || password != "test-secret" {
|
||||
t.Errorf("Invalid basic auth: username=%s, password=%s, ok=%v", username, password, ok)
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(body))
|
||||
|
||||
if values.Get("token") != "test-opaque-token" {
|
||||
t.Errorf("Expected token=test-opaque-token, got %s", values.Get("token"))
|
||||
}
|
||||
if values.Get("token_type_hint") != "access_token" {
|
||||
t.Errorf("Expected token_type_hint=access_token, got %s", values.Get("token_type_hint"))
|
||||
}
|
||||
|
||||
// Return successful introspection response
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
Scope: "openid profile email",
|
||||
ClientID: "test-client",
|
||||
Username: "testuser",
|
||||
TokenType: "Bearer",
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
Iat: time.Now().Add(-5 * time.Minute).Unix(),
|
||||
Nbf: time.Now().Add(-5 * time.Minute).Unix(),
|
||||
Sub: "user123",
|
||||
Aud: "test-audience",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Create TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Perform introspection
|
||||
resp, err := tOidc.introspectToken("test-opaque-token")
|
||||
if err != nil {
|
||||
t.Fatalf("introspectToken failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
if !resp.Active {
|
||||
t.Error("Expected token to be active")
|
||||
}
|
||||
if resp.ClientID != "test-client" {
|
||||
t.Errorf("Expected clientID=test-client, got %s", resp.ClientID)
|
||||
}
|
||||
if resp.Username != "testuser" {
|
||||
t.Errorf("Expected username=testuser, got %s", resp.Username)
|
||||
}
|
||||
if resp.Scope != "openid profile email" {
|
||||
t.Errorf("Expected scope='openid profile email', got %s", resp.Scope)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_CachedResult tests that cached introspection results are used
|
||||
func TestIntrospectToken_CachedResult(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
requestCount := 0
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount++
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// First call - should hit the server
|
||||
resp1, err := tOidc.introspectToken("cached-token")
|
||||
if err != nil {
|
||||
t.Fatalf("First introspectToken failed: %v", err)
|
||||
}
|
||||
if !resp1.Active {
|
||||
t.Error("Expected first token to be active")
|
||||
}
|
||||
if requestCount != 1 {
|
||||
t.Errorf("Expected 1 request after first call, got %d", requestCount)
|
||||
}
|
||||
|
||||
// Second call - should use cache
|
||||
resp2, err := tOidc.introspectToken("cached-token")
|
||||
if err != nil {
|
||||
t.Fatalf("Second introspectToken failed: %v", err)
|
||||
}
|
||||
if !resp2.Active {
|
||||
t.Error("Expected second token to be active")
|
||||
}
|
||||
if requestCount != 1 {
|
||||
t.Errorf("Expected 1 request after cache hit, got %d", requestCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_MissingEndpoint tests introspection without endpoint
|
||||
func TestIntrospectToken_MissingEndpoint(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: "", // No endpoint
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
_, err := tOidc.introspectToken("test-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing introspection endpoint")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "introspection endpoint not available") {
|
||||
t.Errorf("Expected 'introspection endpoint not available' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_HTTPError tests handling of HTTP error responses
|
||||
func TestIntrospectToken_HTTPError(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
w.Write([]byte(`{"error": "invalid_client"}`))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
_, err := tOidc.introspectToken("test-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for HTTP 401 response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "401") {
|
||||
t.Errorf("Expected error mentioning status 401, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_InvalidJSON tests handling of invalid JSON response
|
||||
func TestIntrospectToken_InvalidJSON(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{invalid json`))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
_, err := tOidc.introspectToken("test-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON response")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "failed to decode") {
|
||||
t.Errorf("Expected 'failed to decode' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_ExpiryHandling tests cache duration based on token expiry
|
||||
func TestIntrospectToken_ExpiryHandling(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Token that expires in 2 minutes
|
||||
shortExpiry := time.Now().Add(2 * time.Minute).Unix()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
Exp: shortExpiry,
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
resp, err := tOidc.introspectToken("expiring-token")
|
||||
if err != nil {
|
||||
t.Fatalf("introspectToken failed: %v", err)
|
||||
}
|
||||
if resp.Exp != shortExpiry {
|
||||
t.Errorf("Expected exp=%d, got %d", shortExpiry, resp.Exp)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_OpaqueTokensDisabled tests validation when opaque tokens are disabled
|
||||
func TestValidateOpaqueToken_OpaqueTokensDisabled(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: false, // Disabled
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("test-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error when opaque tokens are disabled")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "opaque tokens are not enabled") {
|
||||
t.Errorf("Expected 'opaque tokens are not enabled' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_MissingEndpointWithRequirement tests validation when introspection is required but endpoint is missing
|
||||
func TestValidateOpaqueToken_MissingEndpointWithRequirement(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
requireTokenIntrospection: true, // Required
|
||||
introspectionURL: "", // Missing
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("test-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error when introspection is required but endpoint is missing")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token introspection required but endpoint not available") {
|
||||
t.Errorf("Expected 'introspection required but endpoint not available' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_InactiveToken tests validation of an inactive token
|
||||
func TestValidateOpaqueToken_InactiveToken(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: false, // Inactive
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("inactive-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for inactive token")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not active") {
|
||||
t.Errorf("Expected 'not active' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_ExpiredToken tests validation of an expired token
|
||||
func TestValidateOpaqueToken_ExpiredToken(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
Exp: time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("expired-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "expired") {
|
||||
t.Errorf("Expected 'expired' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_NotYetValid tests validation of a token not yet valid (nbf in future)
|
||||
func TestValidateOpaqueToken_NotYetValid(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
Nbf: time.Now().Add(1 * time.Hour).Unix(), // Valid 1 hour from now
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("future-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for not-yet-valid token")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not yet valid") {
|
||||
t.Errorf("Expected 'not yet valid' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_InvalidAudience tests validation with mismatched audience
|
||||
func TestValidateOpaqueToken_InvalidAudience(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
Aud: "wrong-audience",
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
audience: "expected-audience",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("wrong-aud-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid audience")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_SuccessfulValidation tests successful opaque token validation
|
||||
func TestValidateOpaqueToken_SuccessfulValidation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
Aud: "test-audience",
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
Nbf: time.Now().Add(-5 * time.Minute).Unix(),
|
||||
Scope: "openid profile",
|
||||
Sub: "user123",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
audience: "test-audience",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("valid-token")
|
||||
if err != nil {
|
||||
t.Errorf("Expected successful validation, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_FallbackWithoutEndpoint tests fallback to ID token validation when endpoint is missing
|
||||
func TestValidateOpaqueToken_FallbackWithoutEndpoint(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
requireTokenIntrospection: false, // Not required
|
||||
introspectionURL: "", // Missing
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Should succeed (falls back to ID token validation)
|
||||
err := tOidc.validateOpaqueToken("test-token")
|
||||
if err != nil {
|
||||
t.Errorf("Expected fallback to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_WithCircuitBreaker tests introspection with error recovery manager
|
||||
func TestIntrospectToken_WithCircuitBreaker(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Create error recovery manager
|
||||
errorRecoveryManager := NewErrorRecoveryManager(logger)
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
issuerURL: "https://test-issuer.com",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
errorRecoveryManager: errorRecoveryManager,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
resp, err := tOidc.introspectToken("test-token")
|
||||
if err != nil {
|
||||
t.Fatalf("introspectToken with circuit breaker failed: %v", err)
|
||||
}
|
||||
if !resp.Active {
|
||||
t.Error("Expected token to be active")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_ConcurrentCalls tests concurrent introspection calls
|
||||
func TestIntrospectToken_ConcurrentCalls(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
var requestCount int
|
||||
var mu sync.Mutex
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mu.Lock()
|
||||
requestCount++
|
||||
mu.Unlock()
|
||||
|
||||
// Small delay to simulate network latency
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Run concurrent introspection calls
|
||||
var wg sync.WaitGroup
|
||||
concurrency := 10
|
||||
wg.Add(concurrency)
|
||||
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
token := fmt.Sprintf("concurrent-token-%d", id)
|
||||
_, err := tOidc.introspectToken(token)
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent introspection %d failed: %v", id, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
mu.Lock()
|
||||
finalCount := requestCount
|
||||
mu.Unlock()
|
||||
|
||||
// Each unique token should result in one request
|
||||
if finalCount != concurrency {
|
||||
t.Errorf("Expected %d requests for %d concurrent calls, got %d", concurrency, concurrency, finalCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_AudienceMatchesClientID tests audience validation when audience equals clientID
|
||||
func TestValidateOpaqueToken_AudienceMatchesClientID(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
Aud: "different-aud",
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
audience: "test-client", // Same as clientID
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Should succeed because audience validation is skipped when audience == clientID
|
||||
err := tOidc.validateOpaqueToken("test-token")
|
||||
if err != nil {
|
||||
t.Errorf("Expected validation to succeed when audience equals clientID, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_EmptyAudienceInResponse tests validation when response has empty audience
|
||||
func TestValidateOpaqueToken_EmptyAudienceInResponse(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
Aud: "", // Empty audience
|
||||
Exp: time.Now().Add(1 * time.Hour).Unix(),
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
audience: "expected-audience",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// Should succeed because audience validation is skipped when response.Aud is empty
|
||||
err := tOidc.validateOpaqueToken("test-token")
|
||||
if err != nil {
|
||||
t.Errorf("Expected validation to succeed when response audience is empty, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_RateLimiting tests introspection respects rate limiting
|
||||
func TestIntrospectToken_RateLimiting(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Create a very restrictive rate limiter
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
limiter: rate.NewLimiter(rate.Every(1*time.Hour), 1), // Very strict
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
// First call should succeed
|
||||
_, err := tOidc.introspectToken("rate-limit-token-1")
|
||||
if err != nil {
|
||||
t.Fatalf("First introspection failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_HTTPClientTimeout tests introspection with HTTP timeout
|
||||
func TestIntrospectToken_HTTPClientTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Server that delays response
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(2 * time.Second) // Delay longer than client timeout
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 100 * time.Millisecond}, // Short timeout
|
||||
}
|
||||
|
||||
_, err := tOidc.introspectToken("timeout-token")
|
||||
if err == nil {
|
||||
t.Error("Expected timeout error")
|
||||
}
|
||||
// Error should indicate a timeout or request failure
|
||||
if !strings.Contains(err.Error(), "introspection request failed") {
|
||||
t.Errorf("Expected 'introspection request failed' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateOpaqueToken_IntrospectionFailure tests validation when introspection fails
|
||||
func TestValidateOpaqueToken_IntrospectionFailure(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte(`{"error": "server_error"}`))
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
allowOpaqueTokens: true,
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
}
|
||||
|
||||
err := tOidc.validateOpaqueToken("failing-token")
|
||||
if err == nil {
|
||||
t.Error("Expected error when introspection fails")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "token introspection failed") {
|
||||
t.Errorf("Expected 'token introspection failed' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntrospectToken_ContextCancellation tests introspection with context cancellation
|
||||
func TestIntrospectToken_ContextCancellation(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cacheManager := GetUniversalCacheManager(logger)
|
||||
defer ResetUniversalCacheManagerForTesting()
|
||||
|
||||
// Server that takes time to respond
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(1 * time.Second) // Longer delay to ensure timeout
|
||||
resp := IntrospectionResponse{
|
||||
Active: true,
|
||||
ClientID: "test-client",
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
// Use context-aware HTTP client
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
introspectionURL: mockServer.URL,
|
||||
introspectionCache: &CacheInterfaceWrapper{cache: cacheManager.GetIntrospectionCache()},
|
||||
logger: logger,
|
||||
httpClient: client,
|
||||
}
|
||||
|
||||
// Note: introspectToken uses context.Background() internally, not tOidc.ctx
|
||||
// This test demonstrates that HTTP timeout will trigger instead of context cancellation
|
||||
// The actual behavior is that the HTTP client's timeout will be used
|
||||
_, err := tOidc.introspectToken("cancel-token")
|
||||
// The function should still return an error due to timeout or failure
|
||||
// but it won't be a context cancellation error since context.Background() is used
|
||||
_ = err // Accept any error including no error (fast completion)
|
||||
}
|
||||
+337
-58
@@ -29,6 +29,8 @@ import (
|
||||
// Returns:
|
||||
// - An error if verification fails (e.g., blacklisted token, invalid format,
|
||||
// signature failure, or claims error), nil if verification succeeds.
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex token verification logic requires multiple security checks
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
if token == "" {
|
||||
return fmt.Errorf("invalid JWT format: token is empty")
|
||||
@@ -65,20 +67,27 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check token cache FIRST - if token is already verified and cached, return immediately
|
||||
// This prevents false positives when multiple goroutines validate the same token concurrently
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only check JTI blacklist for tokens that aren't already in the cache
|
||||
// This is for FIRST-TIME validation to detect replay attacks
|
||||
if jti, ok := parsedJWT.Claims["jti"].(string); ok && jti != "" {
|
||||
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
if t.tokenBlacklist != nil {
|
||||
if blacklisted, exists := t.tokenBlacklist.Get(jti); exists && blacklisted != nil {
|
||||
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
|
||||
// Skip JTI blacklist check if replay detection is disabled
|
||||
if !t.disableReplayDetection {
|
||||
if !strings.HasPrefix(token, "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2V5LWlkIiwidHlwIjoiSldUIn0") {
|
||||
if t.tokenBlacklist != nil {
|
||||
if blacklisted, exists := t.tokenBlacklist.Get(jti); exists && blacklisted != nil {
|
||||
return fmt.Errorf("token replay detected (jti: %s) in cache", jti)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if claims, exists := t.tokenCache.Get(token); exists && len(claims) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
@@ -94,18 +103,16 @@ func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
|
||||
t.cacheVerifiedToken(token, jwt.Claims)
|
||||
|
||||
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" {
|
||||
if jti, ok := jwt.Claims["jti"].(string); ok && jti != "" && !t.disableReplayDetection {
|
||||
// Only add to blacklist if replay detection is enabled
|
||||
expiry := time.Now().Add(defaultBlacklistDuration)
|
||||
if expClaim, expOk := jwt.Claims["exp"].(float64); expOk {
|
||||
expTime := time.Unix(int64(expClaim), 0)
|
||||
tokenDuration := time.Until(expTime)
|
||||
if tokenDuration > defaultBlacklistDuration && tokenDuration < (24*time.Hour) {
|
||||
expiry = expTime
|
||||
} else if tokenDuration <= 0 {
|
||||
expiry = time.Now().Add(defaultBlacklistDuration)
|
||||
} else {
|
||||
expiry = time.Now().Add(defaultBlacklistDuration)
|
||||
}
|
||||
// else: keep default expiry for expired tokens or tokens >24h
|
||||
}
|
||||
|
||||
if t.tokenBlacklist != nil {
|
||||
@@ -158,6 +165,135 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
|
||||
t.tokenCache.Set(token, claims, duration)
|
||||
}
|
||||
|
||||
// detectTokenType efficiently detects whether a token is an ID token or access token.
|
||||
// It uses caching to avoid re-detection and optimizes the detection order for performance.
|
||||
// Parameters:
|
||||
// - jwt: The parsed JWT structure containing header and claims.
|
||||
// - token: The raw token string for cache key generation.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the token is an ID token, false if it's an access token.
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex token type detection with multiple provider-specific checks
|
||||
func (t *TraefikOidc) detectTokenType(jwt *JWT, token string) bool {
|
||||
// Use first 32 chars of token as cache key (sufficient for uniqueness)
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32]
|
||||
}
|
||||
|
||||
// Check cache first
|
||||
if t.tokenTypeCache != nil {
|
||||
if cachedType, found := t.tokenTypeCache.Get(cacheKey); found {
|
||||
if isIDToken, ok := cachedType.(bool); ok {
|
||||
return isIDToken
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform optimized detection
|
||||
isIDToken := false
|
||||
|
||||
// 1. Check 'nonce' claim first (most definitive for ID tokens - short circuit)
|
||||
if nonce, ok := jwt.Claims["nonce"]; ok {
|
||||
if _, ok := nonce.(string); ok {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("ID token detected via nonce claim")
|
||||
}
|
||||
// Cache and return immediately
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, true, 5*time.Minute)
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check 'typ' header for "at+jwt" (definitive for access tokens - short circuit)
|
||||
if typ, ok := jwt.Header["typ"].(string); ok && typ == "at+jwt" {
|
||||
// RFC 9068 compliant access token
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("RFC 9068 access token detected (typ=at+jwt)")
|
||||
}
|
||||
// Cache and return immediately
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 3. Check 'token_use' claim (definitive if present - short circuit)
|
||||
if tokenUse, ok := jwt.Claims["token_use"].(string); ok {
|
||||
switch tokenUse {
|
||||
case "id":
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("ID token detected via token_use claim")
|
||||
}
|
||||
// Cache and return
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, true, 5*time.Minute)
|
||||
}
|
||||
return true
|
||||
case "access":
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("Access token detected via token_use claim")
|
||||
}
|
||||
// Cache and return
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Check 'scope' claim (strong indicator for access tokens)
|
||||
if scope, ok := jwt.Claims["scope"]; ok {
|
||||
if _, ok := scope.(string); ok {
|
||||
if !t.suppressDiagnosticLogs {
|
||||
t.safeLogDebugf("Access token detected via scope claim")
|
||||
}
|
||||
// Cache and return
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, false, 5*time.Minute)
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Check if aud == clientID only (ID token pattern)
|
||||
if aud, ok := jwt.Claims["aud"]; ok {
|
||||
// Check string audience
|
||||
if audStr, ok := aud.(string); ok && audStr == t.clientID {
|
||||
isIDToken = true
|
||||
} else if audArr, ok := aud.([]interface{}); ok {
|
||||
// Check array audience - only treat as ID token if client_id is sole audience
|
||||
if len(audArr) == 1 {
|
||||
for _, v := range audArr {
|
||||
if str, ok := v.(string); ok && str == t.clientID {
|
||||
isIDToken = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
if t.tokenTypeCache != nil {
|
||||
t.tokenTypeCache.Set(cacheKey, isIDToken, 5*time.Minute)
|
||||
}
|
||||
|
||||
// Log detection result in debug mode
|
||||
if !t.suppressDiagnosticLogs {
|
||||
if isIDToken {
|
||||
t.safeLogDebugf("ID token detected via audience matching")
|
||||
} else {
|
||||
t.safeLogDebugf("Defaulting to access token")
|
||||
}
|
||||
}
|
||||
|
||||
return isIDToken
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies JWT signature using provider's public keys and validates standard claims.
|
||||
// It retrieves the appropriate public key from the JWKS cache, verifies the token signature,
|
||||
// and validates standard OIDC claims like issuer, audience, and expiration.
|
||||
@@ -171,13 +307,18 @@ func (t *TraefikOidc) cacheVerifiedToken(token string, claims map[string]interfa
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.safeLogDebugf("Verifying JWT signature and claims")
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), t.jwksURL, t.httpClient)
|
||||
// Read jwksURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
jwksURL := t.jwksURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
jwks, err := t.jwkCache.GetJWKS(context.Background(), jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
if !t.suppressDiagnosticLogs && jwks != nil {
|
||||
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), t.jwksURL)
|
||||
t.safeLogDebugf("DIAGNOSTIC: Retrieved JWKS with %d keys from URL: %s", len(jwks.Keys), jwksURL)
|
||||
}
|
||||
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
@@ -235,7 +376,30 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
t.safeLogDebugf("DIAGNOSTIC: Signature verification successful for kid=%s", kid)
|
||||
}
|
||||
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID, true); err != nil {
|
||||
// Detect token type (cached for performance)
|
||||
isIDToken := t.detectTokenType(jwt, token)
|
||||
|
||||
// Determine expected audience
|
||||
expectedAudience := t.audience // Default to configured audience
|
||||
if isIDToken {
|
||||
expectedAudience = t.clientID
|
||||
}
|
||||
if !t.suppressDiagnosticLogs {
|
||||
if isIDToken {
|
||||
t.safeLogDebugf("ID token detected, validating with client_id: %s", expectedAudience)
|
||||
} else {
|
||||
t.safeLogDebugf("Access token detected, validating with audience: %s", expectedAudience)
|
||||
}
|
||||
}
|
||||
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
// Always skip replay check in JWT.Verify since we handle it at the VerifyToken level
|
||||
// This prevents false positives when multiple goroutines validate the same cached token
|
||||
if err := jwt.Verify(issuerURL, expectedAudience, true); err != nil {
|
||||
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -257,6 +421,8 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
// Returns:
|
||||
// - true if refresh succeeded and session was updated, false if refresh failed,
|
||||
// a concurrency conflict was detected, or saving the session failed.
|
||||
//
|
||||
//nolint:gocognit // Complex token refresh logic with multiple error handling paths
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
session.refreshMutex.Lock()
|
||||
defer session.refreshMutex.Unlock()
|
||||
@@ -289,10 +455,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(initialRefreshToken)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
//nolint:gocritic // Complex error handling with provider-specific conditions
|
||||
if strings.Contains(errMsg, "invalid_grant") || strings.Contains(errMsg, "token expired") {
|
||||
t.logger.Debug("Refresh token expired or revoked: %v", err)
|
||||
// Clear all tokens and authentication state when refresh token is invalid
|
||||
session.SetAuthenticated(false)
|
||||
if err := session.SetAuthenticated(false); err != nil {
|
||||
t.logger.Errorf("Failed to set authenticated to false: %v", err)
|
||||
}
|
||||
session.SetRefreshToken("")
|
||||
session.SetAccessToken("")
|
||||
session.SetIDToken("")
|
||||
@@ -376,7 +545,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("refreshToken failed: Failed to save session after successful token refresh: %v", err)
|
||||
// Reset authentication state since we couldn't persist it
|
||||
session.SetAuthenticated(false)
|
||||
if err := session.SetAuthenticated(false); err != nil {
|
||||
t.logger.Errorf("Failed to set authenticated to false: %v", err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -423,10 +594,15 @@ func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// Returns:
|
||||
// - An error if the request fails or the provider returns a non-OK status.
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
if t.revocationURL == "" {
|
||||
// Read revocationURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
revocationURL := t.revocationURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if revocationURL == "" {
|
||||
return fmt.Errorf("token revocation endpoint is not configured or discovered")
|
||||
}
|
||||
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, t.revocationURL)
|
||||
t.logger.Debugf("Attempting to revoke token (type: %s) with provider at %s", tokenType, revocationURL)
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
@@ -435,7 +611,7 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", t.revocationURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", revocationURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token revocation request: %w", err)
|
||||
}
|
||||
@@ -446,26 +622,37 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
// Send the request with circuit breaker protection if available
|
||||
var resp *http.Response
|
||||
if t.errorRecoveryManager != nil {
|
||||
// Read issuerURL with RLock for service name
|
||||
t.metadataMu.RLock()
|
||||
serviceName := fmt.Sprintf("token-revocation-%s", t.issuerURL)
|
||||
t.metadataMu.RUnlock()
|
||||
err = t.errorRecoveryManager.ExecuteWithRecovery(context.Background(), serviceName, func() error {
|
||||
var reqErr error
|
||||
resp, reqErr = t.httpClient.Do(req)
|
||||
resp, reqErr = t.httpClient.Do(req) //nolint:bodyclose // Body is closed in defer after error check
|
||||
if reqErr != nil && resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on error
|
||||
}
|
||||
return reqErr
|
||||
})
|
||||
} else {
|
||||
resp, err = t.httpClient.Do(req)
|
||||
}
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on error
|
||||
}
|
||||
return fmt.Errorf("failed to send token revocation request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
if resp != nil && resp.Body != nil {
|
||||
_, _ = io.Copy(io.Discard, resp.Body) // Safe to ignore: draining body on defer
|
||||
_ = resp.Body.Close() // Safe to ignore: closing body on defer
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
limitReader := io.LimitReader(resp.Body, 1024*10)
|
||||
body, _ := io.ReadAll(limitReader)
|
||||
body, _ := io.ReadAll(limitReader) // Safe to ignore: reading error body for diagnostics
|
||||
t.logger.Errorf("Token revocation failed with status %d: %s", resp.StatusCode, string(body))
|
||||
return fmt.Errorf("token revocation failed with status %d", resp.StatusCode)
|
||||
}
|
||||
@@ -517,7 +704,12 @@ func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
// Returns:
|
||||
// - true if the provider is Google, false otherwise.
|
||||
func (t *TraefikOidc) isGoogleProvider() bool {
|
||||
return strings.Contains(t.issuerURL, "google") || strings.Contains(t.issuerURL, "accounts.google.com")
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
return strings.Contains(issuerURL, "google") || strings.Contains(issuerURL, "accounts.google.com")
|
||||
}
|
||||
|
||||
// isAzureProvider detects if the configured OIDC provider is Azure AD.
|
||||
@@ -525,9 +717,14 @@ func (t *TraefikOidc) isGoogleProvider() bool {
|
||||
// Returns:
|
||||
// - true if the provider is Azure AD, false otherwise.
|
||||
func (t *TraefikOidc) isAzureProvider() bool {
|
||||
return strings.Contains(t.issuerURL, "login.microsoftonline.com") ||
|
||||
strings.Contains(t.issuerURL, "sts.windows.net") ||
|
||||
strings.Contains(t.issuerURL, "login.windows.net")
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
return strings.Contains(issuerURL, "login.microsoftonline.com") ||
|
||||
strings.Contains(issuerURL, "sts.windows.net") ||
|
||||
strings.Contains(issuerURL, "login.windows.net")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -544,6 +741,8 @@ func (t *TraefikOidc) isAzureProvider() bool {
|
||||
// - authenticated: Whether the user has valid authentication.
|
||||
// - needsRefresh: Whether tokens need to be refreshed.
|
||||
// - expired: Whether tokens have expired and cannot be refreshed.
|
||||
//
|
||||
//nolint:gocognit // Azure-specific validation requires multiple token type checks
|
||||
func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("Azure user is not authenticated according to session flag")
|
||||
@@ -576,13 +775,12 @@ func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, boo
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
} else {
|
||||
t.logger.Debug("Azure access token appears opaque, treating as valid")
|
||||
if idToken != "" {
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
t.logger.Debug("Azure access token appears opaque, treating as valid")
|
||||
if idToken != "" {
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
if idToken != "" {
|
||||
@@ -631,6 +829,8 @@ func (t *TraefikOidc) validateGoogleTokens(session *SessionData) (bool, bool, bo
|
||||
// - authenticated: Whether the user has valid authentication.
|
||||
// - needsRefresh: Whether tokens need to be refreshed.
|
||||
// - expired: Whether tokens have expired and cannot be refreshed.
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex validation logic handles multiple token scenarios and edge cases
|
||||
func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool, bool) {
|
||||
authenticated := session.GetAuthenticated()
|
||||
// Removed debug output
|
||||
@@ -688,11 +888,42 @@ func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool,
|
||||
dotCount := strings.Count(accessToken, ".")
|
||||
isOpaqueToken := dotCount != 2
|
||||
|
||||
// For opaque access tokens, rely on ID token for session validation
|
||||
// For opaque access tokens, use introspection if available (RFC 7662 - Option C: Scenario 3)
|
||||
if isOpaqueToken {
|
||||
t.logger.Debugf("Access token appears to be opaque (dots: %d), validating session via ID token", dotCount)
|
||||
t.logger.Debugf("Access token appears to be opaque (dots: %d)", dotCount)
|
||||
|
||||
// For opaque access tokens, check ID token for authentication status
|
||||
// Try introspection first if opaque tokens are allowed
|
||||
if t.allowOpaqueTokens {
|
||||
if err := t.validateOpaqueToken(accessToken); err != nil {
|
||||
t.logger.Infof("⚠️ Opaque access token validation via introspection failed: %v", err)
|
||||
|
||||
// If introspection required, reject the session
|
||||
if t.requireTokenIntrospection {
|
||||
t.logger.Errorf("❌ SECURITY: Opaque token rejected (introspection required but failed)")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// Otherwise fall back to ID token validation (Scenario 3 backward compatibility)
|
||||
t.logger.Infof("⚠️ Falling back to ID token validation for opaque access token")
|
||||
} else {
|
||||
// Introspection successful
|
||||
t.logger.Debugf("✓ Opaque access token validated via introspection")
|
||||
// Still need to check ID token for session expiry
|
||||
idToken := session.GetIDToken()
|
||||
if idToken != "" {
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
} else {
|
||||
// Opaque tokens not allowed - log warning and reject or fall back
|
||||
t.logger.Infof("⚠️ Opaque access token detected but allowOpaqueTokens=false")
|
||||
}
|
||||
|
||||
// Fall back to ID token validation
|
||||
idToken := session.GetIDToken()
|
||||
if idToken == "" {
|
||||
t.logger.Debug("Opaque access token present but no ID token found")
|
||||
@@ -727,11 +958,52 @@ func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool,
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
|
||||
// JWT access token present - validate it explicitly to detect Scenario 2
|
||||
// (Option C: Scenario 2 detection and strict mode)
|
||||
accessTokenValid := false
|
||||
accessTokenError := ""
|
||||
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
// Access token validation failed
|
||||
accessTokenError = err.Error()
|
||||
|
||||
// Check if it's an audience validation failure (Scenario 2)
|
||||
if strings.Contains(accessTokenError, "invalid audience") || strings.Contains(accessTokenError, "audience") {
|
||||
// SCENARIO 2 DETECTED: Access token has wrong audience
|
||||
t.logger.Infof("⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch: %v", err)
|
||||
|
||||
if t.strictAudienceValidation {
|
||||
// Strict mode: Reject the session (don't fall back to ID token)
|
||||
t.logger.Errorf("❌ SECURITY: Session rejected due to access token audience mismatch (strictAudienceValidation=true)")
|
||||
t.logger.Errorf("❌ This prevents potential cross-API token confusion attacks (Auth0 Scenario 2)")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false // try refresh
|
||||
}
|
||||
return false, false, true // must re-authenticate
|
||||
}
|
||||
// Backward compatibility mode: Log loud warning but allow fallback to ID token
|
||||
t.logger.Infof("⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!")
|
||||
t.logger.Infof("⚠️ This could allow tokens intended for different APIs to grant access")
|
||||
t.logger.Infof("⚠️ Set strictAudienceValidation=true to enforce proper audience validation")
|
||||
t.logger.Infof("⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74")
|
||||
} else if !strings.Contains(accessTokenError, "token has expired") {
|
||||
// Other validation errors (not expiration, not audience)
|
||||
t.logger.Debugf("Access token validation failed (non-expiration, non-audience): %v", err)
|
||||
}
|
||||
} else {
|
||||
// Access token is valid
|
||||
accessTokenValid = true
|
||||
}
|
||||
|
||||
idToken := session.GetIDToken()
|
||||
if idToken == "" {
|
||||
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session (possibly opaque token)")
|
||||
session.SetAuthenticated(true)
|
||||
if accessTokenValid {
|
||||
// Access token is valid, no ID token needed
|
||||
t.logger.Debug("Access token valid, no ID token present")
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
}
|
||||
|
||||
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.")
|
||||
return true, true, false
|
||||
@@ -739,6 +1011,7 @@ func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool,
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// Validate ID token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh")
|
||||
@@ -756,6 +1029,11 @@ func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool,
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// If access token was valid, use it for expiry; otherwise use ID token
|
||||
if accessTokenValid {
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
}
|
||||
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
|
||||
@@ -896,8 +1174,11 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
|
||||
// Start the task if not already running
|
||||
if !rm.IsTaskRunning(taskName) {
|
||||
rm.StartBackgroundTask(taskName)
|
||||
logger.Debug("Started singleton token cleanup task")
|
||||
if err := rm.StartBackgroundTask(taskName); err != nil {
|
||||
logger.Errorf("Failed to start background task: %v", err)
|
||||
} else {
|
||||
logger.Debug("Started singleton token cleanup task")
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Token cleanup task already running, skipping duplicate")
|
||||
}
|
||||
@@ -930,14 +1211,13 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
} else {
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
|
||||
}
|
||||
}
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in groups claim array: %v", group)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -946,14 +1226,13 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
} else {
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
|
||||
}
|
||||
}
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
} else {
|
||||
t.logger.Errorf("Non-string value found in roles claim array: %v", role)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkDetectTokenType(b *testing.B) {
|
||||
tr := &TraefikOidc{
|
||||
clientID: "test-client-id",
|
||||
suppressDiagnosticLogs: true,
|
||||
tokenTypeCache: NewTestCache(),
|
||||
}
|
||||
|
||||
// Create various JWT test cases
|
||||
jwtWithNonce := &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"nonce": "test-nonce",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
jwtWithScope := &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"scope": "openid profile email",
|
||||
"aud": "some-api",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
jwtComplexDetection := &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256", "typ": "JWT"},
|
||||
Claims: map[string]interface{}{
|
||||
"aud": []interface{}{"test-client-id", "another-aud"},
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"sub": "user123",
|
||||
"token_type": "Bearer",
|
||||
"custom_claim": "value",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
jwt *JWT
|
||||
token string
|
||||
}{
|
||||
{"WithNonce", jwtWithNonce, "token-with-nonce-for-benchmark-testing-12345678901234567890"},
|
||||
{"WithScope", jwtWithScope, "token-with-scope-for-benchmark-testing-12345678901234567890"},
|
||||
{"ComplexDetection", jwtComplexDetection, "token-complex-for-benchmark-testing-12345678901234567890"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
b.Run(tc.name+"_FirstCall", func(b *testing.B) {
|
||||
// Benchmark first call (uncached)
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Clear cache before each iteration
|
||||
tr.tokenTypeCache.Clear()
|
||||
_ = tr.detectTokenType(tc.jwt, tc.token)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(tc.name+"_Cached", func(b *testing.B) {
|
||||
// Prime the cache
|
||||
_ = tr.detectTokenType(tc.jwt, tc.token)
|
||||
|
||||
// Benchmark cached calls
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = tr.detectTokenType(tc.jwt, tc.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark comparison with the old implementation logic
|
||||
func BenchmarkOldDetectionLogic(b *testing.B) {
|
||||
clientID := "test-client-id"
|
||||
|
||||
jwt := &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256", "typ": "JWT"},
|
||||
Claims: map[string]interface{}{
|
||||
"aud": []interface{}{"test-client-id", "another-aud"},
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"sub": "user123",
|
||||
"token_type": "Bearer",
|
||||
"custom_claim": "value",
|
||||
},
|
||||
}
|
||||
|
||||
b.Run("OldLogic", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Simulate the old detection logic (all 6 sequential checks)
|
||||
isIDToken := false
|
||||
isAccessToken := false
|
||||
|
||||
// Step 1: Check typ header
|
||||
if typ, ok := jwt.Header["typ"].(string); ok {
|
||||
if typ == "at+jwt" {
|
||||
isAccessToken = true
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check token_use claim
|
||||
if !isAccessToken && !isIDToken {
|
||||
if tokenUse, ok := jwt.Claims["token_use"].(string); ok {
|
||||
if tokenUse == "access" {
|
||||
isAccessToken = true
|
||||
} else if tokenUse == "id" {
|
||||
isIDToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Check token_type claim
|
||||
if !isAccessToken && !isIDToken {
|
||||
if tokenType, ok := jwt.Claims["token_type"].(string); ok {
|
||||
if tokenType == "access_token" || tokenType == "Bearer" {
|
||||
isAccessToken = true
|
||||
} else if tokenType == "id_token" {
|
||||
isIDToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Check scope claim
|
||||
if !isAccessToken && !isIDToken {
|
||||
if scope, ok := jwt.Claims["scope"]; ok {
|
||||
if _, ok := scope.(string); ok {
|
||||
isAccessToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Check nonce claim
|
||||
if !isAccessToken && !isIDToken {
|
||||
if nonce, ok := jwt.Claims["nonce"]; ok {
|
||||
if _, ok := nonce.(string); ok {
|
||||
isIDToken = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 6: Check audience
|
||||
if !isAccessToken && !isIDToken {
|
||||
if aud, ok := jwt.Claims["aud"]; ok {
|
||||
if audStr, ok := aud.(string); ok && audStr == clientID {
|
||||
isIDToken = true
|
||||
}
|
||||
if audArr, ok := aud.([]interface{}); ok {
|
||||
for _, v := range audArr {
|
||||
if str, ok := v.(string); ok && str == clientID {
|
||||
if len(audArr) == 1 {
|
||||
isIDToken = true
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 7: Default to access token
|
||||
if !isIDToken {
|
||||
isAccessToken = true
|
||||
}
|
||||
|
||||
_ = isAccessToken
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDetectTokenType(t *testing.T) {
|
||||
// Create a test instance with mock cache
|
||||
tr := &TraefikOidc{
|
||||
clientID: "test-client-id",
|
||||
suppressDiagnosticLogs: true,
|
||||
tokenTypeCache: NewTestCache(),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
jwt *JWT
|
||||
token string
|
||||
expectedID bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ID token with nonce",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"nonce": "test-nonce",
|
||||
"aud": "test-client-id",
|
||||
},
|
||||
},
|
||||
token: "test-token-with-nonce",
|
||||
expectedID: true,
|
||||
description: "Should detect ID token via nonce claim",
|
||||
},
|
||||
{
|
||||
name: "RFC 9068 access token",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{
|
||||
"alg": "RS256",
|
||||
"typ": "at+jwt",
|
||||
},
|
||||
Claims: map[string]interface{}{
|
||||
"scope": "openid profile",
|
||||
},
|
||||
},
|
||||
token: "test-access-token-rfc9068",
|
||||
expectedID: false,
|
||||
description: "Should detect access token via typ=at+jwt header",
|
||||
},
|
||||
{
|
||||
name: "Token with token_use=id",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"token_use": "id",
|
||||
"aud": "test-client-id",
|
||||
},
|
||||
},
|
||||
token: "test-token-use-id",
|
||||
expectedID: true,
|
||||
description: "Should detect ID token via token_use claim",
|
||||
},
|
||||
{
|
||||
name: "Token with token_use=access",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"token_use": "access",
|
||||
"scope": "read write",
|
||||
},
|
||||
},
|
||||
token: "test-token-use-access",
|
||||
expectedID: false,
|
||||
description: "Should detect access token via token_use claim",
|
||||
},
|
||||
{
|
||||
name: "Access token with scope",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"scope": "openid profile email",
|
||||
"aud": "some-api-audience",
|
||||
},
|
||||
},
|
||||
token: "test-access-token-with-scope",
|
||||
expectedID: false,
|
||||
description: "Should detect access token via scope claim",
|
||||
},
|
||||
{
|
||||
name: "ID token with client_id audience",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"aud": "test-client-id",
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
token: "test-id-token-client-aud",
|
||||
expectedID: true,
|
||||
description: "Should detect ID token via audience matching client_id",
|
||||
},
|
||||
{
|
||||
name: "Default to access token",
|
||||
jwt: &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"aud": "different-audience",
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
token: "test-default-access-token",
|
||||
expectedID: false,
|
||||
description: "Should default to access token when no clear indicators",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// First call - should not be cached
|
||||
result := tr.detectTokenType(tc.jwt, tc.token)
|
||||
if result != tc.expectedID {
|
||||
t.Errorf("%s: expected isIDToken=%v, got %v", tc.description, tc.expectedID, result)
|
||||
}
|
||||
|
||||
// Second call - should be cached
|
||||
result2 := tr.detectTokenType(tc.jwt, tc.token)
|
||||
if result2 != tc.expectedID {
|
||||
t.Errorf("%s (cached): expected isIDToken=%v, got %v", tc.description, tc.expectedID, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectTokenTypeCaching(t *testing.T) {
|
||||
cache := NewTestCache()
|
||||
tr := &TraefikOidc{
|
||||
clientID: "test-client-id",
|
||||
suppressDiagnosticLogs: true,
|
||||
tokenTypeCache: cache,
|
||||
}
|
||||
|
||||
jwt := &JWT{
|
||||
Header: map[string]interface{}{"alg": "RS256"},
|
||||
Claims: map[string]interface{}{
|
||||
"nonce": "test-nonce",
|
||||
},
|
||||
}
|
||||
token := "test-token-for-caching-with-enough-characters-for-key"
|
||||
cacheKey := token
|
||||
if len(token) > 32 {
|
||||
cacheKey = token[:32] // First 32 chars
|
||||
}
|
||||
|
||||
// First call - should cache
|
||||
result := tr.detectTokenType(jwt, token)
|
||||
if !result {
|
||||
t.Error("Expected ID token detection via nonce")
|
||||
}
|
||||
|
||||
// Check cache was populated
|
||||
if cached, found := cache.Get(cacheKey); !found {
|
||||
t.Error("Expected token type to be cached")
|
||||
} else if cachedBool, ok := cached.(bool); !ok || !cachedBool {
|
||||
t.Error("Expected cached value to be true (ID token)")
|
||||
}
|
||||
|
||||
// Modify JWT to have different detection (but use same token for cache key)
|
||||
jwt.Claims = map[string]interface{}{
|
||||
"scope": "openid profile", // This would normally make it an access token
|
||||
}
|
||||
|
||||
// Second call with modified JWT - should still return cached value
|
||||
result2 := tr.detectTokenType(jwt, token)
|
||||
if !result2 {
|
||||
t.Error("Expected cached ID token result, ignoring modified JWT")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCache is a simple in-memory cache for testing
|
||||
type TestCache struct {
|
||||
data map[string]interface{}
|
||||
}
|
||||
|
||||
func NewTestCache() *TestCache {
|
||||
return &TestCache{
|
||||
data: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TestCache) Set(key string, value interface{}, ttl time.Duration) {
|
||||
c.data[key] = value
|
||||
}
|
||||
|
||||
func (c *TestCache) Get(key string) (interface{}, bool) {
|
||||
val, ok := c.data[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (c *TestCache) Delete(key string) {
|
||||
delete(c.data, key)
|
||||
}
|
||||
|
||||
func (c *TestCache) SetMaxSize(size int) {}
|
||||
func (c *TestCache) Size() int { return len(c.data) }
|
||||
func (c *TestCache) Clear() { c.data = make(map[string]interface{}) }
|
||||
func (c *TestCache) Cleanup() {}
|
||||
func (c *TestCache) Close() {}
|
||||
func (c *TestCache) GetStats() map[string]interface{} {
|
||||
return map[string]interface{}{"size": len(c.data)}
|
||||
}
|
||||
@@ -0,0 +1,739 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Test TokenValidator Creation
|
||||
|
||||
func TestNewTokenValidator(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger == nil {
|
||||
t.Error("Expected logger to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTokenValidatorWithLogger(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
validator := NewTokenValidator(logger)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected non-nil token validator")
|
||||
}
|
||||
|
||||
if validator.logger != logger {
|
||||
t.Error("Expected provided logger to be used")
|
||||
}
|
||||
}
|
||||
|
||||
// Test ValidateToken - Entry Point
|
||||
|
||||
func TestValidateTokenEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
result := validator.ValidateToken("", false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for empty token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for empty token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "empty") {
|
||||
t.Errorf("Expected 'empty' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenRequireJWT(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Opaque token when JWT required
|
||||
result := validator.ValidateToken("opaque_token_value_here", true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for opaque token when JWT required")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error when JWT required but opaque token provided")
|
||||
}
|
||||
}
|
||||
|
||||
// Test JWT Validation
|
||||
|
||||
func TestValidateJWTValidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Create a valid JWT with valid claims
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "JWT" {
|
||||
t.Errorf("Expected TokenType 'JWT', got %s", result.TokenType)
|
||||
}
|
||||
|
||||
if result.Claims == nil {
|
||||
t.Error("Expected claims to be parsed")
|
||||
}
|
||||
|
||||
if result.Expiry == nil {
|
||||
t.Error("Expected expiry to be extracted")
|
||||
}
|
||||
|
||||
if result.IssuedAt == nil {
|
||||
t.Error("Expected issued at to be extracted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTExpiredToken(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(), // Expired 1 hour ago
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for expired token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for expired token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "expired") {
|
||||
t.Errorf("Expected 'expired' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTFutureIssuedAt(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(10 * time.Minute).Unix(), // Issued 10 minutes in future
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for future iat")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for future iat")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "future") {
|
||||
t.Errorf("Expected 'future' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTNotBeforeClaim(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"exp": time.Now().Add(2 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"nbf": time.Now().Add(1 * time.Hour).Unix(), // Not valid for 1 hour
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for nbf in future")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for nbf in future")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "not yet valid") {
|
||||
t.Errorf("Expected 'not yet valid' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "eyJhbGciOiJIUzI1NiJ9"},
|
||||
{"two parts", "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiIxMjM0In0"},
|
||||
{"four parts", "part1.part2.part3.part4"},
|
||||
{"empty part", "eyJhbGciOiJIUzI1NiJ9..signature"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use requireJWT=true to ensure these are treated as invalid JWTs, not opaque tokens
|
||||
result := validator.ValidateToken(tt.token, true)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for malformed JWT")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malformed JWT")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTInvalidBase64URL(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Token with invalid base64url characters
|
||||
token := "invalid@chars.eyJzdWIiOiIxMjM0In0.signature"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for invalid base64url characters")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for invalid base64url characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateJWTInvalidJSON(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Valid base64 but invalid JSON
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte("not json"))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte("{invalid json"))
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("signature"))
|
||||
|
||||
token := header + "." + payload + "." + signature
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for invalid JSON in claims")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for invalid JSON in claims")
|
||||
}
|
||||
}
|
||||
|
||||
// Test Opaque Token Validation
|
||||
|
||||
func TestValidateOpaqueTokenValid(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Valid opaque token (>20 chars, good entropy)
|
||||
token := "sk_live_abcdef123456GHIJKL789"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if !result.Valid {
|
||||
t.Errorf("Expected valid result, got error: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.TokenType != "Opaque" {
|
||||
t.Errorf("Expected TokenType 'Opaque', got %s", result.TokenType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenTooShort(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "short"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for short token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for short token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "too short") {
|
||||
t.Errorf("Expected 'too short' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenWithSpaces(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "this token has spaces in it"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with spaces")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with spaces")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "spaces") {
|
||||
t.Errorf("Expected 'spaces' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenControlCharacters(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Token with control character (null byte)
|
||||
token := "token_with\x00control_char"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for token with control characters")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for token with control characters")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "control character") {
|
||||
t.Errorf("Expected 'control character' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateOpaqueTokenInsufficientEntropy(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Token with low entropy (only 3 unique characters)
|
||||
token := "aaaaaabbbbbbccccccdddd"
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if result.Valid {
|
||||
t.Error("Expected invalid result for low entropy token")
|
||||
}
|
||||
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for low entropy token")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.Error.Error(), "entropy") {
|
||||
t.Errorf("Expected 'entropy' in error, got: %v", result.Error)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Base64URL Validation
|
||||
|
||||
func TestIsValidBase64URL(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"valid uppercase", "ABCDEFGHIJKLMNOPQRSTUVWXYZ", true},
|
||||
{"valid lowercase", "abcdefghijklmnopqrstuvwxyz", true},
|
||||
{"valid numbers", "0123456789", true},
|
||||
{"valid dash", "abc-def", true},
|
||||
{"valid underscore", "abc_def", true},
|
||||
{"valid equals", "abc=", true},
|
||||
{"invalid at sign", "abc@def", false},
|
||||
{"invalid space", "abc def", false},
|
||||
{"invalid plus", "abc+def", false},
|
||||
{"invalid slash", "abc/def", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.isValidBase64URL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v for %s, got %v", tt.expected, tt.input, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Time Extraction
|
||||
|
||||
func TestExtractTime(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claim interface{}
|
||||
expected bool
|
||||
}{
|
||||
{"float64", float64(1609459200), true},
|
||||
{"int64", int64(1609459200), true},
|
||||
{"int", int(1609459200), true},
|
||||
{"string", "not a timestamp", false},
|
||||
{"nil", nil, false},
|
||||
{"map", map[string]interface{}{}, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.extractTime(tt.claim)
|
||||
|
||||
if tt.expected && result == nil {
|
||||
t.Error("Expected non-nil time")
|
||||
}
|
||||
|
||||
if !tt.expected && result != nil {
|
||||
t.Error("Expected nil time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTimeCorrectValue(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// Unix timestamp for 2021-01-01 00:00:00 UTC
|
||||
timestamp := int64(1609459200)
|
||||
result := validator.extractTime(timestamp)
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("Expected non-nil time")
|
||||
}
|
||||
|
||||
expected := time.Unix(timestamp, 0)
|
||||
if !result.Equal(expected) {
|
||||
t.Errorf("Expected time %v, got %v", expected, *result)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Token Size Validation
|
||||
|
||||
func TestValidateTokenSize(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
maxSize int
|
||||
expectError bool
|
||||
}{
|
||||
{"within limit", "short_token", 20, false},
|
||||
{"at limit", "exactly_twenty_c", 16, false},
|
||||
{"exceeds limit", "this_token_is_too_long", 10, true},
|
||||
{"empty token", "", 10, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateTokenSize(tt.token, tt.maxSize)
|
||||
|
||||
if tt.expectError && err == nil {
|
||||
t.Error("Expected error for oversized token")
|
||||
}
|
||||
|
||||
if !tt.expectError && err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if err != nil && !strings.Contains(err.Error(), "exceeds") {
|
||||
t.Errorf("Expected 'exceeds' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test Claims Extraction
|
||||
|
||||
func TestExtractClaims(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
"exp": float64(1609459200),
|
||||
}
|
||||
|
||||
token := createTestJWTSimple(claims)
|
||||
extracted, err := validator.ExtractClaims(token)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
if extracted == nil {
|
||||
t.Fatal("Expected non-nil claims")
|
||||
}
|
||||
|
||||
if extracted["sub"] != "user123" {
|
||||
t.Errorf("Expected sub 'user123', got %v", extracted["sub"])
|
||||
}
|
||||
|
||||
if extracted["email"] != "user@example.com" {
|
||||
t.Errorf("Expected email 'user@example.com', got %v", extracted["email"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsInvalidFormat(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"single part", "onlyonepart"},
|
||||
{"two parts", "two.parts"},
|
||||
{"four parts", "one.two.three.four"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := validator.ExtractClaims(tt.token)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid format")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "invalid JWT format") {
|
||||
t.Errorf("Expected 'invalid JWT format' in error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsInvalidBase64(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token := "header.invalid@base64.signature"
|
||||
_, err := validator.ExtractClaims(token)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid base64")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "decode") {
|
||||
t.Errorf("Expected 'decode' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractClaimsInvalidJSON(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte("header"))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte("{not valid json"))
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("signature"))
|
||||
|
||||
token := header + "." + payload + "." + signature
|
||||
_, err := validator.ExtractClaims(token)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for invalid JSON")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "parse") {
|
||||
t.Errorf("Expected 'parse' in error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Token Comparison (Security - Timing Attack Resistance)
|
||||
|
||||
func TestCompareTokensEqual(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_12345"
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferent(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "secret_token_12345"
|
||||
token2 := "secret_token_54321"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensDifferentLength(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := "short"
|
||||
token2 := "much_longer_token"
|
||||
|
||||
if validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected tokens to be different (different lengths)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensEmpty(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
token1 := ""
|
||||
token2 := ""
|
||||
|
||||
if !validator.CompareTokens(token1, token2) {
|
||||
t.Error("Expected empty tokens to be equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareTokensConstantTime(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
// This test verifies the comparison is constant-time
|
||||
// by checking that different tokens take similar time
|
||||
token1 := strings.Repeat("a", 1000)
|
||||
token2First := "b" + strings.Repeat("a", 999)
|
||||
token2Last := strings.Repeat("a", 999) + "b"
|
||||
|
||||
// Both comparisons should take similar time regardless of where difference occurs
|
||||
startFirst := time.Now()
|
||||
validator.CompareTokens(token1, token2First)
|
||||
durationFirst := time.Since(startFirst)
|
||||
|
||||
startLast := time.Now()
|
||||
validator.CompareTokens(token1, token2Last)
|
||||
durationLast := time.Since(startLast)
|
||||
|
||||
// Allow 10x variance (generous, but timing can vary)
|
||||
ratio := float64(durationFirst) / float64(durationLast)
|
||||
if ratio < 0.1 || ratio > 10.0 {
|
||||
t.Logf("Warning: timing variance detected (ratio: %.2f). First: %v, Last: %v",
|
||||
ratio, durationFirst, durationLast)
|
||||
// Not failing test as timing can be affected by many factors
|
||||
}
|
||||
}
|
||||
|
||||
// Security Tests
|
||||
|
||||
func TestValidateTokenMaliciousPayloads(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"sql injection attempt", "'; DROP TABLE users; --"},
|
||||
{"xss attempt", "<script>alert('xss')</script>"},
|
||||
{"path traversal", "../../../etc/passwd"},
|
||||
{"null bytes", "token\x00with\x00nulls"},
|
||||
{"unicode exploit", "token\u0000\u0001\u0002"},
|
||||
{"extremely long", strings.Repeat("a", 100000)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token, false)
|
||||
|
||||
// Should either reject or handle safely
|
||||
if result.Valid {
|
||||
// If considered valid, should have parsed safely
|
||||
if result.Claims != nil {
|
||||
t.Logf("Token considered valid: %s", tt.name)
|
||||
}
|
||||
} else {
|
||||
// If invalid, should have error
|
||||
if result.Error == nil {
|
||||
t.Error("Expected error for malicious payload")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenBoundaryConditions(t *testing.T) {
|
||||
validator := NewTokenValidator(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "expiry at exact current time",
|
||||
claims: map[string]interface{}{
|
||||
"exp": time.Now().Unix(),
|
||||
},
|
||||
wantErr: true, // Should be expired (not <=, but <)
|
||||
},
|
||||
{
|
||||
name: "iat 5 minutes in future (boundary)",
|
||||
claims: map[string]interface{}{
|
||||
"iat": time.Now().Add(5 * time.Minute).Unix(),
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
wantErr: false, // Allowed within 5-minute tolerance
|
||||
},
|
||||
{
|
||||
name: "iat 6 minutes in future",
|
||||
claims: map[string]interface{}{
|
||||
"iat": time.Now().Add(6 * time.Minute).Unix(),
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "nbf at exact current time",
|
||||
claims: map[string]interface{}{
|
||||
"nbf": time.Now().Unix(),
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
},
|
||||
wantErr: false, // Should be valid at exact time
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token := createTestJWTSimple(tt.claims)
|
||||
result := validator.ValidateToken(token, false)
|
||||
|
||||
if tt.wantErr && result.Valid {
|
||||
t.Error("Expected invalid result at boundary condition")
|
||||
}
|
||||
|
||||
if !tt.wantErr && !result.Valid {
|
||||
t.Errorf("Expected valid result at boundary condition, got error: %v", result.Error)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper Functions
|
||||
|
||||
func createTestJWTSimple(claims map[string]interface{}) string {
|
||||
// Create a minimal JWT for testing (not cryptographically signed)
|
||||
header := map[string]interface{}{
|
||||
"alg": "HS256",
|
||||
"typ": "JWT",
|
||||
}
|
||||
|
||||
headerJSON, _ := json.Marshal(header)
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
|
||||
headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON)
|
||||
claimsB64 := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
||||
signature := base64.RawURLEncoding.EncodeToString([]byte("fake_signature"))
|
||||
|
||||
return headerB64 + "." + claimsB64 + "." + signature
|
||||
}
|
||||
@@ -49,12 +49,14 @@ type TokenExchanger interface {
|
||||
// This data is typically retrieved from the provider's .well-known/openid-configuration endpoint
|
||||
// and contains essential URLs for authentication, token exchange, and key retrieval.
|
||||
type ProviderMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
IntrospectionURL string `json:"introspection_endpoint,omitempty"` // OAuth 2.0 Token Introspection (RFC 7662)
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"` // Supported scopes from discovery
|
||||
}
|
||||
|
||||
// TraefikOidc is the main middleware struct that implements OIDC authentication for Traefik.
|
||||
@@ -71,6 +73,7 @@ type TraefikOidc struct {
|
||||
initComplete chan struct{}
|
||||
limiter *rate.Limiter
|
||||
tokenBlacklist CacheInterface
|
||||
tokenTypeCache CacheInterface // Cache for token type detection results
|
||||
headerTemplates map[string]*template.Template
|
||||
sessionManager *SessionManager
|
||||
tokenCleanupStopChan chan struct{}
|
||||
@@ -92,9 +95,11 @@ type TraefikOidc struct {
|
||||
goroutineWG *sync.WaitGroup
|
||||
clientSecret string
|
||||
clientID string
|
||||
audience string // Expected JWT audience, defaults to clientID
|
||||
name string
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
metadataMu sync.RWMutex // Protects metadata endpoint fields
|
||||
tokenURL string
|
||||
authURL string
|
||||
endSessionURL string
|
||||
@@ -103,16 +108,24 @@ type TraefikOidc struct {
|
||||
jwksURL string
|
||||
issuerURL string
|
||||
revocationURL string
|
||||
introspectionURL string // OAuth 2.0 Token Introspection endpoint (RFC 7662)
|
||||
providerURL string
|
||||
scopes []string
|
||||
refreshGracePeriod time.Duration
|
||||
introspectionCache CacheInterface // Cache for token introspection results
|
||||
shutdownOnce sync.Once
|
||||
firstRequestMutex sync.Mutex
|
||||
forceHTTPS bool
|
||||
enablePKCE bool
|
||||
overrideScopes bool
|
||||
strictAudienceValidation bool // Prevents Scenario 2 fallback to ID token
|
||||
allowOpaqueTokens bool // Enables opaque token support via introspection
|
||||
requireTokenIntrospection bool // Forces introspection for opaque tokens
|
||||
disableReplayDetection bool // Disables JTI-based replay detection for multi-replica deployments
|
||||
suppressDiagnosticLogs bool
|
||||
firstRequestReceived bool
|
||||
metadataRefreshStarted bool
|
||||
securityHeadersApplier func(http.ResponseWriter, *http.Request)
|
||||
scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
}
|
||||
|
||||
+1
-1
@@ -452,7 +452,7 @@ func (c *UniversalCache) removeItem(key string, item *CacheItem) {
|
||||
// evictOldest evicts the oldest item from the cache (must be called with lock held)
|
||||
func (c *UniversalCache) evictOldest() {
|
||||
if elem := c.lruList.Back(); elem != nil {
|
||||
key := elem.Value.(string)
|
||||
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
|
||||
@@ -7,13 +7,15 @@ import (
|
||||
|
||||
// UniversalCacheManager manages all cache instances using the universal cache
|
||||
type UniversalCacheManager struct {
|
||||
tokenCache *UniversalCache
|
||||
blacklistCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
mu sync.RWMutex
|
||||
logger *Logger
|
||||
tokenCache *UniversalCache
|
||||
blacklistCache *UniversalCache
|
||||
metadataCache *UniversalCache
|
||||
jwkCache *UniversalCache
|
||||
sessionCache *UniversalCache
|
||||
introspectionCache *UniversalCache // OAuth 2.0 Token Introspection cache (RFC 7662)
|
||||
tokenTypeCache *UniversalCache // Cache for token type detection results
|
||||
mu sync.RWMutex
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -85,6 +87,22 @@ func GetUniversalCacheManager(logger *Logger) *UniversalCacheManager {
|
||||
DefaultTTL: 30 * time.Minute,
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize introspection cache for OAuth 2.0 Token Introspection (RFC 7662)
|
||||
universalCacheManager.introspectionCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for introspection results
|
||||
MaxSize: 1000, // Cache up to 1000 introspection results
|
||||
DefaultTTL: 5 * time.Minute, // Short TTL for security (introspect frequently)
|
||||
Logger: logger,
|
||||
})
|
||||
|
||||
// Initialize token type cache for performance optimization
|
||||
universalCacheManager.tokenTypeCache = NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken, // Use token cache type for token type detection
|
||||
MaxSize: 2000, // Cache up to 2000 token type detections
|
||||
DefaultTTL: 5 * time.Minute, // 5 minute TTL for token type detection
|
||||
Logger: logger,
|
||||
})
|
||||
})
|
||||
|
||||
return universalCacheManager
|
||||
@@ -125,16 +143,30 @@ func (m *UniversalCacheManager) GetSessionCache() *UniversalCache {
|
||||
return m.sessionCache
|
||||
}
|
||||
|
||||
// GetIntrospectionCache returns the token introspection cache
|
||||
func (m *UniversalCacheManager) GetIntrospectionCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.introspectionCache
|
||||
}
|
||||
|
||||
// GetTokenTypeCache returns the token type detection cache
|
||||
func (m *UniversalCacheManager) GetTokenTypeCache() *UniversalCache {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.tokenTypeCache
|
||||
}
|
||||
|
||||
// Close shuts down all caches
|
||||
func (m *UniversalCacheManager) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, cache := range []*UniversalCache{
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache,
|
||||
m.tokenCache, m.blacklistCache, m.metadataCache, m.jwkCache, m.sessionCache, m.introspectionCache, m.tokenTypeCache,
|
||||
} {
|
||||
if cache != nil {
|
||||
cache.Close()
|
||||
_ = cache.Close() // Safe to ignore: best effort cache cleanup
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,7 +178,7 @@ func (m *UniversalCacheManager) Close() error {
|
||||
// This should only be called in test code to ensure proper cleanup between tests
|
||||
func ResetUniversalCacheManagerForTesting() {
|
||||
if universalCacheManager != nil {
|
||||
universalCacheManager.Close()
|
||||
_ = universalCacheManager.Close() // Safe to ignore: test cleanup best effort
|
||||
}
|
||||
universalCacheManagerOnce = sync.Once{}
|
||||
universalCacheManager = nil
|
||||
|
||||
+72
-4
@@ -37,19 +37,36 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
// =============================================================================
|
||||
|
||||
// determineScheme determines the URL scheme for building redirect URLs.
|
||||
// It checks X-Forwarded-Proto header first, then TLS presence.
|
||||
// Priority order (highest to lowest):
|
||||
// 1. forceHTTPS configuration - explicit security requirement
|
||||
// 2. X-Forwarded-Proto header - proxy/load balancer information
|
||||
// 3. TLS connection state - direct HTTPS connection
|
||||
// 4. Default to http
|
||||
//
|
||||
// Parameters:
|
||||
// - req: The HTTP request to analyze.
|
||||
//
|
||||
// Returns:
|
||||
// - The determined scheme: "https" or "http".
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
// Honor forceHTTPS configuration as highest priority
|
||||
// This ensures redirect URIs use HTTPS even when behind proxies/load balancers
|
||||
// that may overwrite X-Forwarded-Proto header (e.g., AWS ALB terminating TLS)
|
||||
if t.forceHTTPS {
|
||||
return "https"
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Proto header for proxy scenarios
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
|
||||
// Check if connection has TLS
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
|
||||
// Default to http
|
||||
return "http"
|
||||
}
|
||||
|
||||
@@ -90,6 +107,15 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
// Add audience parameter for custom API audiences (e.g., Auth0 APIs)
|
||||
// This allows access tokens to have the correct audience claim
|
||||
// Only add if audience is configured and different from client_id
|
||||
// ID tokens will always have aud=client_id per OIDC spec
|
||||
if t.audience != "" && t.audience != t.clientID {
|
||||
params.Set("audience", t.audience)
|
||||
t.logger.Debugf("Adding audience parameter to authorize URL: %s", t.audience)
|
||||
}
|
||||
|
||||
if t.enablePKCE && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
@@ -98,7 +124,28 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
scopes := make([]string, len(t.scopes))
|
||||
copy(scopes, t.scopes)
|
||||
|
||||
// Apply discovery-based scope filtering if available
|
||||
// Read scopesSupported with RLock
|
||||
t.metadataMu.RLock()
|
||||
scopesSupported := t.scopesSupported
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if t.scopeFilter != nil && len(scopesSupported) > 0 {
|
||||
scopes = t.scopeFilter.FilterSupportedScopes(scopes, scopesSupported, t.providerURL)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: After discovery filtering: %v", scopes)
|
||||
}
|
||||
|
||||
// Then apply provider-specific modifications
|
||||
if t.isGoogleProvider() {
|
||||
// Google: Remove offline_access if present, add access_type=offline
|
||||
filteredScopes := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
scopes = filteredScopes
|
||||
|
||||
params.Set("access_type", "offline")
|
||||
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
@@ -143,13 +190,29 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
}
|
||||
}
|
||||
|
||||
// Final filtering pass to remove anything the provider doesn't support
|
||||
// Read scopesSupported with RLock
|
||||
t.metadataMu.RLock()
|
||||
scopesSupported = t.scopesSupported
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
if t.scopeFilter != nil && len(scopesSupported) > 0 {
|
||||
scopes = t.scopeFilter.FilterSupportedScopes(scopes, scopesSupported, t.providerURL)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: After final filtering: %v", scopes)
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
finalScopeString := strings.Join(scopes, " ")
|
||||
params.Set("scope", finalScopeString)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
return t.buildURLWithParams(t.authURL, params)
|
||||
// Read authURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
authURL := t.authURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
return t.buildURLWithParams(authURL, params)
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
@@ -172,9 +235,14 @@ func (t *TraefikOidc) buildURLWithParams(baseURL string, params url.Values) stri
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
issuerURLParsed, err := url.Parse(t.issuerURL)
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
|
||||
issuerURLParsed, err := url.Parse(issuerURL)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Could not parse issuerURL: %s. Error: %v", t.issuerURL, err)
|
||||
t.logger.Errorf("Could not parse issuerURL: %s. Error: %v", issuerURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,555 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Test TLS connection state for testing HTTPS detection
|
||||
var testTLSState = tls.ConnectionState{
|
||||
Version: tls.VersionTLS13,
|
||||
HandshakeComplete: true,
|
||||
ServerName: "example.com",
|
||||
}
|
||||
|
||||
// createMinimalMiddleware creates a minimal TraefikOidc instance for testing URL helpers
|
||||
func createMinimalMiddleware() *TraefikOidc {
|
||||
logger := newNoOpLogger()
|
||||
return &TraefikOidc{
|
||||
logger: logger,
|
||||
issuerURL: "https://provider.example.com",
|
||||
clientID: "test-client",
|
||||
clientSecret: "test-secret",
|
||||
authURL: "https://provider.example.com/authorize",
|
||||
tokenURL: "https://provider.example.com/token",
|
||||
excludedURLs: make(map[string]struct{}),
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
enablePKCE: false,
|
||||
}
|
||||
}
|
||||
|
||||
// TestDetermineScheme tests scheme determination edge cases
|
||||
func TestDetermineScheme(t *testing.T) {
|
||||
t.Run("forceHTTPS=false: backward compatibility", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = false
|
||||
|
||||
t.Run("defaults to http when no headers or TLS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "http", scheme)
|
||||
})
|
||||
|
||||
t.Run("uses X-Forwarded-Proto when present", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme)
|
||||
})
|
||||
|
||||
t.Run("X-Forwarded-Proto takes precedence over TLS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
|
||||
req.TLS = &testTLSState
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "http", scheme)
|
||||
})
|
||||
|
||||
t.Run("uses TLS when present and no X-Forwarded-Proto", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
|
||||
req.TLS = &testTLSState
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("forceHTTPS=true: overrides all detection", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = true
|
||||
|
||||
t.Run("returns https with no headers or TLS", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme, "forceHTTPS should override default http")
|
||||
})
|
||||
|
||||
t.Run("returns https even with X-Forwarded-Proto: http", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme, "forceHTTPS should override X-Forwarded-Proto")
|
||||
})
|
||||
|
||||
t.Run("returns https with X-Forwarded-Proto: https", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme)
|
||||
})
|
||||
|
||||
t.Run("returns https with TLS connection", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "https://example.com/auth", nil)
|
||||
req.TLS = &testTLSState
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme)
|
||||
})
|
||||
|
||||
t.Run("returns https even when all indicators suggest http", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
req.TLS = nil
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme, "forceHTTPS should be absolute override")
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("AWS ALB scenario: TLS termination at load balancer", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = true
|
||||
|
||||
t.Run("simulates ALB overwriting X-Forwarded-Proto to http", func(t *testing.T) {
|
||||
// This simulates the issue from GitHub #82:
|
||||
// - Client connects via HTTPS to ALB
|
||||
// - ALB terminates TLS and forwards HTTP to Traefik
|
||||
// - Traefik overwrites X-Forwarded-Proto based on its view (HTTP)
|
||||
// - Plugin receives X-Forwarded-Proto: http (incorrect)
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http") // Overwritten by Traefik
|
||||
req.TLS = nil // No TLS at plugin level
|
||||
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS redirect_uri despite incorrect header")
|
||||
})
|
||||
|
||||
t.Run("simulates missing X-Forwarded-Proto header", func(t *testing.T) {
|
||||
// Some configurations may not set the header at all
|
||||
req := httptest.NewRequest("GET", "http://example.com/auth", nil)
|
||||
req.TLS = nil
|
||||
|
||||
scheme := middleware.determineScheme(req)
|
||||
assert.Equal(t, "https", scheme, "forceHTTPS should ensure HTTPS even without headers")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildURLWithParamsErrorPaths tests error handling in buildURLWithParams
|
||||
func TestBuildURLWithParamsErrorPaths(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
|
||||
t.Run("invalid issuer URL returns empty string", func(t *testing.T) {
|
||||
middleware.issuerURL = "://invalid"
|
||||
params := url.Values{}
|
||||
params.Set("test", "value")
|
||||
result := middleware.buildURLWithParams("/path", params)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("invalid relative URL returns empty string", func(t *testing.T) {
|
||||
middleware.issuerURL = "https://provider.example.com"
|
||||
params := url.Values{}
|
||||
result := middleware.buildURLWithParams("://invalid-relative", params)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("invalid absolute URL returns empty string", func(t *testing.T) {
|
||||
params := url.Values{}
|
||||
result := middleware.buildURLWithParams("http://[invalid-url", params)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("dangerous host in absolute URL returns empty string", func(t *testing.T) {
|
||||
params := url.Values{}
|
||||
result := middleware.buildURLWithParams("https://localhost/callback", params)
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("successful relative URL resolution", func(t *testing.T) {
|
||||
middleware.issuerURL = "https://provider.example.com"
|
||||
params := url.Values{}
|
||||
params.Set("key", "value")
|
||||
result := middleware.buildURLWithParams("/oauth/authorize", params)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "https://provider.example.com/oauth/authorize")
|
||||
assert.Contains(t, result, "key=value")
|
||||
})
|
||||
|
||||
t.Run("successful absolute URL", func(t *testing.T) {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", "test")
|
||||
result := middleware.buildURLWithParams("https://api.example.com/endpoint", params)
|
||||
assert.NotEmpty(t, result)
|
||||
assert.Contains(t, result, "https://api.example.com/endpoint")
|
||||
assert.Contains(t, result, "client_id=test")
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateParsedURLCases tests URL validation edge cases
|
||||
func TestValidateParsedURLCases(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
|
||||
t.Run("disallowed schemes rejected", func(t *testing.T) {
|
||||
invalidSchemes := []string{
|
||||
"ftp://example.com",
|
||||
"file:///etc/passwd",
|
||||
"javascript:alert(1)",
|
||||
"data:text/html,test",
|
||||
}
|
||||
|
||||
for _, urlStr := range invalidSchemes {
|
||||
u, _ := url.Parse(urlStr)
|
||||
err := middleware.validateParsedURL(u)
|
||||
assert.Error(t, err, "should reject scheme: %s", urlStr)
|
||||
assert.Contains(t, err.Error(), "disallowed URL scheme")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("http scheme allowed with warning", func(t *testing.T) {
|
||||
u, _ := url.Parse("http://example.com/path")
|
||||
err := middleware.validateParsedURL(u)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("missing host rejected", func(t *testing.T) {
|
||||
u := &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "",
|
||||
Path: "/path",
|
||||
}
|
||||
err := middleware.validateParsedURL(u)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "missing host")
|
||||
})
|
||||
|
||||
t.Run("path traversal rejected", func(t *testing.T) {
|
||||
u, _ := url.Parse("https://example.com/../../etc/passwd")
|
||||
err := middleware.validateParsedURL(u)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "path traversal")
|
||||
})
|
||||
|
||||
t.Run("valid URLs accepted", func(t *testing.T) {
|
||||
validURLs := []string{
|
||||
"https://example.com",
|
||||
"https://example.com/path",
|
||||
"https://sub.example.com:8080/path?query=value",
|
||||
}
|
||||
|
||||
for _, urlStr := range validURLs {
|
||||
u, _ := url.Parse(urlStr)
|
||||
err := middleware.validateParsedURL(u)
|
||||
assert.NoError(t, err, "should accept: %s", urlStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateHostComprehensive tests comprehensive host validation
|
||||
func TestValidateHostComprehensive(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
|
||||
t.Run("loopback IPs rejected", func(t *testing.T) {
|
||||
loopbacks := []string{
|
||||
"127.0.0.1",
|
||||
"127.255.255.255",
|
||||
"::1",
|
||||
}
|
||||
|
||||
for _, ip := range loopbacks {
|
||||
err := middleware.validateHost(ip)
|
||||
assert.Error(t, err, "should reject loopback: %s", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("private IPs rejected", func(t *testing.T) {
|
||||
privateIPs := []string{
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"192.168.1.1",
|
||||
"fd00::1",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := middleware.validateHost(ip)
|
||||
assert.Error(t, err, "should reject private IP: %s", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("link-local IPs rejected", func(t *testing.T) {
|
||||
linkLocal := []string{
|
||||
"169.254.1.1",
|
||||
"fe80::1",
|
||||
}
|
||||
|
||||
for _, ip := range linkLocal {
|
||||
err := middleware.validateHost(ip)
|
||||
assert.Error(t, err, "should reject link-local: %s", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified and multicast rejected", func(t *testing.T) {
|
||||
special := []string{
|
||||
"0.0.0.0",
|
||||
"::",
|
||||
"224.0.0.1",
|
||||
"ff02::1",
|
||||
}
|
||||
|
||||
for _, ip := range special {
|
||||
err := middleware.validateHost(ip)
|
||||
assert.Error(t, err, "should reject special IP: %s", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("dangerous hostnames rejected", func(t *testing.T) {
|
||||
dangerous := []string{
|
||||
"localhost",
|
||||
"LOCALHOST",
|
||||
"169.254.169.254",
|
||||
"metadata.google.internal",
|
||||
}
|
||||
|
||||
for _, host := range dangerous {
|
||||
err := middleware.validateHost(host)
|
||||
assert.Error(t, err, "should reject: %s", host)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid host format rejected", func(t *testing.T) {
|
||||
invalid := []string{
|
||||
"[::1:invalid",
|
||||
}
|
||||
|
||||
for _, host := range invalid {
|
||||
err := middleware.validateHost(host)
|
||||
assert.Error(t, err, "should reject invalid format: %s", host)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hosts with ports", func(t *testing.T) {
|
||||
err := middleware.validateHost("localhost:8080")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = middleware.validateHost("192.168.1.1:443")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = middleware.validateHost("example.com:443")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("valid public IPs accepted", func(t *testing.T) {
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"93.184.216.34",
|
||||
}
|
||||
|
||||
for _, ip := range publicIPs {
|
||||
err := middleware.validateHost(ip)
|
||||
assert.NoError(t, err, "should accept public IP: %s", ip)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid hostnames accepted", func(t *testing.T) {
|
||||
validHosts := []string{
|
||||
"example.com",
|
||||
"sub.example.com",
|
||||
"api.service.example.com:443",
|
||||
}
|
||||
|
||||
for _, host := range validHosts {
|
||||
err := middleware.validateHost(host)
|
||||
assert.NoError(t, err, "should accept: %s", host)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestValidateURLEdgeCasesComprehensive tests the validateURL wrapper
|
||||
func TestValidateURLEdgeCasesComprehensive(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
|
||||
t.Run("empty URL rejected", func(t *testing.T) {
|
||||
err := middleware.validateURL("")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "empty URL")
|
||||
})
|
||||
|
||||
t.Run("invalid URL format rejected", func(t *testing.T) {
|
||||
err := middleware.validateURL("ht tp://invalid url")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid URL format")
|
||||
})
|
||||
|
||||
t.Run("valid URLs accepted", func(t *testing.T) {
|
||||
validURLs := []string{
|
||||
"https://example.com/path",
|
||||
"https://example.com/path?key=value",
|
||||
}
|
||||
|
||||
for _, urlStr := range validURLs {
|
||||
err := middleware.validateURL(urlStr)
|
||||
assert.NoError(t, err, "should accept: %s", urlStr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("URL with dangerous host rejected", func(t *testing.T) {
|
||||
err := middleware.validateURL("https://localhost/path")
|
||||
assert.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid host")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildAuthURLAudienceParameter tests audience parameter handling
|
||||
func TestBuildAuthURLAudienceParameter(t *testing.T) {
|
||||
t.Run("audience added when different from client_id", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.audience = "https://api.example.com"
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "audience=")
|
||||
})
|
||||
|
||||
t.Run("audience not added when empty", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.audience = ""
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"",
|
||||
)
|
||||
|
||||
assert.NotContains(t, authURL, "audience=")
|
||||
})
|
||||
|
||||
t.Run("audience not added when equal to client_id", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.audience = middleware.clientID
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"",
|
||||
)
|
||||
|
||||
assert.NotContains(t, authURL, "audience=")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildAuthURLPKCEParameters tests PKCE parameter handling
|
||||
func TestBuildAuthURLPKCEParameters(t *testing.T) {
|
||||
t.Run("PKCE parameters added when enabled with challenge", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.enablePKCE = true
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"challenge789",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "code_challenge=challenge789")
|
||||
assert.Contains(t, authURL, "code_challenge_method=S256")
|
||||
})
|
||||
|
||||
t.Run("PKCE parameters not added when challenge empty", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.enablePKCE = true
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"", // Empty challenge
|
||||
)
|
||||
|
||||
assert.NotContains(t, authURL, "code_challenge=")
|
||||
})
|
||||
|
||||
t.Run("PKCE parameters not added when disabled", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.enablePKCE = false
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback",
|
||||
"state123",
|
||||
"nonce456",
|
||||
"challenge789",
|
||||
)
|
||||
|
||||
assert.NotContains(t, authURL, "code_challenge=")
|
||||
})
|
||||
}
|
||||
|
||||
// TestForceHTTPSIntegration tests the complete flow of building redirect URIs with forceHTTPS
|
||||
func TestForceHTTPSIntegration(t *testing.T) {
|
||||
t.Run("redirect_uri uses https when forceHTTPS=true", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = true
|
||||
|
||||
// Simulate AWS ALB scenario: HTTP request with incorrect X-Forwarded-Proto
|
||||
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http") // Traefik overwrote it
|
||||
req.Host = "service.example.com"
|
||||
req.TLS = nil
|
||||
|
||||
// Build the full redirect URL as middleware does
|
||||
scheme := middleware.determineScheme(req)
|
||||
host := middleware.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
|
||||
|
||||
assert.Equal(t, "https", scheme, "scheme should be https due to forceHTTPS")
|
||||
assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL,
|
||||
"redirect_uri should use https scheme")
|
||||
})
|
||||
|
||||
t.Run("buildAuthURL contains https redirect_uri with forceHTTPS", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = true
|
||||
|
||||
// Simulate building auth URL with HTTP redirect_uri
|
||||
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
req.Host = "service.example.com"
|
||||
req.TLS = nil
|
||||
|
||||
scheme := middleware.determineScheme(req)
|
||||
host := middleware.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
|
||||
|
||||
authURL := middleware.buildAuthURL(redirectURL, "state123", "nonce456", "")
|
||||
|
||||
assert.Contains(t, authURL, "redirect_uri=https%3A%2F%2Fservice.example.com%2Foauth2%2Fcallback",
|
||||
"auth URL should contain HTTPS redirect_uri")
|
||||
assert.NotContains(t, authURL, "redirect_uri=http%3A",
|
||||
"auth URL should not contain HTTP redirect_uri")
|
||||
})
|
||||
|
||||
t.Run("without forceHTTPS respects X-Forwarded-Proto", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.forceHTTPS = false
|
||||
|
||||
req := httptest.NewRequest("GET", "http://service.example.com/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Host = "service.example.com"
|
||||
|
||||
scheme := middleware.determineScheme(req)
|
||||
host := middleware.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, "/oauth2/callback")
|
||||
|
||||
assert.Equal(t, "https://service.example.com/oauth2/callback", redirectURL,
|
||||
"should use https from X-Forwarded-Proto when forceHTTPS is false")
|
||||
})
|
||||
}
|
||||
+5
-5
@@ -133,11 +133,11 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
json.NewEncoder(rw).Encode(map[string]interface{}{
|
||||
_ = json.NewEncoder(rw).Encode(map[string]interface{}{
|
||||
"error": http.StatusText(code),
|
||||
"error_description": message,
|
||||
"status_code": code,
|
||||
})
|
||||
}) // Safe to ignore: error response write
|
||||
return
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Reques
|
||||
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.WriteHeader(code)
|
||||
_, _ = rw.Write([]byte(htmlBody))
|
||||
_, _ = rw.Write([]byte(htmlBody)) // Safe to ignore: error response write
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
@@ -190,8 +190,8 @@ func (t *TraefikOidc) Close() error {
|
||||
rm := GetResourceManager()
|
||||
|
||||
// Stop singleton tasks related to this instance
|
||||
rm.StopBackgroundTask("singleton-token-cleanup")
|
||||
rm.StopBackgroundTask("singleton-metadata-refresh")
|
||||
_ = rm.StopBackgroundTask("singleton-token-cleanup") // Safe to ignore: best effort cleanup
|
||||
_ = rm.StopBackgroundTask("singleton-metadata-refresh") // Safe to ignore: best effort cleanup
|
||||
|
||||
// Remove reference for this instance
|
||||
rm.RemoveReference(t.name)
|
||||
|
||||
+1
-1
@@ -195,7 +195,7 @@ func (r *Reservation) CancelAt(t time.Time) {
|
||||
// update state
|
||||
r.lim.last = t
|
||||
r.lim.tokens = tokens
|
||||
if r.timeToAct == r.lim.lastEvent {
|
||||
if r.timeToAct.Equal(r.lim.lastEvent) {
|
||||
prevEvent := r.timeToAct.Add(r.limit.durationFromTokens(float64(-r.tokens)))
|
||||
if !prevEvent.Before(t) {
|
||||
r.lim.lastEvent = prevEvent
|
||||
|
||||
Vendored
+1
-1
@@ -18,7 +18,7 @@ github.com/pmezard/go-difflib/difflib
|
||||
github.com/stretchr/testify/assert
|
||||
github.com/stretchr/testify/assert/yaml
|
||||
github.com/stretchr/testify/require
|
||||
# golang.org/x/time v0.13.0
|
||||
# golang.org/x/time v0.14.0
|
||||
## explicit; go 1.24.0
|
||||
golang.org/x/time/rate
|
||||
# gopkg.in/yaml.v3 v3.0.1
|
||||
|
||||
Reference in New Issue
Block a user