mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6d893df12b | |||
| 6efb78b7a8 | |||
| d0b920c4f0 | |||
| c474bbafd6 | |||
| 9126c74723 | |||
| a750c4f5b9 | |||
| 56051779ee | |||
| 3f126d50f3 | |||
| 91f0fc9ab8 | |||
| 66b9ed0861 |
@@ -1,622 +0,0 @@
|
||||
name: PR Validation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
checks: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
# Fast feedback - format and basic checks
|
||||
quick-checks:
|
||||
name: Quick Checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
# Exclude vendor directory from format checks
|
||||
UNFORMATTED=$(gofmt -s -l . | grep -v "^vendor/" || true)
|
||||
if [ -n "$UNFORMATTED" ]; then
|
||||
echo "Code is not formatted. Run: gofmt -s -w ."
|
||||
echo "Unformatted files:"
|
||||
echo "$UNFORMATTED"
|
||||
gofmt -s -d $(echo "$UNFORMATTED")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Go mod verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Go mod tidy check
|
||||
run: |
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum
|
||||
|
||||
# Static analysis with golangci-lint (advisory - will not fail the build)
|
||||
golangci-lint:
|
||||
name: golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=10m
|
||||
continue-on-error: true # Allow pipeline to continue even with linting warnings
|
||||
|
||||
# Staticcheck analysis
|
||||
staticcheck:
|
||||
name: Staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install staticcheck
|
||||
run: go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run staticcheck
|
||||
run: staticcheck ./...
|
||||
|
||||
# Security scanning with gosec
|
||||
gosec:
|
||||
name: Gosec Security Scanner
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run Gosec Security Scanner
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -no-fail -fmt sarif -out results.sarif ./... || echo "Gosec completed with warnings"
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: always() && hashFiles('results.sarif') != ''
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
continue-on-error: true
|
||||
|
||||
# Vulnerability scanning
|
||||
govulncheck:
|
||||
name: Vulnerability Scan
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Run govulncheck
|
||||
run: govulncheck ./...
|
||||
|
||||
# CodeQL analysis
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: go
|
||||
continue-on-error: true
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
continue-on-error: true
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
continue-on-error: true
|
||||
|
||||
# Unit tests with race detection
|
||||
test-race:
|
||||
name: Unit Tests (Race Detector)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with race detector
|
||||
run: go test -race -timeout=15m -count=1 -v ./...
|
||||
env:
|
||||
GOMAXPROCS: 4
|
||||
|
||||
# Coverage analysis with threshold check
|
||||
test-coverage:
|
||||
name: Test Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test -coverprofile=coverage.out -covermode=atomic -timeout=15m ./...
|
||||
go tool cover -func=coverage.out -o=coverage.txt
|
||||
|
||||
- name: Calculate coverage
|
||||
id: coverage
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//')
|
||||
echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT
|
||||
echo "Total Coverage: $COVERAGE%"
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
continue-on-error: true
|
||||
|
||||
- name: Comment coverage on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const coverage = '${{ steps.coverage.outputs.coverage }}';
|
||||
const threshold = 70;
|
||||
const coverageNum = parseFloat(coverage);
|
||||
const emoji = coverageNum >= threshold ? '✅' : '⚠️';
|
||||
const status = coverageNum >= threshold ? 'meets' : 'below';
|
||||
|
||||
const body = `## ${emoji} Test Coverage Report
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| **Total Coverage** | ${coverage}% |
|
||||
| **Threshold** | ${threshold}% |
|
||||
| **Status** | ${emoji} Coverage ${status} threshold |`;
|
||||
|
||||
// Find existing comment
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
|
||||
const botComment = comments.find(comment =>
|
||||
comment.user.type === 'Bot' &&
|
||||
comment.body.includes('Test Coverage Report')
|
||||
);
|
||||
|
||||
if (botComment) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: botComment.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
}
|
||||
|
||||
- name: Check coverage threshold
|
||||
run: |
|
||||
COVERAGE=${{ steps.coverage.outputs.coverage }}
|
||||
THRESHOLD=70
|
||||
echo "Coverage: $COVERAGE%"
|
||||
echo "Threshold: $THRESHOLD%"
|
||||
if (( $(echo "$COVERAGE < $THRESHOLD" | bc -l) )); then
|
||||
echo "⚠️ Coverage $COVERAGE% is below threshold $THRESHOLD%"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Coverage $COVERAGE% meets threshold $THRESHOLD%"
|
||||
|
||||
# Memory leak detection
|
||||
test-memory-leaks:
|
||||
name: Memory Leak Detection
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run goroutine leak tests
|
||||
run: |
|
||||
echo "Running goroutine leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Gg]oroutine.*[Ll]eak.*' ./... || echo "No goroutine leak tests found"
|
||||
|
||||
- name: Run memory leak tests
|
||||
run: |
|
||||
echo "Running memory leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Mm]emory.*[Ll]eak.*' ./... || echo "No memory leak tests found"
|
||||
|
||||
- name: Run cleanup tests
|
||||
run: |
|
||||
echo "Running cleanup and resource management tests..."
|
||||
go test -v -timeout=20m -run='.*[Cc]leanup.*' ./... || echo "No cleanup tests found"
|
||||
|
||||
# Integration tests
|
||||
test-integration:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
if [ -d "./integration" ]; then
|
||||
go test -v -timeout=20m ./integration/...
|
||||
else
|
||||
echo "Running integration tests from all packages..."
|
||||
go test -v -timeout=20m -run='.*[Ii]ntegration.*' ./...
|
||||
fi
|
||||
|
||||
# Regression tests
|
||||
test-regression:
|
||||
name: Regression Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
echo "Running regression tests..."
|
||||
go test -v -timeout=20m -run='.*[Rr]egression.*' ./...
|
||||
|
||||
# Provider-specific tests (parallel matrix)
|
||||
test-providers:
|
||||
name: Provider Tests (${{ matrix.provider }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
provider:
|
||||
- google
|
||||
- azure
|
||||
- auth0
|
||||
- okta
|
||||
- keycloak
|
||||
- cognito
|
||||
- gitlab
|
||||
- github
|
||||
- generic
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run ${{ matrix.provider }} provider tests
|
||||
run: |
|
||||
PROVIDER_CAP=$(echo "${{ matrix.provider }}" | sed 's/.*/\u&/')
|
||||
echo "Testing $PROVIDER_CAP provider..."
|
||||
go test -v -timeout=10m -run=".*$PROVIDER_CAP.*" ./internal/providers/... || true
|
||||
go test -v -timeout=10m -run=".*${{ matrix.provider }}.*" ./... || true
|
||||
|
||||
# Benchmark tests with performance tracking
|
||||
benchmark:
|
||||
name: Benchmark Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run benchmarks
|
||||
run: |
|
||||
echo "Running benchmark tests..."
|
||||
go test -bench=. -benchmem -benchtime=1s -run=^$ ./... | tee benchmark.txt
|
||||
|
||||
- name: Upload benchmark results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results
|
||||
path: benchmark.txt
|
||||
retention-days: 30
|
||||
|
||||
- name: Compare benchmarks
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
echo "Benchmark results available in artifacts"
|
||||
echo "To compare with main branch, download previous benchmark results"
|
||||
|
||||
# Build validation across platforms
|
||||
build:
|
||||
name: Build (${{ matrix.os }}/${{ matrix.arch }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [linux, darwin]
|
||||
arch: [amd64, arm64]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Build for ${{ matrix.os }}/${{ matrix.arch }}
|
||||
env:
|
||||
GOOS: ${{ matrix.os }}
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
run: |
|
||||
echo "Building for $GOOS/$GOARCH..."
|
||||
go build -v -ldflags="-s -w" ./...
|
||||
|
||||
# Security-specific edge case tests
|
||||
test-security-edge-cases:
|
||||
name: Security Edge Cases
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run security edge case tests
|
||||
run: |
|
||||
echo "Running security edge case tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ecurity.*' ./...
|
||||
|
||||
# Session management tests
|
||||
test-session:
|
||||
name: Session Management Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run session tests
|
||||
run: |
|
||||
echo "Running session management tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ession.*' ./...
|
||||
|
||||
# Token validation tests
|
||||
test-token:
|
||||
name: Token Validation Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run token validation tests
|
||||
run: |
|
||||
echo "Running token validation tests..."
|
||||
go test -v -timeout=15m -run='.*[Tt]oken.*' ./...
|
||||
|
||||
# CSRF and security tests
|
||||
test-csrf:
|
||||
name: CSRF and Security Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run CSRF tests
|
||||
run: |
|
||||
echo "Running CSRF and security tests..."
|
||||
go test -v -timeout=15m -run='.*[Cc][Ss][Rr][Ff].*' ./...
|
||||
|
||||
# Multi-Go version compatibility
|
||||
test-go-versions:
|
||||
name: Go ${{ matrix.go-version }} Compatibility
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ['1.24']
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Run tests on Go ${{ matrix.go-version }}
|
||||
run: go test -short -timeout=10m ./...
|
||||
|
||||
# Final validation - all checks must pass (golangci-lint is advisory)
|
||||
all-checks-passed:
|
||||
name: ✅ All Checks Passed
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- quick-checks
|
||||
- golangci-lint
|
||||
- staticcheck
|
||||
- gosec
|
||||
- govulncheck
|
||||
- codeql
|
||||
- test-race
|
||||
- test-coverage
|
||||
- test-memory-leaks
|
||||
- test-integration
|
||||
- test-regression
|
||||
- test-providers
|
||||
- benchmark
|
||||
- build
|
||||
- test-security-edge-cases
|
||||
- test-session
|
||||
- test-token
|
||||
- test-csrf
|
||||
- test-go-versions
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
echo "Checking status of all jobs..."
|
||||
|
||||
# Check critical jobs (excluding golangci-lint which is advisory)
|
||||
CRITICAL_FAILURES=false
|
||||
|
||||
if [ "${{ needs.quick-checks.result }}" == "failure" ] || \
|
||||
[ "${{ needs.staticcheck.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-race.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-coverage.result }}" == "failure" ] || \
|
||||
[ "${{ needs.build.result }}" == "failure" ]; then
|
||||
CRITICAL_FAILURES=true
|
||||
fi
|
||||
|
||||
if [ "$CRITICAL_FAILURES" == "true" ]; then
|
||||
echo "❌ Critical checks failed"
|
||||
exit 1
|
||||
elif [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
|
||||
echo "⚠️ Some checks were cancelled"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All critical checks passed successfully!"
|
||||
if [ "${{ needs.golangci-lint.result }}" != "success" ]; then
|
||||
echo "ℹ️ Note: golangci-lint reported issues (advisory only)"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Post summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "# PR Validation Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "## Job Status" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Quick Checks: ${{ needs.quick-checks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Linting (advisory): ${{ needs.golangci-lint.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Static Analysis: ${{ needs.staticcheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Scan (gosec): ${{ needs.gosec.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Vulnerability Scan: ${{ needs.govulncheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CodeQL: ${{ needs.codeql.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Race Detection: ${{ needs.test-race.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Coverage: ${{ needs.test-coverage.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Memory Leaks: ${{ needs.test-memory-leaks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Integration Tests: ${{ needs.test-integration.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Regression Tests: ${{ needs.test-regression.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Provider Tests: ${{ needs.test-providers.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Benchmarks: ${{ needs.benchmark.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Build: ${{ needs.build.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Edge Cases: ${{ needs.test-security-edge-cases.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Session Tests: ${{ needs.test-session.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Token Tests: ${{ needs.test-token.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CSRF Tests: ${{ needs.test-csrf.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Go Version Compatibility: ${{ needs.test-go-versions.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- "**"
|
||||
- "!main"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.go"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
packages: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
+2
-1
@@ -1,2 +1,3 @@
|
||||
docker/
|
||||
.claude/
|
||||
.claude/*.out
|
||||
*.test
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
version: 2
|
||||
|
||||
# Traefik plugins are source-only - no binary builds
|
||||
# Traefik loads plugins via Yaegi interpreter at runtime
|
||||
builds:
|
||||
- skip: true
|
||||
|
||||
# Create source archive for GitHub releases
|
||||
archives:
|
||||
- formats: [tar.gz]
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
|
||||
files:
|
||||
- "*.go"
|
||||
- "**/*.go"
|
||||
- go.mod
|
||||
- go.sum
|
||||
- .traefik.yml
|
||||
- LICENSE*
|
||||
- README*
|
||||
# Exclude test files and vendor from release archive
|
||||
- "!**/*_test.go"
|
||||
- "!vendor/**"
|
||||
- "!docker/**"
|
||||
- "!integration/**"
|
||||
- "!regression/**"
|
||||
- "!examples/**"
|
||||
- "!docs/**"
|
||||
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
|
||||
algorithm: sha256
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
- "^Merge"
|
||||
- "^WIP"
|
||||
- "^chore:"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: lukaszraczylo
|
||||
name: traefikoidc
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
|
||||
signs:
|
||||
- cmd: cosign
|
||||
signature: "${artifact}.sigstore.json"
|
||||
args:
|
||||
- sign-blob
|
||||
- "--bundle=${signature}"
|
||||
- "${artifact}"
|
||||
- "--yes"
|
||||
artifacts: checksum
|
||||
output: true
|
||||
+118
@@ -77,6 +77,7 @@ testData:
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
@@ -120,6 +121,8 @@ testData:
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
@@ -266,6 +269,8 @@ testDataWithRedis:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -287,6 +292,26 @@ testDataWithRedis:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -605,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -903,6 +960,67 @@ configuration:
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
|
||||
-286
@@ -1,286 +0,0 @@
|
||||
# CI/CD Setup Guide
|
||||
|
||||
## 📋 Overview
|
||||
|
||||
This repository now has a comprehensive CI/CD pipeline that runs **20+ parallel checks** on every pull request to ensure code quality, security, and reliability.
|
||||
|
||||
## 🎯 What Was Added
|
||||
|
||||
### GitHub Actions Workflow
|
||||
- **`.github/workflows/pr-validation.yml`** - Main CI/CD pipeline (single file, all parallel)
|
||||
|
||||
### Configuration Files
|
||||
- **`.golangci.yml`** - Linter configuration with 30+ enabled checks
|
||||
- **`.github/dependabot.yml`** - Automated dependency updates
|
||||
- **`.github/CODEOWNERS`** - Automatic PR reviewer assignment
|
||||
- **`.github/PULL_REQUEST_TEMPLATE.md`** - Standardized PR descriptions
|
||||
- **`.github/workflows/README.md`** - Detailed workflow documentation
|
||||
- **`.github/workflows/.gitattributes`** - Consistent line endings
|
||||
|
||||
## ✅ What Gets Tested (All in Parallel)
|
||||
|
||||
### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters including style, complexity, bugs
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning with SARIF reports
|
||||
- **Govulncheck** - Go vulnerability database scanning
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
### Testing (9 test suites)
|
||||
- **Race Detector** - Concurrent access bugs
|
||||
- **Coverage** - 75% threshold with PR comments
|
||||
- **Memory Leaks** - Goroutine and memory leak detection
|
||||
- **Integration Tests** - Full integration suite
|
||||
- **Regression Tests** - Prevent old bugs from returning
|
||||
- **Security Edge Cases** - Security-specific scenarios
|
||||
- **Session Tests** - Session management
|
||||
- **Token Tests** - Token validation
|
||||
- **CSRF Tests** - CSRF protection
|
||||
|
||||
### Provider Testing (9 providers in parallel)
|
||||
- Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, GitHub, Generic
|
||||
|
||||
### Performance & Build (3 checks)
|
||||
- **Benchmarks** - Performance regression detection
|
||||
- **Multi-platform Build** - 4 combinations (linux/darwin × amd64/arm64)
|
||||
- **Go Version Compatibility** - Go 1.23 & 1.24
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
### 1. Push to GitHub
|
||||
```bash
|
||||
git add .github .golangci.yml CI_SETUP.md
|
||||
git commit -m "Add comprehensive CI/CD pipeline"
|
||||
git push origin main
|
||||
```
|
||||
|
||||
### 2. Create a Test PR
|
||||
```bash
|
||||
# Create a feature branch
|
||||
git checkout -b feature/test-ci
|
||||
echo "# Test" >> test.md
|
||||
git add test.md
|
||||
git commit -m "Test CI pipeline"
|
||||
git push origin feature/test-ci
|
||||
|
||||
# Create PR on GitHub
|
||||
# Watch all 20+ checks run in parallel! ⚡
|
||||
```
|
||||
|
||||
### 3. Monitor Results
|
||||
- Go to Actions tab: `https://github.com/{owner}/{repo}/actions`
|
||||
- Click on latest workflow run
|
||||
- See all parallel checks in action
|
||||
- Review coverage comment on PR
|
||||
|
||||
## 📊 Key Features
|
||||
|
||||
### ⚡ Maximum Speed
|
||||
- **Parallel execution** - All checks run simultaneously
|
||||
- **Smart caching** - Go modules and build cache
|
||||
- **Optimized order** - Quick checks first for fast feedback
|
||||
- **Expected runtime**: 5-10 minutes for full suite
|
||||
|
||||
### 🔒 Security First
|
||||
- **3 security scanners** - gosec, govulncheck, CodeQL
|
||||
- **SARIF integration** - Results in GitHub Security tab
|
||||
- **Dependency scanning** - Automated with Dependabot
|
||||
- **Security edge case tests**
|
||||
|
||||
### 📈 Coverage Tracking
|
||||
- **Automatic PR comments** with coverage stats
|
||||
- **Per-package breakdown** included
|
||||
- **75% threshold** enforced (configurable)
|
||||
- **Codecov integration** ready (optional)
|
||||
|
||||
### 🎨 Developer Experience
|
||||
- **Clear PR template** guides contributors
|
||||
- **Auto code owners** assignment
|
||||
- **Detailed error messages** for failures
|
||||
- **Benchmark tracking** for performance
|
||||
|
||||
## 🛠️ Local Development
|
||||
|
||||
### Install Required Tools
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
### Run Checks Locally
|
||||
```bash
|
||||
# Quick validation (before committing)
|
||||
gofmt -s -w . # Format code
|
||||
go vet ./... # Basic checks
|
||||
go mod tidy # Clean dependencies
|
||||
|
||||
# Linting
|
||||
golangci-lint run # Full lint suite
|
||||
staticcheck ./... # Static analysis
|
||||
|
||||
# Testing
|
||||
go test -race -timeout=15m ./... # Tests with race detector
|
||||
go test -coverprofile=coverage.out ./... # Coverage
|
||||
go tool cover -func=coverage.out # View coverage
|
||||
|
||||
# Security
|
||||
gosec ./... # Security scan
|
||||
govulncheck ./... # Vulnerability check
|
||||
|
||||
# Benchmarks
|
||||
go test -bench=. -benchmem ./... # Performance tests
|
||||
```
|
||||
|
||||
### Pre-commit Checklist
|
||||
```bash
|
||||
# Run this before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "✅ Ready to commit!"
|
||||
```
|
||||
|
||||
## 📝 Configuration
|
||||
|
||||
### Adjust Coverage Threshold
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
THRESHOLD=75 # Change to desired percentage
|
||||
```
|
||||
|
||||
### Modify Linter Rules
|
||||
Edit `.golangci.yml`:
|
||||
```yaml
|
||||
linters:
|
||||
enable:
|
||||
- newlinter # Add new linters here
|
||||
```
|
||||
|
||||
### Update Go Version
|
||||
Edit `.github/workflows/pr-validation.yml`:
|
||||
```yaml
|
||||
go-version: '1.24' # Update version
|
||||
```
|
||||
|
||||
## 🐛 Troubleshooting
|
||||
|
||||
### Coverage Below Threshold
|
||||
```bash
|
||||
# See uncovered lines in browser
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Race Condition Found
|
||||
```bash
|
||||
# Run specific test with race detector
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
### Linter Errors
|
||||
```bash
|
||||
# See detailed lint errors
|
||||
golangci-lint run -v
|
||||
|
||||
# Auto-fix some issues
|
||||
golangci-lint run --fix
|
||||
```
|
||||
|
||||
### Provider Test Fails
|
||||
```bash
|
||||
# Test specific provider
|
||||
go test -v -run='.*Azure.*' ./internal/providers/
|
||||
```
|
||||
|
||||
## 📈 Metrics & Monitoring
|
||||
|
||||
### GitHub Actions Dashboard
|
||||
- View all runs: `Actions` tab
|
||||
- Filter by workflow, branch, status
|
||||
- Download logs and artifacts
|
||||
|
||||
### Status Badge
|
||||
Add to README.md:
|
||||
```markdown
|
||||
[](https://github.com/lukaszraczylo/traefikoidc/actions/workflows/pr-validation.yml)
|
||||
```
|
||||
|
||||
### Notifications
|
||||
- Configure in: Settings → Notifications
|
||||
- Email alerts for workflow failures
|
||||
- Slack/Discord webhooks supported
|
||||
|
||||
## 🔄 Continuous Improvement
|
||||
|
||||
### Dependabot Updates
|
||||
- Automatic weekly dependency checks (Mondays 9 AM)
|
||||
- Security updates prioritized
|
||||
- Groups patch updates together
|
||||
|
||||
### Code Owners
|
||||
- Auto-assigns reviewers based on file paths
|
||||
- Ensures expertise reviews changes
|
||||
- Speeds up PR review process
|
||||
|
||||
## 📚 Additional Resources
|
||||
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Dependabot Config](.github/dependabot.yml)
|
||||
|
||||
## 🎉 Benefits
|
||||
|
||||
### For Contributors
|
||||
- Clear expectations via PR template
|
||||
- Fast feedback (5-10 min)
|
||||
- Comprehensive local tooling
|
||||
- Detailed error messages
|
||||
|
||||
### For Maintainers
|
||||
- Automated code review
|
||||
- Security scanning
|
||||
- Performance tracking
|
||||
- Quality gates enforcement
|
||||
|
||||
### For Users
|
||||
- Higher code quality
|
||||
- Fewer bugs in production
|
||||
- Better security
|
||||
- Consistent performance
|
||||
|
||||
## 🚦 Success Criteria
|
||||
|
||||
All PRs must pass:
|
||||
- ✅ All 20+ parallel checks
|
||||
- ✅ 75% test coverage minimum
|
||||
- ✅ Zero security vulnerabilities
|
||||
- ✅ No race conditions
|
||||
- ✅ No memory leaks
|
||||
- ✅ All providers tested
|
||||
- ✅ Builds on all platforms
|
||||
|
||||
## 💡 Tips
|
||||
|
||||
1. **Run checks locally** before pushing to save CI time
|
||||
2. **Watch for PR comments** - coverage stats posted automatically
|
||||
3. **Check Security tab** for gosec/CodeQL findings
|
||||
4. **Review benchmark results** in artifacts
|
||||
5. **Use draft PRs** for work-in-progress to skip some checks
|
||||
|
||||
---
|
||||
|
||||
**Ready to go!** 🚀 Push your changes and create a PR to see it in action.
|
||||
@@ -82,6 +82,19 @@ experimental:
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
|
||||
### Verifying Release Signatures
|
||||
|
||||
All release checksums are signed with [cosign](https://github.com/sigstore/cosign) using keyless signing. To verify:
|
||||
|
||||
```bash
|
||||
# Download the checksum file and its sigstore bundle from the release
|
||||
cosign verify-blob \
|
||||
--certificate-identity-regexp "https://github.com/lukaszraczylo/traefikoidc/.*" \
|
||||
--certificate-oidc-issuer "https://token.actions.githubusercontent.com" \
|
||||
--bundle "traefikoidc_v<version>_checksums.txt.sigstore.json" \
|
||||
traefikoidc_v<version>_checksums.txt
|
||||
```
|
||||
|
||||
### Local Development with Docker Compose
|
||||
|
||||
For local development or testing, you can use the provided Docker Compose setup:
|
||||
@@ -124,6 +137,7 @@ The middleware supports the following configuration options:
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
@@ -138,6 +152,8 @@ The middleware supports the following configuration options:
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
@@ -1241,6 +1257,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1327,8 +1382,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1629,12 +1688,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1862,6 +1948,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
+1376
-1
File diff suppressed because it is too large
Load Diff
@@ -1,931 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// TestConfigAudienceValidation tests the Config.Validate() method for the audience field
|
||||
func TestConfigAudienceValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Empty audience is valid for backward compatibility",
|
||||
audience: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTPS URL audience Auth0 format",
|
||||
audience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid identifier audience",
|
||||
audience: "my-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Azure AD Application ID URI format",
|
||||
audience: "api://12345-guid-67890",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Auth0 API identifier",
|
||||
audience: "https://my-company.auth0.com/api/v2/",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP URL audience should fail",
|
||||
audience: "http://api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must use HTTPS",
|
||||
},
|
||||
{
|
||||
name: "Audience with wildcard should fail",
|
||||
audience: "https://api.*.example.com",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience with single asterisk should fail",
|
||||
audience: "*",
|
||||
wantErr: true,
|
||||
errContains: "must not contain wildcards",
|
||||
},
|
||||
{
|
||||
name: "Audience over 256 characters should fail",
|
||||
audience: strings.Repeat("a", 257),
|
||||
wantErr: true,
|
||||
errContains: "must not exceed 256 characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with newline should fail",
|
||||
audience: "my-api\ninjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with carriage return should fail",
|
||||
audience: "my-api\rinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Audience with tab should fail",
|
||||
audience: "my-api\tinjection",
|
||||
wantErr: true,
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Valid audience exactly 256 characters",
|
||||
audience: strings.Repeat("a", 256),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid simple identifier",
|
||||
audience: "my-service-api",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid URN format",
|
||||
audience: "urn:myservice:api:v1",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceVerification tests JWT verification with custom audience values
|
||||
func TestJWTAudienceVerification(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate RSA key for signing JWTs
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
// Create JWK
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configAudience string
|
||||
tokenAudience interface{}
|
||||
wantErr bool
|
||||
errContains string
|
||||
skipReplayCheck bool
|
||||
}{
|
||||
{
|
||||
name: "JWT with string aud matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://api.example.com",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://api.example.com", "https://another.com"},
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with string aud NOT matching configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "https://wrong-api.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with array aud NOT containing configured audience",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: []interface{}{"https://other.com", "https://another.com"},
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with clientID as aud when no custom audience configured",
|
||||
configAudience: "",
|
||||
tokenAudience: "test-client-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "JWT with empty string aud",
|
||||
configAudience: "https://api.example.com",
|
||||
tokenAudience: "",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Azure AD Application ID URI format",
|
||||
configAudience: "api://12345-app-id",
|
||||
tokenAudience: "api://12345-app-id",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Auth0 custom API audience",
|
||||
configAudience: "https://mycompany.com/api",
|
||||
tokenAudience: "https://mycompany.com/api",
|
||||
wantErr: false,
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
{
|
||||
name: "Token confusion attack - audience for different service",
|
||||
configAudience: "https://service-a.example.com",
|
||||
tokenAudience: "https://service-b.example.com",
|
||||
wantErr: true,
|
||||
errContains: "invalid audience",
|
||||
skipReplayCheck: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create TraefikOidc instance
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Set up the token verifier and JWT verifier
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Determine the expected audience for validation
|
||||
expectedAudience := tt.configAudience
|
||||
if expectedAudience == "" {
|
||||
expectedAudience = tOidc.clientID
|
||||
}
|
||||
|
||||
// Set the audience field on the tOidc instance
|
||||
tOidc.audience = expectedAudience
|
||||
|
||||
// Create JWT with specified audience
|
||||
jti := generateRandomString(16)
|
||||
if tt.skipReplayCheck {
|
||||
// Use a unique JTI for each test to avoid replay detection
|
||||
jti = fmt.Sprintf("test-%s-%s", tt.name, jti)
|
||||
}
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": tt.tokenAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": jti,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifyToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Error message should contain %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAudienceBackwardCompatibility tests that existing behavior is preserved
|
||||
// when the Audience field is not set
|
||||
func TestJWTAudienceBackwardCompatibility(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Test with no custom audience configured - should use clientID
|
||||
jwt, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // Should match clientID
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-subject",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
err = ts.tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Backward compatibility broken: VerifyToken() error = %v, expected nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAuth0Scenario tests Auth0-specific use case
|
||||
func TestAudienceIntegrationAuth0Scenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Auth0 scenario: custom audience for API access
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://mycompany.auth0.com"
|
||||
config.ClientID = "auth0-client-id"
|
||||
config.ClientSecret = "auth0-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "https://api.mycompany.com" // Custom API audience
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Auth0 config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "auth0-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://mycompany.auth0.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Auth0 API access token with custom audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches configured audience
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Auth0 token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Auth0 ACCESS token with clientID instead of API audience should fail", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "auth0-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.ClientID, // Using clientID instead of API audience
|
||||
"scope": "openid profile email", // Mark as access token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "auth0|123456",
|
||||
"email": "user@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Auth0 JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err == nil {
|
||||
t.Error("Auth0 access token with wrong audience should have been rejected")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceIntegrationAzureADScenario tests Azure AD-specific use case
|
||||
func TestAudienceIntegrationAzureADScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Simulate Azure AD scenario: Application ID URI format
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://login.microsoftonline.com/tenant-id/v2.0"
|
||||
config.ClientID = "azure-client-id"
|
||||
config.ClientSecret = "azure-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = "api://12345-abcd-6789-efgh" // Azure AD Application ID URI
|
||||
|
||||
// Validate config
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Fatalf("Azure AD config validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "azure-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: config.ProviderURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
audience: config.Audience, // Set audience from config
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: config.ProviderURL + "/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
|
||||
// Default audience to clientID if not specified
|
||||
if tOidc.audience == "" {
|
||||
tOidc.audience = tOidc.clientID
|
||||
}
|
||||
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
t.Run("Valid Azure AD token with Application ID URI audience", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": config.Audience, // Matches Application ID URI
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure AD token with multiple audiences including correct one", func(t *testing.T) {
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "azure-key-id", map[string]interface{}{
|
||||
"iss": config.ProviderURL,
|
||||
"aud": []interface{}{config.ClientID, config.Audience, "https://graph.microsoft.com"},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "azure-user-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "object-id-12345",
|
||||
"tid": "tenant-id",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Azure AD JWT: %v", err)
|
||||
}
|
||||
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Errorf("Azure AD token with multiple audiences verification failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityTokenConfusionAttack tests security against token confusion attacks
|
||||
func TestAudienceSecurityTokenConfusionAttack(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
// Service A configuration
|
||||
serviceA := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "service-a-client-id",
|
||||
clientSecret: "service-a-secret",
|
||||
audience: "service-a-client-id", // Service A uses its clientID as audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
serviceA.jwtVerifier = serviceA
|
||||
serviceA.tokenVerifier = serviceA
|
||||
|
||||
t.Run("Token confusion - Try to use service B token on service A", func(t *testing.T) {
|
||||
// Create a token intended for service B
|
||||
serviceBToken, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": "https://service-b.example.com", // For service B
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "attacker@example.com",
|
||||
"email": "attacker@example.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create service B token: %v", err)
|
||||
}
|
||||
|
||||
// Try to verify the service B token on service A
|
||||
err = serviceA.VerifyToken(serviceBToken)
|
||||
switch {
|
||||
case err == nil:
|
||||
t.Error("SECURITY VULNERABILITY: Token confusion attack succeeded - service B token was accepted by service A")
|
||||
case !strings.Contains(err.Error(), "invalid audience"):
|
||||
t.Errorf("Expected 'invalid audience' error for token confusion, got: %v", err)
|
||||
default:
|
||||
t.Logf("Token confusion attack correctly prevented: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestAudienceSecurityWildcardInjection tests that wildcards are rejected
|
||||
func TestAudienceSecurityWildcardInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
}{
|
||||
{
|
||||
name: "Single asterisk",
|
||||
audience: "*",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in URL",
|
||||
audience: "https://*.example.com",
|
||||
},
|
||||
{
|
||||
name: "Wildcard in path",
|
||||
audience: "https://api.example.com/*",
|
||||
},
|
||||
{
|
||||
name: "Multiple wildcards",
|
||||
audience: "https://*.*.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Wildcard audience %q was not rejected", tt.audience)
|
||||
} else if !strings.Contains(err.Error(), "must not contain wildcards") {
|
||||
t.Errorf("Expected 'must not contain wildcards' error, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceSecurityInjectionAttempts tests various injection attempts
|
||||
func TestAudienceSecurityInjectionAttempts(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
audience string
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "Newline injection",
|
||||
audience: "api.example.com\nmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Carriage return injection",
|
||||
audience: "api.example.com\rmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Tab injection",
|
||||
audience: "api.example.com\tmalicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
{
|
||||
name: "Null byte injection",
|
||||
audience: "api.example.com\x00malicious.com",
|
||||
errContains: "contains invalid characters",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := CreateConfig()
|
||||
config.ProviderURL = "https://provider.example.com"
|
||||
config.ClientID = "test-client-id"
|
||||
config.ClientSecret = "test-client-secret"
|
||||
config.CallbackURL = "/callback"
|
||||
config.SessionEncryptionKey = strings.Repeat("a", MinSessionEncryptionKeyLength)
|
||||
config.Audience = tt.audience
|
||||
|
||||
err := config.Validate()
|
||||
if err == nil {
|
||||
t.Errorf("SECURITY VULNERABILITY: Injection attempt with %q was not rejected", tt.name)
|
||||
} else if !strings.Contains(err.Error(), tt.errContains) {
|
||||
t.Errorf("Expected error containing %q, got: %v", tt.errContains, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceWithReplayProtection tests that replay protection works correctly with custom audiences
|
||||
func TestAudienceWithReplayProtection(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
issuerURL: "https://auth.example.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.example.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
httpClient: &http.Client{},
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
// Create a token with custom audience and fixed JTI
|
||||
fixedJTI := "replay-test-jti-" + generateRandomString(8)
|
||||
customAudience := "https://api.example.com"
|
||||
|
||||
// Set the audience field to match what we expect
|
||||
tOidc.audience = customAudience
|
||||
|
||||
jwt, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.example.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "user@example.com",
|
||||
"jti": fixedJTI,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create JWT: %v", err)
|
||||
}
|
||||
|
||||
// First verification should succeed
|
||||
err = tOidc.VerifyToken(jwt)
|
||||
if err != nil {
|
||||
t.Fatalf("First verification failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the JTI was blacklisted
|
||||
if blacklisted, exists := tOidc.tokenBlacklist.Get(fixedJTI); !exists || blacklisted == nil {
|
||||
t.Logf("Note: JTI was not added to blacklist (may be due to test token prefix)")
|
||||
} else {
|
||||
t.Logf("Replay protection verified: JTI %s is correctly blacklisted", fixedJTI)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAudienceEndToEndScenario tests a complete end-to-end scenario with middleware
|
||||
func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
// Create cleanup helper
|
||||
tc := newTestCleanup(t)
|
||||
|
||||
// Create a test next handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Authenticated with custom audience"))
|
||||
})
|
||||
|
||||
// Generate test keys
|
||||
rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate RSA key: %v", err)
|
||||
}
|
||||
rsaPublicKey := &rsaPrivateKey.PublicKey
|
||||
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1}),
|
||||
}
|
||||
jwks := &JWKSet{
|
||||
Keys: []JWK{jwk},
|
||||
}
|
||||
|
||||
mockJWKCache := &MockJWKCache{
|
||||
JWKS: jwks,
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
|
||||
tokenBlacklist := tc.addCache(NewCache())
|
||||
tokenCache := tc.addTokenCache(NewTokenCache())
|
||||
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
close(tOidc.initComplete)
|
||||
|
||||
t.Run("End-to-end with correct custom audience", func(t *testing.T) {
|
||||
// Create a valid token with the custom audience
|
||||
validJWT, err := createTestJWT(rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://auth.company.com",
|
||||
"aud": customAudience,
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Add(-2 * time.Minute).Unix()),
|
||||
"sub": "user-123",
|
||||
"email": "user@company.com",
|
||||
"jti": generateRandomString(16),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create valid JWT: %v", err)
|
||||
}
|
||||
|
||||
// Create a request with authenticated session
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
|
||||
// Create session with token
|
||||
resp := httptest.NewRecorder()
|
||||
session, err := sm.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
t.Fatalf("Failed to set authenticated: %v", err)
|
||||
}
|
||||
session.SetEmail("user@company.com")
|
||||
session.SetIDToken(validJWT)
|
||||
session.SetAccessToken(validJWT)
|
||||
|
||||
if err := session.Save(req, resp); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Get cookies and add them to a new request
|
||||
cookies := resp.Result().Cookies()
|
||||
req = httptest.NewRequest("GET", "/protected", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "company.com")
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
resp = httptest.NewRecorder()
|
||||
tOidc.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,409 +0,0 @@
|
||||
// Package auth provides authentication-related functionality for the OIDC middleware.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ScopeFilter interface for filtering OAuth scopes based on provider capabilities
|
||||
type ScopeFilter interface {
|
||||
FilterSupportedScopes(requestedScopes, supportedScopes []string, providerURL string) []string
|
||||
}
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
}
|
||||
}
|
||||
|
||||
// InitiateAuthentication initiates the OIDC authentication flow.
|
||||
// It generates CSRF tokens, nonce, PKCE parameters (if enabled), clears the session,
|
||||
// stores authentication state, and redirects the user to the OIDC provider.
|
||||
func (h *Handler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request,
|
||||
session SessionData, redirectURL string,
|
||||
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
||||
h.logger.Debugf("Initiating new OIDC authentication flow for request: %s", req.URL.RequestURI())
|
||||
|
||||
const maxRedirects = 5
|
||||
redirectCount := session.GetRedirectCount()
|
||||
if redirectCount >= maxRedirects {
|
||||
h.logger.Errorf("Maximum redirect limit (%d) exceeded, possible redirect loop detected", maxRedirects)
|
||||
session.ResetRedirectCount()
|
||||
http.Error(rw, "Authentication failed: Too many redirects", http.StatusLoopDetected)
|
||||
return
|
||||
}
|
||||
|
||||
session.IncrementRedirectCount()
|
||||
|
||||
csrfToken := uuid.NewString()
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate nonce: %v", err)
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate PKCE code verifier and challenge if PKCE is enabled
|
||||
var codeVerifier, codeChallenge string
|
||||
if h.enablePKCE {
|
||||
codeVerifier, err = generateCodeVerifier()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code verifier: %v", err)
|
||||
http.Error(rw, "Failed to generate code verifier", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
codeChallenge, err = deriveCodeChallenge()
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to generate code challenge: %v", err)
|
||||
http.Error(rw, "Failed to generate code challenge", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("PKCE enabled, generated code challenge")
|
||||
}
|
||||
|
||||
session.SetAuthenticated(false)
|
||||
session.SetEmail("")
|
||||
session.SetAccessToken("")
|
||||
session.SetRefreshToken("")
|
||||
session.SetIDToken("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
if h.enablePKCE {
|
||||
session.SetCodeVerifier(codeVerifier)
|
||||
}
|
||||
session.SetIncomingPath(req.URL.RequestURI())
|
||||
h.logger.Debugf("Storing incoming path: %s", req.URL.RequestURI())
|
||||
|
||||
session.MarkDirty()
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session before redirecting to provider: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Session saved before redirect. CSRF: %s, Nonce: %s",
|
||||
csrfToken, nonce)
|
||||
|
||||
authURL := h.BuildAuthURL(redirectURL, csrfToken, nonce, codeChallenge)
|
||||
h.logger.Debugf("Redirecting user to OIDC provider: %s", authURL)
|
||||
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildAuthURL constructs the OIDC provider authorization URL.
|
||||
// It builds the URL with all necessary parameters including client_id, scopes,
|
||||
// PKCE parameters, and provider-specific parameters for Google and Azure.
|
||||
func (h *Handler) BuildAuthURL(redirectURL, state, nonce, codeChallenge string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", h.clientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Set("redirect_uri", redirectURL)
|
||||
params.Set("state", state)
|
||||
params.Set("nonce", nonce)
|
||||
|
||||
if h.enablePKCE && codeChallenge != "" {
|
||||
params.Set("code_challenge", codeChallenge)
|
||||
params.Set("code_challenge_method", "S256")
|
||||
}
|
||||
|
||||
scopes := make([]string, len(h.scopes))
|
||||
copy(scopes, h.scopes)
|
||||
|
||||
// Apply discovery-based scope filtering if available
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After discovery filtering: %v", scopes)
|
||||
}
|
||||
|
||||
// Apply provider-specific modifications
|
||||
scopes, params = h.applyProviderSpecificConfig(scopes, params)
|
||||
|
||||
// Final filtering pass to remove anything the provider doesn't support
|
||||
if h.scopeFilter != nil && len(h.scopesSupported) > 0 {
|
||||
scopes = h.scopeFilter.FilterSupportedScopes(scopes, h.scopesSupported, h.issuerURL)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: After final filtering: %v", scopes)
|
||||
}
|
||||
|
||||
if len(scopes) > 0 {
|
||||
finalScopeString := strings.Join(scopes, " ")
|
||||
params.Set("scope", finalScopeString)
|
||||
h.logger.Debugf("AuthHandler.BuildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
return h.buildURLWithParams(h.authURL, params)
|
||||
}
|
||||
|
||||
// applyProviderSpecificConfig applies provider-specific scope and parameter modifications
|
||||
func (h *Handler) applyProviderSpecificConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
switch {
|
||||
case h.isGoogleProv():
|
||||
return h.applyGoogleConfig(scopes, params)
|
||||
case h.isAzureProv():
|
||||
return h.applyAzureConfig(scopes, params)
|
||||
default:
|
||||
return h.applyStandardProviderConfig(scopes, params)
|
||||
}
|
||||
}
|
||||
|
||||
// applyGoogleConfig applies Google-specific configuration
|
||||
func (h *Handler) applyGoogleConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
// Google: Remove offline_access if present, add access_type=offline
|
||||
filteredScopes := make([]string, 0, len(scopes))
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
params.Set("access_type", "offline")
|
||||
h.logger.Debugf("Google OIDC provider detected, added access_type=offline")
|
||||
params.Set("prompt", "consent")
|
||||
h.logger.Debugf("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
return filteredScopes, params
|
||||
}
|
||||
|
||||
// applyAzureConfig applies Azure AD-specific configuration
|
||||
func (h *Handler) applyAzureConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
params.Set("response_mode", "query")
|
||||
h.logger.Debugf("Azure AD provider detected, added response_mode=query")
|
||||
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Azure AD provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Azure AD provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// applyStandardProviderConfig applies configuration for standard OIDC providers
|
||||
func (h *Handler) applyStandardProviderConfig(scopes []string, params url.Values) ([]string, url.Values) {
|
||||
if h.shouldAddOfflineAccess(scopes) {
|
||||
scopes = append(scopes, "offline_access")
|
||||
h.logger.Debugf("Standard provider: Added offline_access scope (overrideScopes: %t, user scopes count: %d)",
|
||||
h.overrideScopes, len(h.scopes))
|
||||
} else {
|
||||
h.logger.Debugf("Standard provider: User is overriding scopes (count: %d), offline_access not automatically added.",
|
||||
len(h.scopes))
|
||||
}
|
||||
return scopes, params
|
||||
}
|
||||
|
||||
// shouldAddOfflineAccess determines if offline_access scope should be added
|
||||
func (h *Handler) shouldAddOfflineAccess(scopes []string) bool {
|
||||
if h.overrideScopes && len(h.scopes) > 0 {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// buildURLWithParams constructs a URL by combining a base URL with query parameters.
|
||||
// It handles both relative and absolute URLs, validates URL security,
|
||||
// and properly encodes query parameters.
|
||||
func (h *Handler) buildURLWithParams(baseURL string, params url.Values) string {
|
||||
if baseURL != "" {
|
||||
if strings.HasPrefix(baseURL, "http://") || strings.HasPrefix(baseURL, "https://") {
|
||||
if err := h.validateURL(baseURL); err != nil {
|
||||
h.logger.Errorf("URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") {
|
||||
issuerURLParsed, err := url.Parse(h.issuerURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse issuerURL: %s. Error: %v", h.issuerURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
baseURLParsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL := issuerURLParsed.ResolveReference(baseURLParsed)
|
||||
|
||||
if err := h.validateURL(resolvedURL.String()); err != nil {
|
||||
h.logger.Errorf("Resolved URL validation failed for %s: %v", resolvedURL.String(), err)
|
||||
return ""
|
||||
}
|
||||
|
||||
resolvedURL.RawQuery = params.Encode()
|
||||
return resolvedURL.String()
|
||||
}
|
||||
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Could not parse absolute baseURL: %s. Error: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
if err := h.validateParsedURL(u); err != nil {
|
||||
h.logger.Errorf("Parsed URL validation failed for %s: %v", baseURL, err)
|
||||
return ""
|
||||
}
|
||||
|
||||
u.RawQuery = params.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// validateURL performs security validation on URLs to prevent SSRF attacks.
|
||||
// It checks for allowed schemes, validates hosts, and prevents access to private networks.
|
||||
func (h *Handler) validateURL(urlStr string) error {
|
||||
if urlStr == "" {
|
||||
return fmt.Errorf("empty URL")
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid URL format: %w", err)
|
||||
}
|
||||
|
||||
return h.validateParsedURL(u)
|
||||
}
|
||||
|
||||
// validateParsedURL validates a parsed URL structure for security.
|
||||
// It checks schemes, hosts, and paths to prevent malicious URLs.
|
||||
func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
allowedSchemes := map[string]bool{
|
||||
"https": true,
|
||||
"http": true,
|
||||
}
|
||||
|
||||
if !allowedSchemes[u.Scheme] {
|
||||
return fmt.Errorf("disallowed URL scheme: %s", u.Scheme)
|
||||
}
|
||||
|
||||
if u.Scheme == "http" {
|
||||
h.logger.Debugf("Warning: Using HTTP scheme for URL: %s", u.String())
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("missing host in URL")
|
||||
}
|
||||
|
||||
if err := h.validateHost(u.Host); err != nil {
|
||||
return fmt.Errorf("invalid host: %w", err)
|
||||
}
|
||||
|
||||
if strings.Contains(u.Path, "..") {
|
||||
return fmt.Errorf("path traversal detected in URL path")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
// Strip port if present
|
||||
if strings.Contains(host, ":") {
|
||||
var err error
|
||||
host, _, err = net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid host:port format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
for _, localhost := range localhostVariations {
|
||||
if strings.EqualFold(host, localhost) {
|
||||
return fmt.Errorf("localhost access not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to parse as IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
return fmt.Errorf("link-local IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsMulticast() {
|
||||
return fmt.Errorf("multicast IP not allowed: %s", host)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SessionData interface for dependency injection
|
||||
type SessionData interface {
|
||||
GetRedirectCount() int
|
||||
ResetRedirectCount()
|
||||
IncrementRedirectCount()
|
||||
SetAuthenticated(bool)
|
||||
SetEmail(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetIDToken(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetCSRF(string)
|
||||
SetIncomingPath(string)
|
||||
MarkDirty()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,562 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestAuthHandler_validateURL tests URL validation functionality
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
url: "http://example.com/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
url: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty URL",
|
||||
},
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
url: "not-a-url",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - javascript",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - data",
|
||||
url: "data:text/html,<script>alert('xss')</script>",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - file",
|
||||
url: "file:///etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Disallowed scheme - ftp",
|
||||
url: "ftp://example.com/file",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal attempt",
|
||||
url: "https://example.com/../../../etc/passwd",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Path traversal in middle",
|
||||
url: "https://example.com/path/../sensitive/file",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Localhost attempt",
|
||||
url: "https://localhost/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 attempt",
|
||||
url: "https://127.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost attempt",
|
||||
url: "https://[::1]/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0 attempt",
|
||||
url: "https://0.0.0.0/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 192.168.x.x",
|
||||
url: "https://192.168.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 10.x.x.x",
|
||||
url: "https://10.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP - 172.16.x.x",
|
||||
url: "https://172.16.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
url: "https://169.254.1.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
url: "https://224.0.0.1/auth",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Valid public IP",
|
||||
url: "https://8.8.8.8/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid domain with port",
|
||||
url: "https://example.com:8443/auth",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "localhost with case variation",
|
||||
url: "https://LOCALHOST/auth",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
url: "https://example.com:notanumber/auth",
|
||||
wantErr: true,
|
||||
errMsg: "invalid URL format",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateURL(tt.url)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateURL() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost tests host validation specifically
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid hostname",
|
||||
host: "example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with subdomain",
|
||||
host: "api.example.com",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid hostname with port",
|
||||
host: "example.com:8080",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty host",
|
||||
host: "",
|
||||
wantErr: true,
|
||||
errMsg: "empty host",
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
host: "localhost",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "LOCALHOST (case insensitive)",
|
||||
host: "LOCALHOST",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "localhost with port",
|
||||
host: "localhost:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1",
|
||||
host: "127.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "127.0.0.1 with port",
|
||||
host: "127.0.0.1:8080",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "IPv6 localhost",
|
||||
host: "::1",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "0.0.0.0",
|
||||
host: "0.0.0.0",
|
||||
wantErr: true,
|
||||
errMsg: "localhost access not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 192.168.1.1",
|
||||
host: "192.168.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 10.0.0.1",
|
||||
host: "10.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Private IP 172.16.0.1",
|
||||
host: "172.16.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "private IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Public IP 8.8.8.8",
|
||||
host: "8.8.8.8",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Link-local IP",
|
||||
host: "169.254.1.1",
|
||||
wantErr: true,
|
||||
errMsg: "link-local IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Multicast IP",
|
||||
host: "224.0.0.1",
|
||||
wantErr: true,
|
||||
errMsg: "multicast IP not allowed",
|
||||
},
|
||||
{
|
||||
name: "Invalid host:port format",
|
||||
host: "example.com::",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host:port format",
|
||||
},
|
||||
{
|
||||
name: "Valid international domain",
|
||||
host: "example.org",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid ccTLD",
|
||||
host: "example.co.uk",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := handler.validateHost(tt.host)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateHost() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateHost() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateHost() unexpected error = %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams tests URL building with parameters
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseURL string
|
||||
params url.Values
|
||||
expected string
|
||||
expectEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "Absolute HTTPS URL",
|
||||
baseURL: "https://provider.com/auth",
|
||||
params: url.Values{
|
||||
"client_id": []string{"test-client"},
|
||||
"response_type": []string{"code"},
|
||||
},
|
||||
expected: "https://provider.com/auth?client_id=test-client&response_type=code",
|
||||
},
|
||||
{
|
||||
name: "Absolute HTTP URL",
|
||||
baseURL: "http://provider.com/auth",
|
||||
params: url.Values{
|
||||
"state": []string{"test-state"},
|
||||
},
|
||||
expected: "http://provider.com/auth?state=test-state",
|
||||
},
|
||||
{
|
||||
name: "Relative URL resolved against issuer",
|
||||
baseURL: "/oauth2/authorize",
|
||||
params: url.Values{
|
||||
"scope": []string{"openid"},
|
||||
},
|
||||
expected: "https://example.com/oauth2/authorize?scope=openid",
|
||||
},
|
||||
{
|
||||
name: "Root relative URL",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{
|
||||
"nonce": []string{"test-nonce"},
|
||||
},
|
||||
expected: "https://example.com/auth?nonce=test-nonce",
|
||||
},
|
||||
{
|
||||
name: "Invalid absolute URL",
|
||||
baseURL: "https://localhost/auth",
|
||||
params: url.Values{},
|
||||
expectEmpty: true, // Should return empty string due to validation failure
|
||||
},
|
||||
{
|
||||
name: "Invalid relative URL when resolved",
|
||||
baseURL: "/auth",
|
||||
params: url.Values{},
|
||||
expected: "", // Should be empty because issuer validation would be tested separately
|
||||
expectEmpty: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.buildURLWithParams(tt.baseURL, tt.params)
|
||||
|
||||
if tt.expectEmpty {
|
||||
if result != "" {
|
||||
t.Errorf("buildURLWithParams() expected empty string, got %v", result)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For relative URLs, we expect them to be resolved against the issuer URL
|
||||
if !strings.HasPrefix(tt.baseURL, "http") {
|
||||
// Verify it starts with the issuer URL
|
||||
if !strings.HasPrefix(result, handler.issuerURL) {
|
||||
t.Errorf("buildURLWithParams() relative URL not resolved against issuer URL. Got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the result to verify parameters
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("buildURLWithParams() produced invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify all expected parameters are present
|
||||
resultParams := parsedURL.Query()
|
||||
for key, expectedValues := range tt.params {
|
||||
actualValues := resultParams[key]
|
||||
if len(actualValues) != len(expectedValues) {
|
||||
t.Errorf("Parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
|
||||
continue
|
||||
}
|
||||
for i, expectedValue := range expectedValues {
|
||||
if actualValues[i] != expectedValue {
|
||||
t.Errorf("Parameter %s[%d]: expected %v, got %v", key, i, expectedValue, actualValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_buildURLWithParams_ParameterEncoding tests proper parameter encoding
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
"redirect_uri": []string{"https://example.com/callback?test=value&other=data"},
|
||||
"state": []string{"state with spaces and & special chars"},
|
||||
"scope": []string{"openid profile email"},
|
||||
"special": []string{"value+with+plus&ersand=equals"},
|
||||
}
|
||||
|
||||
result := handler.buildURLWithParams("https://provider.com/auth", params)
|
||||
|
||||
parsedURL, err := url.Parse(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse result URL: %v", err)
|
||||
}
|
||||
|
||||
// Verify parameters are correctly encoded/decoded
|
||||
resultParams := parsedURL.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"redirect_uri": "https://example.com/callback?test=value&other=data",
|
||||
"state": "state with spaces and & special chars",
|
||||
"scope": "openid profile email",
|
||||
"special": "value+with+plus&ersand=equals",
|
||||
}
|
||||
|
||||
for key, expectedValue := range expectedParams {
|
||||
actualValue := resultParams.Get(key)
|
||||
if actualValue != expectedValue {
|
||||
t.Errorf("Parameter %s: expected %v, got %v", key, expectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateParsedURL tests validateParsedURL method
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid HTTPS URL",
|
||||
url: "https://example.com/path",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL with warning",
|
||||
url: "http://example.com/path",
|
||||
wantErr: false, // Should not error but should log warning
|
||||
},
|
||||
{
|
||||
name: "Invalid scheme",
|
||||
url: "javascript:alert('xss')",
|
||||
wantErr: true,
|
||||
errMsg: "disallowed URL scheme",
|
||||
},
|
||||
{
|
||||
name: "Missing host",
|
||||
url: "https:///path",
|
||||
wantErr: true,
|
||||
errMsg: "missing host",
|
||||
},
|
||||
{
|
||||
name: "Path traversal",
|
||||
url: "https://example.com/path/../../../etc",
|
||||
wantErr: true,
|
||||
errMsg: "path traversal detected",
|
||||
},
|
||||
{
|
||||
name: "Invalid host (private IP)",
|
||||
url: "https://192.168.1.1/path",
|
||||
wantErr: true,
|
||||
errMsg: "invalid host",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parsedURL, err := url.Parse(tt.url)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse test URL: %v", err)
|
||||
}
|
||||
|
||||
err = handler.validateParsedURL(parsedURL)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("validateParsedURL() expected error but got none")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.errMsg) {
|
||||
t.Errorf("validateParsedURL() error = %v, expected error containing %v", err, tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("validateParsedURL() unexpected error = %v", err)
|
||||
}
|
||||
|
||||
// Check for HTTP warning in debug logs
|
||||
if parsedURL.Scheme == "http" && len(logger.debugMessages) > 0 {
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if strings.Contains(msg, "Warning: Using HTTP scheme") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected HTTP scheme warning in debug logs")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,428 +0,0 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file contains tests for Auth0-specific audience validation scenarios.
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestAuth0Scenario1WithCustomAudience tests Auth0 scenario 1:
|
||||
// - Custom audience configured in plugin
|
||||
// - Authorize endpoint called WITH audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, custom_audience]
|
||||
// Expected: Both tokens validate correctly
|
||||
func TestAuth0Scenario1WithCustomAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (OIDC standard)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token always has client_id
|
||||
"nonce": "test-nonce-scenario1", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, custom_audience]
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience, // Custom API audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data", // Access tokens have scope
|
||||
"jti": "access-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed (should validate against client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify access token validates against custom audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token validation failed (should validate against custom audience): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL includes audience parameter (URL-encoded)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if !strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should contain audience parameter when custom audience is configured, got: %s", authURL)
|
||||
}
|
||||
// Verify the audience is properly URL-encoded (contains %3A for :, %2F for /)
|
||||
if !strings.Contains(authURL, "audience=https%3A%2F%2Fmy-api.example.com") {
|
||||
t.Errorf("Auth URL should contain URL-encoded custom audience, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2DefaultAudience tests Auth0 scenario 2:
|
||||
// - No custom audience configured (defaults to client_id)
|
||||
// - Authorize endpoint called WITHOUT audience parameter
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: aud = [userinfo, default_audience] (no client_id)
|
||||
// Expected: ID token validates, access token falls back to ID token validation
|
||||
func TestAuth0Scenario2DefaultAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// No custom audience - defaults to client_id
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token with aud = client_id
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario2", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with aud = [userinfo, some_default_audience]
|
||||
// This represents Auth0's default audience behavior
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://test-issuer.com/api/v2/", // Default Auth0 Management API
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-jti-2",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Access token won't have client_id in aud, so it will fail validation
|
||||
// This is expected for scenario 2 - the session validation relies on ID token
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Logf("Access token validation passed (unexpected but OK if client_id is in aud array)")
|
||||
} else {
|
||||
// Expected failure - access token doesn't have client_id in aud
|
||||
t.Logf("Access token validation failed as expected (aud doesn't contain client_id): %v", err)
|
||||
}
|
||||
|
||||
// Verify buildAuthURL does NOT include audience parameter (since audience == client_id)
|
||||
authURL := ts.tOidc.buildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
if strings.Contains(authURL, "audience=") {
|
||||
t.Errorf("Auth URL should NOT contain audience parameter when audience equals client_id, got: %s", authURL)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario3OpaqueAccessToken tests Auth0 scenario 3:
|
||||
// - No custom audience configured
|
||||
// - No default audience in Auth0
|
||||
// - ID token: aud = client_id
|
||||
// - Access token: opaque (not JWT)
|
||||
// Expected: ID token validates, opaque access token is accepted
|
||||
func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable opaque tokens for this scenario (Option C requirement)
|
||||
ts.tOidc.allowOpaqueTokens = true
|
||||
|
||||
// No custom audience
|
||||
ts.tOidc.audience = ts.tOidc.clientID
|
||||
|
||||
// Create ID token
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-scenario3", // ID tokens have nonce per OIDC spec
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-jti-3",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token (not a JWT - just a random string)
|
||||
opaqueAccessToken := "opaque_access_token_random_string_12345"
|
||||
|
||||
// Verify ID token validates
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token validation failed: %v", err)
|
||||
}
|
||||
|
||||
// Opaque access token should fail JWT validation (expected)
|
||||
err = ts.tOidc.VerifyToken(opaqueAccessToken)
|
||||
if err == nil {
|
||||
t.Error("Opaque access token should fail JWT validation")
|
||||
} else {
|
||||
t.Logf("Opaque access token correctly rejected by JWT validator: %v", err)
|
||||
}
|
||||
|
||||
// Test that validateStandardTokens handles opaque tokens correctly
|
||||
// by falling back to ID token validation
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0AudienceArrayValidation tests that audience validation
|
||||
// correctly handles array audiences (common in Auth0)
|
||||
func TestAuth0AudienceArrayValidation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with audience as array containing our custom audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
customAudience,
|
||||
"https://another-api.example.com",
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email read:data write:data",
|
||||
"jti": "array-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - custom audience is in the array
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err != nil {
|
||||
t.Errorf("Access token with audience array should validate when custom audience is present: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0MismatchedAudience tests that tokens with wrong audience fail validation
|
||||
func TestAuth0MismatchedAudience(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Access token with WRONG audience
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://different-api.example.com", // Wrong audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "wrong-aud-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail validation - audience doesn't match
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(accessToken)
|
||||
if err == nil {
|
||||
t.Error("Access token with wrong audience should fail validation")
|
||||
} else if !strings.Contains(err.Error(), "invalid audience") {
|
||||
t.Errorf("Expected 'invalid audience' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuth0Scenario2StrictMode tests strict audience validation mode:
|
||||
// - Scenario 2 (access token with wrong audience) should be REJECTED
|
||||
// - strictAudienceValidation=true prevents fallback to ID token
|
||||
// - This addresses Allan's security concerns about audience bypass
|
||||
func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Enable strict mode to prevent Scenario 2 bypass (Option C)
|
||||
ts.tOidc.strictAudienceValidation = true
|
||||
|
||||
// Configure custom audience
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (valid)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"nonce": "test-nonce-strict",
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Create access token with WRONG audience (doesn't include custom audience)
|
||||
accessToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": []interface{}{
|
||||
"https://test-issuer.com/userinfo",
|
||||
"https://wrong-api.example.com", // Wrong audience - not our custom audience
|
||||
},
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"scope": "openid profile email",
|
||||
"jti": "access-token-strict-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create access token: %v", err)
|
||||
}
|
||||
|
||||
// Test session validation with wrong access token and valid ID token
|
||||
req := httptest.NewRequest("GET", "https://example.com/test", nil)
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken(accessToken)
|
||||
session.SetIDToken(idToken)
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
if !needsRefresh {
|
||||
t.Errorf("Strict mode: Should signal refresh needed after rejection, got needsRefresh=%v", needsRefresh)
|
||||
}
|
||||
if expired {
|
||||
t.Errorf("Strict mode: Should not mark as expired (should try refresh first), got expired=%v", expired)
|
||||
}
|
||||
|
||||
t.Logf("✓ Strict mode correctly rejected Scenario 2 (access token audience mismatch)")
|
||||
}
|
||||
|
||||
// TestIDTokenAlwaysValidatesAgainstClientID verifies that ID tokens
|
||||
// are ALWAYS validated against client_id, regardless of configured audience
|
||||
func TestIDTokenAlwaysValidatesAgainstClientID(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure a custom audience different from client_id
|
||||
customAudience := "https://my-api.example.com"
|
||||
ts.tOidc.audience = customAudience
|
||||
|
||||
// Create ID token with aud = client_id (per OIDC spec)
|
||||
idToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id", // ID token MUST have client_id
|
||||
"nonce": "test-nonce-123", // ID tokens have nonce for replay protection
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "id-token-client-id-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should validate successfully - ID tokens are checked against client_id
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(idToken)
|
||||
if err != nil {
|
||||
t.Errorf("ID token should validate against client_id even when custom audience is configured: %v", err)
|
||||
}
|
||||
|
||||
// Create ID token with WRONG audience (should fail)
|
||||
wrongIDToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": customAudience, // WRONG - should be client_id
|
||||
"nonce": "test-nonce-wrong-456", // ID token has nonce, so it will be detected as ID token
|
||||
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
||||
"iat": float64(time.Now().Unix()),
|
||||
"sub": "test-user",
|
||||
"email": "test@example.com",
|
||||
"jti": "wrong-id-token-jti",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create wrong ID token: %v", err)
|
||||
}
|
||||
|
||||
// Should fail - ID tokens must have client_id as audience
|
||||
cleanupReplayCache()
|
||||
initReplayCache()
|
||||
err = ts.tOidc.VerifyToken(wrongIDToken)
|
||||
if err == nil {
|
||||
t.Error("ID token with custom audience (not client_id) should fail validation")
|
||||
}
|
||||
}
|
||||
+19
-13
@@ -8,10 +8,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// AUTHENTICATION FLOW
|
||||
// ============================================================================
|
||||
|
||||
// validateRedirectCount checks if redirect limit is exceeded and handles the error
|
||||
func (t *TraefikOidc) validateRedirectCount(session *SessionData, rw http.ResponseWriter, req *http.Request) error {
|
||||
const maxRedirects = 5
|
||||
@@ -223,15 +219,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +246,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+18
-15
@@ -222,17 +222,16 @@ func (bt *BackgroundTask) run() {
|
||||
// TaskCircuitBreaker implements circuit breaker pattern for background task creation
|
||||
// It limits concurrent task execution and tracks failures to prevent system overload
|
||||
type TaskCircuitBreaker struct {
|
||||
state int32 // CircuitBreakerState
|
||||
failureCount int32
|
||||
lastFailureTime int64 // Unix timestamp
|
||||
failureThreshold int32
|
||||
timeout time.Duration
|
||||
logger *Logger
|
||||
// Concurrency limiting
|
||||
concurrentTasks int32 // Current number of running tasks
|
||||
maxConcurrent int32 // Maximum concurrent tasks allowed
|
||||
activeTasks map[string]struct{} // Track active task names
|
||||
tasksMu sync.RWMutex // Separate mutex for task tracking
|
||||
activeTasks map[string]struct{}
|
||||
lastFailureTime int64
|
||||
timeout time.Duration
|
||||
tasksMu sync.RWMutex
|
||||
state int32
|
||||
failureCount int32
|
||||
failureThreshold int32
|
||||
concurrentTasks int32
|
||||
maxConcurrent int32
|
||||
}
|
||||
|
||||
// NewTaskCircuitBreaker creates a new circuit breaker for background tasks
|
||||
@@ -266,18 +265,21 @@ func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
|
||||
max := atomic.LoadInt32(&cb.maxConcurrent)
|
||||
|
||||
// For cleanup tasks, be more restrictive (singleton-like behavior)
|
||||
// However, allow distinct realm-specific tasks (e.g., singleton-metadata-refresh-abc123 vs singleton-metadata-refresh-def456)
|
||||
if strings.Contains(taskName, "cleanup") || strings.Contains(taskName, "singleton") {
|
||||
cb.tasksMu.RLock()
|
||||
hasCleanupTask := false
|
||||
hasSameTask := false
|
||||
for activeTask := range cb.activeTasks {
|
||||
if strings.Contains(activeTask, "cleanup") || strings.Contains(activeTask, "singleton") {
|
||||
hasCleanupTask = true
|
||||
// Only block if the EXACT same task is already running
|
||||
// This allows realm-specific tasks like singleton-metadata-refresh-{hash} to run concurrently
|
||||
if activeTask == taskName {
|
||||
hasSameTask = true
|
||||
break
|
||||
}
|
||||
}
|
||||
cb.tasksMu.RUnlock()
|
||||
|
||||
if hasCleanupTask {
|
||||
if hasSameTask {
|
||||
return fmt.Errorf("cleanup/singleton task already running: %s", taskName)
|
||||
}
|
||||
}
|
||||
@@ -377,9 +379,9 @@ func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
|
||||
// TaskRegistry maintains a registry of all active background tasks to prevent duplicates
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
cb *TaskCircuitBreaker
|
||||
logger *Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// GlobalTaskRegistry is the singleton instance for managing all background tasks
|
||||
@@ -787,6 +789,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
+6
-6
@@ -330,12 +330,12 @@ func TestValidateGoogleTokens(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidGoogleTokens",
|
||||
@@ -476,13 +476,13 @@ func TestIsUserAuthenticated(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
providerType string
|
||||
setupSession func() *SessionData
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "AzureProvider",
|
||||
@@ -660,12 +660,12 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
name string
|
||||
description string
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UnauthenticatedWithRefreshToken",
|
||||
|
||||
@@ -97,15 +97,15 @@ func TestMemoryMonitorComprehensive(t *testing.T) {
|
||||
|
||||
t.Run("String method returns pressure name", func(t *testing.T) {
|
||||
pressures := []struct {
|
||||
level MemoryPressureLevel
|
||||
name string
|
||||
level MemoryPressureLevel
|
||||
}{
|
||||
{MemoryPressureNone, "None"},
|
||||
{MemoryPressureLow, "Low"},
|
||||
{MemoryPressureModerate, "Moderate"},
|
||||
{MemoryPressureHigh, "High"},
|
||||
{MemoryPressureCritical, "Critical"},
|
||||
{MemoryPressureLevel(999), "Unknown"},
|
||||
{level: MemoryPressureNone, name: "None"},
|
||||
{level: MemoryPressureLow, name: "Low"},
|
||||
{level: MemoryPressureModerate, name: "Moderate"},
|
||||
{level: MemoryPressureHigh, name: "High"},
|
||||
{level: MemoryPressureCritical, name: "Critical"},
|
||||
{level: MemoryPressureLevel(999), name: "Unknown"},
|
||||
}
|
||||
|
||||
for _, p := range pressures {
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// UNIVERSAL CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheSet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
cache.Get(fmt.Sprintf("key%d", i%1000))
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheSetGet(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCacheLRUEviction(b *testing.B) {
|
||||
config := createTestCacheConfig()
|
||||
config.MaxSize = 100
|
||||
cache := NewUniversalCache(config)
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheConcurrent(b *testing.B) {
|
||||
cache := NewUniversalCache(createTestCacheConfig())
|
||||
defer cache.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
switch i % 3 {
|
||||
case 0:
|
||||
cache.Set(fmt.Sprintf("key%d", i), fmt.Sprintf("value%d", i), 1*time.Hour)
|
||||
case 1:
|
||||
cache.Get(fmt.Sprintf("key%d", i))
|
||||
case 2:
|
||||
cache.Delete(fmt.Sprintf("key%d", i))
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE MANAGER BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// CACHE COMPATIBILITY BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SHARDED CACHE BENCHMARKS
|
||||
// =============================================================================
|
||||
|
||||
func BenchmarkShardedCache(b *testing.B) {
|
||||
b.Run("Set", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("Get", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
for i := 0; i < 10000; i++ {
|
||||
cache.Set(fmt.Sprintf("key-%d", i), i, 5*time.Minute)
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get(fmt.Sprintf("key-%d", i%10000))
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ParallelSetGet", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
cache.Set(key, i, 5*time.Minute)
|
||||
cache.Get(key)
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkShardedVsGlobalMutex compares sharded cache with global mutex approach
|
||||
func BenchmarkShardedVsGlobalMutex(b *testing.B) {
|
||||
b.Run("ShardedCache64", func(b *testing.B) {
|
||||
cache := NewShardedCache(64, 100000)
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
if !cache.Exists(key) {
|
||||
cache.Set(key, true, 5*time.Minute)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
b.Run("GlobalMutexCache", func(b *testing.B) {
|
||||
var mu sync.RWMutex
|
||||
data := make(map[string]bool)
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
key := fmt.Sprintf("jti-%d", i%10000)
|
||||
|
||||
mu.RLock()
|
||||
_, exists := data[key]
|
||||
mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
mu.Lock()
|
||||
data[key] = true
|
||||
mu.Unlock()
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
+3
-3
@@ -155,9 +155,9 @@ type CacheStrategy interface {
|
||||
|
||||
// CacheEntry for backward compatibility
|
||||
type CacheEntry struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
Value interface{}
|
||||
Key string
|
||||
}
|
||||
|
||||
// Cache is an alias for backward compatibility
|
||||
@@ -175,10 +175,10 @@ func NewOptimizedCacheWithConfig(config OptimizedCacheConfig) *CacheInterfaceWra
|
||||
|
||||
// ListNode for backward compatibility
|
||||
type ListNode struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
Next *ListNode
|
||||
Prev *ListNode
|
||||
Key string
|
||||
}
|
||||
|
||||
// NewFixedMetadataCache creates a metadata cache with fixed configuration
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewBoundedCache tests creation of bounded cache
|
||||
func TestNewBoundedCache(t *testing.T) {
|
||||
maxSize := 500
|
||||
cache := NewBoundedCache(maxSize)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify we can use basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultUnifiedCacheConfig tests default configuration
|
||||
func TestDefaultUnifiedCacheConfig(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
|
||||
if config.Type != CacheTypeGeneral {
|
||||
t.Errorf("Expected CacheTypeGeneral, got %v", config.Type)
|
||||
}
|
||||
|
||||
if config.MaxSize != 500 {
|
||||
t.Errorf("Expected MaxSize 500, got %d", config.MaxSize)
|
||||
}
|
||||
|
||||
if config.MaxMemoryBytes != 64*1024*1024 {
|
||||
t.Errorf("Expected MaxMemoryBytes 64MB, got %d", config.MaxMemoryBytes)
|
||||
}
|
||||
|
||||
if config.CleanupInterval != 2*time.Minute {
|
||||
t.Errorf("Expected CleanupInterval 2 minutes, got %v", config.CleanupInterval)
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
t.Error("Expected Logger to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewUnifiedCache tests unified cache creation
|
||||
func TestNewUnifiedCache(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
if cache.UniversalCache == nil {
|
||||
t.Error("Expected UniversalCache to be set")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedCache_SetMaxSize tests SetMaxSize method
|
||||
func TestUnifiedCache_SetMaxSize(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
// Test setting max size
|
||||
newSize := 1000
|
||||
cache.SetMaxSize(newSize)
|
||||
|
||||
// We can't easily verify the size was set without exposing internal fields,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestNewCacheAdapter tests cache adapter creation
|
||||
func TestNewCacheAdapter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache interface{}
|
||||
expectNil bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UniversalCache",
|
||||
cache: NewUniversalCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UniversalCache",
|
||||
},
|
||||
{
|
||||
name: "UnifiedCache",
|
||||
cache: NewUnifiedCache(DefaultUnifiedCacheConfig()),
|
||||
expectNil: false,
|
||||
description: "Should create adapter for UnifiedCache",
|
||||
},
|
||||
{
|
||||
name: "Invalid cache type",
|
||||
cache: "not-a-cache",
|
||||
expectNil: true,
|
||||
description: "Should return nil for invalid cache type",
|
||||
},
|
||||
{
|
||||
name: "Nil cache",
|
||||
cache: nil,
|
||||
expectNil: true,
|
||||
description: "Should return nil for nil cache",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
adapter := NewCacheAdapter(tt.cache)
|
||||
|
||||
if tt.expectNil {
|
||||
if adapter != nil {
|
||||
t.Errorf("Expected nil adapter, got %v", adapter)
|
||||
}
|
||||
} else {
|
||||
if adapter == nil {
|
||||
t.Error("Expected non-nil adapter")
|
||||
}
|
||||
// Test basic operations
|
||||
adapter.Set("test", "value", time.Hour)
|
||||
value, found := adapter.Get("test")
|
||||
if !found {
|
||||
t.Error("Expected key to be found")
|
||||
}
|
||||
if value != "value" {
|
||||
t.Errorf("Expected 'value', got %v", value)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCache tests optimized cache creation
|
||||
func TestNewOptimizedCache(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewLRUStrategy tests LRU strategy creation
|
||||
func TestNewLRUStrategy(t *testing.T) {
|
||||
maxSize := 100
|
||||
strategy := NewLRUStrategy(maxSize)
|
||||
|
||||
if strategy == nil {
|
||||
t.Fatal("Expected strategy to be created, got nil")
|
||||
}
|
||||
|
||||
lruStrategy, ok := strategy.(*LRUStrategy)
|
||||
if !ok {
|
||||
t.Fatal("Expected LRUStrategy type")
|
||||
}
|
||||
|
||||
if lruStrategy.maxSize != maxSize {
|
||||
t.Errorf("Expected maxSize %d, got %d", maxSize, lruStrategy.maxSize)
|
||||
}
|
||||
|
||||
if lruStrategy.order == nil {
|
||||
t.Error("Expected order list to be initialized")
|
||||
}
|
||||
|
||||
if lruStrategy.elements == nil {
|
||||
t.Error("Expected elements map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_Name tests strategy name
|
||||
func TestLRUStrategy_Name(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
name := strategy.Name()
|
||||
if name != "LRU" {
|
||||
t.Errorf("Expected 'LRU', got %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_ShouldEvict tests eviction logic
|
||||
func TestLRUStrategy_ShouldEvict(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// LRU strategy always returns false for ShouldEvict
|
||||
result := strategy.ShouldEvict("test-item", time.Now())
|
||||
if result != false {
|
||||
t.Error("Expected ShouldEvict to return false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnAccess tests access callback
|
||||
func TestLRUStrategy_OnAccess(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnAccess should not panic
|
||||
strategy.OnAccess("test-key", "test-value")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_OnRemove tests removal callback
|
||||
func TestLRUStrategy_OnRemove(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
// OnRemove should not panic
|
||||
strategy.OnRemove("test-key")
|
||||
}
|
||||
|
||||
// TestLRUStrategy_EstimateSize tests size estimation
|
||||
func TestLRUStrategy_EstimateSize(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
size := strategy.EstimateSize("test-item")
|
||||
if size != 64 {
|
||||
t.Errorf("Expected size 64, got %d", size)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLRUStrategy_GetEvictionCandidate tests eviction candidate retrieval
|
||||
func TestLRUStrategy_GetEvictionCandidate(t *testing.T) {
|
||||
strategy := NewLRUStrategy(100)
|
||||
|
||||
key, found := strategy.GetEvictionCandidate()
|
||||
if found {
|
||||
t.Error("Expected no eviction candidate to be found")
|
||||
}
|
||||
if key != "" {
|
||||
t.Errorf("Expected empty key, got %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewOptimizedCacheWithConfig tests optimized cache with custom config
|
||||
func TestNewOptimizedCacheWithConfig(t *testing.T) {
|
||||
config := UniversalCacheConfig{
|
||||
Type: CacheTypeGeneral,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 128 * 1024 * 1024,
|
||||
EnableMetrics: true,
|
||||
Logger: GetSingletonNoOpLogger(),
|
||||
}
|
||||
|
||||
cache := NewOptimizedCacheWithConfig(config)
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with basic operations
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Error("Expected key to be found in cache")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewFixedMetadataCache tests fixed metadata cache creation
|
||||
func TestNewFixedMetadataCache(t *testing.T) {
|
||||
cache := NewFixedMetadataCache()
|
||||
|
||||
if cache == nil {
|
||||
t.Fatal("Expected cache to be created, got nil")
|
||||
}
|
||||
|
||||
// Verify it works with proper metadata operations
|
||||
metadata := &ProviderMetadata{
|
||||
Issuer: "https://example.com",
|
||||
AuthURL: "https://example.com/auth",
|
||||
TokenURL: "https://example.com/token",
|
||||
JWKSURL: "https://example.com/jwks",
|
||||
}
|
||||
|
||||
err := cache.Set("test-provider", metadata, time.Hour)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error setting metadata: %v", err)
|
||||
}
|
||||
|
||||
// Test that the cache was created (basic verification)
|
||||
// Note: We can't easily test Get without more complex setup
|
||||
}
|
||||
|
||||
// TestNewDoublyLinkedList tests doubly linked list creation
|
||||
func TestNewDoublyLinkedList(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
if list == nil {
|
||||
t.Fatal("Expected list to be created, got nil")
|
||||
}
|
||||
|
||||
// Test it's a proper list structure
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected empty list initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDoublyLinkedList_PopFront tests front element removal
|
||||
func TestDoublyLinkedList_PopFront(t *testing.T) {
|
||||
list := NewDoublyLinkedList()
|
||||
|
||||
// Test popping from empty list
|
||||
element := list.PopFront()
|
||||
if element != nil {
|
||||
t.Error("Expected nil when popping from empty list")
|
||||
}
|
||||
|
||||
// Add an element and test popping
|
||||
added := list.PushBack("test-value")
|
||||
if added == nil {
|
||||
t.Fatal("Expected element to be added")
|
||||
}
|
||||
|
||||
popped := list.PopFront()
|
||||
if popped == nil {
|
||||
t.Error("Expected element to be popped")
|
||||
}
|
||||
|
||||
if list.Len() != 0 {
|
||||
t.Error("Expected list to be empty after popping")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkNewBoundedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewBoundedCache(1000)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewOptimizedCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewOptimizedCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLRUStrategy_EstimateSize(b *testing.B) {
|
||||
strategy := NewLRUStrategy(1000)
|
||||
item := "test-item"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
strategy.EstimateSize(item)
|
||||
}
|
||||
}
|
||||
@@ -1,314 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Helper function to ensure we have a working cache manager for tests
|
||||
func getTestCacheManager(t *testing.T) *CacheManager {
|
||||
cm := GetGlobalCacheManager(&sync.WaitGroup{})
|
||||
if cm == nil {
|
||||
t.Fatal("Failed to get cache manager")
|
||||
}
|
||||
if cm.manager == nil {
|
||||
t.Fatal("Cache manager has nil internal manager")
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
// TestCacheManager_Close tests cache manager close functionality
|
||||
func TestCacheManager_Close(t *testing.T) {
|
||||
// Get a fresh cache manager
|
||||
wg := &sync.WaitGroup{}
|
||||
cm := GetGlobalCacheManager(wg)
|
||||
|
||||
if cm == nil {
|
||||
t.Fatal("Expected cache manager to be created")
|
||||
}
|
||||
|
||||
// Test closing the cache manager
|
||||
err := cm.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error closing cache manager: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanupGlobalCacheManager tests global cleanup
|
||||
func TestCleanupGlobalCacheManager(t *testing.T) {
|
||||
// Test cleanup when no instance exists (should not error)
|
||||
originalInstance := globalCacheManagerInstance
|
||||
globalCacheManagerInstance = nil
|
||||
err := CleanupGlobalCacheManager()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error during cleanup of nil instance: %v", err)
|
||||
}
|
||||
|
||||
// Restore original instance
|
||||
globalCacheManagerInstance = originalInstance
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Delete tests delete functionality
|
||||
func TestCacheInterfaceWrapper_Delete(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item
|
||||
cache.Set("test-key", "test-value", time.Hour)
|
||||
|
||||
// Verify it exists
|
||||
value, found := cache.Get("test-key")
|
||||
if !found {
|
||||
t.Fatal("Expected key to be found after setting")
|
||||
}
|
||||
if value != "test-value" {
|
||||
t.Errorf("Expected 'test-value', got %v", value)
|
||||
}
|
||||
|
||||
// Delete it
|
||||
cache.Delete("test-key")
|
||||
|
||||
// Verify it's gone
|
||||
_, found = cache.Get("test-key")
|
||||
if found {
|
||||
t.Error("Expected key to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Size tests size functionality
|
||||
func TestCacheInterfaceWrapper_Size(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Clear cache first
|
||||
cache.Clear()
|
||||
|
||||
// Check initial size
|
||||
initialSize := cache.Size()
|
||||
if initialSize != 0 {
|
||||
t.Errorf("Expected initial size 0, got %d", initialSize)
|
||||
}
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Check size increased
|
||||
newSize := cache.Size()
|
||||
if newSize != 2 {
|
||||
t.Errorf("Expected size 2, got %d", newSize)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Clear tests clear functionality
|
||||
func TestCacheInterfaceWrapper_Clear(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add some items
|
||||
cache.Set("key1", "value1", time.Hour)
|
||||
cache.Set("key2", "value2", time.Hour)
|
||||
|
||||
// Verify items exist
|
||||
size := cache.Size()
|
||||
if size != 2 {
|
||||
t.Errorf("Expected 2 items before clear, got %d", size)
|
||||
}
|
||||
|
||||
// Clear all
|
||||
cache.Clear()
|
||||
|
||||
// Verify cache is empty
|
||||
size = cache.Size()
|
||||
if size != 0 {
|
||||
t.Errorf("Expected 0 items after clear, got %d", size)
|
||||
}
|
||||
|
||||
// Verify specific items are gone
|
||||
_, found := cache.Get("key1")
|
||||
if found {
|
||||
t.Error("Expected key1 to be cleared")
|
||||
}
|
||||
|
||||
_, found = cache.Get("key2")
|
||||
if found {
|
||||
t.Error("Expected key2 to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Close tests wrapper close functionality
|
||||
func TestCacheInterfaceWrapper_Close(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test close - should not panic
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
wrapper.Close() // Should not panic
|
||||
|
||||
// Test close with nil cache
|
||||
nilWrapper := &CacheInterfaceWrapper{cache: nil}
|
||||
nilWrapper.Close() // Should not panic
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_GetStats tests stats functionality
|
||||
func TestCacheInterfaceWrapper_GetStats(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
wrapper, ok := cache.(*CacheInterfaceWrapper)
|
||||
if !ok {
|
||||
t.Fatal("Expected CacheInterfaceWrapper")
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := wrapper.GetStats()
|
||||
if stats == nil {
|
||||
t.Error("Expected non-nil stats")
|
||||
}
|
||||
|
||||
// Stats should be accessible (len() never returns negative values)
|
||||
// Just verify it's accessible by checking it's not nil (already done above)
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_Cleanup tests cleanup functionality
|
||||
func TestCacheInterfaceWrapper_Cleanup(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Add an item that will expire quickly
|
||||
cache.Set("expire-key", "expire-value", time.Millisecond)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Trigger cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Item should be cleaned up
|
||||
_, found := cache.Get("expire-key")
|
||||
if found {
|
||||
t.Error("Expected expired key to be cleaned up")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCacheInterfaceWrapper_SetMaxSize tests max size setting
|
||||
func TestCacheInterfaceWrapper_SetMaxSize(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Test setting max size (should not panic)
|
||||
cache.SetMaxSize(1000)
|
||||
|
||||
// We can't easily verify the size was set without exposing internals,
|
||||
// but we can ensure the method doesn't panic
|
||||
}
|
||||
|
||||
// TestGetSharedCaches tests getting shared cache instances
|
||||
func TestGetSharedCaches(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
|
||||
// Test getting shared token blacklist
|
||||
blacklist := cm.GetSharedTokenBlacklist()
|
||||
if blacklist == nil {
|
||||
t.Error("Expected non-nil token blacklist")
|
||||
}
|
||||
|
||||
// Test getting shared token cache
|
||||
tokenCache := cm.GetSharedTokenCache()
|
||||
if tokenCache == nil {
|
||||
t.Error("Expected non-nil token cache")
|
||||
}
|
||||
|
||||
// Test getting shared metadata cache
|
||||
metadataCache := cm.GetSharedMetadataCache()
|
||||
if metadataCache == nil {
|
||||
t.Error("Expected non-nil metadata cache")
|
||||
}
|
||||
|
||||
// Test getting shared JWK cache
|
||||
jwkCache := cm.GetSharedJWKCache()
|
||||
if jwkCache == nil {
|
||||
t.Error("Expected non-nil JWK cache")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentCacheAccess tests thread safety
|
||||
func TestConcurrentCacheAccess(t *testing.T) {
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 10
|
||||
|
||||
// Concurrent operations
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("key-%d-%d", id, j)
|
||||
value := fmt.Sprintf("value-%d-%d", id, j)
|
||||
|
||||
cache.Set(key, value, time.Hour)
|
||||
|
||||
retrieved, found := cache.Get(key)
|
||||
if found && retrieved != value {
|
||||
t.Errorf("Concurrent access failed: expected %s, got %v", value, retrieved)
|
||||
}
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Benchmark tests for performance
|
||||
func BenchmarkCacheInterfaceWrapper_Set(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Get(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
// Pre-populate cache
|
||||
cache.Set("benchmark-key", "benchmark-value", time.Hour)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache.Get("benchmark-key")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCacheInterfaceWrapper_Delete(b *testing.B) {
|
||||
t := &testing.T{}
|
||||
cm := getTestCacheManager(t)
|
||||
cache := cm.GetSharedTokenBlacklist()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
key := fmt.Sprintf("benchmark-key-%d", i)
|
||||
cache.Set(key, "value", time.Hour)
|
||||
b.StartTimer()
|
||||
|
||||
cache.Delete(key)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,319 +0,0 @@
|
||||
// Package circuit_breaker provides circuit breaker implementation for resilience
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the current state of a circuit breaker.
|
||||
// The circuit breaker pattern prevents cascading failures by monitoring
|
||||
// error rates and temporarily blocking requests to failing services.
|
||||
type CircuitBreakerState int
|
||||
|
||||
// Circuit breaker states following the standard pattern:
|
||||
// Closed: Normal operation, requests flow through
|
||||
// Open: Circuit is tripped, requests are blocked
|
||||
// HalfOpen: Testing state, limited requests allowed to test recovery
|
||||
const (
|
||||
// CircuitBreakerClosed allows all requests through (normal operation)
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen blocks all requests (service is failing)
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen allows limited requests to test service recovery
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
// String returns a string representation of the circuit breaker state
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Infof(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Debugf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism interface for common functionality
|
||||
type BaseRecoveryMechanism interface {
|
||||
RecordRequest()
|
||||
RecordSuccess()
|
||||
RecordFailure()
|
||||
GetBaseMetrics() map[string]interface{}
|
||||
LogInfo(format string, args ...interface{})
|
||||
LogError(format string, args ...interface{})
|
||||
LogDebug(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for external service calls.
|
||||
// It monitors failure rates and automatically opens the circuit when failures
|
||||
// exceed the threshold, preventing further requests until the service recovers.
|
||||
type CircuitBreaker struct {
|
||||
// baseRecovery provides common functionality
|
||||
baseRecovery BaseRecoveryMechanism
|
||||
// maxFailures is the threshold for opening the circuit
|
||||
maxFailures int
|
||||
// timeout is how long to wait before allowing requests in half-open state
|
||||
timeout time.Duration
|
||||
// resetTimeout is how long to wait before transitioning from open to half-open
|
||||
resetTimeout time.Duration
|
||||
// state tracks the current circuit breaker state
|
||||
state CircuitBreakerState
|
||||
// failures counts consecutive failures
|
||||
failures int64
|
||||
// lastFailureTime records when the last failure occurred
|
||||
lastFailureTime time.Time
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// logger for debugging and monitoring
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration parameters for circuit breakers.
|
||||
// These settings control when the circuit opens and how it recovers.
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of failures before opening the circuit
|
||||
MaxFailures int `json:"max_failures"`
|
||||
// Timeout is how long to wait before trying to recover (open -> half-open)
|
||||
Timeout time.Duration `json:"timeout"`
|
||||
// ResetTimeout is how long to wait before fully closing the circuit
|
||||
ResetTimeout time.Duration `json:"reset_timeout"`
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns sensible default configuration for circuit breakers.
|
||||
// Configured for typical web service scenarios with moderate tolerance for failures.
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the specified configuration.
|
||||
// The circuit breaker starts in the closed state, allowing all requests through.
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger, baseRecovery BaseRecoveryMechanism) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
baseRecovery: baseRecovery,
|
||||
maxFailures: config.MaxFailures,
|
||||
timeout: config.Timeout,
|
||||
resetTimeout: config.ResetTimeout,
|
||||
state: CircuitBreakerClosed,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a function through the circuit breaker with context.
|
||||
// It checks if requests are allowed, executes the function, and updates the circuit state
|
||||
// based on the result. Implements the ErrorRecoveryMechanism interface.
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordRequest()
|
||||
}
|
||||
|
||||
if !cb.allowRequest() {
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.RecordSuccess()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute executes a function through the circuit breaker without context.
|
||||
// This is provided for backward compatibility with existing code.
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest determines whether to allow a request based on the circuit state.
|
||||
// Handles state transitions from open to half-open based on timeout.
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
return true
|
||||
|
||||
case CircuitBreakerOpen:
|
||||
if now.Sub(cb.lastFailureTime) > cb.timeout {
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
if cb.logger != nil {
|
||||
cb.logger.Infof("Circuit breaker transitioning to half-open state")
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit.
|
||||
// Updates failure count and triggers state transitions when thresholds are exceeded.
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.failures++
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerClosed:
|
||||
if cb.failures >= int64(cb.maxFailures) {
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker opened after %d failures", cb.failures)
|
||||
}
|
||||
}
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.state = CircuitBreakerOpen
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogError("Circuit breaker returned to open state after failure in half-open")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a successful request and potentially closes the circuit.
|
||||
// Resets failure count and transitions from half-open to closed state on success.
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
switch cb.state {
|
||||
case CircuitBreakerHalfOpen:
|
||||
cb.failures = 0
|
||||
cb.state = CircuitBreakerClosed
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker closed after successful request in half-open state")
|
||||
}
|
||||
|
||||
case CircuitBreakerClosed:
|
||||
cb.failures = 0
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker.
|
||||
// Thread-safe method for monitoring circuit breaker status.
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to its initial closed state.
|
||||
// Clears failure count and state, effectively recovering from any open state.
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.state = CircuitBreakerClosed
|
||||
atomic.StoreInt64(&cb.failures, 0)
|
||||
if cb.baseRecovery != nil {
|
||||
cb.baseRecovery.LogInfo("Circuit breaker has been reset")
|
||||
}
|
||||
}
|
||||
|
||||
// IsAvailable returns whether the circuit breaker is currently allowing requests.
|
||||
// This provides a quick way to check if the service is available.
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
return cb.allowRequest()
|
||||
}
|
||||
|
||||
// GetMetrics returns comprehensive metrics about the circuit breaker.
|
||||
// Includes state information, failure counts, configuration, and base metrics.
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
cb.mutex.RLock()
|
||||
state := cb.state
|
||||
failures := cb.failures
|
||||
lastFailureTime := cb.lastFailureTime
|
||||
cb.mutex.RUnlock()
|
||||
|
||||
var metrics map[string]interface{}
|
||||
if cb.baseRecovery != nil {
|
||||
metrics = cb.baseRecovery.GetBaseMetrics()
|
||||
} else {
|
||||
metrics = make(map[string]interface{})
|
||||
}
|
||||
|
||||
metrics["state"] = state.String()
|
||||
metrics["current_failures"] = failures
|
||||
metrics["max_failures"] = cb.maxFailures
|
||||
metrics["timeout"] = cb.timeout.String()
|
||||
metrics["reset_timeout"] = cb.resetTimeout.String()
|
||||
|
||||
if !lastFailureTime.IsZero() {
|
||||
metrics["last_failure_time"] = lastFailureTime
|
||||
metrics["time_since_last_failure"] = time.Since(lastFailureTime).String()
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// GetFailureCount returns the current failure count
|
||||
func (cb *CircuitBreaker) GetFailureCount() int64 {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.failures
|
||||
}
|
||||
|
||||
// GetLastFailureTime returns the time of the last failure
|
||||
func (cb *CircuitBreaker) GetLastFailureTime() time.Time {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.lastFailureTime
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is in open state
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// IsClosed returns true if the circuit breaker is in closed state
|
||||
func (cb *CircuitBreaker) IsClosed() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerClosed
|
||||
}
|
||||
|
||||
// IsHalfOpen returns true if the circuit breaker is in half-open state
|
||||
func (cb *CircuitBreaker) IsHalfOpen() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state == CircuitBreakerHalfOpen
|
||||
}
|
||||
@@ -1,981 +0,0 @@
|
||||
package circuit_breaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock implementations for testing
|
||||
type mockLogger struct {
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Infof(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockLogger) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future error log verification tests
|
||||
func (m *mockLogger) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
//lint:ignore U1000 May be needed for future test isolation
|
||||
func (m *mockLogger) reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = nil
|
||||
m.errorLogs = nil
|
||||
m.debugLogs = nil
|
||||
}
|
||||
|
||||
type mockBaseRecoveryMechanism struct {
|
||||
requestCount int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
infoLogs []string
|
||||
errorLogs []string
|
||||
debugLogs []string
|
||||
baseMetrics map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newMockBaseRecovery() *mockBaseRecoveryMechanism {
|
||||
return &mockBaseRecoveryMechanism{
|
||||
baseMetrics: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&m.requestCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&m.successCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&m.failureCount, 1)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range m.baseMetrics {
|
||||
result[k] = v
|
||||
}
|
||||
result["total_requests"] = atomic.LoadInt64(&m.requestCount)
|
||||
result["total_successes"] = atomic.LoadInt64(&m.successCount)
|
||||
result["total_failures"] = atomic.LoadInt64(&m.failureCount)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.infoLogs = append(m.infoLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogError(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errorLogs = append(m.errorLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLogs = append(m.debugLogs, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getRequestCount() int64 {
|
||||
return atomic.LoadInt64(&m.requestCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getSuccessCount() int64 {
|
||||
return atomic.LoadInt64(&m.successCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getFailureCount() int64 {
|
||||
return atomic.LoadInt64(&m.failureCount)
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getInfoLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.infoLogs))
|
||||
copy(result, m.infoLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *mockBaseRecoveryMechanism) getErrorLogs() []string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
result := make([]string, len(m.errorLogs))
|
||||
copy(result, m.errorLogs)
|
||||
return result
|
||||
}
|
||||
|
||||
func TestCircuitBreakerState_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
state CircuitBreakerState
|
||||
expected string
|
||||
}{
|
||||
{CircuitBreakerClosed, "closed"},
|
||||
{CircuitBreakerOpen, "open"},
|
||||
{CircuitBreakerHalfOpen, "half-open"},
|
||||
{CircuitBreakerState(999), "unknown"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.expected, func(t *testing.T) {
|
||||
result := tt.state.String()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures to be 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout to be 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout to be 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCircuitBreaker(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
if cb == nil {
|
||||
t.Fatal("NewCircuitBreaker returned nil")
|
||||
}
|
||||
|
||||
if cb.maxFailures != 3 {
|
||||
t.Errorf("Expected maxFailures to be 3, got %d", cb.maxFailures)
|
||||
}
|
||||
|
||||
if cb.timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout to be 30s, got %v", cb.timeout)
|
||||
}
|
||||
|
||||
if cb.resetTimeout != 15*time.Second {
|
||||
t.Errorf("Expected resetTimeout to be 15s, got %v", cb.resetTimeout)
|
||||
}
|
||||
|
||||
if cb.state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be Closed, got %v", cb.state)
|
||||
}
|
||||
|
||||
if cb.logger != logger {
|
||||
t.Error("Expected logger to be set")
|
||||
}
|
||||
|
||||
if cb.baseRecovery != baseRecovery {
|
||||
t.Error("Expected baseRecovery to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Success(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getSuccessCount() != 1 {
|
||||
t.Errorf("Expected 1 success recorded, got %d", baseRecovery.getSuccessCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithContext_Failure(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after single failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if baseRecovery.getRequestCount() != 1 {
|
||||
t.Errorf("Expected 1 request recorded, got %d", baseRecovery.getRequestCount())
|
||||
}
|
||||
|
||||
if baseRecovery.getFailureCount() != 1 {
|
||||
t.Errorf("Expected 1 failure recorded, got %d", baseRecovery.getFailureCount())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Execute(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.Execute(testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_OpenAfterMaxFailures(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First failure
|
||||
err := cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on first failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to remain Closed after first failure, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Second failure - should open circuit
|
||||
err = cb.ExecuteWithContext(ctx, testFunc)
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error on second failure, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open after max failures, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Third attempt - should be blocked
|
||||
callCount := 0
|
||||
blockedFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
err = cb.ExecuteWithContext(ctx, blockedFunc)
|
||||
if err == nil {
|
||||
t.Error("Expected error when circuit is open")
|
||||
}
|
||||
if callCount != 0 {
|
||||
t.Errorf("Expected function not to be called when circuit is open, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenTransition(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond, // Very short for testing
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next request should transition to half-open
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err = cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error in half-open state, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called in half-open state, got %d calls", callCount)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after successful half-open request, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_HalfOpenFailureReturnsToOpen(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// First call should transition to half-open, but we'll force it by checking allowRequest
|
||||
if !cb.allowRequest() {
|
||||
t.Error("Expected allowRequest to return true after timeout")
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerHalfOpen {
|
||||
t.Errorf("Expected state to be HalfOpen, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Failure in half-open should return to open
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if err != testError {
|
||||
t.Errorf("Expected test error, got %v", err)
|
||||
}
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to return to Open after half-open failure, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
_ = cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Errorf("Expected state to be Open, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
// Reset circuit
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed after reset, got %v", cb.GetState())
|
||||
}
|
||||
|
||||
if cb.GetFailureCount() != 0 {
|
||||
t.Errorf("Expected failure count to be 0 after reset, got %d", cb.GetFailureCount())
|
||||
}
|
||||
|
||||
// Should allow requests again
|
||||
callCount := 0
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
callCount++
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after reset, got %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called after reset, got %d calls", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially available
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should not be available when open
|
||||
if cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be unavailable when open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should be available again after timeout (half-open)
|
||||
if !cb.IsAvailable() {
|
||||
t.Error("Expected circuit breaker to be available after timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateCheckers(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially closed
|
||||
if !cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker to be closed initially")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open initially")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open initially")
|
||||
}
|
||||
|
||||
// Trigger opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Should be open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when open")
|
||||
}
|
||||
if !cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker to be open")
|
||||
}
|
||||
if cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker not to be half-open when open")
|
||||
}
|
||||
|
||||
// Wait for timeout and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
cb.allowRequest() // This will transition to half-open
|
||||
|
||||
// Should be half-open
|
||||
if cb.IsClosed() {
|
||||
t.Error("Expected circuit breaker not to be closed when half-open")
|
||||
}
|
||||
if cb.IsOpen() {
|
||||
t.Error("Expected circuit breaker not to be open when half-open")
|
||||
}
|
||||
if !cb.IsHalfOpen() {
|
||||
t.Error("Expected circuit breaker to be half-open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 30 * time.Second,
|
||||
ResetTimeout: 15 * time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
baseRecovery.baseMetrics["custom_metric"] = "custom_value"
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Record some activity
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Check circuit breaker specific metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["current_failures"] != int64(1) {
|
||||
t.Errorf("Expected current_failures to be 1, got %v", metrics["current_failures"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
if metrics["timeout"] != "30s" {
|
||||
t.Errorf("Expected timeout to be '30s', got %v", metrics["timeout"])
|
||||
}
|
||||
|
||||
if metrics["reset_timeout"] != "15s" {
|
||||
t.Errorf("Expected reset_timeout to be '15s', got %v", metrics["reset_timeout"])
|
||||
}
|
||||
|
||||
// Check base metrics are included
|
||||
if metrics["total_requests"] != int64(1) {
|
||||
t.Errorf("Expected total_requests to be 1, got %v", metrics["total_requests"])
|
||||
}
|
||||
|
||||
if metrics["custom_metric"] != "custom_value" {
|
||||
t.Errorf("Expected custom_metric to be 'custom_value', got %v", metrics["custom_metric"])
|
||||
}
|
||||
|
||||
// Check failure time metrics
|
||||
if _, exists := metrics["last_failure_time"]; !exists {
|
||||
t.Error("Expected last_failure_time to exist")
|
||||
}
|
||||
|
||||
if _, exists := metrics["time_since_last_failure"]; !exists {
|
||||
t.Error("Expected time_since_last_failure to exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetMetrics_NoBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
|
||||
// Should still have circuit breaker metrics
|
||||
if metrics["state"] != "closed" {
|
||||
t.Errorf("Expected state to be 'closed', got %v", metrics["state"])
|
||||
}
|
||||
|
||||
if metrics["max_failures"] != 2 {
|
||||
t.Errorf("Expected max_failures to be 2, got %v", metrics["max_failures"])
|
||||
}
|
||||
|
||||
// Should not have base metrics
|
||||
if _, exists := metrics["total_requests"]; exists {
|
||||
t.Error("Expected total_requests not to exist without base recovery")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_GetLastFailureTime(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Initially should be zero
|
||||
if !cb.GetLastFailureTime().IsZero() {
|
||||
t.Error("Expected last failure time to be zero initially")
|
||||
}
|
||||
|
||||
// Record a failure
|
||||
before := time.Now()
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
after := time.Now()
|
||||
|
||||
lastFailure := cb.GetLastFailureTime()
|
||||
if lastFailure.IsZero() {
|
||||
t.Error("Expected last failure time to be set after failure")
|
||||
}
|
||||
|
||||
if lastFailure.Before(before) || lastFailure.After(after) {
|
||||
t.Errorf("Expected last failure time to be between %v and %v, got %v",
|
||||
before, after, lastFailure)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ExecuteWithoutBaseRecovery(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
cb := NewCircuitBreaker(config, logger, nil)
|
||||
|
||||
callCount := 0
|
||||
testFunc := func() error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
err := cb.ExecuteWithContext(context.Background(), testFunc)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Expected function to be called once, got %d", callCount)
|
||||
}
|
||||
|
||||
// Should work fine without base recovery
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be Closed, got %v", cb.GetState())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 10, // Higher threshold for concurrent test
|
||||
Timeout: 100 * time.Millisecond,
|
||||
ResetTimeout: 50 * time.Millisecond,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
const numGoroutines = 10
|
||||
const numOperations = 50
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int64(0)
|
||||
errorCount := int64(0)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
err := cb.ExecuteWithContext(context.Background(), func() error {
|
||||
// Simulate some failures
|
||||
if j%10 == 9 { // Every 10th operation fails
|
||||
return fmt.Errorf("simulated error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
atomic.AddInt64(&errorCount, 1)
|
||||
} else {
|
||||
atomic.AddInt64(&successCount, 1)
|
||||
}
|
||||
|
||||
// Intermittently check state and metrics
|
||||
if j%5 == 0 {
|
||||
cb.GetState()
|
||||
cb.GetMetrics()
|
||||
cb.IsAvailable()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify we got both successes and errors
|
||||
finalSuccessCount := atomic.LoadInt64(&successCount)
|
||||
finalErrorCount := atomic.LoadInt64(&errorCount)
|
||||
|
||||
if finalSuccessCount == 0 {
|
||||
t.Error("Expected some successful operations")
|
||||
}
|
||||
|
||||
if finalErrorCount == 0 {
|
||||
t.Error("Expected some failed operations")
|
||||
}
|
||||
|
||||
totalOperations := finalSuccessCount + finalErrorCount
|
||||
expectedMax := int64(numGoroutines * numOperations)
|
||||
|
||||
if totalOperations > expectedMax {
|
||||
t.Errorf("Expected at most %d operations, got %d", expectedMax, totalOperations)
|
||||
}
|
||||
|
||||
t.Logf("Concurrent test completed: %d successes, %d errors, final state: %v",
|
||||
finalSuccessCount, finalErrorCount, cb.GetState())
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_StateTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Trigger circuit opening
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Check that error was logged when circuit opened
|
||||
errorLogs := baseRecovery.getErrorLogs()
|
||||
if len(errorLogs) == 0 {
|
||||
t.Error("Expected error log when circuit breaker opened")
|
||||
} else {
|
||||
if !contains(errorLogs, "Circuit breaker opened after") {
|
||||
t.Errorf("Expected circuit opening log, got %v", errorLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait and trigger half-open
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Successful request should close circuit and log
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Check that success was logged when circuit closed
|
||||
infoLogs := baseRecovery.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log when circuit breaker closed")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker closed after successful request") {
|
||||
t.Errorf("Expected circuit closing log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset should also be logged
|
||||
cb.Reset()
|
||||
infoLogs = baseRecovery.getInfoLogs()
|
||||
if !contains(infoLogs, "Circuit breaker has been reset") {
|
||||
t.Errorf("Expected reset log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCircuitBreaker_LoggerTransitionLogging(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Wait for timeout and check half-open transition logging
|
||||
testError := fmt.Errorf("test error")
|
||||
cb.ExecuteWithContext(context.Background(), func() error {
|
||||
return testError
|
||||
})
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Next allowRequest call should log transition to half-open
|
||||
cb.allowRequest()
|
||||
|
||||
infoLogs := logger.getInfoLogs()
|
||||
if len(infoLogs) == 0 {
|
||||
t.Error("Expected info log for half-open transition")
|
||||
} else {
|
||||
if !contains(infoLogs, "Circuit breaker transitioning to half-open state") {
|
||||
t.Errorf("Expected half-open transition log, got %v", infoLogs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a slice contains a string with substring
|
||||
func contains(slice []string, substr string) bool {
|
||||
for _, s := range slice {
|
||||
if len(s) >= len(substr) && s[:len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Success(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testFunc := func() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithContext_Failure(b *testing.B) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1000, // High threshold to avoid opening during benchmark
|
||||
Timeout: time.Second,
|
||||
ResetTimeout: time.Second,
|
||||
}
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
testError := fmt.Errorf("test error")
|
||||
testFunc := func() error {
|
||||
return testError
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.ExecuteWithContext(ctx, testFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetState(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
cb.GetState()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_GetMetrics(b *testing.B) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := &mockLogger{}
|
||||
baseRecovery := newMockBaseRecovery()
|
||||
cb := NewCircuitBreaker(config, logger, baseRecovery)
|
||||
|
||||
// Add some activity
|
||||
for i := 0; i < 100; i++ {
|
||||
if i%2 == 0 {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return nil })
|
||||
} else {
|
||||
cb.ExecuteWithContext(context.Background(), func() error { return fmt.Errorf("error") })
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.GetMetrics()
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
// Package config provides backward compatibility for legacy configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// LegacyAdapter provides backward compatibility for old Config struct
|
||||
type LegacyAdapter struct {
|
||||
unified *UnifiedConfig
|
||||
adapter *compat.ConfigAdapter
|
||||
}
|
||||
|
||||
// NewLegacyAdapter creates a new legacy adapter from unified config
|
||||
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
|
||||
adapter := compat.NewConfigAdapter(unified)
|
||||
|
||||
// Register getters for commonly used fields
|
||||
adapter.RegisterGetter("ProviderURL", func() interface{} {
|
||||
return unified.Provider.IssuerURL
|
||||
})
|
||||
adapter.RegisterGetter("ClientID", func() interface{} {
|
||||
return unified.Provider.ClientID
|
||||
})
|
||||
adapter.RegisterGetter("ClientSecret", func() interface{} {
|
||||
return unified.Provider.ClientSecret
|
||||
})
|
||||
adapter.RegisterGetter("CallbackURL", func() interface{} {
|
||||
return unified.Provider.RedirectURL
|
||||
})
|
||||
adapter.RegisterGetter("LogoutURL", func() interface{} {
|
||||
return unified.Provider.LogoutURL
|
||||
})
|
||||
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
|
||||
return unified.Provider.PostLogoutRedirectURI
|
||||
})
|
||||
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
|
||||
return unified.Session.EncryptionKey
|
||||
})
|
||||
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
|
||||
return unified.Security.ForceHTTPS
|
||||
})
|
||||
adapter.RegisterGetter("LogLevel", func() interface{} {
|
||||
return unified.Logging.Level
|
||||
})
|
||||
adapter.RegisterGetter("Scopes", func() interface{} {
|
||||
return unified.Provider.Scopes
|
||||
})
|
||||
adapter.RegisterGetter("OverrideScopes", func() interface{} {
|
||||
return unified.Provider.OverrideScopes
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUsers", func() interface{} {
|
||||
return unified.Security.AllowedUsers
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
|
||||
return unified.Security.AllowedUserDomains
|
||||
})
|
||||
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
|
||||
return unified.Security.AllowedRolesAndGroups
|
||||
})
|
||||
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
|
||||
return unified.Security.ExcludedURLs
|
||||
})
|
||||
adapter.RegisterGetter("EnablePKCE", func() interface{} {
|
||||
return unified.Security.EnablePKCE
|
||||
})
|
||||
adapter.RegisterGetter("RateLimit", func() interface{} {
|
||||
return unified.RateLimit.RequestsPerSecond
|
||||
})
|
||||
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
|
||||
return int(unified.Token.RefreshGracePeriod.Seconds())
|
||||
})
|
||||
adapter.RegisterGetter("CookieDomain", func() interface{} {
|
||||
return unified.Session.Domain
|
||||
})
|
||||
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
|
||||
return unified.Security.Headers
|
||||
})
|
||||
|
||||
return &LegacyAdapter{
|
||||
unified: unified,
|
||||
adapter: adapter,
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig converts unified config to old Config struct format
|
||||
func (la *LegacyAdapter) ToOldConfig() *Config {
|
||||
// Use feature flags to determine behavior
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Return existing Config if unified config not enabled
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ProviderURL: la.unified.Provider.IssuerURL,
|
||||
ClientID: la.unified.Provider.ClientID,
|
||||
ClientSecret: la.unified.Provider.ClientSecret,
|
||||
CallbackURL: la.unified.Provider.RedirectURL,
|
||||
LogoutURL: la.unified.Provider.LogoutURL,
|
||||
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
|
||||
SessionEncryptionKey: la.unified.Session.EncryptionKey,
|
||||
ForceHTTPS: la.unified.Security.ForceHTTPS,
|
||||
LogLevel: la.unified.Logging.Level,
|
||||
Scopes: la.unified.Provider.Scopes,
|
||||
OverrideScopes: la.unified.Provider.OverrideScopes,
|
||||
AllowedUsers: la.unified.Security.AllowedUsers,
|
||||
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
|
||||
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
|
||||
ExcludedURLs: la.unified.Security.ExcludedURLs,
|
||||
EnablePKCE: la.unified.Security.EnablePKCE,
|
||||
RateLimit: la.unified.RateLimit.RequestsPerSecond,
|
||||
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
|
||||
Headers: la.convertHeaders(),
|
||||
CookieDomain: la.unified.Session.Domain,
|
||||
SecurityHeaders: la.unified.Security.Headers,
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// convertHeaders converts unified header config to old format
|
||||
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
|
||||
headers := make([]HeaderConfig, 0)
|
||||
|
||||
for name, value := range la.unified.Middleware.CustomHeaders {
|
||||
headers = append(headers, HeaderConfig{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// FromOldConfig creates unified config from old Config struct
|
||||
func FromOldConfig(old *Config) *UnifiedConfig {
|
||||
unified := NewUnifiedConfig()
|
||||
|
||||
// Map provider settings
|
||||
unified.Provider.IssuerURL = old.ProviderURL
|
||||
unified.Provider.ClientID = old.ClientID
|
||||
unified.Provider.ClientSecret = old.ClientSecret
|
||||
unified.Provider.RedirectURL = old.CallbackURL
|
||||
unified.Provider.LogoutURL = old.LogoutURL
|
||||
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
|
||||
unified.Provider.Scopes = old.Scopes
|
||||
unified.Provider.OverrideScopes = old.OverrideScopes
|
||||
|
||||
// Map session settings
|
||||
unified.Session.EncryptionKey = old.SessionEncryptionKey
|
||||
unified.Session.Domain = old.CookieDomain
|
||||
|
||||
// Map security settings
|
||||
unified.Security.ForceHTTPS = old.ForceHTTPS
|
||||
unified.Security.EnablePKCE = old.EnablePKCE
|
||||
unified.Security.AllowedUsers = old.AllowedUsers
|
||||
unified.Security.AllowedUserDomains = old.AllowedUserDomains
|
||||
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
|
||||
unified.Security.ExcludedURLs = old.ExcludedURLs
|
||||
unified.Security.Headers = old.SecurityHeaders
|
||||
|
||||
// Map rate limiting
|
||||
unified.RateLimit.RequestsPerSecond = old.RateLimit
|
||||
unified.RateLimit.Enabled = old.RateLimit > 0
|
||||
|
||||
// Map token settings
|
||||
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
|
||||
|
||||
// Map logging
|
||||
unified.Logging.Level = old.LogLevel
|
||||
|
||||
// Map custom headers
|
||||
if len(old.Headers) > 0 {
|
||||
unified.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, header := range old.Headers {
|
||||
unified.Middleware.CustomHeaders[header.Name] = header.Value
|
||||
}
|
||||
}
|
||||
|
||||
// Store original config in legacy field for reference
|
||||
unified.Legacy["original"] = old
|
||||
|
||||
return unified
|
||||
}
|
||||
|
||||
// timeSecondsToDuration converts seconds to time.Duration
|
||||
func timeSecondsToDuration(seconds int) time.Duration {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// GetConfigInterface returns appropriate config based on feature flag
|
||||
func GetConfigInterface() interface{} {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
// ValidateConfig validates config based on feature flag
|
||||
func ValidateConfig(cfg interface{}) error {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if unified, ok := cfg.(*UnifiedConfig); ok {
|
||||
return unified.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to old validation if available
|
||||
if old, ok := cfg.(*Config); ok {
|
||||
return old.Validate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add Validate method to old Config for compatibility
|
||||
func (c *Config) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Basic validation for old config
|
||||
if c.ProviderURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ProviderURL",
|
||||
Message: "provider URL is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientSecret == "" && !c.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE)",
|
||||
})
|
||||
}
|
||||
|
||||
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "SessionEncryptionKey",
|
||||
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
|
||||
Value: len(c.SessionEncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,363 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// NewLegacyAdapter Tests
|
||||
func TestNewLegacyAdapter(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "test-client"
|
||||
unified.Provider.ClientSecret = "test-secret"
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("Expected NewLegacyAdapter to return non-nil")
|
||||
}
|
||||
|
||||
if adapter.unified != unified {
|
||||
t.Error("Expected adapter to reference the unified config")
|
||||
}
|
||||
|
||||
if adapter.adapter == nil {
|
||||
t.Error("Expected internal adapter to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig Tests
|
||||
func TestLegacyAdapter_ToOldConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://issuer.example.com"
|
||||
unified.Provider.ClientID = "client-123"
|
||||
unified.Provider.ClientSecret = "secret-456"
|
||||
unified.Provider.RedirectURL = "https://app.example.com/callback"
|
||||
unified.Provider.LogoutURL = "/logout"
|
||||
unified.Provider.PostLogoutRedirectURI = "https://app.example.com"
|
||||
unified.Provider.Scopes = []string{"openid", "profile"}
|
||||
unified.Provider.OverrideScopes = true
|
||||
unified.Session.EncryptionKey = "test-encryption-key-32-chars!!"
|
||||
unified.Session.Domain = "example.com"
|
||||
unified.Security.ForceHTTPS = true
|
||||
unified.Security.EnablePKCE = true
|
||||
unified.Security.AllowedUsers = []string{"user@example.com"}
|
||||
unified.Security.AllowedUserDomains = []string{"example.com"}
|
||||
unified.Security.AllowedRolesAndGroups = []string{"admin"}
|
||||
unified.Security.ExcludedURLs = []string{"/health"}
|
||||
unified.RateLimit.RequestsPerSecond = 100
|
||||
unified.Logging.Level = "debug"
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Header-1": "value1",
|
||||
"X-Header-2": "value2",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
oldConfig := adapter.ToOldConfig()
|
||||
|
||||
if oldConfig == nil {
|
||||
t.Fatal("Expected ToOldConfig to return non-nil")
|
||||
}
|
||||
|
||||
// ToOldConfig behavior depends on feature flag
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// When feature is disabled, returns default config
|
||||
if oldConfig.ProviderURL == "" {
|
||||
t.Log("Feature flag disabled - ToOldConfig returns default config")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// When feature is enabled, verify all fields were correctly mapped
|
||||
if oldConfig.ProviderURL != unified.Provider.IssuerURL {
|
||||
t.Errorf("Expected ProviderURL '%s', got '%s'", unified.Provider.IssuerURL, oldConfig.ProviderURL)
|
||||
}
|
||||
|
||||
if oldConfig.ClientID != unified.Provider.ClientID {
|
||||
t.Errorf("Expected ClientID '%s', got '%s'", unified.Provider.ClientID, oldConfig.ClientID)
|
||||
}
|
||||
|
||||
if oldConfig.ClientSecret != unified.Provider.ClientSecret {
|
||||
t.Errorf("Expected ClientSecret '%s', got '%s'", unified.Provider.ClientSecret, oldConfig.ClientSecret)
|
||||
}
|
||||
|
||||
if oldConfig.CallbackURL != unified.Provider.RedirectURL {
|
||||
t.Error("Expected CallbackURL to match RedirectURL")
|
||||
}
|
||||
|
||||
if oldConfig.LogoutURL != unified.Provider.LogoutURL {
|
||||
t.Error("Expected LogoutURL to match")
|
||||
}
|
||||
|
||||
if oldConfig.ForceHTTPS != unified.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to match")
|
||||
}
|
||||
|
||||
if oldConfig.EnablePKCE != unified.Security.EnablePKCE {
|
||||
t.Error("Expected EnablePKCE to match")
|
||||
}
|
||||
|
||||
if oldConfig.RateLimit != unified.RateLimit.RequestsPerSecond {
|
||||
t.Errorf("Expected RateLimit %d, got %d", unified.RateLimit.RequestsPerSecond, oldConfig.RateLimit)
|
||||
}
|
||||
|
||||
if len(oldConfig.Headers) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(oldConfig.Headers))
|
||||
}
|
||||
}
|
||||
|
||||
// convertHeaders Tests
|
||||
func TestLegacyAdapter_convertHeaders(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Custom-Header-1": "value1",
|
||||
"X-Custom-Header-2": "value2",
|
||||
"X-Custom-Header-3": "value3",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 3 {
|
||||
t.Errorf("Expected 3 headers, got %d", len(headers))
|
||||
}
|
||||
|
||||
// Check that headers were converted
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h.Name] = h.Value
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-1"] != "value1" {
|
||||
t.Error("Expected X-Custom-Header-1 to have value 'value1'")
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-2"] != "value2" {
|
||||
t.Error("Expected X-Custom-Header-2 to have value 'value2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAdapter_convertHeaders_Empty(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
// No custom headers
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 0 {
|
||||
t.Errorf("Expected 0 headers, got %d", len(headers))
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigInterface Tests
|
||||
func TestGetConfigInterface(t *testing.T) {
|
||||
cfg := GetConfigInterface()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("Expected GetConfigInterface to return non-nil")
|
||||
}
|
||||
|
||||
// Should return either UnifiedConfig or Config depending on feature flag
|
||||
_, isUnified := cfg.(*UnifiedConfig)
|
||||
_, isOld := cfg.(*Config)
|
||||
|
||||
if !isUnified && !isOld {
|
||||
t.Error("Expected either *UnifiedConfig or *Config")
|
||||
}
|
||||
|
||||
// Verify consistency with feature flag
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if !isUnified {
|
||||
t.Error("Expected *UnifiedConfig when unified config is enabled")
|
||||
}
|
||||
} else {
|
||||
if !isOld {
|
||||
t.Error("Expected *Config when unified config is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig Tests
|
||||
func TestValidateConfig_UnifiedConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "client-id"
|
||||
unified.Provider.ClientSecret = "client-secret"
|
||||
unified.Session.EncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(unified)
|
||||
// Should succeed regardless of feature flag since we're passing the right type
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid unified config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_OldConfig(t *testing.T) {
|
||||
old := CreateConfig()
|
||||
old.ProviderURL = "https://provider.example.com"
|
||||
old.ClientID = "client-id"
|
||||
old.ClientSecret = "client-secret"
|
||||
old.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(old)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid old config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidType(t *testing.T) {
|
||||
// Pass something that's not a config
|
||||
err := ValidateConfig("not a config")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil for unknown type, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Config.Validate Tests
|
||||
func TestConfig_Validate_Valid(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid config to pass, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingProviderURL(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ProviderURL")
|
||||
}
|
||||
|
||||
// Check if it's a ValidationErrors type
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ProviderURL" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ProviderURL validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientID(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientID")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientID" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientID validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_NoPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = false
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientSecret without PKCE")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientSecret" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientSecret validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_WithPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = true // PKCE enabled, so ClientSecret not required
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with PKCE enabled and no ClientSecret, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ShortEncryptionKey(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "short" // Too short
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for short encryption key")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "SessionEncryptionKey" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected SessionEncryptionKey validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MultipleErrors(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
// Missing ProviderURL, ClientID, and ClientSecret
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected validation errors")
|
||||
}
|
||||
|
||||
verrs, ok := err.(ValidationErrors)
|
||||
if !ok {
|
||||
t.Fatal("Expected ValidationErrors type")
|
||||
}
|
||||
|
||||
if len(verrs) < 2 {
|
||||
t.Errorf("Expected at least 2 validation errors, got %d", len(verrs))
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,276 +0,0 @@
|
||||
// Package config provides default values and initialization for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewUnifiedConfig creates a new unified configuration with sensible defaults
|
||||
func NewUnifiedConfig() *UnifiedConfig {
|
||||
return &UnifiedConfig{
|
||||
Provider: DefaultProviderConfig(),
|
||||
Session: DefaultSessionConfig(),
|
||||
Token: DefaultTokenConfig(),
|
||||
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
|
||||
Security: DefaultSecurityConfig(),
|
||||
Middleware: DefaultMiddlewareConfig(),
|
||||
Cache: DefaultCacheConfig(),
|
||||
RateLimit: DefaultRateLimitConfig(),
|
||||
Logging: DefaultLoggingConfig(),
|
||||
Metrics: DefaultMetricsConfig(),
|
||||
Health: DefaultHealthConfig(),
|
||||
Transport: DefaultTransportConfig(),
|
||||
Pool: DefaultPoolConfig(),
|
||||
Circuit: DefaultCircuitConfig(),
|
||||
Legacy: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultProviderConfig returns default provider configuration
|
||||
func DefaultProviderConfig() ProviderConfig {
|
||||
return ProviderConfig{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
OverrideScopes: false,
|
||||
CustomClaims: make(map[string]string),
|
||||
JWKCachePeriod: 24 * time.Hour,
|
||||
MetadataCacheTTL: 24 * time.Hour,
|
||||
Discovery: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSessionConfig returns default session configuration
|
||||
func DefaultSessionConfig() SessionConfig {
|
||||
return SessionConfig{
|
||||
Name: "oidc_session",
|
||||
MaxAge: 86400, // 24 hours
|
||||
ChunkSize: 4000, // Safe size for cookies
|
||||
MaxChunks: 5,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
StorageType: "cookie",
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTokenConfig returns default token configuration
|
||||
func DefaultTokenConfig() TokenConfig {
|
||||
return TokenConfig{
|
||||
AccessTokenTTL: 1 * time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
RefreshGracePeriod: 60 * time.Second,
|
||||
ValidationMode: "jwt",
|
||||
CacheEnabled: true,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
CacheNegativeTTL: 30 * time.Second,
|
||||
ValidateSignature: true,
|
||||
ValidateExpiry: true,
|
||||
ValidateAudience: true,
|
||||
ValidateIssuer: true,
|
||||
RequiredClaims: []string{"sub", "iat", "exp"},
|
||||
ClockSkew: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSecurityConfig returns default security configuration
|
||||
func DefaultSecurityConfig() SecurityConfig {
|
||||
return SecurityConfig{
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
AllowedUsers: []string{},
|
||||
AllowedUserDomains: []string{},
|
||||
AllowedRolesAndGroups: []string{},
|
||||
ExcludedURLs: []string{
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
"/health",
|
||||
"/.well-known/",
|
||||
"/metrics",
|
||||
"/ping",
|
||||
"/static/",
|
||||
"/assets/",
|
||||
"/js/",
|
||||
"/css/",
|
||||
"/images/",
|
||||
"/fonts/",
|
||||
},
|
||||
Headers: createDefaultSecurityConfig(),
|
||||
CSRFProtection: true,
|
||||
CSRFTokenName: "csrf_token",
|
||||
CSRFTokenTTL: 1 * time.Hour,
|
||||
MaxLoginAttempts: 5,
|
||||
LockoutDuration: 15 * time.Minute,
|
||||
RequireMFA: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMiddlewareConfig returns default middleware configuration
|
||||
func DefaultMiddlewareConfig() MiddlewareConfig {
|
||||
return MiddlewareConfig{
|
||||
Priority: 1000,
|
||||
SkipPaths: []string{},
|
||||
RequirePaths: []string{},
|
||||
PassthroughMode: false,
|
||||
MaxRequestSize: 10 * 1024 * 1024, // 10MB
|
||||
RequestTimeout: 30 * time.Second,
|
||||
IdleTimeout: 90 * time.Second,
|
||||
CustomHeaders: make(map[string]string),
|
||||
RemoveHeaders: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCacheConfig returns default cache configuration
|
||||
func DefaultCacheConfig() CacheConfig {
|
||||
return CacheConfig{
|
||||
Enabled: true,
|
||||
Type: "memory",
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxEntries: 10000,
|
||||
MaxEntrySize: 1024 * 1024, // 1MB
|
||||
EvictionPolicy: "lru",
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
Namespace: "traefikoidc",
|
||||
Compression: false,
|
||||
Serialization: "json",
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRateLimitConfig returns default rate limiting configuration
|
||||
func DefaultRateLimitConfig() RateLimitConfig {
|
||||
return RateLimitConfig{
|
||||
Enabled: false,
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
StorageType: "memory",
|
||||
WindowDuration: 1 * time.Minute,
|
||||
KeyType: "ip",
|
||||
CustomKeyFunc: "",
|
||||
WhitelistIPs: []string{},
|
||||
WhitelistUsers: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLoggingConfig returns default logging configuration
|
||||
func DefaultLoggingConfig() LoggingConfig {
|
||||
return LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
FilePath: "",
|
||||
FilterSensitive: true,
|
||||
MaskFields: []string{
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
},
|
||||
BufferSize: 8192,
|
||||
FlushInterval: 5 * time.Second,
|
||||
AuditEnabled: false,
|
||||
AuditEvents: []string{
|
||||
"login",
|
||||
"logout",
|
||||
"token_refresh",
|
||||
"auth_failure",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMetricsConfig returns default metrics configuration
|
||||
func DefaultMetricsConfig() MetricsConfig {
|
||||
return MetricsConfig{
|
||||
Enabled: false,
|
||||
Provider: "prometheus",
|
||||
Endpoint: "/metrics",
|
||||
Namespace: "traefikoidc",
|
||||
Subsystem: "middleware",
|
||||
CollectInterval: 10 * time.Second,
|
||||
Histograms: true,
|
||||
Labels: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHealthConfig returns default health check configuration
|
||||
func DefaultHealthConfig() HealthConfig {
|
||||
return HealthConfig{
|
||||
Enabled: true,
|
||||
Path: "/health",
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
CheckProvider: true,
|
||||
CheckRedis: true,
|
||||
CheckCache: true,
|
||||
MaxLatency: 1 * time.Second,
|
||||
MinMemory: 100 * 1024 * 1024, // 100MB
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTransportConfig returns default HTTP transport configuration
|
||||
func DefaultTransportConfig() TransportConfig {
|
||||
return TransportConfig{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 0, // No limit
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
TLSMinVersion: "TLS1.2",
|
||||
TLSCipherSuites: []string{},
|
||||
ProxyURL: "",
|
||||
NoProxy: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultPoolConfig returns default connection pool configuration
|
||||
func DefaultPoolConfig() PoolConfig {
|
||||
return PoolConfig{
|
||||
Enabled: true,
|
||||
Size: 10,
|
||||
MinSize: 2,
|
||||
MaxSize: 50,
|
||||
MaxAge: 30 * time.Minute,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
WaitTimeout: 5 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCircuitConfig returns default circuit breaker configuration
|
||||
func DefaultCircuitConfig() CircuitConfig {
|
||||
return CircuitConfig{
|
||||
Enabled: true,
|
||||
MaxRequests: 100,
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 60 * time.Second,
|
||||
ConsecutiveFailures: 5,
|
||||
FailureRatio: 0.5,
|
||||
OnOpen: "reject",
|
||||
OnHalfOpen: "passthrough",
|
||||
MetricsEnabled: true,
|
||||
LogStateChanges: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeWithDefaults merges a partial configuration with defaults
|
||||
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
|
||||
if partial == nil {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
|
||||
// Ensure Legacy field is initialized
|
||||
if partial.Legacy == nil {
|
||||
partial.Legacy = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// TODO: Implement deep merge logic with defaults
|
||||
// For now, just return the partial config
|
||||
return partial
|
||||
}
|
||||
@@ -1,396 +0,0 @@
|
||||
// Package config provides configuration loading and merging logic
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigLoader handles loading configuration from various sources
|
||||
type ConfigLoader struct {
|
||||
migrator *ConfigMigrator
|
||||
envPrefix string
|
||||
configPaths []string
|
||||
}
|
||||
|
||||
// NewConfigLoader creates a new configuration loader
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
migrator: NewConfigMigrator(),
|
||||
envPrefix: "TRAEFIKOIDC_",
|
||||
configPaths: getDefaultConfigPaths(),
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultConfigPaths returns default configuration file paths to check
|
||||
func getDefaultConfigPaths() []string {
|
||||
return []string{
|
||||
"traefik-oidc.yaml",
|
||||
"traefik-oidc.yml",
|
||||
"traefik-oidc.json",
|
||||
"config.yaml",
|
||||
"config.yml",
|
||||
"config.json",
|
||||
"/etc/traefik-oidc/config.yaml",
|
||||
"/etc/traefik-oidc/config.json",
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from all available sources
|
||||
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
|
||||
// Start with defaults
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Try to load from file
|
||||
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
|
||||
config = l.mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
l.LoadFromEnv(config)
|
||||
|
||||
// Validate the final configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadFromFile loads configuration from a file
|
||||
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
|
||||
// Use provided paths or default paths
|
||||
searchPaths := paths
|
||||
if len(searchPaths) == 0 {
|
||||
searchPaths = l.configPaths
|
||||
}
|
||||
|
||||
// Check for config file in environment variable
|
||||
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
|
||||
searchPaths = append([]string{envPath}, searchPaths...)
|
||||
}
|
||||
|
||||
// Try each path
|
||||
for _, path := range searchPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return l.loadFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
// No config file found, not an error (use defaults)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// loadFile loads a specific configuration file
|
||||
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories (current dir or subdirs)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
// Check if unified config is enabled
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
// Use migrator to handle any version
|
||||
config, warnings, err := l.migrator.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, use proper logging
|
||||
fmt.Printf("Config Warning (%s): %s\n", path, warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Legacy path: load old config and convert
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
var oldConfig Config
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
if err := json.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
|
||||
}
|
||||
case ".yaml", ".yml":
|
||||
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
|
||||
}
|
||||
|
||||
return FromOldConfig(&oldConfig), nil
|
||||
}
|
||||
|
||||
// LoadFromEnv loads configuration from environment variables
|
||||
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
|
||||
// Provider configuration
|
||||
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
|
||||
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
|
||||
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
|
||||
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
|
||||
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
|
||||
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
|
||||
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
|
||||
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
|
||||
|
||||
// Session configuration
|
||||
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
|
||||
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
|
||||
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
|
||||
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
|
||||
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
|
||||
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
|
||||
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
|
||||
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
|
||||
|
||||
// Security configuration
|
||||
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
|
||||
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
|
||||
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
|
||||
|
||||
// Cache configuration
|
||||
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
|
||||
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
|
||||
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
|
||||
// MaxEntrySize is int64, skip for now
|
||||
|
||||
// Rate limiting
|
||||
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
|
||||
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
|
||||
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
|
||||
|
||||
// Logging
|
||||
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
|
||||
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
|
||||
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
|
||||
|
||||
// Redis configuration (already handled by its own LoadFromEnv)
|
||||
config.Redis.LoadFromEnv()
|
||||
|
||||
// Feature flags
|
||||
features.GetManager().LoadFromEnv()
|
||||
}
|
||||
|
||||
// Helper methods for environment variable loading
|
||||
|
||||
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitAndTrim(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConfigs merges two configurations, with source overriding target
|
||||
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
|
||||
if source == nil {
|
||||
return target
|
||||
}
|
||||
if target == nil {
|
||||
return source
|
||||
}
|
||||
|
||||
// Use reflection for deep merge
|
||||
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
// mergeStructs recursively merges two structs
|
||||
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
|
||||
for i := 0; i < source.NumField(); i++ {
|
||||
sourceField := source.Field(i)
|
||||
targetField := target.Field(i)
|
||||
|
||||
// Skip if source field is zero value
|
||||
if isZeroValue(sourceField) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sourceField.Kind() {
|
||||
case reflect.Struct:
|
||||
// Recursively merge structs
|
||||
l.mergeStructs(targetField, sourceField)
|
||||
case reflect.Slice:
|
||||
// Replace slice if source has values
|
||||
if sourceField.Len() > 0 {
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
case reflect.Map:
|
||||
// Merge maps
|
||||
if !sourceField.IsNil() {
|
||||
if targetField.IsNil() {
|
||||
targetField.Set(reflect.MakeMap(sourceField.Type()))
|
||||
}
|
||||
for _, key := range sourceField.MapKeys() {
|
||||
targetField.SetMapIndex(key, sourceField.MapIndex(key))
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Replace value
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValue checks if a reflect.Value is a zero value
|
||||
func isZeroValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return v.IsNil()
|
||||
case reflect.Slice, reflect.Map:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Struct:
|
||||
// Check if all fields are zero
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !isZeroValue(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
zero := reflect.Zero(v.Type())
|
||||
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// SaveToFile saves the configuration to a file
|
||||
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(absPath))
|
||||
|
||||
var data []byte
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
data, err = json.MarshalIndent(config, "", " ")
|
||||
case ".yaml", ".yml":
|
||||
data, err = yaml.Marshal(config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported file extension: %s", ext)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist with secure permissions
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Write file with secure permissions (owner read/write only)
|
||||
if err := os.WriteFile(absPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,832 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConfigLoader tests the config loader functionality
|
||||
func TestConfigLoader(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
if loader == nil {
|
||||
t.Fatal("NewConfigLoader should not return nil")
|
||||
}
|
||||
|
||||
if loader.migrator == nil {
|
||||
t.Error("ConfigLoader should have a migrator")
|
||||
}
|
||||
|
||||
if loader.envPrefix != "TRAEFIKOIDC_" {
|
||||
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
|
||||
}
|
||||
|
||||
if len(loader.configPaths) == 0 {
|
||||
t.Error("ConfigLoader should have default config paths")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFromEnv tests loading configuration from environment variables
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
testEnvVars := map[string]string{
|
||||
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
|
||||
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
|
||||
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ENABLED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
|
||||
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
|
||||
"TRAEFIKOIDC_CACHE_ENABLED": "true",
|
||||
"TRAEFIKOIDC_CACHE_TYPE": "redis",
|
||||
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
|
||||
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
|
||||
}
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range testEnvVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{}
|
||||
loader.LoadFromEnv(config)
|
||||
|
||||
// Verify values were loaded
|
||||
if config.Provider.IssuerURL != "https://test.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client-id" {
|
||||
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
|
||||
}
|
||||
if config.Provider.ClientSecret != "test-secret" {
|
||||
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
|
||||
}
|
||||
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
|
||||
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
|
||||
}
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true")
|
||||
}
|
||||
if !config.Cache.Enabled {
|
||||
t.Error("Expected Cache to be enabled")
|
||||
}
|
||||
if config.Cache.Type != "redis" {
|
||||
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
|
||||
}
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("Expected RateLimit to be enabled")
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond != 100 {
|
||||
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveToFile tests saving configuration to files
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "32-character-encryption-key-12345",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "save as JSON",
|
||||
filename: "config.json",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YAML",
|
||||
filename: "config.yaml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YML",
|
||||
filename: "config.yml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported extension",
|
||||
filename: "config.txt",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/config.json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, tt.filename)
|
||||
err := loader.SaveToFile(config, filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file was created with correct permissions
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to stat saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check file permissions (should be 0600)
|
||||
mode := info.Mode().Perm()
|
||||
if mode != 0600 {
|
||||
t.Errorf("Expected file permissions 0600, got %o", mode)
|
||||
}
|
||||
|
||||
// Verify content can be read back
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify secrets are redacted
|
||||
content := string(data)
|
||||
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
|
||||
t.Error("Secrets should be redacted in saved file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFile tests loading configuration from files
|
||||
func TestLoadFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Test data - using old config format since unified config is not enabled by default
|
||||
jsonConfig := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
|
||||
yamlConfig := `
|
||||
providerurl: https://auth.example.com
|
||||
clientid: test-client
|
||||
clientsecret: secret
|
||||
sessionencryptionkey: 32-character-encryption-key-12345
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "load JSON config",
|
||||
filename: "config.json",
|
||||
content: jsonConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "load YAML config",
|
||||
filename: "config.yaml",
|
||||
content: yamlConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/passwd",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
filename: "does-not-exist.json",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var filePath string
|
||||
if tt.content != "" {
|
||||
filePath = filepath.Join(tmpDir, tt.filename)
|
||||
err := os.WriteFile(filePath, []byte(tt.content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
filePath = tt.filename
|
||||
}
|
||||
|
||||
config, err := loader.loadFile(filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verify loaded config
|
||||
if config == nil {
|
||||
t.Error("Expected config to be loaded")
|
||||
return
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================================
|
||||
// Tests for untested functions (0% coverage)
|
||||
// ====================================================================================
|
||||
|
||||
// TestConfigLoader_Load tests the full Load pipeline
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-load-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test config file
|
||||
configPath := filepath.Join(tmpDir, "traefik-oidc.json")
|
||||
configData := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to temp directory so loader can find the config
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
// Set some environment variables to test merging
|
||||
os.Setenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS", "true")
|
||||
defer os.Unsetenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.Load()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("Load() returned nil config")
|
||||
}
|
||||
|
||||
// Verify file was loaded
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL from file, got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Verify env vars were loaded
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS from env var to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_LoadFromFile tests the LoadFromFile function
|
||||
func TestConfigLoader_LoadFromFile(t *testing.T) {
|
||||
t.Run("NoConfigFile", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-nofile-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
// Should not error when no config file found
|
||||
if err != nil {
|
||||
t.Errorf("LoadFromFile() should not error when no file found: %v", err)
|
||||
}
|
||||
|
||||
// Should return nil config
|
||||
if config != nil {
|
||||
t.Error("LoadFromFile() should return nil config when no file found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadFromEnvPath", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-envpath-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "custom-config.json")
|
||||
configData := `{
|
||||
"providerURL": "https://custom.example.com",
|
||||
"clientID": "custom-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Set env variable pointing to config
|
||||
os.Setenv("TRAEFIKOIDC_CONFIG_FILE", configPath)
|
||||
defer os.Unsetenv("TRAEFIKOIDC_CONFIG_FILE")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://custom.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://custom.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadWithProvidedPaths", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-provided-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "specific.json")
|
||||
configData := `{
|
||||
"providerURL": "https://specific.example.com",
|
||||
"clientID": "specific-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile(configPath)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() with path failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://specific.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://specific.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAndTrim tests the splitAndTrim helper function
|
||||
func TestSplitAndTrim(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Simple comma-separated",
|
||||
input: "a,b,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "With spaces",
|
||||
input: "a, b , c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Empty strings filtered out",
|
||||
input: "a,,b, ,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Leading and trailing spaces",
|
||||
input: " a , b , c ",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "single",
|
||||
expected: []string{"single"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Only commas and spaces",
|
||||
input: " , , , ",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Complex real-world example",
|
||||
input: "openid, profile, email, groups",
|
||||
expected: []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitAndTrim(tt.input)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d items, got %d: %v", len(tt.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expected {
|
||||
if result[i] != expected {
|
||||
t.Errorf("At index %d: expected %q, got %q", i, expected, result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeConfigs tests the mergeConfigs function
|
||||
func TestConfigLoader_MergeConfigs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("MergeNilSource", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, nil)
|
||||
|
||||
if result != target {
|
||||
t.Error("mergeConfigs should return target when source is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeNilTarget", func(t *testing.T) {
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(nil, source)
|
||||
|
||||
if result != source {
|
||||
t.Error("mergeConfigs should return source when target is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSimpleFields", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "",
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
ClientID: "source-client",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if result.Provider.IssuerURL != "https://source.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be overridden, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be set, got %s", result.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSlices", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"openid", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Source slice should replace target slice
|
||||
if len(result.Provider.Scopes) != 2 {
|
||||
t.Errorf("Expected 2 scopes, got %d", len(result.Provider.Scopes))
|
||||
}
|
||||
|
||||
if result.Provider.Scopes[0] != "email" {
|
||||
t.Errorf("Expected first scope 'email', got %s", result.Provider.Scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeMaps", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Target-Header": "target-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Source-Header": "source-value",
|
||||
"X-Target-Header": "overridden-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if len(result.Middleware.CustomHeaders) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(result.Middleware.CustomHeaders))
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Target-Header"] != "overridden-value" {
|
||||
t.Errorf("Expected X-Target-Header to be overridden")
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Source-Header"] != "source-value" {
|
||||
t.Errorf("Expected X-Source-Header to be added")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeStructs tests the mergeStructs function indirectly
|
||||
func TestConfigLoader_MergeStructs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("NestedStructMerge", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "target-client",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "target-session",
|
||||
MaxAge: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "source-client",
|
||||
ClientSecret: "source-secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
MaxAge: 7200,
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Provider.IssuerURL should remain (zero value in source)
|
||||
if result.Provider.IssuerURL != "https://target.example.com" {
|
||||
t.Errorf("Expected IssuerURL to remain, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Provider.ClientID should be overridden
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be overridden, got %s", result.Provider.ClientID)
|
||||
}
|
||||
|
||||
// Provider.ClientSecret should be added
|
||||
if result.Provider.ClientSecret != "source-secret" {
|
||||
t.Errorf("Expected ClientSecret to be added, got %s", result.Provider.ClientSecret)
|
||||
}
|
||||
|
||||
// Session.Name should remain (zero value in source)
|
||||
if result.Session.Name != "target-session" {
|
||||
t.Errorf("Expected Session.Name to remain, got %s", result.Session.Name)
|
||||
}
|
||||
|
||||
// Session.MaxAge should be overridden
|
||||
if result.Session.MaxAge != 7200 {
|
||||
t.Errorf("Expected Session.MaxAge to be overridden, got %d", result.Session.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsZeroValue tests the isZeroValue helper function
|
||||
func TestIsZeroValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Zero string",
|
||||
value: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero string",
|
||||
value: "hello",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero int",
|
||||
value: 0,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero int",
|
||||
value: 42,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero bool",
|
||||
value: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero bool",
|
||||
value: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil pointer",
|
||||
value: (*string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-nil pointer",
|
||||
value: stringPtr("test"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil slice",
|
||||
value: ([]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
value: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty slice",
|
||||
value: []string{"a"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil map",
|
||||
value: (map[string]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty map",
|
||||
value: map[string]string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty map",
|
||||
value: map[string]string{"key": "value"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.value)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected isZeroValue to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsZeroValue_Struct tests isZeroValue with struct types
|
||||
func TestIsZeroValue_Struct(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Field1 string
|
||||
Field2 int
|
||||
}
|
||||
|
||||
t.Run("Zero struct", func(t *testing.T) {
|
||||
s := TestStruct{}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected zero struct to return true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field1 set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test"}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field2 set", func(t *testing.T) {
|
||||
s := TestStruct{Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Both fields set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test", Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function for pointer tests
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
@@ -1,407 +0,0 @@
|
||||
// Package config provides configuration migration from old to new format
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigVersion represents the version of a configuration format
|
||||
type ConfigVersion string
|
||||
|
||||
const (
|
||||
// VersionLegacy represents the original config format
|
||||
VersionLegacy ConfigVersion = "legacy"
|
||||
|
||||
// VersionUnified represents the new unified config format
|
||||
VersionUnified ConfigVersion = "unified"
|
||||
|
||||
// CurrentVersion is the current config version
|
||||
CurrentVersion ConfigVersion = VersionUnified
|
||||
)
|
||||
|
||||
// ConfigMigrator handles migration between config versions
|
||||
type ConfigMigrator struct {
|
||||
compatLayer *compat.CompatibilityLayer
|
||||
migrations map[ConfigVersion]MigrationFunc
|
||||
}
|
||||
|
||||
// MigrationFunc defines a function that migrates configuration
|
||||
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
|
||||
|
||||
// NewConfigMigrator creates a new configuration migrator
|
||||
func NewConfigMigrator() *ConfigMigrator {
|
||||
m := &ConfigMigrator{
|
||||
compatLayer: compat.GetLayer(),
|
||||
migrations: make(map[ConfigVersion]MigrationFunc),
|
||||
}
|
||||
|
||||
// Register migration functions
|
||||
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// DetectVersion detects the version of a configuration
|
||||
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
|
||||
var testMap map[string]interface{}
|
||||
|
||||
// Try JSON first
|
||||
if err := json.Unmarshal(data, &testMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &testMap); err != nil {
|
||||
return VersionLegacy // Default to legacy if can't parse
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unified config markers
|
||||
if _, hasProvider := testMap["provider"]; hasProvider {
|
||||
if _, hasSession := testMap["session"]; hasSession {
|
||||
return VersionUnified
|
||||
}
|
||||
}
|
||||
|
||||
// Check for legacy config markers
|
||||
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
// Migrate migrates configuration data to the current version
|
||||
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
|
||||
warnings := []string{}
|
||||
|
||||
// Detect version
|
||||
version := m.DetectVersion(data)
|
||||
|
||||
// If already current version, just unmarshal
|
||||
if version == CurrentVersion {
|
||||
var config UnifiedConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
|
||||
}
|
||||
}
|
||||
return &config, warnings, nil
|
||||
}
|
||||
|
||||
// Parse to generic map
|
||||
var configMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &configMap); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply migration
|
||||
migrationFunc, exists := m.migrations[version]
|
||||
if !exists {
|
||||
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
|
||||
}
|
||||
|
||||
config, err := migrationFunc(configMap)
|
||||
if err != nil {
|
||||
return nil, warnings, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Collect any deprecation warnings
|
||||
for key := range configMap {
|
||||
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
|
||||
warnings = append(warnings, warning)
|
||||
}
|
||||
}
|
||||
|
||||
return config, warnings, nil
|
||||
}
|
||||
|
||||
// migrateLegacyToUnified migrates legacy config to unified format
|
||||
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Use compatibility layer for field mapping
|
||||
migratedMap, warnings := m.compatLayer.MigrateMap(data)
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, these would be logged
|
||||
_ = warning
|
||||
}
|
||||
|
||||
// Map provider configuration
|
||||
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
|
||||
_ = mapToStruct(provider, &config.Provider)
|
||||
} else {
|
||||
// Direct field mapping for legacy format
|
||||
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
|
||||
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
|
||||
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
|
||||
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
|
||||
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
|
||||
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
|
||||
|
||||
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
|
||||
config.Provider.Scopes = scopes
|
||||
}
|
||||
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
|
||||
}
|
||||
|
||||
// Map session configuration
|
||||
if session, ok := getNestedMap(migratedMap, "Session"); ok {
|
||||
_ = mapToStruct(session, &config.Session)
|
||||
} else {
|
||||
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
|
||||
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
|
||||
}
|
||||
|
||||
// Map security configuration
|
||||
if security, ok := getNestedMap(migratedMap, "Security"); ok {
|
||||
_ = mapToStruct(security, &config.Security)
|
||||
} else {
|
||||
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
|
||||
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
|
||||
|
||||
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
|
||||
config.Security.AllowedUsers = users
|
||||
}
|
||||
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
|
||||
config.Security.AllowedUserDomains = domains
|
||||
}
|
||||
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
|
||||
config.Security.AllowedRolesAndGroups = roles
|
||||
}
|
||||
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
|
||||
config.Security.ExcludedURLs = excluded
|
||||
}
|
||||
|
||||
// Handle security headers
|
||||
if headers := data["securityHeaders"]; headers != nil {
|
||||
// Security headers might be in old format
|
||||
_ = mapToStruct(headers, &config.Security.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
// Map rate limiting
|
||||
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = rateLimit
|
||||
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
|
||||
}
|
||||
|
||||
// Map token configuration
|
||||
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
|
||||
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
|
||||
}
|
||||
|
||||
// Map logging
|
||||
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
|
||||
if config.Logging.Level == "" {
|
||||
config.Logging.Level = "info"
|
||||
}
|
||||
|
||||
// Map custom headers
|
||||
if headers := data["headers"]; headers != nil {
|
||||
if headerList, ok := headers.([]interface{}); ok {
|
||||
config.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, h := range headerList {
|
||||
if headerMap, ok := h.(map[string]interface{}); ok {
|
||||
name := getStringFromInterface(headerMap["name"])
|
||||
value := getStringFromInterface(headerMap["value"])
|
||||
if name != "" {
|
||||
config.Middleware.CustomHeaders[name] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store original data for reference
|
||||
config.Legacy = data
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// MigrateFile migrates a configuration file
|
||||
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
config, warnings, err := m.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
fmt.Printf("Migration Warning: %s\n", warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// AutoMigrate automatically migrates config based on feature flags
|
||||
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Feature not enabled, return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
migrator := NewConfigMigrator()
|
||||
|
||||
// Handle different input types
|
||||
switch v := data.(type) {
|
||||
case []byte:
|
||||
config, _, err := migrator.Migrate(v)
|
||||
return config, err
|
||||
case string:
|
||||
config, _, err := migrator.Migrate([]byte(v))
|
||||
return config, err
|
||||
case *Config:
|
||||
// Convert old config to unified
|
||||
return FromOldConfig(v), nil
|
||||
case *UnifiedConfig:
|
||||
// Already unified
|
||||
return v, nil
|
||||
case map[string]interface{}:
|
||||
// Convert map to JSON then migrate
|
||||
jsonData, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config, _, err := migrator.Migrate(jsonData)
|
||||
return config, err
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||
if val, exists := m[key]; exists {
|
||||
if mapped, ok := val.(map[string]interface{}); ok {
|
||||
return mapped, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getStringValue(m map[string]interface{}, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
return getStringFromInterface(val)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringFromInterface(val interface{}) string {
|
||||
if val == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func getBoolValue(m map[string]interface{}, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
// Try string conversion
|
||||
if s, ok := val.(string); ok {
|
||||
return strings.ToLower(s) == "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIntValue(m map[string]interface{}, keys ...string) int {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
// Try to parse
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
|
||||
// If parsing fails, return default
|
||||
return 0
|
||||
}
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getArrayValue(m map[string]interface{}, keys ...string) []string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if arr, ok := val.([]interface{}); ok {
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
result = append(result, getStringFromInterface(item))
|
||||
}
|
||||
return result
|
||||
}
|
||||
if strArr, ok := val.([]string); ok {
|
||||
return strArr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapToStruct(m interface{}, target interface{}) error {
|
||||
// Simple mapping using JSON as intermediate
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, target)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,297 +0,0 @@
|
||||
// Package config provides configuration structures for the Traefik OIDC plugin.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedisMode represents the Redis deployment mode
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
// RedisModeStandalone represents a single Redis instance
|
||||
RedisModeStandalone RedisMode = "standalone"
|
||||
|
||||
// RedisModeCluster represents Redis cluster mode
|
||||
RedisModeCluster RedisMode = "cluster"
|
||||
|
||||
// RedisModeSentinel represents Redis sentinel mode
|
||||
RedisModeSentinel RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
// RedisConfig holds Redis cache backend configuration
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis backend should be used
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// Mode specifies the Redis deployment mode
|
||||
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
|
||||
|
||||
// === Standalone Configuration ===
|
||||
// Addr is the Redis server address (host:port)
|
||||
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
|
||||
|
||||
// Password for Redis authentication
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the database number (0-15)
|
||||
DB int `json:"db,omitempty" yaml:"db,omitempty"`
|
||||
|
||||
// === Cluster Configuration ===
|
||||
// ClusterAddrs is the list of cluster node addresses
|
||||
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
|
||||
|
||||
// === Sentinel Configuration ===
|
||||
// MasterName is the name of the master instance
|
||||
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
|
||||
|
||||
// SentinelAddrs is the list of sentinel addresses
|
||||
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
|
||||
|
||||
// SentinelPassword is the password for sentinel authentication
|
||||
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
|
||||
|
||||
// === Connection Pool Settings ===
|
||||
// PoolSize is the maximum number of socket connections
|
||||
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections
|
||||
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up
|
||||
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
|
||||
|
||||
// === Timeouts ===
|
||||
// DialTimeout is the timeout for establishing new connections
|
||||
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
|
||||
|
||||
// ReadTimeout is the timeout for socket reads
|
||||
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
|
||||
|
||||
// WriteTimeout is the timeout for socket writes
|
||||
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
|
||||
|
||||
// PoolTimeout is the timeout for connection pool
|
||||
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
|
||||
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
|
||||
|
||||
// ConnMaxLifetime is the maximum lifetime of a connection
|
||||
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
|
||||
|
||||
// === Key Management ===
|
||||
// KeyPrefix is the prefix for all Redis keys
|
||||
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
|
||||
|
||||
// === TLS Configuration ===
|
||||
// TLSEnabled enables TLS for Redis connections
|
||||
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
|
||||
|
||||
// TLSInsecureSkipVerify skips TLS certificate verification
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
|
||||
|
||||
// === Resilience Settings ===
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis operations
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
|
||||
|
||||
// CircuitBreakerMaxFailures is the number of failures before opening circuit
|
||||
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
|
||||
|
||||
// CircuitBreakerTimeout is how long the circuit stays open
|
||||
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks
|
||||
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
|
||||
|
||||
// HealthCheckInterval is how often to check Redis health
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns default Redis configuration
|
||||
func DefaultRedisConfig() *RedisConfig {
|
||||
return &RedisConfig{
|
||||
Enabled: false,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 2,
|
||||
MaxRetries: 3,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
KeyPrefix: "traefikoidc:",
|
||||
TLSEnabled: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
EnableCircuitBreaker: true,
|
||||
CircuitBreakerMaxFailures: 5,
|
||||
CircuitBreakerTimeout: 30 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Redis configuration from environment variables
|
||||
func (c *RedisConfig) LoadFromEnv() {
|
||||
// Enable Redis if environment variable is set
|
||||
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
|
||||
c.Enabled = strings.ToLower(enabled) == "true"
|
||||
}
|
||||
|
||||
// Mode
|
||||
if mode := os.Getenv("REDIS_MODE"); mode != "" {
|
||||
c.Mode = RedisMode(strings.ToLower(mode))
|
||||
}
|
||||
|
||||
// Standalone configuration
|
||||
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
|
||||
c.Addr = addr
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
c.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster configuration
|
||||
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
|
||||
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
|
||||
for i := range c.ClusterAddrs {
|
||||
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel configuration
|
||||
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
|
||||
c.MasterName = masterName
|
||||
}
|
||||
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
|
||||
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
|
||||
for i := range c.SentinelAddrs {
|
||||
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
|
||||
}
|
||||
}
|
||||
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
|
||||
c.SentinelPassword = sentinelPassword
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if size, err := strconv.Atoi(poolSize); err == nil {
|
||||
c.PoolSize = size
|
||||
}
|
||||
}
|
||||
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
|
||||
if conns, err := strconv.Atoi(minIdleConns); err == nil {
|
||||
c.MinIdleConns = conns
|
||||
}
|
||||
}
|
||||
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
|
||||
if retries, err := strconv.Atoi(maxRetries); err == nil {
|
||||
c.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
// Timeouts
|
||||
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
|
||||
c.DialTimeout = timeout
|
||||
}
|
||||
}
|
||||
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(readTimeout); err == nil {
|
||||
c.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
|
||||
c.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Key prefix
|
||||
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
|
||||
c.KeyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
|
||||
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
|
||||
}
|
||||
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
|
||||
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
|
||||
}
|
||||
|
||||
// Resilience settings
|
||||
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
|
||||
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
|
||||
}
|
||||
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
|
||||
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
|
||||
c.CircuitBreakerMaxFailures = failures
|
||||
}
|
||||
}
|
||||
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
|
||||
c.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
|
||||
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
|
||||
}
|
||||
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
|
||||
if interval, err := time.ParseDuration(hcInterval); err == nil {
|
||||
c.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *RedisConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case RedisModeStandalone:
|
||||
if c.Addr == "" {
|
||||
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
|
||||
}
|
||||
case RedisModeCluster:
|
||||
if len(c.ClusterAddrs) == 0 {
|
||||
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
|
||||
}
|
||||
case RedisModeSentinel:
|
||||
if c.MasterName == "" {
|
||||
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
|
||||
}
|
||||
if len(c.SentinelAddrs) == 0 {
|
||||
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
|
||||
}
|
||||
default:
|
||||
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ConfigError) Error() string {
|
||||
return "redis config error: " + e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -1,511 +0,0 @@
|
||||
// Package config provides configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
minEncryptionKeyLength = 16
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
//lint:ignore U1000 May be referenced for default exclusion patterns
|
||||
var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon.ico": {},
|
||||
"/robots.txt": {},
|
||||
"/health": {},
|
||||
"/.well-known/": {},
|
||||
"/metrics": {},
|
||||
"/ping": {},
|
||||
"/api/": {},
|
||||
"/static/": {},
|
||||
"/assets/": {},
|
||||
"/js/": {},
|
||||
"/css/": {},
|
||||
"/images/": {},
|
||||
"/fonts/": {},
|
||||
}
|
||||
|
||||
// Settings manages configuration and initialization for the OIDC middleware
|
||||
type Settings struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debug(msg string)
|
||||
Debugf(format string, args ...interface{})
|
||||
Info(msg string)
|
||||
Infof(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// Config represents the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerUrl"`
|
||||
ClientID string `json:"clientId"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackUrl"`
|
||||
LogoutURL string `json:"logoutUrl"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectUri"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
ForceHTTPS bool `json:"forceHttps"`
|
||||
LogLevel string `json:"logLevel"`
|
||||
Scopes []string `json:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedUrls"`
|
||||
EnablePKCE bool `json:"enablePkce"`
|
||||
RateLimit int `json:"rateLimit"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
Headers []HeaderConfig `json:"headers"`
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
|
||||
// Dynamic Client Registration (RFC 7591) configuration
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrationConfig struct {
|
||||
// Enabled enables automatic client registration with the OIDC provider
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// InitialAccessToken is an optional bearer token for protected registration endpoints
|
||||
// Some providers require this token to authorize new client registrations
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
|
||||
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
|
||||
// If empty, uses the registration_endpoint from .well-known/openid-configuration
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
|
||||
// ClientMetadata contains the client metadata to register
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
|
||||
// PersistCredentials determines whether to save registered credentials to a file
|
||||
// This allows reusing the same client_id/client_secret across restarts
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
|
||||
// CredentialsFile is the path to store/load registered client credentials
|
||||
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
type ClientRegistrationMetadata struct {
|
||||
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
|
||||
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
|
||||
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
|
||||
// ApplicationType is either "web" (default) or "native"
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
|
||||
// Contacts is an array of email addresses for responsible parties
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
|
||||
// ClientName is a human-readable name for the client
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
|
||||
// LogoURI is a URL pointing to a logo for the client
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
|
||||
// ClientURI is a URL of the home page of the client
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
|
||||
// PolicyURI is a URL pointing to the client's privacy policy
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
|
||||
// TOSURI is a URL pointing to the client's terms of service
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
|
||||
// JWKSURI is a URL for the client's JSON Web Key Set
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
|
||||
// SubjectType is "pairwise" or "public" (provider-specific)
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
|
||||
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
|
||||
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
|
||||
// DefaultMaxAge is the default maximum authentication age in seconds
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
|
||||
// RequireAuthTime specifies whether auth_time claim is required in ID token
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
|
||||
// DefaultACRValues specifies default ACR values
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
|
||||
// Scope is a space-separated list of scopes (alternative to config.Scopes)
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// HeaderConfig represents header template configuration
|
||||
type HeaderConfig struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
// SecurityHeadersConfig configures security headers for the plugin
|
||||
type SecurityHeadersConfig struct {
|
||||
// Enable security headers (default: true)
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Security profile: "default", "strict", "development", "api", or "custom"
|
||||
Profile string `json:"profile"`
|
||||
|
||||
// Content Security Policy
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
|
||||
// HSTS settings
|
||||
StrictTransportSecurity bool `json:"strictTransportSecurity"`
|
||||
StrictTransportSecurityMaxAge int `json:"strictTransportSecurityMaxAge"` // seconds
|
||||
StrictTransportSecuritySubdomains bool `json:"strictTransportSecuritySubdomains"`
|
||||
StrictTransportSecurityPreload bool `json:"strictTransportSecurityPreload"`
|
||||
|
||||
// Frame options: "DENY", "SAMEORIGIN", or "ALLOW-FROM uri"
|
||||
FrameOptions string `json:"frameOptions,omitempty"`
|
||||
|
||||
// Content type options (default: "nosniff")
|
||||
ContentTypeOptions string `json:"contentTypeOptions,omitempty"`
|
||||
|
||||
// XSS protection (default: "1; mode=block")
|
||||
XSSProtection string `json:"xssProtection,omitempty"`
|
||||
|
||||
// Referrer policy
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
|
||||
// Permissions policy
|
||||
PermissionsPolicy string `json:"permissionsPolicy,omitempty"`
|
||||
|
||||
// Cross-origin settings
|
||||
CrossOriginEmbedderPolicy string `json:"crossOriginEmbedderPolicy,omitempty"`
|
||||
CrossOriginOpenerPolicy string `json:"crossOriginOpenerPolicy,omitempty"`
|
||||
CrossOriginResourcePolicy string `json:"crossOriginResourcePolicy,omitempty"`
|
||||
|
||||
// CORS settings
|
||||
CORSEnabled bool `json:"corsEnabled"`
|
||||
CORSAllowedOrigins []string `json:"corsAllowedOrigins,omitempty"`
|
||||
CORSAllowedMethods []string `json:"corsAllowedMethods,omitempty"`
|
||||
CORSAllowedHeaders []string `json:"corsAllowedHeaders,omitempty"`
|
||||
CORSAllowCredentials bool `json:"corsAllowCredentials"`
|
||||
CORSMaxAge int `json:"corsMaxAge"` // seconds
|
||||
|
||||
// Custom headers (in addition to standard security headers)
|
||||
CustomHeaders map[string]string `json:"customHeaders,omitempty"`
|
||||
|
||||
// Security features
|
||||
DisableServerHeader bool `json:"disableServerHeader"`
|
||||
DisablePoweredByHeader bool `json:"disablePoweredByHeader"`
|
||||
}
|
||||
|
||||
// NewSettings creates a new Settings instance
|
||||
func NewSettings(logger Logger) *Settings {
|
||||
return &Settings{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateConfig creates a default configuration
|
||||
func CreateConfig() *Config {
|
||||
return &Config{
|
||||
LogLevel: "INFO",
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
RateLimit: 10,
|
||||
RefreshGracePeriodSeconds: 60,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
Headers: []HeaderConfig{},
|
||||
SecurityHeaders: createDefaultSecurityConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// createDefaultSecurityConfig creates a default security headers configuration
|
||||
func createDefaultSecurityConfig() *SecurityHeadersConfig {
|
||||
return &SecurityHeadersConfig{
|
||||
Enabled: true,
|
||||
Profile: "default",
|
||||
|
||||
// Default security headers
|
||||
StrictTransportSecurity: true,
|
||||
StrictTransportSecurityMaxAge: 31536000, // 1 year
|
||||
StrictTransportSecuritySubdomains: true,
|
||||
StrictTransportSecurityPreload: true,
|
||||
|
||||
FrameOptions: "DENY",
|
||||
ContentTypeOptions: "nosniff",
|
||||
XSSProtection: "1; mode=block",
|
||||
ReferrerPolicy: "strict-origin-when-cross-origin",
|
||||
|
||||
// CORS disabled by default
|
||||
CORSEnabled: false,
|
||||
CORSAllowedMethods: []string{"GET", "POST", "OPTIONS"},
|
||||
CORSAllowedHeaders: []string{"Authorization", "Content-Type"},
|
||||
CORSAllowCredentials: false,
|
||||
CORSMaxAge: 86400, // 24 hours
|
||||
|
||||
// Security features
|
||||
DisableServerHeader: true,
|
||||
DisablePoweredByHeader: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ToInternalSecurityConfig converts plugin SecurityHeadersConfig to internal security config
|
||||
func (c *SecurityHeadersConfig) ToInternalSecurityConfig() interface{} {
|
||||
if c == nil || !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create the internal security config structure
|
||||
config := map[string]interface{}{
|
||||
"DevelopmentMode": false,
|
||||
}
|
||||
|
||||
// Apply profile-based defaults
|
||||
switch strings.ToLower(c.Profile) {
|
||||
case "strict":
|
||||
applyStrictProfile(config)
|
||||
case "development":
|
||||
applyDevelopmentProfile(config)
|
||||
case "api":
|
||||
applyAPIProfile(config)
|
||||
case "custom":
|
||||
// No defaults, use only what's explicitly configured
|
||||
default: // "default"
|
||||
applyDefaultProfile(config)
|
||||
}
|
||||
|
||||
// Override with explicit configuration
|
||||
if c.ContentSecurityPolicy != "" {
|
||||
config["ContentSecurityPolicy"] = c.ContentSecurityPolicy
|
||||
}
|
||||
|
||||
// HSTS configuration
|
||||
if c.StrictTransportSecurity {
|
||||
config["StrictTransportSecurityMaxAge"] = c.StrictTransportSecurityMaxAge
|
||||
config["StrictTransportSecuritySubdomains"] = c.StrictTransportSecuritySubdomains
|
||||
config["StrictTransportSecurityPreload"] = c.StrictTransportSecurityPreload
|
||||
}
|
||||
|
||||
// Frame options
|
||||
if c.FrameOptions != "" {
|
||||
config["FrameOptions"] = c.FrameOptions
|
||||
}
|
||||
|
||||
// Content type and XSS protection
|
||||
if c.ContentTypeOptions != "" {
|
||||
config["ContentTypeOptions"] = c.ContentTypeOptions
|
||||
}
|
||||
if c.XSSProtection != "" {
|
||||
config["XSSProtection"] = c.XSSProtection
|
||||
}
|
||||
|
||||
// Referrer and permissions policies
|
||||
if c.ReferrerPolicy != "" {
|
||||
config["ReferrerPolicy"] = c.ReferrerPolicy
|
||||
}
|
||||
if c.PermissionsPolicy != "" {
|
||||
config["PermissionsPolicy"] = c.PermissionsPolicy
|
||||
}
|
||||
|
||||
// Cross-origin policies
|
||||
if c.CrossOriginEmbedderPolicy != "" {
|
||||
config["CrossOriginEmbedderPolicy"] = c.CrossOriginEmbedderPolicy
|
||||
}
|
||||
if c.CrossOriginOpenerPolicy != "" {
|
||||
config["CrossOriginOpenerPolicy"] = c.CrossOriginOpenerPolicy
|
||||
}
|
||||
if c.CrossOriginResourcePolicy != "" {
|
||||
config["CrossOriginResourcePolicy"] = c.CrossOriginResourcePolicy
|
||||
}
|
||||
|
||||
// CORS configuration
|
||||
config["CORSEnabled"] = c.CORSEnabled
|
||||
if len(c.CORSAllowedOrigins) > 0 {
|
||||
config["CORSAllowedOrigins"] = c.CORSAllowedOrigins
|
||||
}
|
||||
if len(c.CORSAllowedMethods) > 0 {
|
||||
config["CORSAllowedMethods"] = c.CORSAllowedMethods
|
||||
}
|
||||
if len(c.CORSAllowedHeaders) > 0 {
|
||||
config["CORSAllowedHeaders"] = c.CORSAllowedHeaders
|
||||
}
|
||||
config["CORSAllowCredentials"] = c.CORSAllowCredentials
|
||||
if c.CORSMaxAge > 0 {
|
||||
config["CORSMaxAge"] = c.CORSMaxAge
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
if len(c.CustomHeaders) > 0 {
|
||||
config["CustomHeaders"] = c.CustomHeaders
|
||||
}
|
||||
|
||||
// Security features
|
||||
config["DisableServerHeader"] = c.DisableServerHeader
|
||||
config["DisablePoweredByHeader"] = c.DisablePoweredByHeader
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
// applyDefaultProfile applies default security settings
|
||||
func applyDefaultProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-origin"
|
||||
}
|
||||
|
||||
// applyStrictProfile applies strict security settings
|
||||
func applyStrictProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; script-src 'self'; style-src 'self'; img-src 'self'; font-src 'self'; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["PermissionsPolicy"] = "geolocation=(), microphone=(), camera=(), payment=(), usb=(), magnetometer=(), gyroscope=(), speaker=()"
|
||||
config["CrossOriginEmbedderPolicy"] = "require-corp"
|
||||
config["CrossOriginOpenerPolicy"] = "same-origin"
|
||||
config["CrossOriginResourcePolicy"] = "same-site"
|
||||
}
|
||||
|
||||
// applyDevelopmentProfile applies development-friendly settings
|
||||
func applyDevelopmentProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'self' 'unsafe-inline' 'unsafe-eval'; img-src 'self' data: https: http:; connect-src 'self' ws: wss:;"
|
||||
config["FrameOptions"] = "SAMEORIGIN"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginOpenerPolicy"] = "unsafe-none"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
config["DevelopmentMode"] = true
|
||||
}
|
||||
|
||||
// applyAPIProfile applies API-friendly settings
|
||||
func applyAPIProfile(config map[string]interface{}) {
|
||||
config["ContentSecurityPolicy"] = "default-src 'none'; frame-ancestors 'none';"
|
||||
config["FrameOptions"] = "DENY"
|
||||
config["ContentTypeOptions"] = "nosniff"
|
||||
config["XSSProtection"] = "1; mode=block"
|
||||
config["ReferrerPolicy"] = "strict-origin-when-cross-origin"
|
||||
config["CrossOriginResourcePolicy"] = "cross-origin"
|
||||
}
|
||||
|
||||
// GetSecurityHeadersApplier returns a function that applies security headers
|
||||
func (c *Config) GetSecurityHeadersApplier() func(http.ResponseWriter, *http.Request) {
|
||||
if c.SecurityHeaders == nil || !c.SecurityHeaders.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// This would need to import the internal security package
|
||||
// For now, return a basic implementation
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
headers := rw.Header()
|
||||
|
||||
// Apply basic security headers based on configuration
|
||||
if c.SecurityHeaders.FrameOptions != "" {
|
||||
headers.Set("X-Frame-Options", c.SecurityHeaders.FrameOptions)
|
||||
}
|
||||
if c.SecurityHeaders.ContentTypeOptions != "" {
|
||||
headers.Set("X-Content-Type-Options", c.SecurityHeaders.ContentTypeOptions)
|
||||
}
|
||||
if c.SecurityHeaders.XSSProtection != "" {
|
||||
headers.Set("X-XSS-Protection", c.SecurityHeaders.XSSProtection)
|
||||
}
|
||||
if c.SecurityHeaders.ReferrerPolicy != "" {
|
||||
headers.Set("Referrer-Policy", c.SecurityHeaders.ReferrerPolicy)
|
||||
}
|
||||
if c.SecurityHeaders.ContentSecurityPolicy != "" {
|
||||
headers.Set("Content-Security-Policy", c.SecurityHeaders.ContentSecurityPolicy)
|
||||
}
|
||||
|
||||
// HSTS for HTTPS
|
||||
if (req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https") && c.SecurityHeaders.StrictTransportSecurity {
|
||||
hstsValue := fmt.Sprintf("max-age=%d", c.SecurityHeaders.StrictTransportSecurityMaxAge)
|
||||
if c.SecurityHeaders.StrictTransportSecuritySubdomains {
|
||||
hstsValue += "; includeSubDomains"
|
||||
}
|
||||
if c.SecurityHeaders.StrictTransportSecurityPreload {
|
||||
hstsValue += "; preload"
|
||||
}
|
||||
headers.Set("Strict-Transport-Security", hstsValue)
|
||||
}
|
||||
|
||||
// CORS headers
|
||||
if c.SecurityHeaders.CORSEnabled {
|
||||
origin := req.Header.Get("Origin")
|
||||
if origin != "" && isOriginAllowed(origin, c.SecurityHeaders.CORSAllowedOrigins) {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
if len(c.SecurityHeaders.CORSAllowedMethods) > 0 {
|
||||
headers.Set("Access-Control-Allow-Methods", strings.Join(c.SecurityHeaders.CORSAllowedMethods, ", "))
|
||||
}
|
||||
if len(c.SecurityHeaders.CORSAllowedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(c.SecurityHeaders.CORSAllowedHeaders, ", "))
|
||||
}
|
||||
if c.SecurityHeaders.CORSAllowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if c.SecurityHeaders.CORSMaxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.SecurityHeaders.CORSMaxAge))
|
||||
}
|
||||
}
|
||||
|
||||
// Custom headers
|
||||
for name, value := range c.SecurityHeaders.CustomHeaders {
|
||||
headers.Set(name, value)
|
||||
}
|
||||
|
||||
// Remove server headers
|
||||
if c.SecurityHeaders.DisableServerHeader {
|
||||
headers.Del("Server")
|
||||
}
|
||||
if c.SecurityHeaders.DisablePoweredByHeader {
|
||||
headers.Del("X-Powered-By")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if an origin is in the allowed list
|
||||
func isOriginAllowed(origin string, allowedOrigins []string) bool {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if origin == allowed || allowed == "*" {
|
||||
return true
|
||||
}
|
||||
// Simple wildcard matching for subdomains
|
||||
if strings.Contains(allowed, "*") {
|
||||
if strings.HasPrefix(allowed, "https://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "https://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "https://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(allowed, "http://*.") {
|
||||
domain := strings.TrimPrefix(allowed, "http://*.")
|
||||
if strings.HasSuffix(origin, "."+domain) || origin == "http://"+domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -1,287 +0,0 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnifiedConfig is the master configuration structure consolidating all config aspects
|
||||
// This replaces 45 duplicate config structs across the codebase
|
||||
type UnifiedConfig struct {
|
||||
// Core Configuration
|
||||
Provider ProviderConfig `json:"provider" yaml:"provider"`
|
||||
Session SessionConfig `json:"session" yaml:"session"`
|
||||
Token TokenConfig `json:"token" yaml:"token"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Security SecurityConfig `json:"security" yaml:"security"`
|
||||
|
||||
// Middleware Configuration
|
||||
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
|
||||
Cache CacheConfig `json:"cache" yaml:"cache"`
|
||||
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
|
||||
|
||||
// Operational Configuration
|
||||
Logging LoggingConfig `json:"logging" yaml:"logging"`
|
||||
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
|
||||
Health HealthConfig `json:"health" yaml:"health"`
|
||||
|
||||
// Advanced Configuration
|
||||
Transport TransportConfig `json:"transport" yaml:"transport"`
|
||||
Pool PoolConfig `json:"pool" yaml:"pool"`
|
||||
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
|
||||
|
||||
// Compatibility field for migration
|
||||
Legacy map[string]interface{} `json:"-" yaml:"-"`
|
||||
}
|
||||
|
||||
// ProviderConfig contains OIDC provider settings
|
||||
type ProviderConfig struct {
|
||||
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
|
||||
ClientID string `json:"clientID" yaml:"clientID"`
|
||||
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
|
||||
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
|
||||
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
|
||||
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
|
||||
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
|
||||
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
|
||||
Discovery bool `json:"discovery" yaml:"discovery"`
|
||||
|
||||
// Provider-specific endpoints
|
||||
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
|
||||
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
|
||||
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
|
||||
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
|
||||
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
|
||||
}
|
||||
|
||||
// SessionConfig contains session management settings
|
||||
type SessionConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
MaxAge int `json:"maxAge" yaml:"maxAge"`
|
||||
Secret string `json:"secret" yaml:"secret"`
|
||||
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
|
||||
SigningKey string `json:"signingKey" yaml:"signingKey"`
|
||||
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
|
||||
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
|
||||
|
||||
// Cookie settings
|
||||
Domain string `json:"domain" yaml:"domain"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
Secure bool `json:"secure" yaml:"secure"`
|
||||
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
|
||||
SameSite string `json:"sameSite" yaml:"sameSite"`
|
||||
CookiePrefix string `json:"cookiePrefix" yaml:"cookiePrefix"` // Prefix for cookie names (e.g., "_oidc_myapp_")
|
||||
|
||||
// Storage settings
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
}
|
||||
|
||||
// TokenConfig contains token handling settings
|
||||
type TokenConfig struct {
|
||||
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
|
||||
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
|
||||
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
|
||||
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
|
||||
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
|
||||
|
||||
// Token caching
|
||||
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
|
||||
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
|
||||
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
|
||||
|
||||
// Token validation
|
||||
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
|
||||
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
|
||||
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
|
||||
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
|
||||
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
|
||||
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
|
||||
}
|
||||
|
||||
// SecurityConfig contains security-related settings
|
||||
type SecurityConfig struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
|
||||
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
|
||||
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
|
||||
|
||||
// CSRF protection
|
||||
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
|
||||
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
|
||||
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
|
||||
|
||||
// Additional security
|
||||
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
|
||||
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
|
||||
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
|
||||
}
|
||||
|
||||
// MiddlewareConfig contains middleware-specific settings
|
||||
type MiddlewareConfig struct {
|
||||
Priority int `json:"priority" yaml:"priority"`
|
||||
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
|
||||
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
|
||||
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
|
||||
|
||||
// Request handling
|
||||
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
|
||||
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
|
||||
// Response handling
|
||||
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
|
||||
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache configuration
|
||||
type CacheConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
|
||||
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
|
||||
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
|
||||
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
|
||||
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
|
||||
|
||||
// Memory cache settings
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
|
||||
// Distributed cache settings
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Compression bool `json:"compression" yaml:"compression"`
|
||||
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
|
||||
Burst int `json:"burst" yaml:"burst"`
|
||||
|
||||
// Rate limit storage
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
|
||||
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
|
||||
|
||||
// Rate limit keys
|
||||
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
|
||||
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
|
||||
|
||||
// Whitelisting
|
||||
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
|
||||
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
|
||||
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
|
||||
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
|
||||
FilePath string `json:"filePath" yaml:"filePath"`
|
||||
|
||||
// Log filtering
|
||||
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
|
||||
MaskFields []string `json:"maskFields" yaml:"maskFields"`
|
||||
|
||||
// Performance
|
||||
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
|
||||
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
|
||||
|
||||
// Audit logging
|
||||
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
|
||||
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
|
||||
}
|
||||
|
||||
// MetricsConfig contains metrics collection configuration
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
|
||||
Endpoint string `json:"endpoint" yaml:"endpoint"`
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Subsystem string `json:"subsystem" yaml:"subsystem"`
|
||||
|
||||
// Collection settings
|
||||
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
|
||||
Histograms bool `json:"histograms" yaml:"histograms"`
|
||||
|
||||
// Custom labels
|
||||
Labels map[string]string `json:"labels" yaml:"labels"`
|
||||
}
|
||||
|
||||
// HealthConfig contains health check configuration
|
||||
type HealthConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
|
||||
// Checks to perform
|
||||
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
|
||||
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
|
||||
CheckCache bool `json:"checkCache" yaml:"checkCache"`
|
||||
|
||||
// Thresholds
|
||||
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
|
||||
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
|
||||
}
|
||||
|
||||
// TransportConfig contains HTTP transport configuration
|
||||
type TransportConfig struct {
|
||||
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
|
||||
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
|
||||
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
|
||||
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
|
||||
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
|
||||
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
|
||||
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
|
||||
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
|
||||
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
|
||||
|
||||
// TLS configuration
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
|
||||
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
|
||||
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
|
||||
|
||||
// Proxy settings
|
||||
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
|
||||
NoProxy []string `json:"noProxy" yaml:"noProxy"`
|
||||
}
|
||||
|
||||
// PoolConfig contains connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Size int `json:"size" yaml:"size"`
|
||||
MinSize int `json:"minSize" yaml:"minSize"`
|
||||
MaxSize int `json:"maxSize" yaml:"maxSize"`
|
||||
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
|
||||
|
||||
// Health checking
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// CircuitConfig contains circuit breaker configuration
|
||||
type CircuitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
|
||||
Interval time.Duration `json:"interval" yaml:"interval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
|
||||
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
|
||||
|
||||
// Circuit states
|
||||
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
|
||||
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
|
||||
|
||||
// Monitoring
|
||||
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
|
||||
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
|
||||
}
|
||||
@@ -1,263 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
|
||||
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Session.Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("Session.EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("Session.SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Redis.Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("Redis.SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
|
||||
t.Error("IssuerURL should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
|
||||
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
yamlBytes, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "secret: '[REDACTED]'") {
|
||||
t.Error("Session.Secret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
|
||||
t.Error("Session.EncryptionKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
|
||||
t.Error("Session.SigningKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "password: '[REDACTED]'") {
|
||||
t.Error("Redis.Password should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
|
||||
t.Error("Redis.SentinelPassword should be redacted in YAML output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
|
||||
t.Error("IssuerURL should be preserved in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderConfigMarshalling tests individual struct marshalling
|
||||
func TestProviderConfigMarshalling(t *testing.T) {
|
||||
provider := ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
|
||||
// Test YAML marshalling
|
||||
yamlBytes, err := yaml.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionConfigMarshalling tests session config marshalling
|
||||
func TestSessionConfigMarshalling(t *testing.T) {
|
||||
session := SessionConfig{
|
||||
Name: "session-cookie",
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
Domain: "example.com",
|
||||
Secure: true,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"name":"session-cookie"`) {
|
||||
t.Error("Name should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"domain":"example.com"`) {
|
||||
t.Error("Domain should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisConfigMarshalling tests Redis config marshalling
|
||||
func TestRedisConfigMarshalling(t *testing.T) {
|
||||
redis := RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
Addr: "localhost:6379",
|
||||
DB: 1,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(redis)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"addr":"localhost:6379"`) {
|
||||
t.Error("Addr should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"db":1`) {
|
||||
t.Error("DB should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
|
||||
func TestEmptySecretsNotRedacted(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "", // Empty secret
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "", // Empty secret
|
||||
EncryptionKey: "", // Empty secret
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "", // Empty secret
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify empty secrets are not shown as redacted
|
||||
if contains(jsonStr, "[REDACTED]") {
|
||||
t.Error("Empty secrets should not be shown as [REDACTED]")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
@@ -1,652 +0,0 @@
|
||||
// Package config provides validation for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationError represents a configuration validation error
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ValidationError) Error() string {
|
||||
if e.Value != nil {
|
||||
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
|
||||
}
|
||||
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidationErrors represents multiple validation errors
|
||||
type ValidationErrors []ValidationError
|
||||
|
||||
// Error implements the error interface
|
||||
func (e ValidationErrors) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var messages []string
|
||||
for _, err := range e {
|
||||
messages = append(messages, err.Error())
|
||||
}
|
||||
return strings.Join(messages, "; ")
|
||||
}
|
||||
|
||||
// Validate performs comprehensive validation on the unified configuration
|
||||
func (c *UnifiedConfig) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate Provider configuration
|
||||
if err := c.validateProvider(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Session configuration
|
||||
if err := c.validateSession(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Token configuration
|
||||
if err := c.validateToken(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Redis configuration (uses existing validation)
|
||||
if err := c.Redis.Validate(); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Redis",
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate Security configuration
|
||||
if err := c.validateSecurity(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Middleware configuration
|
||||
if err := c.validateMiddleware(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Cache configuration
|
||||
if err := c.validateCache(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate RateLimit configuration
|
||||
if err := c.validateRateLimit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Logging configuration
|
||||
if err := c.validateLogging(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Metrics configuration
|
||||
if err := c.validateMetrics(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Transport configuration
|
||||
if err := c.validateTransport(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Circuit configuration
|
||||
if err := c.validateCircuit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProvider validates provider configuration
|
||||
func (c *UnifiedConfig) validateProvider() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// IssuerURL is required and must be a valid URL
|
||||
if c.Provider.IssuerURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "issuer URL is required",
|
||||
})
|
||||
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "invalid issuer URL",
|
||||
Value: c.Provider.IssuerURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ClientID is required
|
||||
if c.Provider.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
// ClientSecret is required (except for public clients with PKCE)
|
||||
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE for public clients)",
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectURL must be valid if provided
|
||||
if c.Provider.RedirectURL != "" {
|
||||
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.RedirectURL",
|
||||
Message: "invalid redirect URL",
|
||||
Value: c.Provider.RedirectURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes must include 'openid' for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range c.Provider.Scopes {
|
||||
if scope == "openid" {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID && !c.Provider.OverrideScopes {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.Scopes",
|
||||
Message: "scopes must include 'openid' for OIDC",
|
||||
Value: c.Provider.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// JWK cache period must be positive
|
||||
if c.Provider.JWKCachePeriod < 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.JWKCachePeriod",
|
||||
Message: "JWK cache period must be positive",
|
||||
Value: c.Provider.JWKCachePeriod,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSession validates session configuration
|
||||
func (c *UnifiedConfig) validateSession() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Session name must not be empty
|
||||
if c.Session.Name == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.Name",
|
||||
Message: "session name is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Session secret or encryption key is required
|
||||
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session",
|
||||
Message: "either session secret or encryption key is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Encryption key must be at least 32 bytes for security
|
||||
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "encryption key must be at least 32 characters for proper security",
|
||||
Value: len(c.Session.EncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
// ChunkSize must be reasonable (between 1KB and 10KB)
|
||||
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.ChunkSize",
|
||||
Message: "chunk size must be between 1000 and 10000 bytes",
|
||||
Value: c.Session.ChunkSize,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxChunks must be reasonable (between 1 and 100)
|
||||
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.MaxChunks",
|
||||
Message: "max chunks must be between 1 and 100",
|
||||
Value: c.Session.MaxChunks,
|
||||
})
|
||||
}
|
||||
|
||||
// SameSite must be valid
|
||||
validSameSite := map[string]bool{
|
||||
"": true,
|
||||
"Lax": true,
|
||||
"Strict": true,
|
||||
"None": true,
|
||||
}
|
||||
if !validSameSite[c.Session.SameSite] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.SameSite",
|
||||
Message: "invalid SameSite value (must be Lax, Strict, or None)",
|
||||
Value: c.Session.SameSite,
|
||||
})
|
||||
}
|
||||
|
||||
// StorageType must be valid
|
||||
validStorage := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"cookie": true,
|
||||
}
|
||||
if !validStorage[c.Session.StorageType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.StorageType",
|
||||
Message: "invalid storage type (must be memory, redis, or cookie)",
|
||||
Value: c.Session.StorageType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateToken validates token configuration
|
||||
func (c *UnifiedConfig) validateToken() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Token TTLs must be positive
|
||||
if c.Token.AccessTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.AccessTokenTTL",
|
||||
Message: "access token TTL must be positive",
|
||||
Value: c.Token.AccessTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Token.RefreshTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.RefreshTokenTTL",
|
||||
Message: "refresh token TTL must be positive",
|
||||
Value: c.Token.RefreshTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
// Validation mode must be valid
|
||||
validModes := map[string]bool{
|
||||
"jwt": true,
|
||||
"introspect": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validModes[c.Token.ValidationMode] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ValidationMode",
|
||||
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
|
||||
Value: c.Token.ValidationMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect URL required for introspect or hybrid mode
|
||||
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.IntrospectURL",
|
||||
Message: "introspect URL is required for introspect or hybrid validation mode",
|
||||
})
|
||||
}
|
||||
|
||||
// Clock skew must be reasonable (0 to 10 minutes)
|
||||
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ClockSkew",
|
||||
Message: "clock skew must be between 0 and 10 minutes",
|
||||
Value: c.Token.ClockSkew,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSecurity validates security configuration
|
||||
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate allowed user domains are valid domains
|
||||
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
|
||||
for _, domain := range c.Security.AllowedUserDomains {
|
||||
if !domainRegex.MatchString(domain) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.AllowedUserDomains",
|
||||
Message: "invalid domain format",
|
||||
Value: domain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Max login attempts must be reasonable
|
||||
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.MaxLoginAttempts",
|
||||
Message: "max login attempts must be between 0 and 100",
|
||||
Value: c.Security.MaxLoginAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
// Lockout duration must be reasonable
|
||||
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.LockoutDuration",
|
||||
Message: "lockout duration must be between 0 and 24 hours",
|
||||
Value: c.Security.LockoutDuration,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMiddleware validates middleware configuration
|
||||
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max request size must be reasonable (1KB to 100MB)
|
||||
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.MaxRequestSize",
|
||||
Message: "max request size must be between 1KB and 100MB",
|
||||
Value: c.Middleware.MaxRequestSize,
|
||||
})
|
||||
}
|
||||
|
||||
// Request timeout must be reasonable
|
||||
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.RequestTimeout",
|
||||
Message: "request timeout must be between 1 second and 5 minutes",
|
||||
Value: c.Middleware.RequestTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCache validates cache configuration
|
||||
func (c *UnifiedConfig) validateCache() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Cache.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Cache type must be valid
|
||||
validTypes := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validTypes[c.Cache.Type] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.Type",
|
||||
Message: "invalid cache type (must be memory, redis, or hybrid)",
|
||||
Value: c.Cache.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// Max entries must be reasonable
|
||||
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.MaxEntries",
|
||||
Message: "max entries must be between 10 and 1000000",
|
||||
Value: c.Cache.MaxEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// Eviction policy must be valid
|
||||
validEviction := map[string]bool{
|
||||
"lru": true,
|
||||
"lfu": true,
|
||||
"fifo": true,
|
||||
}
|
||||
if !validEviction[c.Cache.EvictionPolicy] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.EvictionPolicy",
|
||||
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
|
||||
Value: c.Cache.EvictionPolicy,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateRateLimit validates rate limiting configuration
|
||||
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.RateLimit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Requests per second must be reasonable
|
||||
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.RequestsPerSecond",
|
||||
Message: "requests per second must be between 1 and 10000",
|
||||
Value: c.RateLimit.RequestsPerSecond,
|
||||
})
|
||||
}
|
||||
|
||||
// Burst must be at least as large as requests per second
|
||||
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.Burst",
|
||||
Message: "burst must be at least as large as requests per second",
|
||||
Value: c.RateLimit.Burst,
|
||||
})
|
||||
}
|
||||
|
||||
// Key type must be valid
|
||||
validKeyTypes := map[string]bool{
|
||||
"ip": true,
|
||||
"user": true,
|
||||
"token": true,
|
||||
"custom": true,
|
||||
}
|
||||
if !validKeyTypes[c.RateLimit.KeyType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.KeyType",
|
||||
Message: "invalid key type (must be ip, user, token, or custom)",
|
||||
Value: c.RateLimit.KeyType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateLogging validates logging configuration
|
||||
func (c *UnifiedConfig) validateLogging() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Log level must be valid
|
||||
validLevels := map[string]bool{
|
||||
"debug": true,
|
||||
"info": true,
|
||||
"warn": true,
|
||||
"error": true,
|
||||
}
|
||||
if !validLevels[c.Logging.Level] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Level",
|
||||
Message: "invalid log level (must be debug, info, warn, or error)",
|
||||
Value: c.Logging.Level,
|
||||
})
|
||||
}
|
||||
|
||||
// Format must be valid
|
||||
validFormats := map[string]bool{
|
||||
"json": true,
|
||||
"text": true,
|
||||
"structured": true,
|
||||
}
|
||||
if !validFormats[c.Logging.Format] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Format",
|
||||
Message: "invalid log format (must be json, text, or structured)",
|
||||
Value: c.Logging.Format,
|
||||
})
|
||||
}
|
||||
|
||||
// Output must be valid
|
||||
validOutputs := map[string]bool{
|
||||
"stdout": true,
|
||||
"stderr": true,
|
||||
"file": true,
|
||||
}
|
||||
if !validOutputs[c.Logging.Output] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Output",
|
||||
Message: "invalid log output (must be stdout, stderr, or file)",
|
||||
Value: c.Logging.Output,
|
||||
})
|
||||
}
|
||||
|
||||
// File path required if output is file
|
||||
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.FilePath",
|
||||
Message: "file path is required when output is 'file'",
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMetrics validates metrics configuration
|
||||
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Metrics.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Provider must be valid
|
||||
validProviders := map[string]bool{
|
||||
"prometheus": true,
|
||||
"statsd": true,
|
||||
"otlp": true,
|
||||
}
|
||||
if !validProviders[c.Metrics.Provider] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Provider",
|
||||
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
|
||||
Value: c.Metrics.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
// Endpoint required for some providers
|
||||
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Endpoint",
|
||||
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateTransport validates transport configuration
|
||||
func (c *UnifiedConfig) validateTransport() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max connections must be reasonable
|
||||
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.MaxIdleConns",
|
||||
Message: "max idle connections must be between 0 and 10000",
|
||||
Value: c.Transport.MaxIdleConns,
|
||||
})
|
||||
}
|
||||
|
||||
// TLS min version must be valid
|
||||
validTLSVersions := map[string]bool{
|
||||
"TLS1.0": true,
|
||||
"TLS1.1": true,
|
||||
"TLS1.2": true,
|
||||
"TLS1.3": true,
|
||||
}
|
||||
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.TLSMinVersion",
|
||||
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
|
||||
Value: c.Transport.TLSMinVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// Proxy URL must be valid if provided
|
||||
if c.Transport.ProxyURL != "" {
|
||||
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.ProxyURL",
|
||||
Message: "invalid proxy URL",
|
||||
Value: c.Transport.ProxyURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCircuit validates circuit breaker configuration
|
||||
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Circuit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Consecutive failures must be reasonable
|
||||
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.ConsecutiveFailures",
|
||||
Message: "consecutive failures must be between 1 and 100",
|
||||
Value: c.Circuit.ConsecutiveFailures,
|
||||
})
|
||||
}
|
||||
|
||||
// Failure ratio must be between 0 and 1
|
||||
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.FailureRatio",
|
||||
Message: "failure ratio must be between 0 and 1",
|
||||
Value: c.Circuit.FailureRatio,
|
||||
})
|
||||
}
|
||||
|
||||
// OnOpen action must be valid
|
||||
validActions := map[string]bool{
|
||||
"reject": true,
|
||||
"fallback": true,
|
||||
"passthrough": true,
|
||||
}
|
||||
if !validActions[c.Circuit.OnOpen] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.OnOpen",
|
||||
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
|
||||
Value: c.Circuit.OnOpen,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
@@ -1,588 +0,0 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
|
||||
func TestValidateUnifiedConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *UnifiedConfig
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "valid config with minimum requirements",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "oidc_session",
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 5,
|
||||
StorageType: "cookie",
|
||||
},
|
||||
Token: TokenConfig{
|
||||
AccessTokenTTL: time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
ValidationMode: "jwt",
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
MaxRequestSize: 10 * 1024 * 1024,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.IssuerURL",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.ClientID",
|
||||
},
|
||||
{
|
||||
name: "encryption key too short",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "too-short",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.EncryptionKey",
|
||||
},
|
||||
{
|
||||
name: "invalid chunk size",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 500, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.ChunkSize",
|
||||
},
|
||||
{
|
||||
name: "invalid max chunks",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 0, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.MaxChunks",
|
||||
},
|
||||
{
|
||||
name: "invalid TLS min version",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Transport: TransportConfig{
|
||||
TLSMinVersion: "1.0", // Too old
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Transport.TLSMinVersion",
|
||||
},
|
||||
{
|
||||
name: "invalid circuit breaker failure ratio",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Circuit: CircuitConfig{
|
||||
Enabled: true,
|
||||
FailureRatio: 1.5, // Too high
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Circuit.FailureRatio",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
|
||||
} else if validationErrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, e := range validationErrs {
|
||||
if e.Field == tt.errorField {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected validation error for field %s, but got errors for: %v",
|
||||
tt.errorField, validationErrs)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidationErrorMessage tests validation error formatting
|
||||
func TestValidationErrorMessage(t *testing.T) {
|
||||
errs := ValidationErrors{
|
||||
{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "is required",
|
||||
Value: nil,
|
||||
},
|
||||
{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "must be at least 32 characters",
|
||||
Value: 16,
|
||||
},
|
||||
}
|
||||
|
||||
errMsg := errs.Error()
|
||||
|
||||
if !strings.Contains(errMsg, "Provider.IssuerURL") {
|
||||
t.Error("Error message should contain field name Provider.IssuerURL")
|
||||
}
|
||||
if !strings.Contains(errMsg, "is required") {
|
||||
t.Error("Error message should contain 'is required'")
|
||||
}
|
||||
if !strings.Contains(errMsg, "Session.EncryptionKey") {
|
||||
t.Error("Error message should contain field name Session.EncryptionKey")
|
||||
}
|
||||
if !strings.Contains(errMsg, "must be at least 32 characters") {
|
||||
t.Error("Error message should contain 'must be at least 32 characters'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedisConfig tests Redis configuration validation
|
||||
func TestValidateRedisConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *RedisConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid standalone config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing address for standalone",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Redis address is required",
|
||||
},
|
||||
{
|
||||
name: "valid cluster config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing cluster addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "cluster address is required",
|
||||
},
|
||||
{
|
||||
name: "valid sentinel config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing master name for sentinel",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Master name is required",
|
||||
},
|
||||
{
|
||||
name: "missing sentinel addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "sentinel address is required",
|
||||
},
|
||||
{
|
||||
name: "disabled redis needs no validation",
|
||||
config: &RedisConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid redis mode",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: "invalid-mode",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Invalid Redis mode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateRateLimit Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateRateLimit_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = false
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when rate limiting is disabled")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidConfig(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid rate limit config")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooLow(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0
|
||||
config.RateLimit.Burst = 100
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooHigh(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 15000
|
||||
config.RateLimit.Burst = 20000
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_BurstTooSmall(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 50 // Less than RequestsPerSecond
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.Burst", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "at least as large as requests per second")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_InvalidKeyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType string
|
||||
}{
|
||||
{"empty key type", ""},
|
||||
{"invalid key type", "invalid"},
|
||||
{"random string", "foobar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = tt.keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.KeyType", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid key type")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidKeyTypes(t *testing.T) {
|
||||
validKeyTypes := []string{"ip", "user", "token", "custom"}
|
||||
|
||||
for _, keyType := range validKeyTypes {
|
||||
t.Run(keyType, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid key type: %s", keyType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0 // Too low
|
||||
config.RateLimit.Burst = 50 // Will pass (0 < 50)
|
||||
config.RateLimit.KeyType = "invalid" // Invalid
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
// Should have 2 errors (rps and keyType)
|
||||
assert.Len(t, errors, 2)
|
||||
|
||||
// Check each error is present
|
||||
fields := make(map[string]bool)
|
||||
for _, err := range errors {
|
||||
fields[err.Field] = true
|
||||
}
|
||||
assert.True(t, fields["RateLimit.RequestsPerSecond"])
|
||||
assert.True(t, fields["RateLimit.KeyType"])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateMetrics Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateMetrics_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = false
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when metrics are disabled")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidPrometheus(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "prometheus"
|
||||
config.Metrics.Endpoint = "" // Prometheus doesn't require endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid prometheus config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidStatsd(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "localhost:8125"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid statsd config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidOTLP(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "localhost:4317"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid otlp config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_InvalidProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
}{
|
||||
{"empty provider", ""},
|
||||
{"invalid provider", "invalid"},
|
||||
{"datadog", "datadog"},
|
||||
{"influx", "influx"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = tt.provider
|
||||
config.Metrics.Endpoint = "localhost:8080"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid metrics provider")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMetrics_StatsdMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for statsd provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_OTLPMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for otlp provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "invalid" // Invalid provider
|
||||
config.Metrics.Endpoint = "" // Would be missing if provider was statsd/otlp
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
// Should have at least 1 error for invalid provider
|
||||
assert.NotEmpty(t, errors)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,456 @@
|
||||
# Configuration Reference
|
||||
|
||||
Complete reference for all Traefik OIDC middleware configuration options.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Required Parameters](#required-parameters)
|
||||
- [Optional Parameters](#optional-parameters)
|
||||
- [Security Options](#security-options)
|
||||
- [Session Management](#session-management)
|
||||
- [Access Control](#access-control)
|
||||
- [Headers Configuration](#headers-configuration)
|
||||
- [Security Headers](#security-headers)
|
||||
- [Scope Configuration](#scope-configuration)
|
||||
- [Advanced Options](#advanced-options)
|
||||
|
||||
---
|
||||
|
||||
## Required Parameters
|
||||
|
||||
| Parameter | Type | Description | Example |
|
||||
|-----------|------|-------------|---------|
|
||||
| `providerURL` | string | Base URL of the OIDC provider | `https://accounts.google.com` |
|
||||
| `clientID` | string | OAuth 2.0 client identifier | `1234567890.apps.googleusercontent.com` |
|
||||
| `clientSecret` | string | OAuth 2.0 client secret | `your-client-secret` |
|
||||
| `sessionEncryptionKey` | string | Key for encrypting session data (min 32 bytes) | `your-32-byte-encryption-key-here` |
|
||||
| `callbackURL` | string | Path where provider redirects after authentication | `/oauth2/callback` |
|
||||
|
||||
### Basic Configuration Example
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-32-byte-encryption-key-here
|
||||
callbackURL: /oauth2/callback
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `logoutURL` | string | `callbackURL + "/logout"` | Path for logout requests |
|
||||
| `postLogoutRedirectURI` | string | `/` | Redirect URL after logout |
|
||||
| `logLevel` | string | `info` | Logging verbosity (`debug`, `info`, `error`) |
|
||||
| `forceHTTPS` | bool | `false` | Force HTTPS for redirect URIs |
|
||||
| `rateLimit` | int | `100` | Maximum requests per second |
|
||||
| `excludedURLs` | []string | none | Paths that bypass authentication |
|
||||
| `revocationURL` | string | auto-discovered | Token revocation endpoint |
|
||||
| `oidcEndSessionURL` | string | auto-discovered | Provider's end session endpoint |
|
||||
| `enablePKCE` | bool | `false` | Enable PKCE for authorization code flow |
|
||||
| `minimalHeaders` | bool | `false` | Reduce forwarded headers |
|
||||
|
||||
### TLS Termination at Load Balancer
|
||||
|
||||
If running Traefik behind a load balancer (AWS ALB, Google Cloud LB, Azure App Gateway) that terminates TLS:
|
||||
|
||||
```yaml
|
||||
forceHTTPS: true # Required for correct redirect URIs
|
||||
```
|
||||
|
||||
Without this setting, redirect URIs will use `http://` instead of `https://`, causing OAuth callback failures.
|
||||
|
||||
---
|
||||
|
||||
## Security Options
|
||||
|
||||
### Audience Validation
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `audience` | string | `clientID` | Expected audience for access token validation |
|
||||
| `strictAudienceValidation` | bool | `false` | Reject sessions with audience mismatch |
|
||||
| `allowOpaqueTokens` | bool | `false` | Enable opaque token support via RFC 7662 |
|
||||
| `requireTokenIntrospection` | bool | `false` | Require introspection for opaque tokens |
|
||||
|
||||
#### Production Security Configuration
|
||||
|
||||
```yaml
|
||||
audience: "https://my-api.example.com"
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
#### Opaque Token Support
|
||||
|
||||
```yaml
|
||||
allowOpaqueTokens: true
|
||||
requireTokenIntrospection: true
|
||||
strictAudienceValidation: true
|
||||
```
|
||||
|
||||
### Other Security Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `disableReplayDetection` | bool | `false` | Disable JTI-based replay attack detection |
|
||||
| `allowPrivateIPAddresses` | bool | `false` | Allow private IPs in provider URLs |
|
||||
|
||||
---
|
||||
|
||||
## Session Management
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `sessionMaxAge` | int | `86400` (24h) | Maximum session age in seconds |
|
||||
| `refreshGracePeriodSeconds` | int | `60` | Seconds before expiry to attempt refresh |
|
||||
| `cookieDomain` | string | auto-detected | Domain for session cookies |
|
||||
| `cookiePrefix` | string | `_oidc_raczylo_` | Prefix for cookie names |
|
||||
|
||||
### Multi-Subdomain Setup
|
||||
|
||||
```yaml
|
||||
cookieDomain: .example.com # Share cookies across subdomains
|
||||
```
|
||||
|
||||
### Multiple Middleware Instances
|
||||
|
||||
When running multiple middleware instances with different authorization requirements, use unique prefixes:
|
||||
|
||||
```yaml
|
||||
# User authentication middleware
|
||||
---
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-userauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_userauth_"
|
||||
sessionEncryptionKey: user-encryption-key-min-32-bytes
|
||||
# ... other config
|
||||
---
|
||||
# Admin authentication middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-adminauth
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
cookiePrefix: "_oidc_adminauth_"
|
||||
sessionEncryptionKey: admin-encryption-key-min-32-bytes
|
||||
allowedUsers:
|
||||
- admin@example.com
|
||||
# ... other config
|
||||
```
|
||||
|
||||
### Extended Session Duration
|
||||
|
||||
```yaml
|
||||
sessionMaxAge: 604800 # 7 days
|
||||
# Common values:
|
||||
# 3600 - 1 hour (high security)
|
||||
# 86400 - 1 day (default)
|
||||
# 259200 - 3 days
|
||||
# 604800 - 7 days
|
||||
# 2592000 - 30 days
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Access Control
|
||||
|
||||
### User Restrictions
|
||||
|
||||
| Parameter | Type | Description |
|
||||
|-----------|------|-------------|
|
||||
| `allowedUserDomains` | []string | Restrict to specific email domains |
|
||||
| `allowedUsers` | []string | Specific email addresses allowed |
|
||||
| `allowedRolesAndGroups` | []string | Required roles or groups |
|
||||
| `roleClaimName` | string | JWT claim for roles (default: `roles`) |
|
||||
| `groupClaimName` | string | JWT claim for groups (default: `groups`) |
|
||||
| `userIdentifierClaim` | string | Claim for user ID (default: `email`) |
|
||||
|
||||
### Domain Restriction
|
||||
|
||||
```yaml
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
- subsidiary.com
|
||||
```
|
||||
|
||||
### Specific User Access
|
||||
|
||||
```yaml
|
||||
allowedUsers:
|
||||
- user@example.com
|
||||
- contractor@external.org
|
||||
```
|
||||
|
||||
### Role-Based Access Control
|
||||
|
||||
```yaml
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- developer
|
||||
roleClaimName: "https://myapp.com/roles" # For namespaced claims (Auth0)
|
||||
```
|
||||
|
||||
### Access Control Logic
|
||||
|
||||
- If only `allowedUsers` is set: Only specified emails can access
|
||||
- If only `allowedUserDomains` is set: Only specified domains can access
|
||||
- If both are set: Access granted if email is in `allowedUsers` OR domain is in `allowedUserDomains`
|
||||
- If neither is set: Any authenticated user can access
|
||||
|
||||
### Users Without Email (Azure AD)
|
||||
|
||||
For Azure AD service accounts or users without email:
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # User object ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Headers Configuration
|
||||
|
||||
### Default Headers
|
||||
|
||||
The middleware sets these headers for downstream services:
|
||||
|
||||
| Header | Description |
|
||||
|--------|-------------|
|
||||
| `X-Forwarded-User` | User's email address |
|
||||
| `X-User-Groups` | Comma-separated user groups |
|
||||
| `X-User-Roles` | Comma-separated user roles |
|
||||
| `X-Auth-Request-Redirect` | Original request URI |
|
||||
| `X-Auth-Request-User` | User's email address |
|
||||
| `X-Auth-Request-Token` | User's ID token |
|
||||
|
||||
### Minimal Headers Mode
|
||||
|
||||
For "431 Request Header Fields Too Large" errors:
|
||||
|
||||
```yaml
|
||||
minimalHeaders: true # Only forwards X-Forwarded-User
|
||||
```
|
||||
|
||||
### Custom Templated Headers
|
||||
|
||||
```yaml
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{{{.Claims.email}}}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{{{.Claims.sub}}}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
```
|
||||
|
||||
**Template Variables:**
|
||||
- `{{.Claims.field}}` - ID token claims
|
||||
- `{{.AccessToken}}` - Raw access token
|
||||
- `{{.IdToken}}` - Raw ID token
|
||||
- `{{.RefreshToken}}` - Raw refresh token
|
||||
|
||||
**Important:** Use double curly braces (`{{{{` and `}}}}`) to escape templates in YAML.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers
|
||||
|
||||
### Security Profiles
|
||||
|
||||
| Profile | Use Case | Security Level |
|
||||
|---------|----------|----------------|
|
||||
| `default` | Standard web apps | High |
|
||||
| `strict` | Maximum security | Very High |
|
||||
| `development` | Local development | Medium |
|
||||
| `api` | API endpoints | High |
|
||||
| `custom` | Custom requirements | Configurable |
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
### API with CORS
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.example.com"
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self'"
|
||||
|
||||
# HSTS
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and Content Protection
|
||||
frameOptions: "DENY"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# CORS
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400
|
||||
|
||||
# Custom Headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "value"
|
||||
|
||||
# Server Identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### CORS Origin Patterns
|
||||
|
||||
```yaml
|
||||
corsAllowedOrigins:
|
||||
- "https://example.com" # Exact match
|
||||
- "https://*.example.com" # Subdomain wildcard
|
||||
- "http://localhost:*" # Port wildcard (development)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
### Default Behavior (Append Mode)
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- roles
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "email", "roles", "custom_scope"]
|
||||
```
|
||||
|
||||
### Override Mode
|
||||
|
||||
```yaml
|
||||
overrideScopes: true
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- custom_scope
|
||||
# Result: ["openid", "profile", "custom_scope"]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Advanced Options
|
||||
|
||||
### Dynamic Client Registration (RFC 7591)
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
initialAccessToken: "your-token" # Optional
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-credentials.json"
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
### Multi-Replica Deployment
|
||||
|
||||
Without Redis, disable replay detection:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true
|
||||
```
|
||||
|
||||
With Redis (recommended):
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
See [REDIS.md](REDIS.md) for complete Redis configuration.
|
||||
|
||||
---
|
||||
|
||||
## Kubernetes Secrets
|
||||
|
||||
Reference secrets instead of hardcoding sensitive values:
|
||||
|
||||
```yaml
|
||||
providerURL: urn:k8s:secret:oidc-secret:ISSUER
|
||||
clientID: urn:k8s:secret:oidc-secret:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:oidc-secret:SECRET
|
||||
```
|
||||
|
||||
Create the secret:
|
||||
|
||||
```bash
|
||||
kubectl create secret generic oidc-secret \
|
||||
--from-literal=ISSUER=https://accounts.google.com \
|
||||
--from-literal=CLIENT_ID=your-client-id \
|
||||
--from-literal=SECRET=your-client-secret \
|
||||
-n traefik
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Environment Variable Naming
|
||||
|
||||
**Important:** Avoid using "API" as a substring in environment variable names when using `${VAR}` syntax in Traefik configuration. Traefik reserves `TRAEFIK_API_*` variables and the substring may cause conflicts.
|
||||
|
||||
```yaml
|
||||
# Bad - may cause issues
|
||||
sessionEncryptionKey: ${OIDC_SECRET_API}
|
||||
|
||||
# Good
|
||||
sessionEncryptionKey: ${OIDC_SECRET_SVC}
|
||||
```
|
||||
@@ -0,0 +1,455 @@
|
||||
# Development Guide
|
||||
|
||||
Guide for local development, testing, and contributing to the Traefik OIDC middleware.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Local Development Setup](#local-development-setup)
|
||||
- [Running Tests](#running-tests)
|
||||
- [Test Categories](#test-categories)
|
||||
- [CI/CD Pipeline](#cicd-pipeline)
|
||||
- [Code Quality](#code-quality)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Go 1.23+** for plugin compilation
|
||||
- **Docker & Docker Compose** for local testing
|
||||
- **OIDC Provider** credentials (Google, Azure, etc.)
|
||||
|
||||
### Required Development Tools
|
||||
|
||||
```bash
|
||||
# golangci-lint (comprehensive linting)
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# staticcheck (static analysis)
|
||||
go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
# gosec (security scanning)
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
|
||||
# govulncheck (vulnerability scanning)
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Local Development Setup
|
||||
|
||||
### Docker Compose Environment
|
||||
|
||||
The repository includes a Docker Compose setup for testing the plugin locally.
|
||||
|
||||
#### 1. Host Configuration
|
||||
|
||||
Add to `/etc/hosts`:
|
||||
|
||||
```bash
|
||||
127.0.0.1 hello.localhost
|
||||
127.0.0.1 traefik.localhost
|
||||
```
|
||||
|
||||
#### 2. Plugin Configuration
|
||||
|
||||
The plugin is loaded using Traefik's **local plugins mode**:
|
||||
|
||||
- Plugin source: Parent directory (`../`)
|
||||
- Mount path: `/plugins-local/src/github.com/lukaszraczylo/traefikoidc`
|
||||
- Configuration: `experimental.localPlugins` in `traefik.yml`
|
||||
|
||||
#### 3. OIDC Provider Setup
|
||||
|
||||
Edit `docker/dynamic.yml` with your provider details:
|
||||
|
||||
**Google:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
logoutURL: "/oauth2/logout"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
**Azure AD:**
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
oidc-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/your-tenant-id/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key"
|
||||
callbackURL: "/oauth2/callback"
|
||||
scopes:
|
||||
- "openid"
|
||||
- "email"
|
||||
- "profile"
|
||||
```
|
||||
|
||||
#### 4. Start Environment
|
||||
|
||||
```bash
|
||||
cd docker
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
#### 5. Test Plugin
|
||||
|
||||
- **Protected App**: http://hello.localhost (redirects to OIDC)
|
||||
- **Traefik Dashboard**: http://traefik.localhost:8080
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Edit plugin code** in the project root
|
||||
2. **Build and test** (optional syntax check):
|
||||
```bash
|
||||
go mod tidy
|
||||
go build .
|
||||
go test ./...
|
||||
```
|
||||
3. **Restart Traefik** to reload plugin:
|
||||
```bash
|
||||
docker-compose restart traefik
|
||||
```
|
||||
4. **Test changes** at http://hello.localhost
|
||||
|
||||
### Debugging
|
||||
|
||||
**View plugin logs:**
|
||||
```bash
|
||||
docker-compose logs -f traefik | grep traefikoidc
|
||||
```
|
||||
|
||||
**Check plugin loading:**
|
||||
```bash
|
||||
docker-compose logs traefik | grep -i plugin
|
||||
```
|
||||
|
||||
**Verify plugin directory:**
|
||||
```bash
|
||||
docker-compose exec traefik ls -la /plugins-local/src/github.com/lukaszraczylo/traefikoidc/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# Fast development testing (< 30 seconds)
|
||||
go test ./... -short
|
||||
|
||||
# Standard tests with race detector
|
||||
go test -race -timeout=15m ./...
|
||||
|
||||
# With coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
```
|
||||
|
||||
### Test Modes
|
||||
|
||||
| Mode | Command | Duration | Use Case |
|
||||
|------|---------|----------|----------|
|
||||
| Quick | `go test ./... -short` | < 30s | During development |
|
||||
| Extended | `RUN_EXTENDED_TESTS=1 go test ./...` | 2-5 min | Before commits |
|
||||
| Long | `RUN_LONG_TESTS=1 go test ./...` | 5-15 min | Release validation |
|
||||
| Stress | `RUN_STRESS_TESTS=1 go test ./...` | 10-30 min | Performance testing |
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1
|
||||
export RUN_LONG_TESTS=1
|
||||
export RUN_STRESS_TESTS=1
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
|
||||
# Customize test parameters
|
||||
export TEST_MAX_CONCURRENCY=10
|
||||
export TEST_MAX_ITERATIONS=50
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Quick Tests (Default)
|
||||
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### Extended Tests
|
||||
|
||||
- Comprehensive testing before commits
|
||||
- More iterations (5-10)
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### Long Tests
|
||||
|
||||
- Performance validation
|
||||
- High iteration counts (50-100)
|
||||
- Large data sets
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### Stress Tests
|
||||
|
||||
- Maximum load testing
|
||||
- Edge case validation
|
||||
- Extreme parameters
|
||||
|
||||
**Configuration:**
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Timeout: 120 seconds
|
||||
|
||||
### Running Specific Test Suites
|
||||
|
||||
```bash
|
||||
# Memory leak tests
|
||||
go test -v -run='.*Leak.*' ./...
|
||||
|
||||
# Integration tests
|
||||
go test -v -run='.*Integration.*' ./...
|
||||
|
||||
# Regression tests
|
||||
go test -v -run='.*Regression.*' ./...
|
||||
|
||||
# Provider-specific tests
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
go test -v -run='.*Google.*' ./...
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## CI/CD Pipeline
|
||||
|
||||
The repository uses GitHub Actions for comprehensive validation with 20+ parallel checks.
|
||||
|
||||
### Triggered On
|
||||
|
||||
- Pull requests to `main` branch
|
||||
- Pushes to `main` branch
|
||||
|
||||
### Parallel Jobs
|
||||
|
||||
#### Code Quality (3 checks)
|
||||
- **Format & Basic Checks** - gofmt, go vet, go mod
|
||||
- **golangci-lint** - 30+ linters
|
||||
- **Staticcheck** - Advanced static analysis
|
||||
|
||||
#### Security (3 checks)
|
||||
- **Gosec** - Security vulnerability scanning
|
||||
- **Govulncheck** - Go vulnerability database
|
||||
- **CodeQL** - GitHub's semantic code analysis
|
||||
|
||||
#### Testing (9 suites)
|
||||
- Race Detector
|
||||
- Coverage (75% threshold)
|
||||
- Memory Leaks
|
||||
- Integration Tests
|
||||
- Regression Tests
|
||||
- Security Edge Cases
|
||||
- Session Tests
|
||||
- Token Tests
|
||||
- CSRF Tests
|
||||
|
||||
#### Provider Testing (9 providers)
|
||||
Tests run in parallel for:
|
||||
- Google
|
||||
- Azure AD
|
||||
- Auth0
|
||||
- Okta
|
||||
- Keycloak
|
||||
- AWS Cognito
|
||||
- GitLab
|
||||
- GitHub
|
||||
- Generic OIDC
|
||||
|
||||
#### Performance & Build (3 checks)
|
||||
- Benchmarks
|
||||
- Multi-platform Build (linux/darwin x amd64/arm64)
|
||||
- Go Version Compatibility (Go 1.23 & 1.24)
|
||||
|
||||
### Quality Gates
|
||||
|
||||
All PRs must pass:
|
||||
- All parallel checks
|
||||
- 75% test coverage minimum
|
||||
- Zero security vulnerabilities
|
||||
- No race conditions
|
||||
- No memory leaks
|
||||
- All providers tested
|
||||
- Builds on all platforms
|
||||
|
||||
---
|
||||
|
||||
## Code Quality
|
||||
|
||||
### Pre-Commit Checklist
|
||||
|
||||
```bash
|
||||
# Run before every commit
|
||||
gofmt -s -w . && \
|
||||
go mod tidy && \
|
||||
golangci-lint run && \
|
||||
go test -race -short ./... && \
|
||||
echo "Ready to commit!"
|
||||
```
|
||||
|
||||
### Local Validation
|
||||
|
||||
```bash
|
||||
# Format code
|
||||
gofmt -s -w .
|
||||
|
||||
# Run linter
|
||||
golangci-lint run
|
||||
|
||||
# Static analysis
|
||||
staticcheck ./...
|
||||
|
||||
# Security scan
|
||||
gosec ./...
|
||||
|
||||
# Vulnerability check
|
||||
govulncheck ./...
|
||||
|
||||
# Tests with race detector
|
||||
go test -race -timeout=15m -count=1 ./...
|
||||
|
||||
# Coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -func=coverage.out
|
||||
|
||||
# View coverage in browser
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Coverage Below Threshold:**
|
||||
```bash
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out # See uncovered lines
|
||||
```
|
||||
|
||||
**Race Condition Found:**
|
||||
```bash
|
||||
go test -race -v -run=TestName ./...
|
||||
```
|
||||
|
||||
**Linter Errors:**
|
||||
```bash
|
||||
golangci-lint run -v
|
||||
golangci-lint run --fix # Auto-fix some issues
|
||||
```
|
||||
|
||||
**Provider Test Fails:**
|
||||
```bash
|
||||
go test -v -run='.*Azure.*' ./...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
### Development Guidelines
|
||||
|
||||
1. **Memory Management**: Ensure all goroutines can be cancelled and resources are bounded
|
||||
2. **Testing**: Add tests for new features, including memory leak tests where appropriate
|
||||
3. **Race Conditions**: Run tests with `-race` flag to detect race conditions
|
||||
4. **Documentation**: Update README and configuration files for new options
|
||||
|
||||
### Pull Request Template
|
||||
|
||||
PRs should include:
|
||||
- Description of changes
|
||||
- Type of change (bug fix, feature, breaking change, etc.)
|
||||
- Related issues
|
||||
- Provider impact (which providers are affected)
|
||||
- Testing performed
|
||||
- Security considerations
|
||||
- Performance impact
|
||||
- Breaking changes (if any)
|
||||
|
||||
### Checklist
|
||||
|
||||
Before submitting:
|
||||
- [ ] Code follows project style
|
||||
- [ ] Self-review completed
|
||||
- [ ] Tests added for new functionality
|
||||
- [ ] All tests pass locally
|
||||
- [ ] Documentation updated
|
||||
- [ ] No new warnings generated
|
||||
|
||||
### Code Owners
|
||||
|
||||
The repository uses CODEOWNERS for automatic PR reviewer assignment based on file paths.
|
||||
|
||||
### Dependabot
|
||||
|
||||
Automated dependency updates run weekly (Mondays 9 AM) with security updates prioritized.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [golangci-lint Rules](.golangci.yml)
|
||||
- [PR Template](.github/PULL_REQUEST_TEMPLATE.md)
|
||||
- [Workflow Documentation](.github/workflows/README.md)
|
||||
- [GitHub Actions Documentation](https://docs.github.com/en/actions)
|
||||
@@ -0,0 +1,580 @@
|
||||
# OIDC Provider Configuration Guide
|
||||
|
||||
Configuration reference for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Provider Support Matrix](#provider-support-matrix)
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [Okta](#okta)
|
||||
- [Keycloak](#keycloak)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [GitLab](#gitlab)
|
||||
- [GitHub](#github)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
- [Automatic Scope Filtering](#automatic-scope-filtering)
|
||||
|
||||
---
|
||||
|
||||
## Provider Support Matrix
|
||||
|
||||
| Provider | OIDC Support | Refresh Tokens | Auto-Detection | ID Tokens |
|
||||
|----------|-------------|----------------|----------------|-----------|
|
||||
| Google | Full | Yes | `accounts.google.com` | Yes |
|
||||
| Azure AD | Full | Yes | `login.microsoftonline.com` | Yes |
|
||||
| Auth0 | Full | Yes | `*.auth0.com` | Yes |
|
||||
| Okta | Full | Yes | `*.okta.com` | Yes |
|
||||
| Keycloak | Full | Yes | `/auth/realms/` path | Yes |
|
||||
| AWS Cognito | Full | Yes | `cognito-idp.*.amazonaws.com` | Yes |
|
||||
| GitLab | Full | Yes | `gitlab.com` | Yes |
|
||||
| GitHub | OAuth 2.0 Only | No | `github.com` | No |
|
||||
| Generic | Full | Yes | Any OIDC endpoint | Yes |
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://accounts.google.com"
|
||||
clientID: "your-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
allowedUserDomains:
|
||||
- "your-gsuite-domain.com" # Optional: Workspace restriction
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
|
||||
- **Automatic offline access**: Middleware adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes unsupported `offline_access` scope
|
||||
- **Workspace domains**: Restrict to specific Google Workspace domains via `hd` claim
|
||||
|
||||
### Google Cloud Console Setup
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Navigate to APIs & Services > Credentials
|
||||
4. Create OAuth 2.0 Client ID (Web application)
|
||||
5. Add authorized redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
6. Configure OAuth consent screen (must be "Published" for production)
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# Single tenant
|
||||
providerURL: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# Multi-tenant
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- "App.Users"
|
||||
- "Admin.Group"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### With Application ID URI (API Access)
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientID: "your-azure-client-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
audience: "api://your-azure-client-id" # Application ID URI
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Users Without Email
|
||||
|
||||
```yaml
|
||||
userIdentifierClaim: sub # Options: sub, oid, upn, preferred_username
|
||||
allowedUsers:
|
||||
- "user-object-id-1"
|
||||
- "user-object-id-2"
|
||||
```
|
||||
|
||||
### Azure AD Setup
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to Azure Active Directory > App registrations
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Create client secret in Certificates & secrets
|
||||
6. Configure Token Configuration for group claims
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access
|
||||
postLogoutRedirectUri: "https://your-app.com"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### With Custom API Audience
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth0-api
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.auth0.com"
|
||||
clientID: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
audience: "https://api.your-domain.com" # API identifier
|
||||
strictAudienceValidation: true
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
roleClaimName: "https://your-app.com/roles" # Namespaced claim
|
||||
groupClaimName: "https://your-app.com/groups"
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
```
|
||||
|
||||
### Auth0 Action for Custom Claims
|
||||
|
||||
```javascript
|
||||
exports.onExecutePostLogin = async (event, api) => {
|
||||
const namespace = 'https://your-app.com/';
|
||||
if (event.authorization) {
|
||||
api.idToken.setCustomClaim(namespace + 'roles', event.authorization.roles);
|
||||
api.idToken.setCustomClaim('email', event.user.email);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
### Auth0 Setup
|
||||
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create Regular Web Application
|
||||
3. Configure Allowed Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
4. Configure Allowed Logout URLs: `https://your-domain.com/oauth2/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
6. Create API in APIs section for custom audiences
|
||||
|
||||
See [AUTH0_AUDIENCE_GUIDE.md](AUTH0_AUDIENCE_GUIDE.md) for detailed audience configuration.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
# Or with custom authorization server:
|
||||
providerURL: "https://your-domain.okta.com/oauth2/default"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-okta
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://your-domain.okta.com"
|
||||
clientID: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- "Everyone"
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Setup
|
||||
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect > Web Application
|
||||
4. Configure Sign-in redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
5. Configure Sign-out redirect URIs: `https://your-domain.com/oauth2/logout`
|
||||
6. Enable Authorization Code and Refresh Token grant types
|
||||
7. Configure Groups claim in authorization server
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://keycloak.your-domain.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-keycloak
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://keycloak.company.com/realms/your-realm"
|
||||
clientID: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- roles
|
||||
- groups
|
||||
- offline_access
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- editor
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
For private IP addresses (Docker networks, Kubernetes):
|
||||
|
||||
```yaml
|
||||
providerURL: "https://192.168.1.100:8443/realms/your-realm"
|
||||
allowPrivateIPAddresses: true # Required for private IPs
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select your realm
|
||||
3. Go to Clients > Create client
|
||||
4. Set Client Protocol: openid-connect
|
||||
5. Set Access Type: confidential
|
||||
6. Add Valid Redirect URIs: `https://your-domain.com/oauth2/callback`
|
||||
7. Generate client secret in Credentials tab
|
||||
8. Configure mappers to add claims to ID Token:
|
||||
- Email: User Property mapper with "Add to ID token" enabled
|
||||
- Roles: User Client Role mapper with "Add to ID token" enabled
|
||||
- Groups: Group Membership mapper with "Add to ID token" enabled
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-cognito
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientID: "your-cognito-client-id"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- aws.cognito.signin.user.admin
|
||||
allowedRolesAndGroups:
|
||||
- admin
|
||||
- users
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/oauth2/callback`
|
||||
- Sign out URLs: `https://your-domain.com/oauth2/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
5. Set up groups for role-based access
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerURL: "https://gitlab.com"
|
||||
|
||||
# Self-hosted
|
||||
providerURL: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-gitlab
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://gitlab.com"
|
||||
clientID: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
# Note: GitLab doesn't require offline_access scope
|
||||
# Refresh tokens are issued automatically with openid
|
||||
allowedRolesAndGroups:
|
||||
- developers
|
||||
- maintainers
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Setup
|
||||
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/oauth2/callback`
|
||||
5. Save and note Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
|
||||
```yaml
|
||||
providerURL: "https://github.com"
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oauth-github
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://github.com/login/oauth"
|
||||
clientID: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- user:email
|
||||
- read:user
|
||||
allowedUsers:
|
||||
- "github-username"
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Limitations
|
||||
|
||||
- **OAuth 2.0 only** - Not OpenID Connect
|
||||
- **No ID tokens** - Only access tokens for API calls
|
||||
- **No refresh tokens** - Users must re-authenticate on expiry
|
||||
- **No standard claims** - User info requires API calls
|
||||
|
||||
Use GitHub only for API access, not for user authentication with claims.
|
||||
|
||||
### GitHub Setup
|
||||
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/oauth2/callback`
|
||||
4. Note Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
For any OIDC-compliant provider not listed above.
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-generic
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: "https://oidc.your-provider.com"
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-32-char-encryption-key-here"
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Requirements
|
||||
|
||||
- Provider must expose `.well-known/openid-configuration` endpoint
|
||||
- Must support authorization code flow
|
||||
- ID tokens must contain required claims (email, sub, etc.)
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's declared capabilities.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Fetches provider's `.well-known/openid-configuration`
|
||||
2. Extracts `scopes_supported` field
|
||||
3. Filters requested scopes to only include supported ones
|
||||
4. Falls back to all requested scopes if provider doesn't declare supported scopes
|
||||
|
||||
### Example: Self-Hosted GitLab
|
||||
|
||||
Self-hosted GitLab may reject `offline_access` scope:
|
||||
|
||||
```yaml
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- email
|
||||
- offline_access # Will be automatically filtered out if unsupported
|
||||
```
|
||||
|
||||
The middleware will:
|
||||
1. Read GitLab's discovery document
|
||||
2. Detect `offline_access` is NOT in `scopes_supported`
|
||||
3. Filter it out automatically
|
||||
4. Authentication succeeds
|
||||
|
||||
### Logging
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes: [offline_access]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If a provider rejects scopes even after filtering:
|
||||
1. Check the provider's discovery document: `curl https://provider/.well-known/openid-configuration`
|
||||
2. Use `overrideScopes: true` with only supported scopes
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
@@ -1,955 +0,0 @@
|
||||
# Provider-Specific Configuration Guide
|
||||
|
||||
This guide covers the configuration requirements and best practices for each supported OIDC provider.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Google](#google)
|
||||
- [Microsoft Azure AD](#microsoft-azure-ad)
|
||||
- [Auth0](#auth0)
|
||||
- [GitHub](#github)
|
||||
- [GitLab](#gitlab)
|
||||
- [AWS Cognito](#aws-cognito)
|
||||
- [Keycloak](#keycloak)
|
||||
- [Okta](#okta)
|
||||
- [Generic OIDC](#generic-oidc)
|
||||
|
||||
---
|
||||
|
||||
## Google
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://accounts.google.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-google-client-id.apps.googleusercontent.com"
|
||||
clientSecret: "your-google-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Google-Specific Features
|
||||
- **Automatic offline access**: Google provider automatically adds `access_type=offline` and `prompt=consent`
|
||||
- **Scope filtering**: Automatically removes `offline_access` scope (not used by Google)
|
||||
- **Refresh token support**: Fully supported
|
||||
- **Domain restrictions**: Can restrict by Google Workspace domains
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedUserDomains: ["example.com", "company.org"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Google OAuth Console Setup
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create or select a project
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URIs: `https://your-domain.com/auth/callback`
|
||||
|
||||
---
|
||||
|
||||
## Microsoft Azure AD
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# For Azure AD (single tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/{tenant-id}/v2.0"
|
||||
|
||||
# For Azure AD (multi-tenant)
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-azure-application-id"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Azure-Specific Features
|
||||
- **Response mode**: Automatically adds `response_mode=query`
|
||||
- **Offline access**: Requires `offline_access` scope for refresh tokens
|
||||
- **Access token validation**: Supports both JWT and opaque access tokens
|
||||
- **Tenant isolation**: Can restrict to specific Azure AD tenants
|
||||
- **Application ID URI**: Supports custom audience for protected APIs
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["App.Users", "Admin.Group"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### Azure AD API Configuration (Application ID URI)
|
||||
|
||||
When exposing your application as an API with a custom Application ID URI, you need to specify the `audience` parameter. Azure AD includes the Application ID URI in the JWT `aud` claim.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
azure-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://login.microsoftonline.com/common/v2.0"
|
||||
clientId: "12345678-1234-1234-1234-123456789abc"
|
||||
clientSecret: "your-azure-client-secret"
|
||||
# Specify the Application ID URI as audience
|
||||
audience: "api://12345678-1234-1234-1234-123456789abc"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Application ID URI (typically `api://{app-id}`)
|
||||
- Find your Application ID URI in Azure Portal → App Registration → Expose an API → Application ID URI
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected
|
||||
- For ID token validation only (no API access), you can omit the `audience` parameter
|
||||
|
||||
### Azure App Registration Setup
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Navigate to "Azure Active Directory" > "App registrations"
|
||||
3. Create new registration
|
||||
4. Add redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Create client secret in "Certificates & secrets"
|
||||
6. Configure API permissions for required scopes
|
||||
|
||||
### Azure AD API Exposure Setup (for custom audiences)
|
||||
1. In your App Registration, go to "Expose an API"
|
||||
2. Set the Application ID URI (e.g., `api://12345678-1234-1234-1234-123456789abc`)
|
||||
3. Add any custom scopes your API exposes
|
||||
4. Update the middleware configuration to include the `audience` parameter with this URI
|
||||
|
||||
---
|
||||
|
||||
## Auth0
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.auth0.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-auth0-client-id"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Auth0-Specific Features
|
||||
- **Custom domains**: Supports Auth0 custom domains
|
||||
- **Rules and hooks**: Leverages Auth0's extensibility
|
||||
- **Social connections**: Works with Auth0's social identity providers
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
- **API audiences**: Supports custom audience for API access tokens
|
||||
|
||||
### Example Configuration (Basic)
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedUsers: ["user@example.com", "admin@company.com"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Auth0 API Configuration (Custom Audience)
|
||||
|
||||
When using Auth0 APIs with custom audience parameters, you need to specify the `audience` field. Auth0 includes the API identifier in the JWT `aud` claim instead of the `clientId`.
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
auth0-api-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.auth0.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-auth0-client-secret"
|
||||
# Specify the Auth0 API identifier as audience
|
||||
audience: "https://api.company.com"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
**Important**:
|
||||
- The `audience` parameter should match your Auth0 API identifier (not the client ID)
|
||||
- Find your API identifier in Auth0 Dashboard → APIs → Your API → Settings → Identifier
|
||||
- Without the `audience` parameter, access tokens with custom audiences will be rejected with "invalid audience" error
|
||||
- For ID token validation only (no APIs), you can omit the `audience` parameter
|
||||
|
||||
### Auth0 Application Setup
|
||||
1. Go to [Auth0 Dashboard](https://manage.auth0.com/)
|
||||
2. Create new application (Regular Web Application)
|
||||
3. Configure allowed callback URLs: `https://your-domain.com/auth/callback`
|
||||
4. Configure allowed logout URLs: `https://your-domain.com/auth/logout`
|
||||
5. Enable OIDC Conformant in Advanced Settings
|
||||
|
||||
### Auth0 API Setup (for custom audiences)
|
||||
1. Go to Auth0 Dashboard → APIs
|
||||
2. Create a new API or select existing API
|
||||
3. Note the "Identifier" field (e.g., `https://api.company.com`) - this is your `audience` value
|
||||
4. In API Settings → Machine to Machine Applications, authorize your application
|
||||
5. Configure API permissions/scopes as needed
|
||||
6. Use the API identifier as the `audience` parameter in your configuration
|
||||
|
||||
---
|
||||
|
||||
## GitHub
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://github.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-github-client-id"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["read:user", "user:email"]
|
||||
```
|
||||
|
||||
### GitHub-Specific Features
|
||||
- **Organization membership**: Can restrict by GitHub organization
|
||||
- **Team membership**: Can restrict by specific teams
|
||||
- **Limited OIDC**: GitHub has limited OIDC support
|
||||
- **Email verification**: Requires verified email addresses
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
github-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://github.com"
|
||||
clientId: "Iv1.abcdef123456"
|
||||
clientSecret: "your-github-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["read:user", "user:email"]
|
||||
allowedUsers: ["octocat", "github-user"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### GitHub OAuth App Setup
|
||||
1. Go to GitHub Settings > Developer settings > OAuth Apps
|
||||
2. Create new OAuth App
|
||||
3. Set Authorization callback URL: `https://your-domain.com/auth/callback`
|
||||
4. Note the Client ID and generate Client Secret
|
||||
|
||||
---
|
||||
|
||||
## GitLab
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
# GitLab.com
|
||||
providerUrl: "https://gitlab.com"
|
||||
|
||||
# Self-hosted GitLab
|
||||
providerUrl: "https://gitlab.your-company.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### GitLab-Specific Features
|
||||
- **Self-hosted support**: Works with self-hosted GitLab instances
|
||||
- **Group membership**: Can restrict by GitLab groups
|
||||
- **Project access**: Can validate project permissions
|
||||
- **Offline access**: Supports refresh tokens without requiring `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.com"
|
||||
clientId: "abcdef123456789"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
# Note: GitLab doesn't support the offline_access scope.
|
||||
# Refresh tokens are issued automatically for the openid scope.
|
||||
allowedRolesAndGroups: ["developers", "maintainers"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### GitLab Application Setup
|
||||
1. Go to GitLab Settings > Applications
|
||||
2. Create new application
|
||||
3. Add scopes: `openid`, `profile`, `email`
|
||||
4. Set redirect URI: `https://your-domain.com/auth/callback`
|
||||
5. Save and note the Application ID and Secret
|
||||
|
||||
---
|
||||
|
||||
## AWS Cognito
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://cognito-idp.{region}.amazonaws.com/{user-pool-id}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-cognito-app-client-id"
|
||||
clientSecret: "your-cognito-app-client-secret" # If app client has secret
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Cognito-Specific Features
|
||||
- **User pools**: Integrates with Cognito User Pools
|
||||
- **Custom attributes**: Supports custom user attributes
|
||||
- **Groups**: Can validate Cognito user group membership
|
||||
- **Regional endpoints**: Requires region-specific URLs
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
cognito-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://cognito-idp.us-east-1.amazonaws.com/us-east-1_ABCDEF123"
|
||||
clientId: "1234567890abcdefghij"
|
||||
clientSecret: "your-cognito-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
allowedRolesAndGroups: ["admin", "users"]
|
||||
forceHttps: true
|
||||
```
|
||||
|
||||
### AWS Cognito Setup
|
||||
1. Create Cognito User Pool
|
||||
2. Create App Client with OIDC scopes
|
||||
3. Configure App Client settings:
|
||||
- Callback URLs: `https://your-domain.com/auth/callback`
|
||||
- Sign out URLs: `https://your-domain.com/auth/logout`
|
||||
- OAuth flows: Authorization code grant
|
||||
4. Configure hosted UI domain (optional)
|
||||
|
||||
---
|
||||
|
||||
## Keycloak
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://keycloak.your-company.com/realms/{realm-name}"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-keycloak-client-id"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Keycloak-Specific Features
|
||||
- **Realm support**: Multi-realm deployments
|
||||
- **Custom mappers**: Rich claim mapping capabilities
|
||||
- **Role-based access**: Fine-grained role management
|
||||
- **Offline access**: Full refresh token support
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
keycloak-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://keycloak.company.com/realms/employees"
|
||||
clientId: "traefik-app"
|
||||
clientSecret: "your-keycloak-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["app-users", "administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Keycloak Client Setup
|
||||
1. Access Keycloak Admin Console
|
||||
2. Select appropriate realm
|
||||
3. Create new client:
|
||||
- Client Protocol: openid-connect
|
||||
- Access Type: confidential
|
||||
- Valid Redirect URIs: `https://your-domain.com/auth/callback`
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-domain.okta.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-okta-client-id"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
```
|
||||
|
||||
### Okta-Specific Features
|
||||
- **Custom authorization servers**: Supports custom auth servers
|
||||
- **Group claims**: Rich group membership information
|
||||
- **Universal Directory**: Integrates with Okta's user store
|
||||
- **Offline access**: Requires `offline_access` scope
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
okta-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://company.okta.com"
|
||||
clientId: "0oa123456789abcdef"
|
||||
clientSecret: "your-okta-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
postLogoutRedirectUri: "https://app.example.com"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
allowedRolesAndGroups: ["Everyone", "Administrators"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
### Okta Application Setup
|
||||
1. Access Okta Admin Console
|
||||
2. Go to Applications > Create App Integration
|
||||
3. Select OIDC - OpenID Connect
|
||||
4. Choose Web Application
|
||||
5. Configure:
|
||||
- Sign-in redirect URIs: `https://your-domain.com/auth/callback`
|
||||
- Sign-out redirect URIs: `https://your-domain.com/auth/logout`
|
||||
- Grant types: Authorization Code, Refresh Token
|
||||
6. Assign users or groups
|
||||
|
||||
---
|
||||
|
||||
## Generic OIDC
|
||||
|
||||
### Provider URL
|
||||
```yaml
|
||||
providerUrl: "https://your-oidc-provider.com"
|
||||
```
|
||||
|
||||
### Required Configuration
|
||||
```yaml
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
```
|
||||
|
||||
### Generic Features
|
||||
- **Standards compliance**: Works with any OIDC-compliant provider
|
||||
- **Auto-discovery**: Uses `.well-known/openid-configuration` endpoint
|
||||
- **Flexible scopes**: Supports custom scope requirements
|
||||
- **Custom claims**: Works with provider-specific claims
|
||||
|
||||
### Example Configuration
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
generic-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://oidc.your-provider.com"
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
logoutUrl: "https://app.example.com/auth/logout"
|
||||
scopes: ["openid", "profile", "email"]
|
||||
forceHttps: true
|
||||
enablePkce: true
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Automatic Scope Filtering
|
||||
|
||||
### Overview
|
||||
|
||||
The middleware automatically filters OAuth scopes based on the provider's capabilities declared in their OIDC discovery document (`.well-known/openid-configuration`). This prevents authentication failures when providers reject unsupported scopes.
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Discovery Document Parsing**: The middleware fetches the provider's discovery document and extracts the `scopes_supported` field
|
||||
2. **Intelligent Filtering**: Requested scopes are filtered to only include those the provider supports
|
||||
3. **Fallback Behavior**: If the provider doesn't declare `scopes_supported`, all requested scopes are used (backward compatible)
|
||||
4. **Provider-Specific Handling**: Special logic for Google and Azure is preserved and applied after filtering
|
||||
|
||||
### Example Scenarios
|
||||
|
||||
#### Self-Hosted GitLab
|
||||
|
||||
**Problem**: Self-hosted GitLab instances reject the `offline_access` scope with error:
|
||||
```
|
||||
The requested scope is invalid, unknown, or malformed.
|
||||
```
|
||||
|
||||
**Solution**: The middleware automatically detects this by:
|
||||
1. Reading GitLab's discovery document at `https://gitlab.example.com/.well-known/openid-configuration`
|
||||
2. Observing that `offline_access` is NOT in the `scopes_supported` list
|
||||
3. Filtering out `offline_access` from the request
|
||||
4. Authentication succeeds
|
||||
|
||||
**Configuration**:
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
gitlab-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
providerUrl: "https://gitlab.example.com"
|
||||
clientId: "your-gitlab-application-id"
|
||||
clientSecret: "your-gitlab-application-secret"
|
||||
callbackUrl: "https://app.example.com/auth/callback"
|
||||
scopes: ["openid", "profile", "email", "offline_access"]
|
||||
# Even though offline_access is listed, it will be automatically
|
||||
# filtered out if GitLab doesn't support it
|
||||
```
|
||||
|
||||
#### Auth0 or Keycloak
|
||||
|
||||
These providers typically support `offline_access` and it will be included:
|
||||
|
||||
```yaml
|
||||
# Auth0 scopes_supported: ["openid", "profile", "email", "offline_access", ...]
|
||||
# Result: All requested scopes are sent
|
||||
```
|
||||
|
||||
### Benefits
|
||||
|
||||
1. **Self-Hosted Support**: Works seamlessly with self-hosted provider instances
|
||||
2. **No Manual Configuration**: No need to know which scopes each provider supports
|
||||
3. **Error Prevention**: Eliminates "invalid scope" authentication failures
|
||||
4. **Standards Compliant**: Uses official OIDC discovery specification (RFC 8414)
|
||||
5. **Backward Compatible**: Existing configurations continue to work
|
||||
|
||||
### Logging
|
||||
|
||||
The middleware provides detailed logging for scope filtering:
|
||||
|
||||
```
|
||||
INFO: ScopeFilter: Filtered unsupported scopes for https://gitlab.example.com: [offline_access]
|
||||
DEBUG: ScopeFilter: Provider https://gitlab.example.com supported scopes: [openid profile email read_user read_api]
|
||||
DEBUG: ScopeFilter: Final filtered scopes: [openid profile email]
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
**Issue**: Provider rejects scope even after filtering
|
||||
|
||||
**Possible Causes**:
|
||||
1. Provider's discovery document is outdated
|
||||
2. Provider doesn't properly implement `scopes_supported`
|
||||
3. Custom authorization server with non-standard behavior
|
||||
|
||||
**Solutions**:
|
||||
1. Use `overrideScopes: true` and explicitly list only supported scopes
|
||||
2. Check the provider's discovery document manually: `curl https://your-provider/.well-known/openid-configuration`
|
||||
3. Review middleware debug logs for filtering decisions
|
||||
|
||||
---
|
||||
|
||||
## Common Configuration Options
|
||||
|
||||
### Audience Configuration
|
||||
|
||||
The `audience` parameter specifies the expected JWT audience claim value. This is particularly important when using Auth0 APIs, Azure AD Application ID URIs, or other providers with custom audience requirements.
|
||||
|
||||
```yaml
|
||||
# Optional: Custom audience for JWT validation
|
||||
# If not set, defaults to clientID for backward compatibility
|
||||
audience: "https://api.example.com" # Auth0 API identifier
|
||||
# OR
|
||||
audience: "api://12345-guid" # Azure AD Application ID URI
|
||||
```
|
||||
|
||||
**When to use**:
|
||||
- **Auth0**: When using Auth0 APIs with custom audience parameters
|
||||
- **Azure AD**: When exposing your app as an API with Application ID URI
|
||||
- **Keycloak**: When using audience-restricted tokens
|
||||
- **Okta**: When using custom authorization servers with API audiences
|
||||
|
||||
**When to omit**:
|
||||
- For standard ID token validation (default behavior)
|
||||
- When the provider sets `aud` claim to your `clientID`
|
||||
- For backward compatibility with existing configurations
|
||||
|
||||
**Security Note**: The `audience` parameter prevents token confusion attacks by ensuring tokens issued for one service cannot be used at another service.
|
||||
|
||||
### Security Settings
|
||||
```yaml
|
||||
# Force HTTPS (recommended for production)
|
||||
forceHttps: true
|
||||
|
||||
# Enable PKCE (recommended for security)
|
||||
enablePkce: true
|
||||
|
||||
# Session encryption key (32+ characters)
|
||||
sessionEncryptionKey: "your-very-long-encryption-key-here"
|
||||
```
|
||||
|
||||
### Access Control
|
||||
```yaml
|
||||
# Restrict by email addresses
|
||||
allowedUsers: ["user1@example.com", "user2@example.com"]
|
||||
|
||||
# Restrict by email domains
|
||||
allowedUserDomains: ["company.com", "partner.org"]
|
||||
|
||||
# Restrict by roles/groups (provider-specific)
|
||||
allowedRolesAndGroups: ["admin", "users", "developers"]
|
||||
```
|
||||
|
||||
### URLs and Endpoints
|
||||
```yaml
|
||||
# OAuth callback URL (must match provider config)
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
|
||||
# Logout endpoint
|
||||
logoutUrl: "https://your-domain.com/auth/logout"
|
||||
|
||||
# Post-logout redirect (optional)
|
||||
postLogoutRedirectUri: "https://your-domain.com"
|
||||
|
||||
# URLs to exclude from authentication
|
||||
excludedUrls: ["/health", "/metrics", "/public"]
|
||||
```
|
||||
|
||||
### Advanced Settings
|
||||
```yaml
|
||||
# Override default scopes
|
||||
overrideScopes: true
|
||||
scopes: ["openid", "custom_scope"]
|
||||
|
||||
# Rate limiting (requests per second)
|
||||
rateLimit: 10
|
||||
|
||||
# Token refresh grace period (seconds)
|
||||
refreshGracePeriodSeconds: 60
|
||||
|
||||
# Cookie domain (for subdomain sharing)
|
||||
cookieDomain: ".example.com"
|
||||
|
||||
# Custom headers to inject
|
||||
headers:
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
- name: "X-User-Name"
|
||||
value: "{{.Claims.name}}"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Invalid redirect URI**
|
||||
- Ensure callback URL exactly matches provider configuration
|
||||
- Check for HTTP vs HTTPS mismatches
|
||||
|
||||
2. **Scope errors**
|
||||
- Verify required scopes are configured in provider
|
||||
- Some providers require specific scopes for refresh tokens
|
||||
|
||||
3. **Token validation failures**
|
||||
- Check provider URL format and accessibility
|
||||
- Verify `.well-known/openid-configuration` endpoint is reachable
|
||||
|
||||
4. **Session issues**
|
||||
- Ensure session encryption key is properly configured
|
||||
- Check cookie domain settings for subdomain scenarios
|
||||
|
||||
### Debug Mode
|
||||
Enable debug logging to troubleshoot configuration issues:
|
||||
```yaml
|
||||
logLevel: "debug"
|
||||
```
|
||||
|
||||
This will provide detailed logs of the authentication flow and help identify configuration problems.
|
||||
|
||||
---
|
||||
|
||||
## Security Headers Configuration
|
||||
|
||||
The plugin includes comprehensive security headers support to protect your applications against common web vulnerabilities.
|
||||
|
||||
### Default Security Headers
|
||||
|
||||
By default, the plugin applies these security headers:
|
||||
|
||||
- `X-Frame-Options: DENY` - Prevents clickjacking
|
||||
- `X-Content-Type-Options: nosniff` - Prevents MIME sniffing
|
||||
- `X-XSS-Protection: 1; mode=block` - Enables XSS protection
|
||||
- `Referrer-Policy: strict-origin-when-cross-origin` - Controls referrer information
|
||||
- `Strict-Transport-Security` - Forces HTTPS (when HTTPS is detected)
|
||||
|
||||
### Security Profiles
|
||||
|
||||
Choose from predefined security profiles or create custom configurations:
|
||||
|
||||
#### Default Profile (Recommended)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "default"
|
||||
```
|
||||
|
||||
#### Strict Profile (Maximum Security)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
# Additional strict CSP and cross-origin policies
|
||||
```
|
||||
|
||||
#### Development Profile (Local Development)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "development"
|
||||
# Relaxed policies for local development
|
||||
```
|
||||
|
||||
#### API Profile (API Endpoints)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "api"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://your-frontend.com"]
|
||||
```
|
||||
|
||||
### Custom Security Configuration
|
||||
|
||||
For complete control, use the custom profile:
|
||||
|
||||
```yaml
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "custom"
|
||||
|
||||
# Content Security Policy
|
||||
contentSecurityPolicy: "default-src 'self'; script-src 'self' 'unsafe-inline'"
|
||||
|
||||
# HSTS Configuration
|
||||
strictTransportSecurity: true
|
||||
strictTransportSecurityMaxAge: 31536000 # 1 year
|
||||
strictTransportSecuritySubdomains: true
|
||||
strictTransportSecurityPreload: true
|
||||
|
||||
# Frame and content protection
|
||||
frameOptions: "DENY" # or "SAMEORIGIN", "ALLOW-FROM uri"
|
||||
contentTypeOptions: "nosniff"
|
||||
xssProtection: "1; mode=block"
|
||||
referrerPolicy: "strict-origin-when-cross-origin"
|
||||
|
||||
# Permissions policy (feature policy)
|
||||
permissionsPolicy: "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Cross-origin policies
|
||||
crossOriginEmbedderPolicy: "require-corp"
|
||||
crossOriginOpenerPolicy: "same-origin"
|
||||
crossOriginResourcePolicy: "same-origin"
|
||||
|
||||
# CORS configuration
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://app.example.com"
|
||||
- "https://*.api.example.com"
|
||||
corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
corsAllowedHeaders: ["Authorization", "Content-Type", "X-Requested-With"]
|
||||
corsAllowCredentials: true
|
||||
corsMaxAge: 86400 # 24 hours
|
||||
|
||||
# Custom headers
|
||||
customHeaders:
|
||||
X-Custom-Header: "custom-value"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Server identification
|
||||
disableServerHeader: true
|
||||
disablePoweredByHeader: true
|
||||
```
|
||||
|
||||
### Complete Example with Security Headers
|
||||
|
||||
Here's a complete configuration example for Google OIDC with custom security headers:
|
||||
|
||||
```yaml
|
||||
# Traefik dynamic configuration
|
||||
http:
|
||||
middlewares:
|
||||
secure-google-oidc:
|
||||
plugin:
|
||||
traefik-oidc:
|
||||
# OIDC Configuration
|
||||
providerUrl: "https://accounts.google.com"
|
||||
clientId: "123456789-abcdef.apps.googleusercontent.com"
|
||||
clientSecret: "GOCSPX-your-client-secret"
|
||||
callbackUrl: "https://your-domain.com/auth/callback"
|
||||
sessionEncryptionKey: "your-32-character-encryption-key-here"
|
||||
|
||||
# Domain restrictions
|
||||
allowedUserDomains: ["your-company.com"]
|
||||
|
||||
# Security Headers
|
||||
securityHeaders:
|
||||
enabled: true
|
||||
profile: "strict"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "https://your-frontend.com"
|
||||
- "https://*.your-domain.com"
|
||||
corsAllowCredentials: true
|
||||
customHeaders:
|
||||
X-Company: "YourCompany"
|
||||
X-Environment: "production"
|
||||
|
||||
routers:
|
||||
secure-app:
|
||||
rule: "Host(`your-domain.com`)"
|
||||
middlewares:
|
||||
- secure-google-oidc
|
||||
service: your-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
```
|
||||
|
||||
### CORS Configuration Details
|
||||
|
||||
For applications with frontend-backend separation, configure CORS properly:
|
||||
|
||||
#### Simple CORS (Single Origin)
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://app.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Wildcard Subdomains
|
||||
```yaml
|
||||
securityHeaders:
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins: ["https://*.example.com"]
|
||||
corsAllowCredentials: true
|
||||
```
|
||||
|
||||
#### Development with Multiple Ports
|
||||
```yaml
|
||||
securityHeaders:
|
||||
profile: "development"
|
||||
corsEnabled: true
|
||||
corsAllowedOrigins:
|
||||
- "http://localhost:*"
|
||||
- "http://127.0.0.1:*"
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
- Set `forceHttps: true`
|
||||
- Configure proper TLS certificates
|
||||
|
||||
2. **Implement proper CSP**
|
||||
- Start with strict policy
|
||||
- Add exceptions only when necessary
|
||||
- Test thoroughly
|
||||
|
||||
3. **Configure CORS restrictively**
|
||||
- Only allow necessary origins
|
||||
- Use specific domains instead of wildcards when possible
|
||||
|
||||
4. **Enable HSTS**
|
||||
- Use long max-age values (1 year minimum)
|
||||
- Include subdomains when appropriate
|
||||
|
||||
5. **Monitor security headers**
|
||||
- Use browser developer tools to verify headers
|
||||
- Test with security scanning tools
|
||||
- Regularly review and update policies
|
||||
|
||||
### Testing Security Headers
|
||||
|
||||
Use browser developer tools or online tools to verify your security headers:
|
||||
|
||||
1. **Browser DevTools**: Check Network tab → Response Headers
|
||||
2. **Online scanners**: Use securityheaders.com or observatory.mozilla.org
|
||||
3. **Command line**: Use `curl -I https://your-domain.com`
|
||||
|
||||
Example verification:
|
||||
```bash
|
||||
curl -I https://your-domain.com
|
||||
# Should show security headers in response
|
||||
```
|
||||
+546
@@ -0,0 +1,546 @@
|
||||
# Redis Cache for Distributed Deployments
|
||||
|
||||
Redis cache support for multi-replica Traefik deployments with shared state.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Why Use Redis Cache?](#why-use-redis-cache)
|
||||
- [Configuration](#configuration)
|
||||
- [Cache Modes](#cache-modes)
|
||||
- [Deployment Examples](#deployment-examples)
|
||||
- [Performance Tuning](#performance-tuning)
|
||||
- [Monitoring](#monitoring)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
- [Migration Guide](#migration-guide)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
The Redis cache feature provides distributed caching for the Traefik OIDC plugin, enabling seamless operation across multiple Traefik instances.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Distributed JTI Replay Detection**: Prevents token replay attacks across all instances
|
||||
- **Shared Session Management**: Consistent user sessions across replicas
|
||||
- **Circuit Breaker**: Automatic fallback to memory cache during Redis outages
|
||||
- **Health Checking**: Continuous monitoring of Redis connectivity
|
||||
- **Flexible Cache Modes**: Memory, Redis, or hybrid caching strategies
|
||||
- **Pure-Go Implementation**: Yaegi-compatible, works with dynamic plugin loading
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌──────────────┐ ┌──────────────┐
|
||||
│ Traefik #1 │ │ Traefik #2 │ │ Traefik #3 │
|
||||
│ (Plugin) │ │ (Plugin) │ │ (Plugin) │
|
||||
└──────┬───────┘ └──────┬───────┘ └──────┬───────┘
|
||||
│ │ │
|
||||
└────────────────────┼────────────────────┘
|
||||
│
|
||||
┌──────▼──────┐
|
||||
│ Redis │
|
||||
│ (Shared │
|
||||
│ Cache) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Why Use Redis Cache?
|
||||
|
||||
### The Problem
|
||||
|
||||
When running multiple Traefik instances without shared cache:
|
||||
|
||||
1. **False Positive Replay Detection**
|
||||
- User authenticates → Token stored in Instance A's JTI cache
|
||||
- Next request → Load balancer routes to Instance B
|
||||
- Instance B doesn't have the JTI → Falsely detects replay attack
|
||||
|
||||
2. **Session Inconsistency**
|
||||
- User session created on Instance A
|
||||
- Subsequent request routed to Instance B
|
||||
- Instance B has no knowledge of the session
|
||||
|
||||
3. **Token Metadata Fragmentation**
|
||||
- Token refresh happens on Instance A
|
||||
- Other instances continue using old tokens
|
||||
|
||||
### The Solution
|
||||
|
||||
Redis provides centralized cache that all instances share, ensuring:
|
||||
|
||||
- **Consistent Authentication**: All instances share authentication state
|
||||
- **True Replay Detection**: JTI cache shared across all instances
|
||||
- **Seamless Scaling**: Add/remove instances without affecting sessions
|
||||
- **High Availability**: Circuit breaker with automatic fallback
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### All Configuration Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Enable Redis caching |
|
||||
| `address` | string | - | Redis server address (`host:port`) |
|
||||
| `password` | string | - | Redis password (optional) |
|
||||
| `db` | int | `0` | Redis database number (0-15) |
|
||||
| `keyPrefix` | string | `traefikoidc:` | Prefix for all Redis keys |
|
||||
| `cacheMode` | string | `redis` | Cache mode: `memory`, `redis`, `hybrid` |
|
||||
| `poolSize` | int | `10` | Connection pool size |
|
||||
| `connectTimeout` | int | `5` | Connection timeout (seconds) |
|
||||
| `readTimeout` | int | `3` | Read timeout (seconds) |
|
||||
| `writeTimeout` | int | `3` | Write timeout (seconds) |
|
||||
| `enableTLS` | bool | `false` | Enable TLS for connections |
|
||||
| `tlsSkipVerify` | bool | `false` | Skip TLS certificate verification |
|
||||
| `enableCircuitBreaker` | bool | `true` | Enable circuit breaker |
|
||||
| `circuitBreakerThreshold` | int | `5` | Failures before circuit opens |
|
||||
| `circuitBreakerTimeout` | int | `60` | Circuit reset timeout (seconds) |
|
||||
| `enableHealthCheck` | bool | `true` | Enable periodic health checks |
|
||||
| `healthCheckInterval` | int | `30` | Health check interval (seconds) |
|
||||
| `hybridL1Size` | int | `500` | Max items in L1 cache (hybrid mode) |
|
||||
| `hybridL1MemoryMB` | int64 | `10` | Max memory for L1 cache in MB |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables are used:
|
||||
|
||||
```bash
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=your-password
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=10
|
||||
REDIS_CONNECT_TIMEOUT=5
|
||||
REDIS_READ_TIMEOUT=3
|
||||
REDIS_WRITE_TIMEOUT=3
|
||||
REDIS_ENABLE_TLS=false
|
||||
REDIS_TLS_SKIP_VERIFY=false
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Modes
|
||||
|
||||
### Memory Mode (Default without Redis)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "memory"
|
||||
```
|
||||
|
||||
- Uses only in-memory cache
|
||||
- Suitable for single-instance deployments
|
||||
- No Redis dependency
|
||||
- Fastest performance
|
||||
|
||||
### Redis Mode
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "redis"
|
||||
```
|
||||
|
||||
- All operations go directly to Redis
|
||||
- Ensures consistency across replicas
|
||||
- Slightly higher latency
|
||||
|
||||
### Hybrid Mode (Recommended)
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
Two-tier caching strategy:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────┐
|
||||
│ Client Request │
|
||||
└────────────────┬────────────────────────┘
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Local Cache │ ← L1 Cache (Fast)
|
||||
│ (Memory) │
|
||||
└────────┬───────┘
|
||||
│ Miss
|
||||
▼
|
||||
┌────────────────┐
|
||||
│ Remote Cache │ ← L2 Cache (Shared)
|
||||
│ (Redis) │
|
||||
└────────────────┘
|
||||
```
|
||||
|
||||
**Read Path:**
|
||||
1. Check local memory cache (L1)
|
||||
2. On miss, check Redis (L2)
|
||||
3. On hit in Redis, populate L1
|
||||
4. Return value
|
||||
|
||||
**Write Path:**
|
||||
1. Write to Redis (L2) for durability
|
||||
2. Write to local cache (L1) for speed
|
||||
|
||||
### Performance Comparison
|
||||
|
||||
| Operation | Memory Mode | Redis Mode | Hybrid Mode |
|
||||
|-----------|------------|------------|-------------|
|
||||
| Read (p50) | 0.1ms | 2ms | 0.2ms |
|
||||
| Read (p99) | 0.5ms | 10ms | 5ms |
|
||||
| Write (p50) | 0.2ms | 3ms | 3ms |
|
||||
| Throughput | 100k/s | 20k/s | 80k/s |
|
||||
|
||||
---
|
||||
|
||||
## Deployment Examples
|
||||
|
||||
### Docker Compose
|
||||
|
||||
```yaml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass ${REDIS_PASSWORD}
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 30s
|
||||
timeout: 3s
|
||||
retries: 3
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
deploy:
|
||||
replicas: 3
|
||||
labels:
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=${REDIS_PASSWORD}"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
```
|
||||
|
||||
### Kubernetes
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
```
|
||||
|
||||
### AWS ElastiCache
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "your-cache.abc123.cache.amazonaws.com:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableTLS: true
|
||||
password: "your-elasticache-auth-token"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Connection Pool Sizing
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
poolSize: 20 # Formula: 2 * CPU cores * replicas
|
||||
# For 4 cores, 3 replicas: poolSize = 24
|
||||
```
|
||||
|
||||
### TTL Strategy
|
||||
|
||||
The plugin automatically sets TTLs based on token lifetimes:
|
||||
|
||||
- **JTI Cache**: Matches token lifetime (typically 1 hour)
|
||||
- **Session**: Matches `sessionMaxAge` configuration
|
||||
- **Token Metadata**: 5 minutes (short-lived)
|
||||
|
||||
### Redis Server Configuration
|
||||
|
||||
```bash
|
||||
# Recommended Redis settings for cache
|
||||
maxmemory 512mb
|
||||
maxmemory-policy allkeys-lru # Evict least recently used
|
||||
|
||||
# For cache data, disable persistence for better performance
|
||||
save ""
|
||||
appendonly no
|
||||
```
|
||||
|
||||
### Hybrid Mode Tuning
|
||||
|
||||
```yaml
|
||||
redis:
|
||||
cacheMode: "hybrid"
|
||||
hybridL1Size: 500 # Max items in local cache
|
||||
hybridL1MemoryMB: 10 # Max memory for local cache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Monitoring
|
||||
|
||||
### Key Metrics
|
||||
|
||||
- **Cache hit rate** (target: >90% for hybrid mode)
|
||||
- **Redis latency** (target: <10ms p99)
|
||||
- **Circuit breaker state**
|
||||
- **Connection pool utilization
|
||||
|
||||
### Redis Commands for Monitoring
|
||||
|
||||
```bash
|
||||
# Monitor commands in real-time
|
||||
redis-cli MONITOR
|
||||
|
||||
# Check slow queries
|
||||
redis-cli SLOWLOG GET 10
|
||||
|
||||
# Memory usage
|
||||
redis-cli INFO memory
|
||||
|
||||
# Key statistics
|
||||
redis-cli DBSIZE
|
||||
|
||||
# List keys with prefix
|
||||
redis-cli --scan --pattern "traefikoidc:*"
|
||||
|
||||
# Check key TTL
|
||||
redis-cli TTL "traefikoidc:session:abc123"
|
||||
```
|
||||
|
||||
### Health Check Endpoint
|
||||
|
||||
The plugin provides health information including:
|
||||
|
||||
```json
|
||||
{
|
||||
"status": "healthy",
|
||||
"cache": {
|
||||
"mode": "hybrid",
|
||||
"redis": {
|
||||
"connected": true,
|
||||
"latency": "2ms"
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": "closed",
|
||||
"failures": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Symptoms:** `dial tcp: connection refused`
|
||||
|
||||
**Solutions:**
|
||||
1. Verify Redis is running: `redis-cli ping`
|
||||
2. Check network connectivity: `telnet redis-host 6379`
|
||||
3. Verify address configuration
|
||||
|
||||
### Authentication Failure
|
||||
|
||||
**Symptoms:** `NOAUTH Authentication required`
|
||||
|
||||
**Solutions:**
|
||||
1. Set Redis password in configuration
|
||||
2. Verify password is correct
|
||||
|
||||
### Circuit Breaker Open
|
||||
|
||||
**Symptoms:** `Circuit breaker is open`, falling back to memory
|
||||
|
||||
**Solutions:**
|
||||
1. Check Redis health: `redis-cli INFO server`
|
||||
2. Review network latency: `redis-cli --latency`
|
||||
3. Adjust circuit breaker thresholds if needed
|
||||
|
||||
### High Memory Usage
|
||||
|
||||
**Symptoms:** Redis memory constantly growing, OOM errors
|
||||
|
||||
**Solutions:**
|
||||
1. Configure eviction policy:
|
||||
```bash
|
||||
CONFIG SET maxmemory 512mb
|
||||
CONFIG SET maxmemory-policy allkeys-lru
|
||||
```
|
||||
2. Review key count: `redis-cli DBSIZE`
|
||||
3. Check for large keys: `redis-cli --bigkeys`
|
||||
|
||||
### Inconsistent Cache State
|
||||
|
||||
**Symptoms:** Different responses from different replicas
|
||||
|
||||
**Solutions:**
|
||||
1. Verify all instances use the same Redis address
|
||||
2. Check cache mode consistency across instances
|
||||
3. Verify time synchronization on all hosts
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Memory-Only to Redis
|
||||
|
||||
#### Phase 1: Preparation
|
||||
|
||||
1. Deploy Redis infrastructure
|
||||
2. Test Redis connectivity
|
||||
3. Configure monitoring
|
||||
|
||||
#### Phase 2: Gradual Rollout
|
||||
|
||||
1. Enable Redis on one instance:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
2. Monitor for errors
|
||||
3. Gradually enable on more instances
|
||||
|
||||
#### Phase 3: Full Migration
|
||||
|
||||
1. Enable Redis on all instances
|
||||
2. Remove `disableReplayDetection: true` if set
|
||||
3. Monitor for issues
|
||||
|
||||
### Rollback Plan
|
||||
|
||||
If issues occur:
|
||||
1. Set `redis.enabled: false`
|
||||
2. Plugin falls back to memory cache automatically
|
||||
3. Investigate and resolve issues
|
||||
|
||||
### Migration Checklist
|
||||
|
||||
- [ ] Redis deployed and accessible
|
||||
- [ ] Redis password configured
|
||||
- [ ] Network connectivity verified
|
||||
- [ ] Monitoring configured
|
||||
- [ ] Backup plan prepared
|
||||
- [ ] Test environment validated
|
||||
- [ ] Gradual rollout planned
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Security
|
||||
|
||||
- Always use Redis password authentication
|
||||
- Enable TLS for production deployments
|
||||
- Use network segmentation (private subnets)
|
||||
- Rotate Redis passwords regularly
|
||||
|
||||
### High Availability
|
||||
|
||||
- Use Redis Sentinel or Cluster for HA
|
||||
- Configure appropriate circuit breaker thresholds
|
||||
- Implement proper health checks
|
||||
- Use connection pooling
|
||||
|
||||
### Performance
|
||||
|
||||
- Use hybrid cache mode for best performance
|
||||
- Monitor cache hit rates
|
||||
- Size Redis memory appropriately
|
||||
- Disable persistence for cache-only usage
|
||||
|
||||
### Operations
|
||||
|
||||
- Implement comprehensive monitoring
|
||||
- Set up alerting for circuit breaker state
|
||||
- Document Redis configuration
|
||||
- Test failover scenarios
|
||||
|
||||
---
|
||||
|
||||
## FAQ
|
||||
|
||||
### Is Redis required?
|
||||
|
||||
No, Redis is optional. The plugin works with in-memory cache for single-instance deployments.
|
||||
|
||||
### What happens if Redis goes down?
|
||||
|
||||
The circuit breaker opens after threshold failures, and the plugin falls back to in-memory cache. It periodically attempts to reconnect.
|
||||
|
||||
### Which cache mode should I use?
|
||||
|
||||
For production multi-replica deployments, use `hybrid` mode for best performance and consistency.
|
||||
|
||||
### How much memory does Redis need?
|
||||
|
||||
Depends on active sessions and token sizes:
|
||||
- Small (1-1000 users): 128MB
|
||||
- Medium (1000-10000 users): 256-512MB
|
||||
- Large (10000+ users): 1GB+
|
||||
|
||||
### Can I use managed Redis services?
|
||||
|
||||
Yes, the plugin works with AWS ElastiCache, Azure Cache for Redis, Google Cloud Memorystore, and Redis Enterprise Cloud.
|
||||
|
||||
### Is data encrypted in Redis?
|
||||
|
||||
Session data is encrypted before storing using `sessionEncryptionKey`. Additionally, you can enable TLS for Redis connections.
|
||||
-1125
File diff suppressed because it is too large
Load Diff
@@ -1,413 +0,0 @@
|
||||
# Redis Cache Backend Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Directory Organization
|
||||
|
||||
```
|
||||
internal/cache/
|
||||
├── backend/
|
||||
│ ├── interface.go # CacheBackend interface definition
|
||||
│ ├── interface_test.go # Contract tests for all backends
|
||||
│ ├── memory.go # In-memory backend implementation
|
||||
│ ├── memory_test.go # Memory backend unit tests
|
||||
│ ├── redis.go # Redis backend implementation
|
||||
│ ├── redis_test.go # Redis backend unit tests
|
||||
│ ├── errors.go # Error definitions
|
||||
│ └── test_helpers_test.go # Test infrastructure and helpers
|
||||
│
|
||||
└── resilience/
|
||||
├── circuit_breaker.go # Circuit breaker implementation
|
||||
├── circuit_breaker_test.go # Circuit breaker tests
|
||||
├── health_check.go # Health checker implementation
|
||||
└── health_check_test.go # Health check tests
|
||||
|
||||
redis_integration_test.go # End-to-end integration tests
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Interface Contract Tests (`interface_test.go`)
|
||||
|
||||
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
|
||||
|
||||
**Test Cases:**
|
||||
- `TestCacheBackendContract` - Runs all contract tests against each backend type
|
||||
- `testBasicSetGet` - Verifies basic set/get operations
|
||||
- `testGetNonExistent` - Tests behavior for non-existent keys
|
||||
- `testUpdateExisting` - Validates updating existing keys
|
||||
- `testDelete` - Tests delete operations
|
||||
- `testDeleteNonExistent` - Delete non-existent keys
|
||||
- `testExists` - Key existence checking
|
||||
- `testTTLExpiration` - TTL and expiration behavior
|
||||
- `testClear` - Clear all keys operation
|
||||
- `testPing` - Health check functionality
|
||||
- `testStats` - Statistics tracking
|
||||
- `testConcurrentAccess` - Thread safety with 10+ goroutines
|
||||
- `testLargeValues` - Handling of 1MB+ values
|
||||
- `testEmptyValues` - Empty byte array handling
|
||||
- `testSpecialCharactersInKeys` - Special characters in key names
|
||||
|
||||
**Coverage:** ~95% of interface methods
|
||||
|
||||
### 2. Memory Backend Tests (`memory_test.go`)
|
||||
|
||||
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (6 tests)
|
||||
- `TestMemoryBackend_BasicOperations` - CRUD operations
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- DeleteNonExistent
|
||||
- Exists
|
||||
- Clear
|
||||
|
||||
#### TTL and Expiration (3 tests)
|
||||
- `TestMemoryBackend_TTLExpiration`
|
||||
- ShortTTL (100ms)
|
||||
- TTLDecrement over time
|
||||
- CleanupExpiredItems
|
||||
|
||||
#### LRU Eviction (2 tests)
|
||||
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
|
||||
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
|
||||
|
||||
#### Edge Cases (6 tests)
|
||||
- `TestMemoryBackend_UpdateExisting` - Overwriting values
|
||||
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
|
||||
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
|
||||
- `TestMemoryBackend_LargeValues` - 1MB values
|
||||
- `TestMemoryBackend_Close` - Proper cleanup
|
||||
- `TestMemoryBackend_Ping` - Health checks
|
||||
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
|
||||
|
||||
**Coverage:** ~92% of memory backend code
|
||||
|
||||
### 3. Redis Backend Tests (`redis_test.go`)
|
||||
|
||||
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (4 tests)
|
||||
- `TestRedisBackend_BasicOperations`
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- Exists
|
||||
|
||||
#### Redis-Specific Features (6 tests)
|
||||
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
|
||||
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
|
||||
- `TestRedisBackend_Clear` - Bulk delete with SCAN
|
||||
- `TestRedisBackend_NoPrefix` - Operation without prefix
|
||||
|
||||
#### Error Handling (2 tests)
|
||||
- `TestRedisBackend_ConnectionFailure` - Connection errors
|
||||
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
|
||||
|
||||
#### Advanced Features (3 tests)
|
||||
- `TestRedisBackend_PipelineOperations`
|
||||
- SetMany (batch writes)
|
||||
- GetMany (batch reads)
|
||||
- GetManyWithNonExistent
|
||||
|
||||
#### Edge Cases (5 tests)
|
||||
- `TestRedisBackend_Stats` - Statistics tracking
|
||||
- `TestRedisBackend_Ping` - Connection health
|
||||
- `TestRedisBackend_Close` - Resource cleanup
|
||||
- `TestRedisBackend_UpdateExisting` - Overwrite handling
|
||||
- `TestRedisBackend_LargeValues` - 1MB values
|
||||
- `TestRedisBackend_EmptyValues` - Empty arrays
|
||||
|
||||
**Coverage:** ~88% of Redis backend code
|
||||
|
||||
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
|
||||
- All basic Redis commands
|
||||
- TTL and expiration
|
||||
- Time manipulation (FastForward)
|
||||
- Error simulation
|
||||
- No external Redis server required
|
||||
|
||||
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
|
||||
|
||||
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### State Transitions (5 tests)
|
||||
- `TestCircuitBreaker_StateTransitions`
|
||||
- Initial state (Closed)
|
||||
- Closed → Open (after max failures)
|
||||
- Open → HalfOpen (after timeout)
|
||||
- HalfOpen → Closed (after successful requests)
|
||||
- HalfOpen → Open (on failure)
|
||||
|
||||
#### Behavior Tests (5 tests)
|
||||
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
|
||||
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
|
||||
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
|
||||
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
|
||||
- `TestCircuitBreaker_Stats` - Statistics tracking
|
||||
|
||||
#### Advanced Tests (7 tests)
|
||||
- `TestCircuitBreaker_Reset` - Manual reset
|
||||
- `TestCircuitBreaker_StateChangeCallback` - Notifications
|
||||
- `TestCircuitBreaker_IsAvailable` - Availability check
|
||||
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
|
||||
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
|
||||
- `TestCircuitBreaker_DefaultConfig` - Default configuration
|
||||
- `TestCircuitBreaker_StateString` - String representation
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkCircuitBreaker_Execute` - Successful operations
|
||||
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
|
||||
|
||||
**Coverage:** ~95% of circuit breaker code
|
||||
|
||||
### 5. Health Check Tests (`health_check_test.go`)
|
||||
|
||||
**Purpose:** Validate periodic health checking and status management.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Status Transitions (4 tests)
|
||||
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
|
||||
- `TestHealthChecker_InitialState` - Default healthy state
|
||||
- `TestHealthChecker_ForceCheck` - Manual health check trigger
|
||||
- `TestHealthChecker_StatusChangeCallback` - Change notifications
|
||||
|
||||
#### Behavior Tests (6 tests)
|
||||
- `TestHealthChecker_Stats` - Statistics tracking
|
||||
- `TestHealthChecker_Timeout` - Check timeout handling
|
||||
- `TestHealthChecker_ConcurrentAccess` - Thread safety
|
||||
- `TestHealthChecker_StopAndStart` - Lifecycle management
|
||||
- `TestHealthChecker_DegradedState` - Degraded status detection
|
||||
- `TestHealthChecker_DefaultConfig` - Default settings
|
||||
|
||||
#### Advanced Tests (2 tests)
|
||||
- `TestHealthChecker_StatusString` - String representation
|
||||
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkHealthChecker_ForceCheck` - Check performance
|
||||
- `BenchmarkHealthChecker_Status` - Status read performance
|
||||
|
||||
**Coverage:** ~90% of health checker code
|
||||
|
||||
### 6. Integration Tests (`redis_integration_test.go`)
|
||||
|
||||
**Purpose:** End-to-end testing of real-world scenarios.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Multi-Instance Tests (3 tests)
|
||||
- `TestRedisIntegration_MultipleInstances`
|
||||
- ShareTokenBlacklist - JTI sharing across Traefik replicas
|
||||
- ShareTokenCache - Token cache sharing
|
||||
- ShareMetadataCache - Provider metadata sharing
|
||||
|
||||
#### Replay Detection (2 tests)
|
||||
- `TestRedisIntegration_JTIReplayDetection`
|
||||
- PreventReplayAcrossInstances - Block used JTIs
|
||||
- ConcurrentJTIChecks - Race condition handling
|
||||
|
||||
#### Resilience (1 test)
|
||||
- `TestRedisIntegration_Failover`
|
||||
- RedisTemporaryFailure - Recovery from temporary failures
|
||||
|
||||
#### Performance (1 test)
|
||||
- `TestRedisIntegration_HighLoad`
|
||||
- HighConcurrency - 50 goroutines × 100 operations
|
||||
|
||||
#### Consistency (2 tests)
|
||||
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
|
||||
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
|
||||
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
|
||||
|
||||
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
|
||||
|
||||
## Test Helpers and Infrastructure
|
||||
|
||||
### Test Helpers (`test_helpers_test.go`)
|
||||
|
||||
**Utilities:**
|
||||
- `TestLogger` - Logging for tests
|
||||
- `MiniredisServer` - Miniredis setup/teardown
|
||||
- `TestConfig` - Default test configurations
|
||||
- `GenerateTestData` - Test data generation
|
||||
- `GenerateLargeValue` - Large value creation
|
||||
- `AssertCacheStats` - Statistics validation
|
||||
- `WaitForCondition` - Async condition waiting
|
||||
- `AssertEventuallyExpires` - TTL expiration verification
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -v
|
||||
go test ./internal/cache/resilience/... -v
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Memory backend only
|
||||
go test ./internal/cache/backend -run TestMemoryBackend -v
|
||||
|
||||
# Redis backend only
|
||||
go test ./internal/cache/backend -run TestRedisBackend -v
|
||||
|
||||
# Circuit breaker only
|
||||
go test ./internal/cache/resilience -run TestCircuitBreaker -v
|
||||
|
||||
# Integration tests only
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -coverprofile=coverage.out
|
||||
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
go test ./internal/cache/backend -bench=. -benchmem
|
||||
go test ./internal/cache/resilience -bench=. -benchmem
|
||||
```
|
||||
|
||||
### Run with Race Detector
|
||||
```bash
|
||||
go test ./internal/cache/... -race -v
|
||||
```
|
||||
|
||||
## Test Patterns Used
|
||||
|
||||
### 1. Table-Driven Tests
|
||||
Used for testing multiple scenarios with similar structure.
|
||||
|
||||
### 2. Subtests (t.Run)
|
||||
Organized test cases into logical groups with clear names.
|
||||
|
||||
### 3. Parallel Tests
|
||||
Tests marked with `t.Parallel()` for faster execution.
|
||||
|
||||
### 4. Test Fixtures
|
||||
Reusable setup functions for common test data.
|
||||
|
||||
### 5. Mocking
|
||||
- `miniredis` for Redis operations
|
||||
- Mock functions for callbacks and health checks
|
||||
|
||||
### 6. Assertion Helpers
|
||||
Using `testify/assert` and `testify/require` for clear assertions.
|
||||
|
||||
## Test Coverage Summary
|
||||
|
||||
| Component | Coverage | Tests | Lines of Code |
|
||||
|-----------|----------|-------|---------------|
|
||||
| Interface Contract | 95% | 14 | ~200 |
|
||||
| Memory Backend | 92% | 18 | ~350 |
|
||||
| Redis Backend | 88% | 21 | ~400 |
|
||||
| Circuit Breaker | 95% | 17 | ~250 |
|
||||
| Health Checker | 90% | 12 | ~200 |
|
||||
| Integration Tests | 80% | 9 | ~300 |
|
||||
| **Total** | **90%** | **91** | **~1,700** |
|
||||
|
||||
## Edge Cases Tested
|
||||
|
||||
1. **Empty values** - Zero-length byte arrays
|
||||
2. **Large values** - 1MB+ data
|
||||
3. **Special characters** - Keys with :, /, -, _, ., |
|
||||
4. **Concurrent access** - 10-50 goroutines
|
||||
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
|
||||
6. **Connection failures** - Network errors, timeouts
|
||||
7. **Redis errors** - Simulated Redis failures
|
||||
8. **Memory limits** - Eviction under memory pressure
|
||||
9. **Race conditions** - Concurrent JTI checks
|
||||
10. **State transitions** - All circuit breaker and health check states
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Benchmarks included for:
|
||||
- Cache operations (Set, Get, Delete)
|
||||
- Circuit breaker execution
|
||||
- Health check operations
|
||||
- Concurrent access patterns
|
||||
- Large datasets (10,000+ items)
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Testing Libraries
|
||||
- `github.com/stretchr/testify` - Assertions and test utilities
|
||||
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
|
||||
- `github.com/redis/go-redis/v9` - Redis client
|
||||
|
||||
### Why Miniredis?
|
||||
- **No external dependencies** - No Redis server required
|
||||
- **Fast** - In-memory, perfect for unit tests
|
||||
- **Full Redis API** - Supports all operations we need
|
||||
- **Time manipulation** - FastForward for TTL testing
|
||||
- **Error simulation** - Test failure scenarios
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Tests
|
||||
1. Hybrid backend tests (L1/L2 cache)
|
||||
2. Network partition scenarios
|
||||
3. Redis cluster support
|
||||
4. Persistence and recovery tests
|
||||
5. Metrics and monitoring integration
|
||||
|
||||
### Test Infrastructure Improvements
|
||||
1. Test containers for real Redis integration
|
||||
2. Performance regression tracking
|
||||
3. Chaos engineering tests
|
||||
4. Load testing framework
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Recommended CI Configuration
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- go test ./internal/cache/... -race -cover -v
|
||||
- go test -run TestRedisIntegration -v
|
||||
- go test ./internal/cache/... -bench=. -benchmem
|
||||
```
|
||||
|
||||
## Maintenance Guidelines
|
||||
|
||||
1. **Add tests for new features** - Maintain >85% coverage
|
||||
2. **Update contract tests** - When interface changes
|
||||
3. **Test edge cases** - Always test error paths
|
||||
4. **Document test purpose** - Clear comments explaining what each test validates
|
||||
5. **Keep tests fast** - Use t.Parallel() where possible
|
||||
6. **Mock external dependencies** - Use miniredis, not real Redis
|
||||
|
||||
## Conclusion
|
||||
|
||||
This comprehensive test suite provides:
|
||||
- **High confidence** in cache backend correctness
|
||||
- **Fast feedback** - Tests run in seconds
|
||||
- **Good coverage** - 90% overall
|
||||
- **Clear documentation** - Each test is well-documented
|
||||
- **Maintainability** - Clear structure and patterns
|
||||
|
||||
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
|
||||
+390
@@ -0,0 +1,390 @@
|
||||
# Testing Guide
|
||||
|
||||
Comprehensive testing infrastructure for traefikoidc.
|
||||
|
||||
## Overview
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| Test files | 99 |
|
||||
| Lines of test code | ~65,500 |
|
||||
| Code coverage | 71.0% |
|
||||
| Race conditions | None (all pass with `-race`) |
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run with race detection
|
||||
go test -race ./...
|
||||
|
||||
# Run with coverage
|
||||
go test -cover ./...
|
||||
|
||||
# Run specific test suite
|
||||
go test -v -run "TokenValidationSuite" .
|
||||
|
||||
# Run edge case tests
|
||||
go test -v -run "ClockSkewEdgeCasesSuite|UnicodeClaimsSuite" .
|
||||
```
|
||||
|
||||
## Test Infrastructure
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
internal/testutil/
|
||||
├── compat.go # Re-exports for main package access
|
||||
├── mocks/
|
||||
│ ├── interfaces.go # JWKCache, TokenExchanger, TokenVerifier, etc.
|
||||
│ ├── session.go # SessionManager, SessionData
|
||||
│ ├── cache.go # Cache, TokenCache, Blacklist
|
||||
│ └── interfaces_test.go # Mock verification tests
|
||||
├── fixtures/
|
||||
│ └── tokens.go # JWT token generation fixtures
|
||||
└── servers/
|
||||
├── oidc.go # Mock OIDC server factory
|
||||
└── oidc_test.go # Server tests
|
||||
```
|
||||
|
||||
### Test Suites
|
||||
|
||||
| Suite | File | Description |
|
||||
|-------|------|-------------|
|
||||
| TokenValidationSuite | `token_validation_suite_test.go` | Token validation happy path and error cases |
|
||||
| JWKCacheTestSuite | `token_validation_suite_test.go` | JWK cache behavior tests |
|
||||
| TokenExchangerTestSuite | `token_validation_suite_test.go` | Token exchange scenarios |
|
||||
| ClockSkewEdgeCasesSuite | `edge_cases_suite_test.go` | Expiry boundary testing |
|
||||
| UnicodeClaimsSuite | `edge_cases_suite_test.go` | Unicode/emoji handling in claims |
|
||||
| LargeClaimsSuite | `edge_cases_suite_test.go` | Large data handling (100s of claims) |
|
||||
| URLPathEdgeCasesSuite | `edge_cases_suite_test.go` | URL parsing edge cases |
|
||||
| ConcurrencyEdgeCasesSuite | `edge_cases_suite_test.go` | Concurrent token validation |
|
||||
| ExampleTestSuite | `testutil_example_test.go` | Example demonstrating patterns |
|
||||
| AuthFlowBehaviourSuite | `auth_flow_behaviour_test.go` | Authentication flow behavior tests |
|
||||
| SessionBehaviourSuite | `session_behaviour_test.go` | Session management behavior tests |
|
||||
| EnhancedMocksSuite | `enhanced_mocks_suite_test.go` | Enhanced mock usage demonstration |
|
||||
|
||||
## Mock Types
|
||||
|
||||
The project provides two mocking patterns:
|
||||
|
||||
### State-Based Mocks (Basic)
|
||||
|
||||
Located in `main_test.go`, `mocks_test.go`. Simple mocks that store data in struct fields.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `MockJWKCache` | `JWKCacheInterface` | Simple state-based mock with JWKS/Err fields |
|
||||
| `MockTokenVerifier` | `TokenVerifier` | Function-based mock for token verification |
|
||||
| `MockTokenExchanger` | `TokenExchanger` | Function-based mock for token exchange |
|
||||
| `MockOAuthProvider` | `http.Handler` | Full HTTP handler mock for OAuth provider simulation |
|
||||
| `MockSessionManager` | `SessionManager` | State-based mock for session management |
|
||||
| `MockHTTPClient` | N/A | Mock HTTP client with customizable responses |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
jwkCache: mock,
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
### Enhanced State-Based Mocks (with Call Tracking)
|
||||
|
||||
Located in `enhanced_mocks_test.go`. State-based mocks with built-in call tracking and assertion helpers.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `EnhancedMockJWKCache` | `JWKCacheInterface` | State-based with call tracking |
|
||||
| `EnhancedMockTokenVerifier` | `TokenVerifier` | State-based with call tracking |
|
||||
| `EnhancedMockTokenExchanger` | `TokenExchanger` | State-based with call tracking |
|
||||
| `EnhancedMockCacheInterface` | `CacheInterface` | Functional cache with call tracking |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
}
|
||||
|
||||
// Make calls
|
||||
result, err := mock.GetJWKS(ctx, "https://example.com/jwks", nil)
|
||||
|
||||
// Verify calls were made
|
||||
mock.AssertGetJWKSCalled(t)
|
||||
mock.AssertGetJWKSCalledWith(t, "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(t, 1)
|
||||
|
||||
// Access call details
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Track all calls with parameters and timestamps
|
||||
- Built-in assertion helpers using testify
|
||||
- Thread-safe for concurrent tests
|
||||
- `Reset()` method to clear state between tests
|
||||
- `LastCall()` to inspect most recent call
|
||||
|
||||
### Testify-Based Mocks
|
||||
|
||||
Located in `testify_mocks_test.go`. Mocks using testify's `.On()/.Return()` pattern for behavior verification.
|
||||
|
||||
| Mock | Interface | Description |
|
||||
|------|-----------|-------------|
|
||||
| `TestifyJWKCache` | `JWKCacheInterface` | Testify mock with `.On()/.Return()` |
|
||||
| `TestifyTokenVerifier` | `TokenVerifier` | Testify mock for token verification |
|
||||
| `TestifyTokenExchanger` | `TokenExchanger` | Testify mock for token exchange |
|
||||
| `TestifyCacheInterface` | `CacheInterface` | Testify mock for cache operations |
|
||||
| `TestifyHTTPClient` | N/A | Testify mock for HTTP client |
|
||||
| `TestifyRoundTripper` | `http.RoundTripper` | Testify mock for HTTP transport |
|
||||
|
||||
**Usage:**
|
||||
```go
|
||||
mock := &TestifyJWKCache{}
|
||||
mock.On("GetJWKS", mock.Anything, "https://example.com/jwks", mock.Anything).
|
||||
Return(&JWKSet{Keys: []JWK{jwk}}, nil)
|
||||
|
||||
// After test
|
||||
mock.AssertExpectations(t)
|
||||
```
|
||||
|
||||
### Testutil Package Mocks
|
||||
|
||||
Located in `internal/testutil/mocks/`. Generic mocks for testing the test infrastructure itself.
|
||||
|
||||
```go
|
||||
import "github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
|
||||
mock := testutil.NewJWKCacheMock()
|
||||
mock.On("GetJWKS", mock.Anything, mock.Anything, mock.Anything).
|
||||
Return(&mocks.JWKSet{Keys: []mocks.JWK{{Kty: "RSA"}}}, nil)
|
||||
```
|
||||
|
||||
### Choosing the Right Mock
|
||||
|
||||
| Use Case | Recommended Mock |
|
||||
|----------|-----------------|
|
||||
| Simple return values only | Basic state-based (`MockJWKCache`) |
|
||||
| Return values + verify calls made | Enhanced state-based (`EnhancedMockJWKCache`) |
|
||||
| Complex call expectations | Testify-based (`TestifyJWKCache`) |
|
||||
| Verify call order/sequence | Testify-based |
|
||||
| HTTP endpoint simulation | `MockOAuthProvider` |
|
||||
| New testify suite tests | Enhanced or Testify-based |
|
||||
|
||||
**Decision Guide:**
|
||||
|
||||
1. **Basic State-Based**: Use when you only need to control return values and don't care about verifying interactions.
|
||||
|
||||
2. **Enhanced State-Based**: Use when you want to verify calls were made with specific parameters, but prefer simpler setup than testify's `.On()/.Return()` pattern.
|
||||
|
||||
3. **Testify-Based**: Use when you need complex behavior like different returns per call, strict call ordering, or detailed expectation matching.
|
||||
|
||||
## Token Fixtures
|
||||
|
||||
The `testutil.TokenFixture` generates JWT tokens for testing:
|
||||
|
||||
```go
|
||||
fixture, err := testutil.NewTokenFixture()
|
||||
|
||||
// Valid token with default claims
|
||||
token, _ := fixture.ValidToken(nil)
|
||||
|
||||
// Token with custom claims
|
||||
token, _ := fixture.ValidToken(map[string]interface{}{
|
||||
"email": "test@example.com",
|
||||
"roles": []string{"admin"},
|
||||
})
|
||||
|
||||
// Expired token
|
||||
token, _ := fixture.ExpiredToken()
|
||||
|
||||
// Token with specific roles/groups
|
||||
token, _ := fixture.TokenWithRoles([]string{"admin", "user"})
|
||||
token, _ := fixture.TokenWithGroups([]string{"developers"})
|
||||
|
||||
// Token with clock skew
|
||||
token, _ := fixture.TokenWithSkew(-2 * time.Minute) // expired 2 min ago
|
||||
token, _ := fixture.TokenWithSkew(5 * time.Minute) // expires in 5 min
|
||||
|
||||
// Token missing specific claims
|
||||
token, _ := fixture.TokenMissingClaim("email", "sub")
|
||||
|
||||
// Malformed token
|
||||
token := fixture.MalformedToken() // "not.a.valid.jwt"
|
||||
|
||||
// Get JWKS for verification
|
||||
jwks := fixture.GetJWKS()
|
||||
```
|
||||
|
||||
## Mock OIDC Server
|
||||
|
||||
The `testutil.OIDCServer` provides a fully functional mock OIDC provider:
|
||||
|
||||
```go
|
||||
// Default configuration
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
defer server.Close()
|
||||
|
||||
// Custom configuration
|
||||
config := testutil.DefaultServerConfig()
|
||||
config.Issuer = "https://custom-issuer.com"
|
||||
config.TokenError = &testutil.OIDCError{
|
||||
Error: "invalid_grant",
|
||||
Description: "Authorization code expired",
|
||||
}
|
||||
server := testutil.NewOIDCServer(config)
|
||||
|
||||
// Provider-specific configurations
|
||||
googleConfig := testutil.GoogleServerConfig()
|
||||
azureConfig := testutil.AzureServerConfig()
|
||||
auth0Config := testutil.Auth0ServerConfig()
|
||||
keycloakConfig := testutil.KeycloakServerConfig()
|
||||
|
||||
// Behavior configurations
|
||||
slowConfig := testutil.SlowServerConfig(100 * time.Millisecond)
|
||||
rateLimitedConfig := testutil.RateLimitedServerConfig(5) // Limit after 5 requests
|
||||
```
|
||||
|
||||
### Server Endpoints
|
||||
|
||||
| Endpoint | Description |
|
||||
|----------|-------------|
|
||||
| `/.well-known/openid-configuration` | OIDC discovery document |
|
||||
| `/authorize` | Authorization endpoint |
|
||||
| `/token` | Token exchange endpoint |
|
||||
| `/jwks` | JSON Web Key Set |
|
||||
| `/userinfo` | User information endpoint |
|
||||
| `/introspect` | Token introspection |
|
||||
| `/revoke` | Token revocation |
|
||||
| `/logout` | End session endpoint |
|
||||
|
||||
### Request Tracking
|
||||
|
||||
```go
|
||||
server := testutil.NewOIDCServer(nil)
|
||||
|
||||
// Make requests...
|
||||
|
||||
count := server.GetRequestCount()
|
||||
requests := server.GetRequests()
|
||||
server.Reset() // Clear tracking
|
||||
```
|
||||
|
||||
## Writing Test Suites
|
||||
|
||||
### Basic Suite Structure
|
||||
|
||||
```go
|
||||
type MyTestSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) SetupTest() {
|
||||
// Per-test setup
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
// ...
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TearDownTest() {
|
||||
// Per-test cleanup
|
||||
}
|
||||
|
||||
func (s *MyTestSuite) TestSomething() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func TestMyTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(MyTestSuite))
|
||||
}
|
||||
```
|
||||
|
||||
### Table-Driven Tests
|
||||
|
||||
```go
|
||||
func (s *MyTestSuite) TestClockSkewEdgeCases() {
|
||||
testCases := []struct {
|
||||
name string
|
||||
skew time.Duration
|
||||
shouldPass bool
|
||||
}{
|
||||
{"valid_token", 5 * time.Minute, true},
|
||||
{"expired_within_tolerance", -1 * time.Minute, true},
|
||||
{"expired_beyond_tolerance", -10 * time.Minute, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
s.Run(tc.name, func() {
|
||||
token, err := s.fixture.TokenWithSkew(tc.skew)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
if tc.shouldPass {
|
||||
s.NoError(err)
|
||||
} else {
|
||||
s.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Happy Path Tests
|
||||
|
||||
Test the expected successful scenarios:
|
||||
|
||||
- Valid token verification
|
||||
- Successful token exchange
|
||||
- Session creation and retrieval
|
||||
- Cache operations
|
||||
|
||||
### Error Case Tests
|
||||
|
||||
Test failure scenarios:
|
||||
|
||||
- Expired tokens
|
||||
- Invalid signatures
|
||||
- Wrong issuer/audience
|
||||
- Network failures
|
||||
- Rate limiting
|
||||
|
||||
### Edge Case Tests
|
||||
|
||||
Test boundary conditions:
|
||||
|
||||
- Clock skew tolerance boundaries
|
||||
- Unicode/emoji in claims
|
||||
- Very large claim values
|
||||
- Concurrent access
|
||||
- Special characters in URLs
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use fixtures for token generation** - Don't manually construct JWTs
|
||||
2. **Use mock servers for integration tests** - Test against realistic OIDC behavior
|
||||
3. **Always run with `-race`** - Catch concurrency issues early
|
||||
4. **Use testify assertions** - Better error messages and cleaner code
|
||||
5. **Clean up resources** - Use `t.Cleanup()` or `TearDownTest()`
|
||||
6. **Test edge cases systematically** - Use table-driven tests
|
||||
@@ -1,308 +0,0 @@
|
||||
# Test Execution Guide
|
||||
|
||||
This guide explains how to run tests efficiently with the new test categorization and optimization system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Fast Development Testing (Default - Target: < 30 seconds)
|
||||
```bash
|
||||
# Run quick smoke tests only
|
||||
go test ./...
|
||||
|
||||
# Or explicitly run in short mode
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Extended Testing (Target: 2-5 minutes)
|
||||
```bash
|
||||
# Enable extended tests with more iterations and concurrency
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Or use the flag equivalent (if using test runner that supports it)
|
||||
go test ./... -extended
|
||||
```
|
||||
|
||||
### Long-Running Performance Tests (Target: 5-15 minutes)
|
||||
```bash
|
||||
# Enable comprehensive performance and stress tests
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Full Stress Testing (Target: 10-30 minutes)
|
||||
```bash
|
||||
# Enable all stress tests with maximum parameters
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Quick Tests (Default)
|
||||
- **Purpose**: Fast feedback during development
|
||||
- **Duration**: < 30 seconds total
|
||||
- **Features**:
|
||||
- Basic functionality verification
|
||||
- Limited iterations (1-3)
|
||||
- Small data sets
|
||||
- Minimal concurrency
|
||||
- Essential memory leak checks
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 3
|
||||
- Max Concurrency: 5
|
||||
- Memory Threshold: 2.0 MB
|
||||
- Cache Size: 50
|
||||
- Timeout: 10 seconds
|
||||
|
||||
### 2. Extended Tests
|
||||
- **Purpose**: Comprehensive testing before commits
|
||||
- **Duration**: 2-5 minutes
|
||||
- **Features**:
|
||||
- Increased test coverage
|
||||
- More iterations (5-10)
|
||||
- Medium concurrency tests
|
||||
- Enhanced memory leak detection
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 10
|
||||
- Max Concurrency: 20
|
||||
- Memory Threshold: 10.0 MB
|
||||
- Cache Size: 200
|
||||
- Timeout: 30 seconds
|
||||
|
||||
### 3. Long Tests
|
||||
- **Purpose**: Performance validation and stress testing
|
||||
- **Duration**: 5-15 minutes
|
||||
- **Features**:
|
||||
- High iteration counts (50-100)
|
||||
- High concurrency scenarios
|
||||
- Large data sets
|
||||
- Comprehensive memory testing
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 100
|
||||
- Max Concurrency: 50
|
||||
- Memory Threshold: 50.0 MB
|
||||
- Cache Size: 1000
|
||||
- Timeout: 60 seconds
|
||||
|
||||
### 4. Stress Tests
|
||||
- **Purpose**: Maximum load testing and edge case validation
|
||||
- **Duration**: 10-30 minutes
|
||||
- **Features**:
|
||||
- Extreme iteration counts (100-500)
|
||||
- Maximum concurrency (100+)
|
||||
- Large memory allocations
|
||||
- Edge case combinations
|
||||
|
||||
**Configuration**:
|
||||
- Max Iterations: 500
|
||||
- Max Concurrency: 100
|
||||
- Memory Threshold: 100.0 MB
|
||||
- Cache Size: 2000
|
||||
- Timeout: 120 seconds
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Test Execution Control
|
||||
```bash
|
||||
# Enable specific test types
|
||||
export RUN_EXTENDED_TESTS=1 # Enable extended tests
|
||||
export RUN_LONG_TESTS=1 # Enable long-running tests
|
||||
export RUN_STRESS_TESTS=1 # Enable stress tests
|
||||
|
||||
# Disable specific features
|
||||
export DISABLE_LEAK_DETECTION=1 # Skip memory leak detection
|
||||
```
|
||||
|
||||
### Parameter Customization
|
||||
```bash
|
||||
# Customize concurrency limits
|
||||
export TEST_MAX_CONCURRENCY=10 # Override max concurrent operations
|
||||
|
||||
# Customize iteration limits
|
||||
export TEST_MAX_ITERATIONS=50 # Override max test iterations
|
||||
|
||||
# Customize memory thresholds
|
||||
export TEST_MEMORY_THRESHOLD_MB=25.5 # Override memory growth limit (in MB)
|
||||
```
|
||||
|
||||
## Test-Specific Behavior
|
||||
|
||||
### Memory Leak Tests
|
||||
- **Quick Mode**: 1-3 iterations, small data sets, strict memory limits
|
||||
- **Extended Mode**: 5-10 iterations, medium data sets, relaxed limits
|
||||
- **Long Mode**: 50-100 iterations, large data sets, performance focus
|
||||
- **Stress Mode**: 100-500 iterations, maximum data sets, stress focus
|
||||
|
||||
### Concurrency Tests
|
||||
- **Quick Mode**: 2-5 concurrent operations, basic race detection
|
||||
- **Extended Mode**: 10-20 concurrent operations, moderate stress
|
||||
- **Long Mode**: 20-50 concurrent operations, high contention
|
||||
- **Stress Mode**: 50-100+ concurrent operations, maximum stress
|
||||
|
||||
### Cache Tests
|
||||
- **Quick Mode**: Small caches (50 items), basic operations
|
||||
- **Extended Mode**: Medium caches (200 items), varied operations
|
||||
- **Long Mode**: Large caches (1000 items), performance testing
|
||||
- **Stress Mode**: Very large caches (2000+ items), stress testing
|
||||
|
||||
## Integration with CI/CD
|
||||
|
||||
### GitHub Actions Example
|
||||
```yaml
|
||||
# Quick tests for every push/PR
|
||||
- name: Quick Tests
|
||||
run: go test ./... -short
|
||||
|
||||
# Extended tests for main branch
|
||||
- name: Extended Tests
|
||||
if: github.ref == 'refs/heads/main'
|
||||
run: RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Nightly comprehensive testing
|
||||
- name: Nightly Stress Tests
|
||||
if: github.event_name == 'schedule'
|
||||
run: RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Local Development Workflow
|
||||
```bash
|
||||
# During active development
|
||||
go test ./... -short
|
||||
|
||||
# Before committing
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
|
||||
# Before major releases
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Performance validation
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
## Performance Optimization Features
|
||||
|
||||
### Dynamic Test Scaling
|
||||
The test system automatically adjusts parameters based on:
|
||||
- Test mode (quick/extended/long/stress)
|
||||
- Available resources
|
||||
- Environment variables
|
||||
- Previous test performance
|
||||
|
||||
### Memory Management
|
||||
- **Garbage Collection**: Forced GC between test iterations
|
||||
- **Memory Monitoring**: Real-time memory growth tracking
|
||||
- **Leak Detection**: Goroutine and memory leak prevention
|
||||
- **Resource Cleanup**: Automatic cleanup of test resources
|
||||
|
||||
### Timeout Management
|
||||
- **Adaptive Timeouts**: Timeouts scale with test complexity
|
||||
- **Graceful Degradation**: Tests adapt to slower environments
|
||||
- **Early Termination**: Failed tests terminate quickly
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Tests Taking Too Long
|
||||
```bash
|
||||
# Check if running in extended mode accidentally
|
||||
echo $RUN_EXTENDED_TESTS $RUN_LONG_TESTS
|
||||
|
||||
# Force quick mode
|
||||
unset RUN_EXTENDED_TESTS RUN_LONG_TESTS RUN_STRESS_TESTS
|
||||
go test ./... -short
|
||||
```
|
||||
|
||||
### Memory Issues
|
||||
```bash
|
||||
# Reduce memory limits for constrained environments
|
||||
export TEST_MEMORY_THRESHOLD_MB=5.0
|
||||
export TEST_MAX_CONCURRENCY=2
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Concurrency Issues
|
||||
```bash
|
||||
# Reduce concurrency for slower systems
|
||||
export TEST_MAX_CONCURRENCY=5
|
||||
export TEST_MAX_ITERATIONS=10
|
||||
go test ./...
|
||||
```
|
||||
|
||||
### Skip Specific Test Types
|
||||
```bash
|
||||
# Skip memory leak detection if problematic
|
||||
export DISABLE_LEAK_DETECTION=1
|
||||
go test ./...
|
||||
```
|
||||
|
||||
## Benchmarking
|
||||
|
||||
### Running Benchmarks
|
||||
```bash
|
||||
# Quick benchmarks
|
||||
go test -bench=. -short
|
||||
|
||||
# Extended benchmarks
|
||||
RUN_EXTENDED_TESTS=1 go test -bench=.
|
||||
|
||||
# Memory profiling
|
||||
go test -bench=. -memprofile=mem.prof
|
||||
go tool pprof mem.prof
|
||||
```
|
||||
|
||||
### Benchmark Categories
|
||||
- **Basic Operations**: Set/Get performance
|
||||
- **Concurrency**: Multi-threaded performance
|
||||
- **Memory**: Allocation and cleanup performance
|
||||
- **Cache**: Eviction and cleanup performance
|
||||
|
||||
## Best Practices
|
||||
|
||||
### For Developers
|
||||
1. Always run quick tests during development (`go test ./... -short`)
|
||||
2. Run extended tests before committing (`RUN_EXTENDED_TESTS=1 go test ./...`)
|
||||
3. Use appropriate test categories for your use case
|
||||
4. Monitor test execution time and adjust if needed
|
||||
|
||||
### For CI/CD
|
||||
1. Use quick tests for fast feedback on PRs
|
||||
2. Use extended tests for main branch validation
|
||||
3. Use long tests for release validation
|
||||
4. Use stress tests for nightly/weekly validation
|
||||
|
||||
### For Performance Testing
|
||||
1. Use consistent environment variables
|
||||
2. Run tests multiple times for statistical significance
|
||||
3. Monitor both execution time and resource usage
|
||||
4. Use profiling tools for detailed analysis
|
||||
|
||||
## Examples
|
||||
|
||||
### Daily Development
|
||||
```bash
|
||||
# Fast tests while coding
|
||||
go test ./... -short
|
||||
|
||||
# Before git commit
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Release Testing
|
||||
```bash
|
||||
# Comprehensive validation
|
||||
RUN_LONG_TESTS=1 go test ./...
|
||||
|
||||
# Stress testing
|
||||
RUN_STRESS_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
```bash
|
||||
# Custom limits for specific environment
|
||||
export TEST_MAX_CONCURRENCY=8
|
||||
export TEST_MAX_ITERATIONS=25
|
||||
export TEST_MEMORY_THRESHOLD_MB=15.0
|
||||
RUN_EXTENDED_TESTS=1 go test ./...
|
||||
```
|
||||
|
||||
This test system provides flexible, scalable test execution that adapts to your development workflow and infrastructure constraints while maintaining comprehensive test coverage.
|
||||
@@ -1,163 +0,0 @@
|
||||
# Google OAuth Integration Fix
|
||||
|
||||
## Problem Overview
|
||||
|
||||
The Traefik OIDC plugin encountered an authentication issue when using Google as an OAuth provider. Authentication would fail with the following error:
|
||||
|
||||
```
|
||||
Some requested scopes were invalid. {valid=[openid, https://www.googleapis.com/auth/userinfo.email, https://www.googleapis.com/auth/userinfo.profile], invalid=[offline_access]}
|
||||
```
|
||||
|
||||
This occurred because Google's OAuth implementation differs from the standard OIDC specification in how it handles refresh tokens and offline access.
|
||||
|
||||
## Technical Details of the Issue
|
||||
|
||||
### Standard OIDC Provider Behavior
|
||||
|
||||
Most OpenID Connect (OIDC) providers follow the standard specification, where:
|
||||
- To obtain a refresh token, clients include the `offline_access` scope in their authorization request
|
||||
- This allows authenticated sessions to persist beyond the initial access token expiration
|
||||
|
||||
### Google's Non-Standard Approach
|
||||
|
||||
Google's OAuth implementation deviates from the standard by:
|
||||
1. Not supporting the `offline_access` scope, instead rejecting it as an invalid scope
|
||||
2. Requiring the `access_type=offline` query parameter for requesting refresh tokens
|
||||
3. Needing the `prompt=consent` parameter to consistently issue refresh tokens (especially for repeat authentications)
|
||||
|
||||
This difference caused the plugin to fail when configured for Google OAuth, as it was using a standard approach that didn't work with Google's implementation.
|
||||
|
||||
## Solution Implementation
|
||||
|
||||
The fix involved modifying the authentication flow to specifically handle Google providers:
|
||||
|
||||
1. **Google Provider Detection**: Added code to detect if the OIDC provider is Google based on the issuer URL:
|
||||
|
||||
```go
|
||||
// Check if we're dealing with a Google OIDC provider
|
||||
isGoogleProvider := strings.Contains(t.issuerURL, "google") ||
|
||||
strings.Contains(t.issuerURL, "accounts.google.com")
|
||||
```
|
||||
|
||||
2. **Provider-Specific Auth URL Building**: Modified the `buildAuthURL` function to handle Google and non-Google providers differently:
|
||||
|
||||
```go
|
||||
// Handle offline access differently for Google vs other providers
|
||||
if isGoogleProvider {
|
||||
// For Google, use access_type=offline parameter instead of offline_access scope
|
||||
params.Set("access_type", "offline")
|
||||
t.logger.Debug("Google OIDC provider detected, added access_type=offline for refresh tokens")
|
||||
|
||||
// Add prompt=consent for Google to ensure refresh token is issued
|
||||
params.Set("prompt", "consent")
|
||||
t.logger.Debug("Google OIDC provider detected, added prompt=consent to ensure refresh tokens")
|
||||
} else {
|
||||
// For non-Google providers, use the offline_access scope
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. **Token Refresh Enhancement**: Improved the token refresh logic to better handle Google's behavior, particularly when refresh tokens aren't returned in refresh responses (as Google often uses the same refresh token for multiple requests).
|
||||
|
||||
## Why This Approach Works
|
||||
|
||||
This solution aligns with Google's OAuth 2.0 documentation which specifies:
|
||||
|
||||
1. **Access Type Parameter**: Google's [OAuth 2.0 documentation](https://developers.google.com/identity/protocols/oauth2/web-server#offline) states that to request a refresh token, applications must include `access_type=offline` in the authorization request.
|
||||
|
||||
2. **Prompt Parameter**: The [`prompt=consent`](https://developers.google.com/identity/protocols/oauth2/web-server#forceapprovalprompt) parameter forces the consent screen to appear, ensuring a refresh token is issued even if the user has previously granted access.
|
||||
|
||||
3. **Scope Validation**: Google strictly validates scopes and rejects non-standard ones like `offline_access`, instead relying on the `access_type` parameter to indicate whether a refresh token should be issued.
|
||||
|
||||
By adapting to these Google-specific requirements, the OIDC plugin can now seamlessly work with both standard OIDC providers and Google's OAuth implementation.
|
||||
|
||||
## Testing and Verification
|
||||
|
||||
Comprehensive tests were implemented to verify the solution:
|
||||
|
||||
1. **Provider Detection Test**: Ensures the code correctly identifies Google providers and applies the appropriate parameters.
|
||||
|
||||
2. **Auth URL Parameter Tests**: Verifies that:
|
||||
- For Google providers: `access_type=offline` and `prompt=consent` are included; `offline_access` scope is NOT included
|
||||
- For non-Google providers: `offline_access` scope IS included; `access_type` parameter is NOT added
|
||||
|
||||
3. **Token Refresh Tests**: Validates that Google's token refresh process works correctly, including the preservation of refresh tokens when Google doesn't return a new one.
|
||||
|
||||
4. **Integration Test**: Tests the complete authentication flow with a mocked Google provider to ensure all components work together seamlessly.
|
||||
|
||||
Sample test case (simplified):
|
||||
|
||||
```go
|
||||
t.Run("Google provider detection adds required parameters", func(t *testing.T) {
|
||||
// Test buildAuthURL to ensure it adds access_type=offline and prompt=consent for Google
|
||||
authURL := tOidc.buildAuthURL("https://example.com/callback", "state123", "nonce123", "")
|
||||
|
||||
// Check that access_type=offline was added (not offline_access scope for Google)
|
||||
if !strings.Contains(authURL, "access_type=offline") {
|
||||
t.Errorf("access_type=offline not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Verify offline_access scope is NOT included for Google providers
|
||||
if strings.Contains(authURL, "offline_access") {
|
||||
t.Errorf("offline_access scope incorrectly added to Google auth URL: %s", authURL)
|
||||
}
|
||||
|
||||
// Check that prompt=consent was added
|
||||
if !strings.Contains(authURL, "prompt=consent") {
|
||||
t.Errorf("prompt=consent not added to Google auth URL: %s", authURL)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Usage Guidance for Developers
|
||||
|
||||
When configuring the Traefik OIDC middleware for Google:
|
||||
|
||||
1. **Provider URL**: Use `https://accounts.google.com` as the `providerURL` value
|
||||
|
||||
2. **Client Configuration**: Create OAuth 2.0 credentials in the Google Cloud Console:
|
||||
- Configure the authorized redirect URI to match your `callbackURL` setting
|
||||
- Ensure your OAuth consent screen is properly configured (especially if you want long-lived refresh tokens)
|
||||
|
||||
3. **Configuration Example**:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-google
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-google-client-id.apps.googleusercontent.com
|
||||
clientSecret: your-google-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
# Note: DO NOT manually add offline_access scope for Google
|
||||
# The middleware handles this automatically and correctly
|
||||
```
|
||||
|
||||
4. **Troubleshooting**: If sessions still expire prematurely with Google (typically after 1 hour):
|
||||
- Ensure your Google Cloud OAuth consent screen is set to "External" and "Production" mode (not "Testing" mode, which limits refresh token validity)
|
||||
- Review your application logs with `logLevel: debug` to check for refresh token errors
|
||||
- Verify you're using a version of the middleware that includes this fix
|
||||
|
||||
## Conclusion
|
||||
|
||||
This fix ensures that the Traefik OIDC plugin works seamlessly with Google's OAuth implementation without requiring users to make provider-specific configuration changes. The middleware now intelligently adapts to the provider's requirements, making it more robust and user-friendly while maintaining compatibility with the standard OIDC specification for other providers.
|
||||
@@ -16,35 +16,26 @@ import (
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
// Required fields
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// Conditional - only for confidential clients
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
|
||||
// Optional - for managing registration
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
|
||||
// Expiration
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
|
||||
// Echo back of registered metadata
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
ClientID string `json:"client_id"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationError represents an error response from client registration (RFC 7591)
|
||||
@@ -55,14 +46,12 @@ type ClientRegistrationError struct {
|
||||
|
||||
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrar struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
providerURL string
|
||||
|
||||
// Cached registration response
|
||||
mu sync.RWMutex
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
providerURL string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrar creates a new dynamic client registrar
|
||||
@@ -348,6 +337,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
||||
@@ -223,10 +223,10 @@ func TestRegisterClientWithInitialAccessToken(t *testing.T) {
|
||||
// TestRegisterClientError tests error handling during registration
|
||||
func TestRegisterClientError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse func(w http.ResponseWriter, r *http.Request)
|
||||
expectError bool
|
||||
name string
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "invalid_redirect_uri error",
|
||||
@@ -321,8 +321,8 @@ func TestRegisterClientError(t *testing.T) {
|
||||
// TestRegisterClientDisabled tests that registration fails when not enabled
|
||||
func TestRegisterClientDisabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dcrConfig *DynamicClientRegistrationConfig
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
@@ -521,8 +521,8 @@ func TestCredentialsValidation(t *testing.T) {
|
||||
registrar := NewDynamicClientRegistrar(&http.Client{}, NewLogger("DEBUG"), dcrConfig, "https://example.com")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
response *ClientRegistrationResponse
|
||||
name string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
@@ -584,9 +584,9 @@ func TestCredentialsValidation(t *testing.T) {
|
||||
// TestBuildRegistrationRequest tests the request body construction
|
||||
func TestBuildRegistrationRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata *ClientRegistrationMetadata
|
||||
expectedFields map[string]interface{}
|
||||
name string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/testutil"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// ClockSkewEdgeCasesSuite tests clock skew tolerance scenarios
|
||||
type ClockSkewEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) SetupTest() {
|
||||
// Create JWK for the test key
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error") // Reduce noise
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestExactlyAtExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(0)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// Token at exact expiry - behavior is implementation-defined
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.T().Logf("Exact expiry result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondBeforeExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token should be valid 1 second before expiry")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestOneSecondAfterExpiry() {
|
||||
token, err := s.fixture.TokenWithSkew(-1 * time.Second)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// With default 2-minute clock skew tolerance, 1 second past expiry should still be valid
|
||||
s.NoError(err, "Token 1 second past expiry should be valid within clock skew tolerance")
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestWithinSkewTolerance() {
|
||||
// Most implementations allow 5-minute clock skew
|
||||
token, err := s.fixture.TokenWithSkew(-4 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
// May pass or fail depending on implementation
|
||||
s.T().Logf("4-minute expired token result: %v", err)
|
||||
}
|
||||
|
||||
func (s *ClockSkewEdgeCasesSuite) TestBeyondSkewTolerance() {
|
||||
token, err := s.fixture.TokenWithSkew(-10 * time.Minute)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.Error(err, "Token should be invalid 10 minutes after expiry")
|
||||
}
|
||||
|
||||
func TestClockSkewEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClockSkewEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// UnicodeClaimsSuite tests Unicode handling in JWT claims
|
||||
type UnicodeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeEmail() {
|
||||
token, err := s.fixture.TokenWithEmail("用户@example.com")
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode email should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestUnicodeName() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "田中太郎",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Unicode name should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestEmojiInClaims() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test User 😀",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Emoji in claims should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestRTLText() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "مستخدم اختبار",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "RTL text should be handled correctly")
|
||||
}
|
||||
|
||||
func (s *UnicodeClaimsSuite) TestMixedScripts() {
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"name": "Test 测试 テスト",
|
||||
"roles": []string{"admin", "管理者", "管理员"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Mixed scripts should be handled correctly")
|
||||
}
|
||||
|
||||
func TestUnicodeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(UnicodeClaimsSuite))
|
||||
}
|
||||
|
||||
// LargeClaimsSuite tests large claim values
|
||||
type LargeClaimsSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyRoles() {
|
||||
roles := make([]string, 100)
|
||||
for i := 0; i < 100; i++ {
|
||||
roles[i] = strings.Repeat("role", 10) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithRoles(roles)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 100 roles should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestManyGroups() {
|
||||
groups := make([]string, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
groups[i] = strings.Repeat("group", 5) + string(rune('A'+i%26))
|
||||
}
|
||||
|
||||
token, err := s.fixture.TokenWithGroups(groups)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with 50 groups should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongEmail() {
|
||||
// RFC 5321 allows up to 254 characters
|
||||
localPart := strings.Repeat("a", 64)
|
||||
domain := strings.Repeat("b", 63) + ".com"
|
||||
email := localPart + "@" + domain
|
||||
|
||||
token, err := s.fixture.TokenWithEmail(email)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long email should be handled")
|
||||
}
|
||||
|
||||
func (s *LargeClaimsSuite) TestLongSubject() {
|
||||
longSub := strings.Repeat("subject", 100)
|
||||
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"sub": longSub,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.tOidc.VerifyToken(token)
|
||||
s.NoError(err, "Token with long subject should be handled")
|
||||
}
|
||||
|
||||
func TestLargeClaimsSuite(t *testing.T) {
|
||||
suite.Run(t, new(LargeClaimsSuite))
|
||||
}
|
||||
|
||||
// URLPathEdgeCasesSuite tests URL handling edge cases
|
||||
type URLPathEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestVeryLongPath() {
|
||||
longPath := "/" + strings.Repeat("segment/", 100)
|
||||
req := httptest.NewRequest("GET", longPath, nil)
|
||||
|
||||
s.NotNil(req)
|
||||
s.Contains(req.URL.Path, "segment")
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestSpecialCharactersInPath() {
|
||||
paths := []string{
|
||||
"/path%20with%20spaces",
|
||||
"/path/with/日本語",
|
||||
"/path?query=value&another=test",
|
||||
"/path#fragment",
|
||||
"/path/../traversal",
|
||||
"/path/./current",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
s.Run(path, func() {
|
||||
req := httptest.NewRequest("GET", path, nil)
|
||||
s.NotNil(req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestEmptyPath() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
s.Equal("/", req.URL.Path)
|
||||
}
|
||||
|
||||
func (s *URLPathEdgeCasesSuite) TestDoubleSlashes() {
|
||||
req := httptest.NewRequest("GET", "//double//slashes//", nil)
|
||||
s.NotNil(req)
|
||||
}
|
||||
|
||||
func TestURLPathEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(URLPathEdgeCasesSuite))
|
||||
}
|
||||
|
||||
// ConcurrencyEdgeCasesSuite tests concurrency scenarios
|
||||
type ConcurrencyEdgeCasesSuite struct {
|
||||
suite.Suite
|
||||
|
||||
fixture *testutil.TokenFixture
|
||||
tOidc *TraefikOidc
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupSuite() {
|
||||
var err error
|
||||
s.fixture, err = testutil.NewTokenFixture()
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) SetupTest() {
|
||||
jwk := JWK{
|
||||
Kty: "RSA",
|
||||
Kid: s.fixture.KeyID,
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(s.fixture.RSAPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(s.fixture.RSAPublicKey.E)))),
|
||||
}
|
||||
|
||||
jwkCache := &MockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{jwk}},
|
||||
Err: nil,
|
||||
}
|
||||
|
||||
tokenBlacklist := NewCache()
|
||||
tokenCacheInternal := NewCache()
|
||||
tokenCache := &TokenCache{}
|
||||
if tokenCache.cache == nil {
|
||||
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
||||
tokenCache.cache = wrapper.cache
|
||||
}
|
||||
}
|
||||
|
||||
logger := NewLogger("error")
|
||||
|
||||
s.tOidc = &TraefikOidc{
|
||||
issuerURL: s.fixture.Issuer,
|
||||
clientID: s.fixture.Audience,
|
||||
audience: s.fixture.Audience,
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles",
|
||||
groupClaimName: "groups",
|
||||
userIdentifierClaim: "email",
|
||||
jwkCache: jwkCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 100), // Higher limit for concurrency tests
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
goroutineWG: &sync.WaitGroup{},
|
||||
ctx: context.Background(),
|
||||
}
|
||||
close(s.tOidc.initComplete)
|
||||
s.tOidc.tokenVerifier = s.tOidc
|
||||
s.tOidc.jwtVerifier = s.tOidc
|
||||
|
||||
s.T().Cleanup(func() {
|
||||
if s.tOidc.tokenBlacklist != nil {
|
||||
s.tOidc.tokenBlacklist.Close()
|
||||
}
|
||||
if s.tOidc.tokenCache != nil && s.tOidc.tokenCache.cache != nil {
|
||||
s.tOidc.tokenCache.cache.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentTokenValidation() {
|
||||
token, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 50
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentDifferentTokens() {
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
token, err := s.fixture.TokenWithCustomClaims(map[string]interface{}{
|
||||
"custom": idx,
|
||||
})
|
||||
if err != nil {
|
||||
errors <- err
|
||||
return
|
||||
}
|
||||
if err := s.tOidc.VerifyToken(token); err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
var errCount int
|
||||
for err := range errors {
|
||||
s.T().Logf("Concurrent different token error: %v", err)
|
||||
errCount++
|
||||
}
|
||||
|
||||
s.Equal(0, errCount, "All concurrent different token validations should succeed")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyEdgeCasesSuite) TestConcurrentMixedValidInvalid() {
|
||||
validToken, err := s.fixture.ValidToken(nil)
|
||||
s.Require().NoError(err)
|
||||
expiredToken, err := s.fixture.ExpiredToken()
|
||||
s.Require().NoError(err)
|
||||
|
||||
const goroutines = 40
|
||||
var wg sync.WaitGroup
|
||||
validCount := int32(0)
|
||||
expiredCount := int32(0)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
var token string
|
||||
if idx%2 == 0 {
|
||||
token = validToken
|
||||
} else {
|
||||
token = expiredToken
|
||||
}
|
||||
|
||||
err := s.tOidc.VerifyToken(token)
|
||||
if idx%2 == 0 {
|
||||
if err == nil {
|
||||
atomic.AddInt32(&validCount, 1)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
atomic.AddInt32(&expiredCount, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Valid passed: %d, Expired rejected: %d", validCount, expiredCount)
|
||||
}
|
||||
|
||||
func TestConcurrencyEdgeCasesSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyEdgeCasesSuite))
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// EnhancedMocksSuite demonstrates improved state-based mocks with call tracking
|
||||
type EnhancedMocksSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheCallTracking() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Make some calls
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.NoError(err)
|
||||
s.NotNil(result)
|
||||
|
||||
// Another call with different URL
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://other.com/jwks", nil)
|
||||
|
||||
// Verify calls were tracked
|
||||
s.Equal(2, mock.GetJWKSCallCount())
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
mock.AssertGetJWKSCalledWith(s.T(), "https://example.com/jwks")
|
||||
mock.AssertGetJWKSCallCount(s.T(), 2)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheWithError() {
|
||||
expectedErr := errors.New("network error")
|
||||
mock := &EnhancedMockJWKCache{
|
||||
Err: expectedErr,
|
||||
}
|
||||
|
||||
result, err := mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
|
||||
s.Nil(result)
|
||||
s.Equal(expectedErr, err)
|
||||
mock.AssertGetJWKSCalled(s.T())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedJWKCacheReset() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
s.Equal(1, mock.GetJWKSCallCount())
|
||||
|
||||
mock.Reset()
|
||||
|
||||
s.Equal(0, mock.GetJWKSCallCount())
|
||||
s.Nil(mock.JWKS)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierCallTracking() {
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
Err: nil, // Valid tokens
|
||||
}
|
||||
|
||||
// Verify a token
|
||||
err := mock.VerifyToken("test-token-1")
|
||||
s.NoError(err)
|
||||
|
||||
// Verify another token
|
||||
err = mock.VerifyToken("test-token-2")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
mock.AssertVerifyTokenCalled(s.T())
|
||||
mock.AssertVerifyTokenCalledWith(s.T(), "test-token-1")
|
||||
|
||||
// Check last call
|
||||
lastCall := mock.LastCall()
|
||||
s.NotNil(lastCall)
|
||||
s.Equal("test-token-2", lastCall.Token)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenVerifierWithDynamicFunc() {
|
||||
callCount := 0
|
||||
mock := &EnhancedMockTokenVerifier{
|
||||
VerifyFunc: func(token string) error {
|
||||
callCount++
|
||||
if token == "invalid" {
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
// Valid token
|
||||
err := mock.VerifyToken("valid-token")
|
||||
s.NoError(err)
|
||||
|
||||
// Invalid token
|
||||
err = mock.VerifyToken("invalid")
|
||||
s.Error(err)
|
||||
|
||||
s.Equal(2, callCount)
|
||||
s.Equal(2, mock.GetVerifyTokenCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerCallTracking() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeResponse: &TokenResponse{
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
RefreshResponse: &TokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Exchange code
|
||||
resp, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "auth-code", "https://redirect.com", "verifier")
|
||||
s.NoError(err)
|
||||
s.Equal("access-token", resp.AccessToken)
|
||||
|
||||
// Refresh token
|
||||
resp, err = mock.GetNewTokenWithRefreshToken("refresh-token")
|
||||
s.NoError(err)
|
||||
s.Equal("new-access-token", resp.AccessToken)
|
||||
|
||||
// Revoke token
|
||||
err = mock.RevokeTokenWithProvider("access-token", "access_token")
|
||||
s.NoError(err)
|
||||
|
||||
// Check tracking
|
||||
mock.AssertExchangeCalled(s.T())
|
||||
mock.AssertExchangeCalledWith(s.T(), "authorization_code")
|
||||
mock.AssertRefreshCalled(s.T())
|
||||
mock.AssertRevokeCalled(s.T())
|
||||
|
||||
s.Equal(1, mock.GetExchangeCallCount())
|
||||
s.Equal(1, mock.GetRefreshCallCount())
|
||||
s.Equal(1, mock.GetRevokeCallCount())
|
||||
|
||||
// Check last exchange call details
|
||||
lastExchange := mock.LastExchangeCall()
|
||||
s.NotNil(lastExchange)
|
||||
s.Equal("authorization_code", lastExchange.GrantType)
|
||||
s.Equal("auth-code", lastExchange.CodeOrToken)
|
||||
s.Equal("https://redirect.com", lastExchange.RedirectURL)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedTokenExchangerWithErrors() {
|
||||
mock := &EnhancedMockTokenExchanger{
|
||||
ExchangeErr: errors.New("invalid_grant"),
|
||||
RefreshErr: errors.New("refresh_expired"),
|
||||
RevokeErr: errors.New("revoke_failed"),
|
||||
}
|
||||
|
||||
_, err := mock.ExchangeCodeForToken(context.Background(), "authorization_code", "code", "", "")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "invalid_grant")
|
||||
|
||||
_, err = mock.GetNewTokenWithRefreshToken("token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "refresh_expired")
|
||||
|
||||
err = mock.RevokeTokenWithProvider("token", "access_token")
|
||||
s.Error(err)
|
||||
s.Contains(err.Error(), "revoke_failed")
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheCallTracking() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// Set some values
|
||||
mock.Set("key1", "value1", 5*time.Minute)
|
||||
mock.Set("key2", "value2", 10*time.Minute)
|
||||
|
||||
// Get values
|
||||
val, found := mock.Get("key1")
|
||||
s.True(found)
|
||||
s.Equal("value1", val)
|
||||
|
||||
_, found = mock.Get("nonexistent")
|
||||
s.False(found)
|
||||
|
||||
// Delete
|
||||
mock.Delete("key1")
|
||||
|
||||
// Verify tracking
|
||||
mock.AssertSetCalled(s.T(), "key1")
|
||||
mock.AssertSetCalled(s.T(), "key2")
|
||||
mock.AssertGetCalled(s.T(), "key1")
|
||||
mock.AssertGetCalled(s.T(), "nonexistent")
|
||||
mock.AssertDeleteCalled(s.T(), "key1")
|
||||
|
||||
s.Equal(2, mock.SetCallCount())
|
||||
s.Equal(2, mock.GetCallCount())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheActualStorage() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
// The enhanced mock actually stores data
|
||||
mock.Set("key", "value", time.Hour)
|
||||
s.Equal(1, mock.Size())
|
||||
|
||||
val, found := mock.Get("key")
|
||||
s.True(found)
|
||||
s.Equal("value", val)
|
||||
|
||||
mock.Delete("key")
|
||||
s.Equal(0, mock.Size())
|
||||
|
||||
_, found = mock.Get("key")
|
||||
s.False(found)
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestEnhancedCacheClear() {
|
||||
mock := NewEnhancedMockCache()
|
||||
|
||||
mock.Set("key1", "value1", time.Hour)
|
||||
mock.Set("key2", "value2", time.Hour)
|
||||
s.Equal(2, mock.Size())
|
||||
|
||||
mock.Clear()
|
||||
s.Equal(0, mock.Size())
|
||||
}
|
||||
|
||||
func (s *EnhancedMocksSuite) TestConcurrentAccess() {
|
||||
mock := &EnhancedMockJWKCache{
|
||||
JWKS: &JWKSet{Keys: []JWK{{Kid: "test-key"}}},
|
||||
}
|
||||
|
||||
// Concurrent calls should be safe
|
||||
done := make(chan bool)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func() {
|
||||
_, _ = mock.GetJWKS(context.Background(), "https://example.com/jwks", nil)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
s.Equal(10, mock.GetJWKSCallCount())
|
||||
}
|
||||
|
||||
func TestEnhancedMocksSuite(t *testing.T) {
|
||||
suite.Run(t, new(EnhancedMocksSuite))
|
||||
}
|
||||
@@ -0,0 +1,577 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// EnhancedMockJWKCache is an improved state-based mock with call tracking
|
||||
type EnhancedMockJWKCache struct {
|
||||
Err error
|
||||
JWKS *JWKSet
|
||||
GetJWKSCalls []JWKSCall
|
||||
mu sync.RWMutex
|
||||
getJWKSCallsMu sync.Mutex
|
||||
CleanupCalls int32
|
||||
CloseCalls int32
|
||||
}
|
||||
|
||||
// JWKSCall records parameters from a GetJWKS call
|
||||
type JWKSCall struct {
|
||||
Timestamp time.Time
|
||||
URL string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = append(m.GetJWKSCalls, JWKSCall{
|
||||
URL: jwksURL,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.JWKS, m.Err
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Cleanup() {
|
||||
atomic.AddInt32(&m.CleanupCalls, 1)
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
}
|
||||
|
||||
func (m *EnhancedMockJWKCache) Close() {
|
||||
atomic.AddInt32(&m.CloseCalls, 1)
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetJWKSCalled verifies GetJWKS was called
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalled(t assert.TestingT) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.GetJWKSCalls, "GetJWKS should have been called")
|
||||
}
|
||||
|
||||
// AssertGetJWKSCalledWith verifies GetJWKS was called with specific URL
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCalledWith(t assert.TestingT, expectedURL string) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
for _, call := range m.GetJWKSCalls {
|
||||
if call.URL == expectedURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "GetJWKS was not called with URL: "+expectedURL)
|
||||
}
|
||||
|
||||
// AssertGetJWKSCallCount verifies the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) AssertGetJWKSCallCount(t assert.TestingT, expected int) bool {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.GetJWKSCalls), "GetJWKS call count mismatch")
|
||||
}
|
||||
|
||||
// GetJWKSCallCount returns the number of GetJWKS calls
|
||||
func (m *EnhancedMockJWKCache) GetJWKSCallCount() int {
|
||||
m.getJWKSCallsMu.Lock()
|
||||
defer m.getJWKSCallsMu.Unlock()
|
||||
return len(m.GetJWKSCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockJWKCache) Reset() {
|
||||
m.mu.Lock()
|
||||
m.JWKS = nil
|
||||
m.Err = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getJWKSCallsMu.Lock()
|
||||
m.GetJWKSCalls = nil
|
||||
m.getJWKSCallsMu.Unlock()
|
||||
|
||||
atomic.StoreInt32(&m.CleanupCalls, 0)
|
||||
atomic.StoreInt32(&m.CloseCalls, 0)
|
||||
}
|
||||
|
||||
// EnhancedMockTokenVerifier is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenVerifier struct {
|
||||
Err error
|
||||
VerifyFunc func(token string) error
|
||||
VerifyCalls []TokenVerifyCall
|
||||
mu sync.RWMutex
|
||||
verifyCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// TokenVerifyCall records parameters from a VerifyToken call
|
||||
type TokenVerifyCall struct {
|
||||
Timestamp time.Time
|
||||
Result error
|
||||
Token string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenVerifier) VerifyToken(token string) error {
|
||||
var result error
|
||||
|
||||
m.mu.RLock()
|
||||
if m.VerifyFunc != nil {
|
||||
result = m.VerifyFunc(token)
|
||||
} else {
|
||||
result = m.Err
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = append(m.VerifyCalls, TokenVerifyCall{
|
||||
Token: token,
|
||||
Timestamp: time.Now(),
|
||||
Result: result,
|
||||
})
|
||||
m.verifyCallsMu.Unlock()
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertVerifyTokenCalled verifies VerifyToken was called
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalled(t assert.TestingT) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.VerifyCalls, "VerifyToken should have been called")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCalledWith verifies VerifyToken was called with specific token
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCalledWith(t assert.TestingT, expectedToken string) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
for _, call := range m.VerifyCalls {
|
||||
if call.Token == expectedToken {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "VerifyToken was not called with expected token")
|
||||
}
|
||||
|
||||
// AssertVerifyTokenCallCount verifies the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) AssertVerifyTokenCallCount(t assert.TestingT, expected int) bool {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return assert.Equal(t, expected, len(m.VerifyCalls), "VerifyToken call count mismatch")
|
||||
}
|
||||
|
||||
// GetVerifyTokenCallCount returns the number of VerifyToken calls
|
||||
func (m *EnhancedMockTokenVerifier) GetVerifyTokenCallCount() int {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
return len(m.VerifyCalls)
|
||||
}
|
||||
|
||||
// LastCall returns the most recent VerifyToken call
|
||||
func (m *EnhancedMockTokenVerifier) LastCall() *TokenVerifyCall {
|
||||
m.verifyCallsMu.Lock()
|
||||
defer m.verifyCallsMu.Unlock()
|
||||
if len(m.VerifyCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.VerifyCalls[len(m.VerifyCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenVerifier) Reset() {
|
||||
m.mu.Lock()
|
||||
m.Err = nil
|
||||
m.VerifyFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.verifyCallsMu.Lock()
|
||||
m.VerifyCalls = nil
|
||||
m.verifyCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockTokenExchanger is an improved state-based mock with call tracking
|
||||
type EnhancedMockTokenExchanger struct {
|
||||
RefreshErr error
|
||||
RevokeErr error
|
||||
ExchangeErr error
|
||||
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
||||
RefreshResponse *TokenResponse
|
||||
ExchangeResponse *TokenResponse
|
||||
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
||||
RevokeTokenFunc func(token, tokenType string) error
|
||||
ExchangeCalls []ExchangeCall
|
||||
RefreshCalls []RefreshCall
|
||||
RevokeCalls []RevokeCall
|
||||
mu sync.RWMutex
|
||||
exchangeCallsMu sync.Mutex
|
||||
refreshCallsMu sync.Mutex
|
||||
revokeCallsMu sync.Mutex
|
||||
}
|
||||
|
||||
// ExchangeCall records parameters from an ExchangeCodeForToken call
|
||||
type ExchangeCall struct {
|
||||
Timestamp time.Time
|
||||
GrantType string
|
||||
CodeOrToken string
|
||||
RedirectURL string
|
||||
CodeVerifier string
|
||||
}
|
||||
|
||||
// RefreshCall records parameters from a GetNewTokenWithRefreshToken call
|
||||
type RefreshCall struct {
|
||||
Timestamp time.Time
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// RevokeCall records parameters from a RevokeTokenWithProvider call
|
||||
type RevokeCall struct {
|
||||
Timestamp time.Time
|
||||
Token string
|
||||
TokenType string
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = append(m.ExchangeCalls, ExchangeCall{
|
||||
GrantType: grantType,
|
||||
CodeOrToken: codeOrToken,
|
||||
RedirectURL: redirectURL,
|
||||
CodeVerifier: codeVerifier,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.ExchangeCodeFunc != nil {
|
||||
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
||||
}
|
||||
return m.ExchangeResponse, m.ExchangeErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = append(m.RefreshCalls, RefreshCall{
|
||||
RefreshToken: refreshToken,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RefreshTokenFunc != nil {
|
||||
return m.RefreshTokenFunc(refreshToken)
|
||||
}
|
||||
return m.RefreshResponse, m.RefreshErr
|
||||
}
|
||||
|
||||
func (m *EnhancedMockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = append(m.RevokeCalls, RevokeCall{
|
||||
Token: token,
|
||||
TokenType: tokenType,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.revokeCallsMu.Unlock()
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.RevokeTokenFunc != nil {
|
||||
return m.RevokeTokenFunc(token, tokenType)
|
||||
}
|
||||
return m.RevokeErr
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertExchangeCalled verifies ExchangeCodeForToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalled(t assert.TestingT) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.ExchangeCalls, "ExchangeCodeForToken should have been called")
|
||||
}
|
||||
|
||||
// AssertExchangeCalledWith verifies ExchangeCodeForToken was called with specific grant type
|
||||
func (m *EnhancedMockTokenExchanger) AssertExchangeCalledWith(t assert.TestingT, grantType string) bool {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
for _, call := range m.ExchangeCalls {
|
||||
if call.GrantType == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "ExchangeCodeForToken was not called with grant type: "+grantType)
|
||||
}
|
||||
|
||||
// AssertRefreshCalled verifies GetNewTokenWithRefreshToken was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRefreshCalled(t assert.TestingT) bool {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RefreshCalls, "GetNewTokenWithRefreshToken should have been called")
|
||||
}
|
||||
|
||||
// AssertRevokeCalled verifies RevokeTokenWithProvider was called
|
||||
func (m *EnhancedMockTokenExchanger) AssertRevokeCalled(t assert.TestingT) bool {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return assert.NotEmpty(t, m.RevokeCalls, "RevokeTokenWithProvider should have been called")
|
||||
}
|
||||
|
||||
// GetExchangeCallCount returns the number of ExchangeCodeForToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetExchangeCallCount() int {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
return len(m.ExchangeCalls)
|
||||
}
|
||||
|
||||
// GetRefreshCallCount returns the number of GetNewTokenWithRefreshToken calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRefreshCallCount() int {
|
||||
m.refreshCallsMu.Lock()
|
||||
defer m.refreshCallsMu.Unlock()
|
||||
return len(m.RefreshCalls)
|
||||
}
|
||||
|
||||
// GetRevokeCallCount returns the number of RevokeTokenWithProvider calls
|
||||
func (m *EnhancedMockTokenExchanger) GetRevokeCallCount() int {
|
||||
m.revokeCallsMu.Lock()
|
||||
defer m.revokeCallsMu.Unlock()
|
||||
return len(m.RevokeCalls)
|
||||
}
|
||||
|
||||
// LastExchangeCall returns the most recent ExchangeCodeForToken call
|
||||
func (m *EnhancedMockTokenExchanger) LastExchangeCall() *ExchangeCall {
|
||||
m.exchangeCallsMu.Lock()
|
||||
defer m.exchangeCallsMu.Unlock()
|
||||
if len(m.ExchangeCalls) == 0 {
|
||||
return nil
|
||||
}
|
||||
return &m.ExchangeCalls[len(m.ExchangeCalls)-1]
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockTokenExchanger) Reset() {
|
||||
m.mu.Lock()
|
||||
m.ExchangeResponse = nil
|
||||
m.ExchangeErr = nil
|
||||
m.RefreshResponse = nil
|
||||
m.RefreshErr = nil
|
||||
m.RevokeErr = nil
|
||||
m.ExchangeCodeFunc = nil
|
||||
m.RefreshTokenFunc = nil
|
||||
m.RevokeTokenFunc = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
m.exchangeCallsMu.Lock()
|
||||
m.ExchangeCalls = nil
|
||||
m.exchangeCallsMu.Unlock()
|
||||
|
||||
m.refreshCallsMu.Lock()
|
||||
m.RefreshCalls = nil
|
||||
m.refreshCallsMu.Unlock()
|
||||
|
||||
m.revokeCallsMu.Lock()
|
||||
m.RevokeCalls = nil
|
||||
m.revokeCallsMu.Unlock()
|
||||
}
|
||||
|
||||
// EnhancedMockCacheInterface is an improved state-based mock for CacheInterface
|
||||
type EnhancedMockCacheInterface struct {
|
||||
data map[string]cacheEntry
|
||||
GetCalls []CacheGetCall
|
||||
SetCalls []CacheSetCall
|
||||
DeleteCalls []string
|
||||
maxSize int
|
||||
mu sync.RWMutex
|
||||
getCalls sync.Mutex
|
||||
setCalls sync.Mutex
|
||||
deleteCalls sync.Mutex
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
value any
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
// CacheGetCall records parameters from a Get call
|
||||
type CacheGetCall struct {
|
||||
Timestamp time.Time
|
||||
Key string
|
||||
Found bool
|
||||
}
|
||||
|
||||
// CacheSetCall records parameters from a Set call
|
||||
type CacheSetCall struct {
|
||||
Timestamp time.Time
|
||||
Value any
|
||||
Key string
|
||||
TTL time.Duration
|
||||
}
|
||||
|
||||
// NewEnhancedMockCache creates a new enhanced cache mock
|
||||
func NewEnhancedMockCache() *EnhancedMockCacheInterface {
|
||||
return &EnhancedMockCacheInterface{
|
||||
data: make(map[string]cacheEntry),
|
||||
maxSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Set(key string, value any, ttl time.Duration) {
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = append(m.SetCalls, CacheSetCall{
|
||||
Key: key,
|
||||
Value: value,
|
||||
TTL: ttl,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
m.data[key] = cacheEntry{value: value, ttl: ttl}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Get(key string) (any, bool) {
|
||||
m.mu.RLock()
|
||||
entry, found := m.data[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = append(m.GetCalls, CacheGetCall{
|
||||
Key: key,
|
||||
Found: found,
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
m.getCalls.Unlock()
|
||||
|
||||
if found {
|
||||
return entry.value, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Delete(key string) {
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = append(m.DeleteCalls, key)
|
||||
m.deleteCalls.Unlock()
|
||||
|
||||
m.mu.Lock()
|
||||
delete(m.data, key)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) SetMaxSize(size int) {
|
||||
m.mu.Lock()
|
||||
m.maxSize = size
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Size() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.data)
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Clear() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Cleanup() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) Close() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
func (m *EnhancedMockCacheInterface) GetStats() map[string]any {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return map[string]any{
|
||||
"size": len(m.data),
|
||||
"max_size": m.maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Assertion helpers
|
||||
|
||||
// AssertGetCalled verifies Get was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertGetCalled(t assert.TestingT, key string) bool {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
for _, call := range m.GetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Get was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertSetCalled verifies Set was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertSetCalled(t assert.TestingT, key string) bool {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
for _, call := range m.SetCalls {
|
||||
if call.Key == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Set was not called with key: "+key)
|
||||
}
|
||||
|
||||
// AssertDeleteCalled verifies Delete was called with specific key
|
||||
func (m *EnhancedMockCacheInterface) AssertDeleteCalled(t assert.TestingT, key string) bool {
|
||||
m.deleteCalls.Lock()
|
||||
defer m.deleteCalls.Unlock()
|
||||
for _, k := range m.DeleteCalls {
|
||||
if k == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return assert.Fail(t, "Delete was not called with key: "+key)
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of Get calls
|
||||
func (m *EnhancedMockCacheInterface) GetCallCount() int {
|
||||
m.getCalls.Lock()
|
||||
defer m.getCalls.Unlock()
|
||||
return len(m.GetCalls)
|
||||
}
|
||||
|
||||
// SetCallCount returns the number of Set calls
|
||||
func (m *EnhancedMockCacheInterface) SetCallCount() int {
|
||||
m.setCalls.Lock()
|
||||
defer m.setCalls.Unlock()
|
||||
return len(m.SetCalls)
|
||||
}
|
||||
|
||||
// Reset clears all state and call tracking
|
||||
func (m *EnhancedMockCacheInterface) Reset() {
|
||||
m.mu.Lock()
|
||||
m.data = make(map[string]cacheEntry)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.getCalls.Lock()
|
||||
m.GetCalls = nil
|
||||
m.getCalls.Unlock()
|
||||
|
||||
m.setCalls.Lock()
|
||||
m.SetCalls = nil
|
||||
m.setCalls.Unlock()
|
||||
|
||||
m.deleteCalls.Lock()
|
||||
m.DeleteCalls = nil
|
||||
m.deleteCalls.Unlock()
|
||||
}
|
||||
+146
-36
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -538,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
@@ -594,14 +642,10 @@ func (e *HTTPError) Error() string {
|
||||
// OIDCError represents OIDC-specific errors with context information.
|
||||
// It provides structured error reporting for authentication and authorization failures.
|
||||
type OIDCError struct {
|
||||
// Code identifies the specific error type
|
||||
Code string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Context contains additional error context (e.g., provider, session details)
|
||||
Cause error
|
||||
Context map[string]interface{}
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the OIDC error.
|
||||
@@ -621,14 +665,10 @@ func (e *OIDCError) Unwrap() error {
|
||||
// SessionError represents session-related errors with context.
|
||||
// Used for session management, validation, and storage errors.
|
||||
type SessionError struct {
|
||||
// Operation describes what session operation failed
|
||||
Cause error
|
||||
Operation string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// SessionID identifies the session (if available)
|
||||
Message string
|
||||
SessionID string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
}
|
||||
|
||||
// Error returns the string representation of the session error.
|
||||
@@ -648,14 +688,10 @@ func (e *SessionError) Unwrap() error {
|
||||
// TokenError represents token-related errors with validation context.
|
||||
// Used for JWT validation, token refresh, and token format errors.
|
||||
type TokenError struct {
|
||||
// TokenType identifies the type of token (id_token, access_token, refresh_token)
|
||||
Cause error
|
||||
TokenType string
|
||||
// Reason describes why the token is invalid
|
||||
Reason string
|
||||
// Message provides a human-readable description
|
||||
Message string
|
||||
// Cause is the underlying error that caused this error
|
||||
Cause error
|
||||
Reason string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error returns the string representation of the token error.
|
||||
@@ -717,24 +753,15 @@ func NewTokenError(tokenType, reason, message string, cause error) *TokenError {
|
||||
// It provides fallback mechanisms when primary services are unavailable and monitors
|
||||
// service health to automatically recover when services become available again.
|
||||
type GracefulDegradation struct {
|
||||
// BaseRecoveryMechanism provides common functionality
|
||||
*BaseRecoveryMechanism
|
||||
// fallbacks stores service-specific fallback implementations
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
// healthChecks stores service health check functions
|
||||
healthChecks map[string]func() bool
|
||||
// degradedServices tracks which services are currently degraded
|
||||
fallbacks map[string]func() (interface{}, error)
|
||||
healthChecks map[string]func() bool
|
||||
degradedServices map[string]time.Time
|
||||
// config contains graceful degradation configuration
|
||||
config GracefulDegradationConfig
|
||||
// mutex protects shared state
|
||||
mutex sync.RWMutex
|
||||
// healthCheckTask manages background health checking
|
||||
healthCheckTask *BackgroundTask
|
||||
// stopChan signals shutdown
|
||||
stopChan chan struct{}
|
||||
// shutdownOnce ensures shutdown happens only once
|
||||
shutdownOnce sync.Once
|
||||
healthCheckTask *BackgroundTask
|
||||
stopChan chan struct{}
|
||||
config GracefulDegradationConfig
|
||||
mutex sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
}
|
||||
|
||||
// GracefulDegradationConfig holds configuration for graceful degradation behavior.
|
||||
@@ -1087,3 +1114,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestDefaultCircuitBreakerConfig tests the default configuration function
|
||||
func TestDefaultCircuitBreakerConfig(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
|
||||
// Test default values
|
||||
if config.MaxFailures != 2 {
|
||||
t.Errorf("Expected MaxFailures 2, got %d", config.MaxFailures)
|
||||
}
|
||||
|
||||
if config.Timeout != 60*time.Second {
|
||||
t.Errorf("Expected Timeout 60s, got %v", config.Timeout)
|
||||
}
|
||||
|
||||
if config.ResetTimeout != 30*time.Second {
|
||||
t.Errorf("Expected ResetTimeout 30s, got %v", config.ResetTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_GetBaseMetrics tests getting base metrics
|
||||
func TestBaseRecoveryMechanism_GetBaseMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
metrics := base.GetBaseMetrics()
|
||||
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Check expected metric fields
|
||||
expectedFields := []string{
|
||||
"total_requests",
|
||||
"total_failures",
|
||||
"total_successes",
|
||||
"uptime_seconds",
|
||||
"name",
|
||||
}
|
||||
|
||||
for _, field := range expectedFields {
|
||||
if _, exists := metrics[field]; !exists {
|
||||
t.Errorf("Expected metric field %s to exist", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordRequest tests request recording
|
||||
func TestBaseRecoveryMechanism_RecordRequest(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some requests
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
base.RecordRequest()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalRequests := metrics["total_requests"].(int64)
|
||||
|
||||
if totalRequests != 3 {
|
||||
t.Errorf("Expected 3 total requests, got %d", totalRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordSuccess tests success recording
|
||||
func TestBaseRecoveryMechanism_RecordSuccess(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some successes
|
||||
base.RecordSuccess()
|
||||
base.RecordSuccess()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalSuccesses := metrics["total_successes"].(int64)
|
||||
|
||||
if totalSuccesses != 2 {
|
||||
t.Errorf("Expected 2 successful requests, got %d", totalSuccesses)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_RecordFailure tests failure recording
|
||||
func TestBaseRecoveryMechanism_RecordFailure(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Record some failures
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
base.RecordFailure()
|
||||
|
||||
// Get metrics to verify
|
||||
metrics := base.GetBaseMetrics()
|
||||
totalFailures := metrics["total_failures"].(int64)
|
||||
|
||||
if totalFailures != 3 {
|
||||
t.Errorf("Expected 3 failed requests, got %d", totalFailures)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogInfo tests info logging
|
||||
func TestBaseRecoveryMechanism_LogInfo(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogInfo("test message")
|
||||
base.LogInfo("test message with args: %s %d", "arg1", 42)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogInfo("test message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogError tests error logging
|
||||
func TestBaseRecoveryMechanism_LogError(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogError("error message")
|
||||
base.LogError("error message with args: %s %d", "error", 500)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogError("error message") // Should not panic
|
||||
}
|
||||
|
||||
// TestBaseRecoveryMechanism_LogDebug tests debug logging
|
||||
func TestBaseRecoveryMechanism_LogDebug(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
// Test logging doesn't panic
|
||||
base.LogDebug("debug message")
|
||||
base.LogDebug("debug message with args: %s %d", "debug", 123)
|
||||
|
||||
// Test with nil logger
|
||||
baseNoLogger := NewBaseRecoveryMechanism("test", nil)
|
||||
baseNoLogger.LogDebug("debug message") // Should not panic
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetState tests getting circuit breaker state
|
||||
func TestCircuitBreaker_GetState(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initial state should be closed
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected initial state to be closed, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests resetting circuit breaker
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Reset should not panic
|
||||
cb.Reset()
|
||||
|
||||
// State should be closed after reset
|
||||
state := cb.GetState()
|
||||
if state != CircuitBreakerClosed {
|
||||
t.Errorf("Expected state to be closed after reset, got %d", state)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsAvailable tests availability check
|
||||
func TestCircuitBreaker_IsAvailable(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Initially should be available
|
||||
available := cb.IsAvailable()
|
||||
if !available {
|
||||
t.Error("Expected circuit breaker to be available initially")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_GetMetrics tests getting circuit breaker metrics
|
||||
func TestCircuitBreaker_GetMetrics(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
logger := GetSingletonNoOpLogger()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
metrics := cb.GetMetrics()
|
||||
if metrics == nil {
|
||||
t.Fatal("Expected non-nil metrics")
|
||||
}
|
||||
|
||||
// Should include base metrics
|
||||
if _, exists := metrics["total_requests"]; !exists {
|
||||
t.Error("Expected total_requests in metrics")
|
||||
}
|
||||
|
||||
// Should include circuit breaker specific metrics
|
||||
if _, exists := metrics["state"]; !exists {
|
||||
t.Error("Expected state in metrics")
|
||||
}
|
||||
}
|
||||
|
||||
// Retry mechanism tests removed due to complex dependencies
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -1,560 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRetryExecutorReset tests the Reset method
|
||||
func TestRetryExecutorReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
require.NotNil(t, executor)
|
||||
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
executor.Reset()
|
||||
})
|
||||
|
||||
// Multiple resets should be safe
|
||||
executor.Reset()
|
||||
executor.Reset()
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsAvailable tests the IsAvailable method
|
||||
func TestRetryExecutorIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
executor := NewRetryExecutor(DefaultRetryConfig(), logger)
|
||||
|
||||
// Retry executor should always be available
|
||||
assert.True(t, executor.IsAvailable())
|
||||
|
||||
// Should remain available after operations
|
||||
ctx := context.Background()
|
||||
executor.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.True(t, executor.IsAvailable())
|
||||
}
|
||||
|
||||
// TestSessionErrorUnwrap tests SessionError.Unwrap
|
||||
func TestSessionErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("root cause")
|
||||
sessionErr := NewSessionError("save", "failed to save session", rootErr)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
sessionErr := NewSessionError("load", "failed to load session", nil)
|
||||
|
||||
unwrapped := sessionErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("database error")
|
||||
sessionErr := NewSessionError("delete", "failed to delete session", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(sessionErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestTokenErrorUnwrap tests TokenError.Unwrap
|
||||
func TestTokenErrorUnwrap(t *testing.T) {
|
||||
t.Run("unwrap with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature verification failed")
|
||||
tokenErr := NewTokenError("id_token", "invalid", "token is invalid", rootErr)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Equal(t, rootErr, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("unwrap without cause", func(t *testing.T) {
|
||||
tokenErr := NewTokenError("access_token", "expired", "token has expired", nil)
|
||||
|
||||
unwrapped := tokenErr.Unwrap()
|
||||
assert.Nil(t, unwrapped)
|
||||
})
|
||||
|
||||
t.Run("error chain", func(t *testing.T) {
|
||||
rootErr := errors.New("crypto error")
|
||||
tokenErr := NewTokenError("refresh_token", "malformed", "token is malformed", rootErr)
|
||||
|
||||
// Verify error chain works
|
||||
assert.True(t, errors.Is(tokenErr, rootErr))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterFallback tests fallback registration
|
||||
func TestGracefulDegradationRegisterFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register single fallback", func(t *testing.T) {
|
||||
fallback := func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
}
|
||||
|
||||
gd.RegisterFallback("service1", fallback)
|
||||
|
||||
// Verify fallback was registered (indirectly)
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return nil, errors.New("service failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("register multiple fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback2", nil
|
||||
})
|
||||
gd.RegisterFallback("service3", func() (interface{}, error) {
|
||||
return "fallback3", nil
|
||||
})
|
||||
|
||||
result2, _ := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
result3, _ := gd.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "fallback2", result2)
|
||||
assert.Equal(t, "fallback3", result3)
|
||||
})
|
||||
|
||||
t.Run("override existing fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "old fallback", nil
|
||||
})
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return "new fallback", nil
|
||||
})
|
||||
|
||||
result, _ := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fail")
|
||||
})
|
||||
|
||||
assert.Equal(t, "new fallback", result)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationRegisterHealthCheck tests health check registration
|
||||
func TestGracefulDegradationRegisterHealthCheck(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("register health check", func(t *testing.T) {
|
||||
healthy := true
|
||||
healthCheck := func() bool {
|
||||
return healthy
|
||||
}
|
||||
|
||||
gd.RegisterHealthCheck("service1", healthCheck)
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
|
||||
// Set healthy and wait for health check to run
|
||||
healthy = true
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (may still be degraded due to timing, but health check was registered)
|
||||
})
|
||||
|
||||
t.Run("multiple health checks", func(t *testing.T) {
|
||||
gd.RegisterHealthCheck("service2", func() bool { return true })
|
||||
gd.RegisterHealthCheck("service3", func() bool { return false })
|
||||
|
||||
// Health checks are registered and will be called periodically
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithContext tests ExecuteWithContext
|
||||
func TestGracefulDegradationExecuteWithContext(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("successful execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("failed execution", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
testErr := errors.New("operation failed")
|
||||
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return testErr
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("uses fallback on failure", func(t *testing.T) {
|
||||
gd.RegisterFallback("default", func() (interface{}, error) {
|
||||
return nil, nil // Success fallback
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := gd.ExecuteWithContext(ctx, func() error {
|
||||
return errors.New("primary failed")
|
||||
})
|
||||
|
||||
// With fallback succeeding, overall operation succeeds
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteWithFallback tests ExecuteWithFallback
|
||||
func TestGracefulDegradationExecuteWithFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("primary succeeds", func(t *testing.T) {
|
||||
result, err := gd.ExecuteWithFallback("service1", func() (interface{}, error) {
|
||||
return "primary result", nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "primary result", result)
|
||||
})
|
||||
|
||||
t.Run("fallback succeeds when primary fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return "fallback result", nil
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback result", result)
|
||||
})
|
||||
|
||||
t.Run("error when no fallback available", func(t *testing.T) {
|
||||
config.EnableFallbacks = false
|
||||
gdNoFallback := NewGracefulDegradation(config, logger)
|
||||
defer gdNoFallback.Close()
|
||||
|
||||
result, err := gdNoFallback.ExecuteWithFallback("service3", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("fallback also fails", func(t *testing.T) {
|
||||
gd.RegisterFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback also failed")
|
||||
})
|
||||
|
||||
result, err := gd.ExecuteWithFallback("service4", func() (interface{}, error) {
|
||||
return nil, errors.New("primary failed")
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback also failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsServiceDegraded tests service degradation status
|
||||
func TestGracefulDegradationIsServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 100 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("service not degraded initially", func(t *testing.T) {
|
||||
assert.False(t, gd.isServiceDegraded("new-service"))
|
||||
})
|
||||
|
||||
t.Run("service degraded after marking", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.isServiceDegraded("service1"))
|
||||
})
|
||||
|
||||
t.Run("service recovers after timeout", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
assert.True(t, gd.isServiceDegraded("service2"))
|
||||
|
||||
// Wait for recovery timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("service2"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationMarkServiceDegraded tests marking services as degraded
|
||||
func TestGracefulDegradationMarkServiceDegraded(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("mark single service", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service1")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service1")
|
||||
})
|
||||
|
||||
t.Run("mark multiple services", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
degraded := gd.GetDegradedServices()
|
||||
assert.Contains(t, degraded, "service2")
|
||||
assert.Contains(t, degraded, "service3")
|
||||
})
|
||||
|
||||
t.Run("marking same service multiple times updates timestamp", func(t *testing.T) {
|
||||
gd.markServiceDegraded("service4")
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
// Service should still be marked as degraded
|
||||
assert.True(t, gd.isServiceDegraded("service4"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationExecuteFallback tests fallback execution
|
||||
func TestGracefulDegradationExecuteFallback(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("execute registered fallback", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) {
|
||||
return "fallback value", nil
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "fallback value", result)
|
||||
})
|
||||
|
||||
t.Run("error when fallback not registered", func(t *testing.T) {
|
||||
result, err := gd.executeFallback("non-existent-service")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "no fallback available")
|
||||
})
|
||||
|
||||
t.Run("propagate fallback errors", func(t *testing.T) {
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) {
|
||||
return nil, errors.New("fallback error")
|
||||
})
|
||||
|
||||
result, err := gd.executeFallback("service2")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "fallback error")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationReset tests Reset method
|
||||
func TestGracefulDegradationReset(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("reset clears degraded services", func(t *testing.T) {
|
||||
// Mark several services as degraded
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
gd.markServiceDegraded("service3")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 3)
|
||||
|
||||
// Reset
|
||||
gd.Reset()
|
||||
|
||||
// All should be cleared
|
||||
assert.Len(t, gd.GetDegradedServices(), 0)
|
||||
})
|
||||
|
||||
t.Run("can mark services degraded after reset", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service4")
|
||||
|
||||
assert.Len(t, gd.GetDegradedServices(), 1)
|
||||
assert.Contains(t, gd.GetDegradedServices(), "service4")
|
||||
})
|
||||
|
||||
t.Run("multiple resets are safe", func(t *testing.T) {
|
||||
assert.NotPanics(t, func() {
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
gd.Reset()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationIsAvailable tests IsAvailable method
|
||||
func TestGracefulDegradationIsAvailable(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Should always return true
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even with degraded services
|
||||
gd.markServiceDegraded("service1")
|
||||
assert.True(t, gd.IsAvailable())
|
||||
|
||||
// Even after reset
|
||||
gd.Reset()
|
||||
assert.True(t, gd.IsAvailable())
|
||||
}
|
||||
|
||||
// TestGracefulDegradationGetMetrics tests GetMetrics method
|
||||
func TestGracefulDegradationGetMetrics(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
t.Run("basic metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
require.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "degraded_services_count")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
assert.Contains(t, metrics, "registered_fallbacks_count")
|
||||
assert.Contains(t, metrics, "registered_health_checks_count")
|
||||
assert.Contains(t, metrics, "health_check_interval_seconds")
|
||||
assert.Contains(t, metrics, "recovery_timeout_seconds")
|
||||
assert.Contains(t, metrics, "fallbacks_enabled")
|
||||
})
|
||||
|
||||
t.Run("metrics reflect degraded services", func(t *testing.T) {
|
||||
gd.Reset()
|
||||
gd.markServiceDegraded("service1")
|
||||
gd.markServiceDegraded("service2")
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.Equal(t, 2, metrics["degraded_services_count"])
|
||||
degradedList := metrics["degraded_services"].([]string)
|
||||
assert.Len(t, degradedList, 2)
|
||||
})
|
||||
|
||||
t.Run("metrics reflect registered fallbacks", func(t *testing.T) {
|
||||
gd.RegisterFallback("service1", func() (interface{}, error) { return nil, nil })
|
||||
gd.RegisterFallback("service2", func() (interface{}, error) { return nil, nil })
|
||||
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
assert.GreaterOrEqual(t, metrics["registered_fallbacks_count"], 2)
|
||||
})
|
||||
|
||||
t.Run("metrics include base metrics", func(t *testing.T) {
|
||||
metrics := gd.GetMetrics()
|
||||
|
||||
// Should include base recovery mechanism metrics
|
||||
assert.Contains(t, metrics, "name")
|
||||
assert.Contains(t, metrics, "uptime_seconds")
|
||||
assert.Contains(t, metrics, "total_requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationFullScenario tests a complete degradation scenario
|
||||
func TestGracefulDegradationFullScenario(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping full scenario test in short mode")
|
||||
}
|
||||
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
config.RecoveryTimeout = 200 * time.Millisecond
|
||||
config.HealthCheckInterval = 50 * time.Millisecond
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register fallback
|
||||
gd.RegisterFallback("critical-service", func() (interface{}, error) {
|
||||
return "fallback data", nil
|
||||
})
|
||||
|
||||
// Register health check
|
||||
serviceHealthy := false
|
||||
gd.RegisterHealthCheck("critical-service", func() bool {
|
||||
return serviceHealthy
|
||||
})
|
||||
|
||||
// First call - primary succeeds
|
||||
result1, err1 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "primary data", nil
|
||||
})
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, "primary data", result1)
|
||||
|
||||
// Second call - primary fails, fallback succeeds
|
||||
result2, err2 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return nil, errors.New("service down")
|
||||
})
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, "fallback data", result2)
|
||||
|
||||
// Service is now degraded
|
||||
assert.True(t, gd.isServiceDegraded("critical-service"))
|
||||
|
||||
// Third call - should use fallback immediately
|
||||
result3, err3 := gd.ExecuteWithFallback("critical-service", func() (interface{}, error) {
|
||||
return "should not be called", nil
|
||||
})
|
||||
assert.NoError(t, err3)
|
||||
assert.Equal(t, "fallback data", result3)
|
||||
|
||||
// Mark service as healthy and wait for health check
|
||||
serviceHealthy = true
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Service should be recovered
|
||||
// (timing-dependent, so we don't assert)
|
||||
|
||||
// Get metrics
|
||||
metrics := gd.GetMetrics()
|
||||
assert.NotNil(t, metrics)
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package traefikoidc
|
||||
|
||||
import "testing"
|
||||
|
||||
func BenchmarkDefaultCircuitBreakerConfig(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
DefaultCircuitBreakerConfig()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_GetBaseMetrics(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.GetBaseMetrics()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseRecoveryMechanism_RecordRequest(b *testing.B) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
base := NewBaseRecoveryMechanism("test-mechanism", logger)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
base.RecordRequest()
|
||||
}
|
||||
}
|
||||
@@ -1,663 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreakerAllowRequestEdgeCases tests edge cases in circuit breaker request allowing
|
||||
func TestCircuitBreakerAllowRequestEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("invalid state returns false", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Force invalid state
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerState(999) // Invalid state
|
||||
cb.mutex.Unlock()
|
||||
|
||||
// Should return false for invalid state
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "invalid state should not allow requests")
|
||||
})
|
||||
|
||||
t.Run("open to half-open transition on timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: baseTimeout,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Verify circuit is open
|
||||
assert.Equal(t, CircuitBreakerOpen, cb.GetState())
|
||||
assert.False(t, cb.allowRequest())
|
||||
|
||||
// Wait for timeout (longer than timeout to ensure transition)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should transition to half-open
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "should allow request after timeout")
|
||||
assert.Equal(t, CircuitBreakerHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("half-open allows requests", func(t *testing.T) {
|
||||
config := DefaultCircuitBreakerConfig()
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Manually set to half-open
|
||||
cb.mutex.Lock()
|
||||
cb.state = CircuitBreakerHalfOpen
|
||||
cb.mutex.Unlock()
|
||||
|
||||
allowed := cb.allowRequest()
|
||||
assert.True(t, allowed, "half-open should allow requests")
|
||||
})
|
||||
|
||||
t.Run("open blocks requests before timeout", func(t *testing.T) {
|
||||
config := CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: 1 * time.Hour, // Long timeout
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}
|
||||
cb := NewCircuitBreaker(config, logger)
|
||||
|
||||
// Trip the circuit
|
||||
cb.Execute(func() error { return errors.New("fail") })
|
||||
|
||||
// Should be blocked
|
||||
allowed := cb.allowRequest()
|
||||
assert.False(t, allowed, "open circuit should block requests")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorIsRetryableErrorEdgeCases tests edge cases for retry decision
|
||||
func TestRetryExecutorIsRetryableErrorEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
config := DefaultRetryConfig()
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
t.Run("nil error is not retryable", func(t *testing.T) {
|
||||
retryable := re.isRetryableError(nil)
|
||||
assert.False(t, retryable)
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 429 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 429,
|
||||
Message: "Too Many Requests",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "429 Too Many Requests should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 500 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 500,
|
||||
Message: "Internal Server Error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "500 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 503 is retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 503,
|
||||
Message: "Service Unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.True(t, retryable, "503 errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("HTTPError with 400 is not retryable", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 400,
|
||||
Message: "Bad Request",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(httpErr)
|
||||
assert.False(t, retryable, "400 errors should not be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with timeout is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: true,
|
||||
temporary: false,
|
||||
msg: "timeout error",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "timeout errors should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection refused is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection refused",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection refused should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with connection reset is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "connection reset by peer",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "connection reset should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with network unreachable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "network is unreachable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "network unreachable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with no route to host is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "no route to host",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "no route to host should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with temporary failure is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "temporary failure in name resolution",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "temporary failure should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with try again is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "try again later",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "try again should be retryable")
|
||||
})
|
||||
|
||||
t.Run("net.Error with resource temporarily unavailable is retryable", func(t *testing.T) {
|
||||
netErr := &mockNetError{
|
||||
timeout: false,
|
||||
temporary: false,
|
||||
msg: "resource temporarily unavailable",
|
||||
}
|
||||
|
||||
retryable := re.isRetryableError(netErr)
|
||||
assert.True(t, retryable, "resource temporarily unavailable should be retryable")
|
||||
})
|
||||
|
||||
t.Run("configured retryable error patterns", func(t *testing.T) {
|
||||
err := errors.New("connection refused by server")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.True(t, retryable, "configured pattern should be retryable")
|
||||
})
|
||||
|
||||
t.Run("non-retryable error", func(t *testing.T) {
|
||||
err := errors.New("invalid input data")
|
||||
|
||||
retryable := re.isRetryableError(err)
|
||||
assert.False(t, retryable, "non-configured error should not be retryable")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRetryExecutorCalculateDelayEdgeCases tests delay calculation edge cases
|
||||
func TestRetryExecutorCalculateDelayEdgeCases(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("delay calculation without jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false, // Jitter disabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 1: 100ms * 2^0 = 100ms
|
||||
delay1 := re.calculateDelay(1)
|
||||
assert.Equal(t, 100*time.Millisecond, delay1)
|
||||
|
||||
// Attempt 2: 100ms * 2^1 = 200ms
|
||||
delay2 := re.calculateDelay(2)
|
||||
assert.Equal(t, 200*time.Millisecond, delay2)
|
||||
|
||||
// Attempt 3: 100ms * 2^2 = 400ms
|
||||
delay3 := re.calculateDelay(3)
|
||||
assert.Equal(t, 400*time.Millisecond, delay3)
|
||||
})
|
||||
|
||||
t.Run("delay calculation with jitter", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: true, // Jitter enabled
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// With jitter, delay should be within 10% of expected
|
||||
delay := re.calculateDelay(2)
|
||||
expectedBase := 200 * time.Millisecond
|
||||
minDelay := time.Duration(float64(expectedBase) * 0.9)
|
||||
maxDelay := time.Duration(float64(expectedBase) * 1.1)
|
||||
|
||||
assert.GreaterOrEqual(t, delay, minDelay, "delay should be >= 90% of base")
|
||||
assert.LessOrEqual(t, delay, maxDelay, "delay should be <= 110% of base")
|
||||
})
|
||||
|
||||
t.Run("delay capped at max delay", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 10,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 500 * time.Millisecond, // Low max delay
|
||||
BackoffFactor: 2.0,
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 10: would be 100ms * 2^9 = 51200ms, but capped at 500ms
|
||||
delay := re.calculateDelay(10)
|
||||
assert.Equal(t, 500*time.Millisecond, delay, "delay should be capped at max")
|
||||
})
|
||||
|
||||
t.Run("delay with large backoff factor", func(t *testing.T) {
|
||||
config := RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffFactor: 3.0, // Larger backoff
|
||||
EnableJitter: false,
|
||||
}
|
||||
re := NewRetryExecutor(config, logger)
|
||||
|
||||
// Attempt 3: 50ms * 3^2 = 450ms
|
||||
delay := re.calculateDelay(3)
|
||||
assert.Equal(t, 450*time.Millisecond, delay)
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorTypesErrorMethodsWithoutCause tests error type Error() methods without cause
|
||||
func TestErrorTypesErrorMethodsWithoutCause(t *testing.T) {
|
||||
t.Run("HTTPError.Error without cause", func(t *testing.T) {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: 404,
|
||||
Message: "Not Found",
|
||||
}
|
||||
|
||||
errStr := httpErr.Error()
|
||||
assert.Equal(t, "HTTP 404: Not Found", errStr)
|
||||
})
|
||||
|
||||
t.Run("HTTPError.Error with different status codes", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
code int
|
||||
message string
|
||||
expected string
|
||||
}{
|
||||
{200, "OK", "HTTP 200: OK"},
|
||||
{301, "Moved", "HTTP 301: Moved"},
|
||||
{401, "Unauthorized", "HTTP 401: Unauthorized"},
|
||||
{500, "Server Error", "HTTP 500: Server Error"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
httpErr := &HTTPError{
|
||||
StatusCode: tc.code,
|
||||
Message: tc.message,
|
||||
}
|
||||
assert.Equal(t, tc.expected, httpErr.Error())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error without cause", func(t *testing.T) {
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_token",
|
||||
Message: "Token validation failed",
|
||||
Context: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Equal(t, "OIDC error [invalid_token]: Token validation failed", errStr)
|
||||
})
|
||||
|
||||
t.Run("OIDCError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("signature mismatch")
|
||||
oidcErr := &OIDCError{
|
||||
Code: "invalid_signature",
|
||||
Message: "JWT signature invalid",
|
||||
Context: make(map[string]interface{}),
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := oidcErr.Error()
|
||||
assert.Contains(t, errStr, "OIDC error [invalid_signature]: JWT signature invalid")
|
||||
assert.Contains(t, errStr, "caused by: signature mismatch")
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error without cause", func(t *testing.T) {
|
||||
sessErr := &SessionError{
|
||||
Operation: "load",
|
||||
Message: "Session not found",
|
||||
SessionID: "sess123",
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Equal(t, "Session error in load: Session not found", errStr)
|
||||
})
|
||||
|
||||
t.Run("SessionError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("database connection failed")
|
||||
sessErr := &SessionError{
|
||||
Operation: "save",
|
||||
Message: "Failed to persist session",
|
||||
SessionID: "sess456",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := sessErr.Error()
|
||||
assert.Contains(t, errStr, "Session error in save: Failed to persist session")
|
||||
assert.Contains(t, errStr, "caused by: database connection failed")
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error without cause", func(t *testing.T) {
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "access_token",
|
||||
Reason: "expired",
|
||||
Message: "Token has expired",
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Equal(t, "Token error (access_token) - expired: Token has expired", errStr)
|
||||
})
|
||||
|
||||
t.Run("TokenError.Error with cause", func(t *testing.T) {
|
||||
rootErr := errors.New("time check failed")
|
||||
tokenErr := &TokenError{
|
||||
TokenType: "id_token",
|
||||
Reason: "expired",
|
||||
Message: "Token validity period exceeded",
|
||||
Cause: rootErr,
|
||||
}
|
||||
|
||||
errStr := tokenErr.Error()
|
||||
assert.Contains(t, errStr, "Token error (id_token) - expired: Token validity period exceeded")
|
||||
assert.Contains(t, errStr, "caused by: time check failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationHealthChecks tests health check functionality
|
||||
func TestGracefulDegradationHealthChecks(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("performHealthChecks recovers degraded service", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns true
|
||||
healthCheckCalled := false
|
||||
gd.RegisterHealthCheck("test-service", func() bool {
|
||||
healthCheckCalled = true
|
||||
return true // Service is healthy
|
||||
})
|
||||
|
||||
// Mark service as degraded
|
||||
gd.markServiceDegraded("test-service")
|
||||
|
||||
// Verify service is degraded
|
||||
assert.True(t, gd.isServiceDegraded("test-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Health check should have been called
|
||||
assert.True(t, healthCheckCalled, "health check should be called")
|
||||
|
||||
// Service should be recovered
|
||||
assert.False(t, gd.isServiceDegraded("test-service"), "service should be recovered")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks marks service degraded on failure", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Register health check that returns false
|
||||
gd.RegisterHealthCheck("failing-service", func() bool {
|
||||
return false // Service is unhealthy
|
||||
})
|
||||
|
||||
// Initially not degraded
|
||||
assert.False(t, gd.isServiceDegraded("failing-service"))
|
||||
|
||||
// Manually trigger health check
|
||||
gd.performHealthChecks()
|
||||
|
||||
// Service should be marked degraded
|
||||
assert.True(t, gd.isServiceDegraded("failing-service"), "service should be degraded")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks runs multiple health checks independently", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
service1Checked := false
|
||||
service2Checked := false
|
||||
|
||||
gd.RegisterHealthCheck("service1", func() bool {
|
||||
service1Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
gd.RegisterHealthCheck("service2", func() bool {
|
||||
service2Checked = true
|
||||
return true
|
||||
})
|
||||
|
||||
// Manually trigger health checks
|
||||
gd.performHealthChecks()
|
||||
|
||||
assert.True(t, service1Checked, "service1 health check should run")
|
||||
assert.True(t, service2Checked, "service2 health check should run")
|
||||
})
|
||||
|
||||
t.Run("performHealthChecks handles empty health checks", func(t *testing.T) {
|
||||
config := DefaultGracefulDegradationConfig()
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Call performHealthChecks with no registered health checks
|
||||
// Should not panic
|
||||
assert.NotPanics(t, func() {
|
||||
gd.performHealthChecks()
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// TestGracefulDegradationServiceRecoveryTimeout tests recovery timeout behavior
|
||||
func TestGracefulDegradationServiceRecoveryTimeout(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
|
||||
t.Run("service auto-recovers after timeout", func(t *testing.T) {
|
||||
baseTimeout := GetTestDuration(50 * time.Millisecond)
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour, // Long interval, won't run during test
|
||||
RecoveryTimeout: baseTimeout,
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("auto-recover-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("auto-recover-service"))
|
||||
|
||||
// Wait for recovery timeout (longer than timeout to ensure recovery)
|
||||
time.Sleep(baseTimeout + GetTestDuration(20*time.Millisecond))
|
||||
|
||||
// Should auto-recover
|
||||
assert.False(t, gd.isServiceDegraded("auto-recover-service"), "service should auto-recover after timeout")
|
||||
})
|
||||
|
||||
t.Run("service remains degraded before timeout", func(t *testing.T) {
|
||||
config := GracefulDegradationConfig{
|
||||
HealthCheckInterval: 1 * time.Hour,
|
||||
RecoveryTimeout: 1 * time.Hour, // Very long timeout
|
||||
EnableFallbacks: true,
|
||||
}
|
||||
gd := NewGracefulDegradation(config, logger)
|
||||
defer gd.Close()
|
||||
|
||||
// Mark service degraded
|
||||
gd.markServiceDegraded("long-timeout-service")
|
||||
|
||||
// Verify degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"))
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(GetTestDuration(10 * time.Millisecond))
|
||||
|
||||
// Should still be degraded
|
||||
assert.True(t, gd.isServiceDegraded("long-timeout-service"), "service should remain degraded before timeout")
|
||||
})
|
||||
}
|
||||
|
||||
// TestErrorRecoveryManagerIntegration tests full integration of error recovery mechanisms
|
||||
func TestErrorRecoveryManagerIntegration(t *testing.T) {
|
||||
logger := GetSingletonNoOpLogger()
|
||||
erm := NewErrorRecoveryManager(logger)
|
||||
|
||||
t.Run("circuit breaker and retry integration", func(t *testing.T) {
|
||||
// Create a circuit breaker with higher max failures to allow retries
|
||||
cb := NewCircuitBreaker(CircuitBreakerConfig{
|
||||
MaxFailures: 10, // High threshold
|
||||
Timeout: 60 * time.Second,
|
||||
ResetTimeout: 30 * time.Second,
|
||||
}, logger)
|
||||
|
||||
erm.mutex.Lock()
|
||||
erm.circuitBreakers["test-service-integration"] = cb
|
||||
erm.mutex.Unlock()
|
||||
|
||||
attempts := 0
|
||||
fn := func() error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := erm.ExecuteWithRecovery(context.Background(), "test-service-integration", fn)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, attempts, 3, "should retry until success")
|
||||
})
|
||||
|
||||
t.Run("circuit breaker opens on repeated failures", func(t *testing.T) {
|
||||
fn := func() error {
|
||||
return errors.New("persistent failure")
|
||||
}
|
||||
|
||||
// First call - should fail after retries
|
||||
err1 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err1)
|
||||
|
||||
// Second call - should fail after retries
|
||||
err2 := erm.ExecuteWithRecovery(context.Background(), "failing-service", fn)
|
||||
assert.Error(t, err2)
|
||||
|
||||
// Check circuit breaker state
|
||||
cb := erm.GetCircuitBreaker("failing-service")
|
||||
state := cb.GetState()
|
||||
assert.Equal(t, CircuitBreakerOpen, state, "circuit should be open after repeated failures")
|
||||
})
|
||||
|
||||
t.Run("recovery metrics include all mechanisms", func(t *testing.T) {
|
||||
metrics := erm.GetRecoveryMetrics()
|
||||
|
||||
assert.NotNil(t, metrics)
|
||||
assert.Contains(t, metrics, "circuit_breakers")
|
||||
assert.Contains(t, metrics, "degraded_services")
|
||||
})
|
||||
}
|
||||
|
||||
// TestContainsHelperFunction tests the contains helper function edge cases
|
||||
func TestContainsHelperFunction(t *testing.T) {
|
||||
t.Run("exact match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("prefix match", func(t *testing.T) {
|
||||
assert.True(t, contains("timeout error occurred", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("suffix match", func(t *testing.T) {
|
||||
assert.True(t, contains("connection timeout", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("middle match", func(t *testing.T) {
|
||||
assert.True(t, contains("a connection timeout error", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
assert.False(t, contains("connection refused", "timeout"))
|
||||
})
|
||||
|
||||
t.Run("substring longer than string", func(t *testing.T) {
|
||||
assert.False(t, contains("abc", "abcdef"))
|
||||
})
|
||||
|
||||
t.Run("empty substring", func(t *testing.T) {
|
||||
assert.True(t, contains("test", ""))
|
||||
})
|
||||
|
||||
t.Run("empty string", func(t *testing.T) {
|
||||
assert.False(t, contains("", "test"))
|
||||
})
|
||||
|
||||
t.Run("both empty", func(t *testing.T) {
|
||||
assert.True(t, contains("", ""))
|
||||
})
|
||||
}
|
||||
+1182
-41
File diff suppressed because it is too large
Load Diff
@@ -1,797 +0,0 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Mock types for testing
|
||||
type TemplatedHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type MockConfig struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
ClientID string `json:"clientID"`
|
||||
ClientSecret string `json:"clientSecret"`
|
||||
CallbackURL string `json:"callbackURL"`
|
||||
SessionEncryptionKey string `json:"sessionEncryptionKey"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
// TestTemplateHeaderFeatures consolidates all template header-related tests
|
||||
func TestTemplateHeaderFeatures(t *testing.T) {
|
||||
t.Run("Issue55_TemplateExecutionWithWrongTypes", testIssue55TemplateExecutionWithWrongTypes)
|
||||
t.Run("Template_Parsing_Validation", testTemplateParsingValidation)
|
||||
t.Run("Middleware_Header_Templating", testMiddlewareHeaderTemplating)
|
||||
t.Run("JSON_Config_Parsing", testJSONConfigParsing)
|
||||
t.Run("Template_Double_Processing", testTemplateDoubleProcessing)
|
||||
t.Run("Template_Execution_Context", testTemplateExecutionContext)
|
||||
t.Run("Template_Integration_With_Plugin", testTemplateIntegrationWithPlugin)
|
||||
t.Run("Template_Syntax_Validation", testTemplateSyntaxValidation)
|
||||
t.Run("Missing_Field_Handling", testMissingFieldHandling)
|
||||
t.Run("Complex_Template_Expressions", testComplexTemplateExpressions)
|
||||
t.Run("Traefik_Configuration_Parsing", testTraefikConfigurationParsing)
|
||||
}
|
||||
|
||||
// testIssue55TemplateExecutionWithWrongTypes tests what happens when templates
|
||||
// receive wrong data types during execution - reproduces GitHub issue #55
|
||||
func testIssue55TemplateExecutionWithWrongTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
templateData interface{}
|
||||
errorContains string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "correct map data",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "valid-token",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "boolean as root context - reproduces issue #55",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: true,
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type bool",
|
||||
},
|
||||
{
|
||||
name: "string as root context",
|
||||
templateText: "Bearer {{.AccessToken}}",
|
||||
templateData: "just a string",
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field AccessToken in type string",
|
||||
},
|
||||
{
|
||||
name: "nested claims access with correct data",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nested claims with wrong structure",
|
||||
templateText: "User: {{.Claims.email}}",
|
||||
templateData: map[string]interface{}{
|
||||
"Claims": "not a map",
|
||||
},
|
||||
expectError: true,
|
||||
errorContains: "can't evaluate field email in type",
|
||||
},
|
||||
{
|
||||
name: "complex nested structure",
|
||||
templateText: "{{.Claims.sub}} - {{.Claims.groups}} - {{.AccessToken}}",
|
||||
templateData: map[string]interface{}{
|
||||
"AccessToken": "token123",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user-id",
|
||||
"groups": "admin,users",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.templateData)
|
||||
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
if tc.errorContains != "" {
|
||||
assert.Contains(t, err.Error(), tc.errorContains)
|
||||
}
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateParsingValidation ensures templates are parsed correctly
|
||||
func testTemplateParsingValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headerTemplates []TemplatedHeader
|
||||
shouldError bool
|
||||
}{
|
||||
{
|
||||
name: "valid bearer token template",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "multiple valid templates",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "template with conditional logic",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "X-Auth-Info", Value: "{{if .AccessToken}}Bearer {{.AccessToken}}{{else}}No Token{{end}}"},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid template syntax",
|
||||
headerTemplates: []TemplatedHeader{
|
||||
{Name: "Bad-Template", Value: "{{.AccessToken"},
|
||||
},
|
||||
shouldError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, header := range tc.headerTemplates {
|
||||
_, err := template.New(header.Name).Parse(header.Value)
|
||||
|
||||
if tc.shouldError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testMiddlewareHeaderTemplating simulates the actual middleware flow
|
||||
func testMiddlewareHeaderTemplating(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers []TemplatedHeader
|
||||
accessToken string
|
||||
idToken string
|
||||
claims map[string]interface{}
|
||||
expectedValues map[string]string
|
||||
}{
|
||||
{
|
||||
name: "authorization header with access token",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
expectedValues: map[string]string{
|
||||
"Authorization": "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple headers with claims",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Groups", Value: "{{.Claims.groups}}"},
|
||||
{Name: "X-Auth-Token", Value: "{{.AccessToken}}"},
|
||||
},
|
||||
accessToken: "token123",
|
||||
claims: map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
"groups": "admin,developers",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Email": "user@example.com",
|
||||
"X-User-Groups": "admin,developers",
|
||||
"X-Auth-Token": "token123",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex template expressions",
|
||||
headers: []TemplatedHeader{
|
||||
{Name: "X-User-Info", Value: "{{.Claims.sub}} ({{.Claims.email}})"},
|
||||
{Name: "X-Auth-Header", Value: "Bearer {{.AccessToken}} | ID: {{.IDToken}}"},
|
||||
},
|
||||
accessToken: "access-token",
|
||||
idToken: "id-token",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-12345",
|
||||
"email": "john@example.com",
|
||||
},
|
||||
expectedValues: map[string]string{
|
||||
"X-User-Info": "user-12345 (john@example.com)",
|
||||
"X-Auth-Header": "Bearer access-token | ID: id-token",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Parse all templates
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
for _, header := range tc.headers {
|
||||
tmpl, err := template.New(header.Name).Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = tmpl
|
||||
}
|
||||
|
||||
// Create template data
|
||||
templateData := map[string]interface{}{
|
||||
"AccessToken": tc.accessToken,
|
||||
"IDToken": tc.idToken,
|
||||
"Claims": tc.claims,
|
||||
}
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
|
||||
// Execute templates and set headers
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(headerName, buf.String())
|
||||
}
|
||||
|
||||
// Verify all expected headers are set correctly
|
||||
for headerName, expectedValue := range tc.expectedValues {
|
||||
actualValue := req.Header.Get(headerName)
|
||||
assert.Equal(t, expectedValue, actualValue)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testJSONConfigParsing tests that JSON configuration is properly parsed
|
||||
func testJSONConfigParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
jsonConfig string
|
||||
expectedError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid JSON configuration",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": "Bearer {{.AccessToken}}"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: false,
|
||||
description: "Properly formatted JSON with string values",
|
||||
},
|
||||
{
|
||||
name: "JSON with boolean value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": true
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Boolean value instead of string template",
|
||||
},
|
||||
{
|
||||
name: "JSON with number value",
|
||||
jsonConfig: `{
|
||||
"headers": [
|
||||
{
|
||||
"name": "Authorization",
|
||||
"value": 123
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expectedError: true,
|
||||
description: "Number value instead of string template",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var config struct {
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
}
|
||||
|
||||
err := json.Unmarshal([]byte(tc.jsonConfig), &config)
|
||||
|
||||
if tc.expectedError {
|
||||
require.Error(t, err, tc.description)
|
||||
} else {
|
||||
require.NoError(t, err, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateDoubleProcessing tests if template strings are being double-processed
|
||||
func testTemplateDoubleProcessing(t *testing.T) {
|
||||
// Simulate how Traefik passes config to the plugin
|
||||
config := &MockConfig{
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Verify that template strings are still raw (not processed)
|
||||
assert.Equal(t, "{{.Claims.email}}", config.Headers[0].Value)
|
||||
assert.Equal(t, "{{.Claims.internal_role}}", config.Headers[1].Value)
|
||||
|
||||
// Simulate template parsing during initialization
|
||||
headerTemplates := make(map[string]*template.Template)
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" || val == "<no value>" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
"get": func(m interface{}, key string) interface{} {
|
||||
if mapVal, ok := m.(map[string]interface{}); ok {
|
||||
if val, exists := mapVal[key]; exists {
|
||||
return val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
},
|
||||
}
|
||||
|
||||
for _, header := range config.Headers {
|
||||
tmpl := template.New(header.Name).Funcs(funcMap).Option("missingkey=zero")
|
||||
parsedTmpl, err := tmpl.Parse(header.Value)
|
||||
require.NoError(t, err)
|
||||
headerTemplates[header.Name] = parsedTmpl
|
||||
}
|
||||
|
||||
// Test execution with actual claims
|
||||
claims := map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
// Note: internal_role is missing
|
||||
}
|
||||
|
||||
templateData := map[string]interface{}{
|
||||
"Claims": claims,
|
||||
}
|
||||
|
||||
// Execute templates
|
||||
for headerName, tmpl := range headerTemplates {
|
||||
var buf bytes.Buffer
|
||||
err := tmpl.Execute(&buf, templateData)
|
||||
require.NoError(t, err)
|
||||
|
||||
result := buf.String()
|
||||
if headerName == "X-User-Email" {
|
||||
assert.Equal(t, "user@example.com", result)
|
||||
} else if headerName == "X-User-Role" {
|
||||
// With missingkey=zero, missing fields return "<no value>"
|
||||
assert.Equal(t, "<no value>", result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateExecutionContext tests the specific template data context
|
||||
func testTemplateExecutionContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expectedValue string
|
||||
}{
|
||||
{
|
||||
name: "Access and ID token distinction",
|
||||
templateText: "Access: {{.AccessToken}} ID: {{.IDToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"IDToken": "id-token-value",
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expectedValue: "Access: access-token-value ID: id-token-value",
|
||||
},
|
||||
{
|
||||
name: "Combining tokens and claims",
|
||||
templateText: "User: {{.Claims.sub}} Token: {{.AccessToken}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token",
|
||||
"IDToken": "id-token",
|
||||
"Claims": map[string]interface{}{
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
expectedValue: "User: user123 Token: access-token",
|
||||
},
|
||||
{
|
||||
name: "Custom non-standard claims",
|
||||
templateText: "X-User-Role: {{.Claims.role}}, X-User-Permissions: {{.Claims.permissions}}",
|
||||
data: map[string]interface{}{
|
||||
"AccessToken": "access-token-value",
|
||||
"Claims": map[string]interface{}{
|
||||
"role": "admin",
|
||||
"permissions": "read:all,write:own",
|
||||
},
|
||||
},
|
||||
expectedValue: "X-User-Role: admin, X-User-Permissions: read:all,write:own",
|
||||
},
|
||||
{
|
||||
name: "Deeply nested custom claims",
|
||||
templateText: "X-Organization: {{.Claims.app_metadata.organization.name}}, X-Team: {{.Claims.app_metadata.team}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"app_metadata": map[string]interface{}{
|
||||
"organization": map[string]interface{}{
|
||||
"name": "acme-corp",
|
||||
},
|
||||
"team": "platform",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedValue: "X-Organization: acme-corp, X-Team: platform",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expectedValue, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTemplateIntegrationWithPlugin tests template processing in the actual plugin
|
||||
func testTemplateIntegrationWithPlugin(t *testing.T) {
|
||||
// Test template integration using mock plugin components
|
||||
|
||||
// Set up test OIDC server
|
||||
var testServerURL string
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"issuer": testServerURL,
|
||||
"authorization_endpoint": testServerURL + "/auth",
|
||||
"token_endpoint": testServerURL + "/token",
|
||||
"jwks_uri": testServerURL + "/jwks",
|
||||
"userinfo_endpoint": testServerURL + "/userinfo",
|
||||
})
|
||||
case "/jwks":
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"keys": []interface{}{},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer testServer.Close()
|
||||
testServerURL = testServer.URL
|
||||
|
||||
// Create config with templates that reference potentially missing fields
|
||||
config := &MockConfig{
|
||||
ProviderURL: testServer.URL,
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
CallbackURL: "/callback",
|
||||
SessionEncryptionKey: "test-encryption-key-32-characters",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-Role", Value: "{{.Claims.internal_role}}"},
|
||||
},
|
||||
}
|
||||
|
||||
// Initialize plugin would be done here
|
||||
ctx := context.Background()
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Test would create plugin handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = config
|
||||
}
|
||||
|
||||
// testTemplateSyntaxValidation tests that template syntax is properly validated
|
||||
func testTemplateSyntaxValidation(t *testing.T) {
|
||||
validTemplates := []string{
|
||||
"{{.Claims.email}}",
|
||||
"{{.Claims.internal_role}}",
|
||||
"{{.AccessToken}}",
|
||||
"{{.IdToken}}",
|
||||
"{{.RefreshToken}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range validTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Template should be valid: %s", tmplStr)
|
||||
}
|
||||
|
||||
// Test invalid templates
|
||||
invalidTemplates := []struct {
|
||||
template string
|
||||
reason string
|
||||
}{
|
||||
{"{{call .SomeFunc}}", "function calls not allowed"},
|
||||
{"{{range .Items}}{{.}}{{end}}", "range not allowed"},
|
||||
{"{{with .Data}}{{.Field}}{{end}}", "with statements blocked"},
|
||||
{"{{index .Array 0}}", "index access blocked"},
|
||||
{"{{printf \"%s\" .Data}}", "printf blocked"},
|
||||
}
|
||||
|
||||
for _, tc := range invalidTemplates {
|
||||
err := validateTemplateSecure(tc.template)
|
||||
assert.Error(t, err, "Template should be invalid: %s (%s)", tc.template, tc.reason)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "dangerous")
|
||||
}
|
||||
|
||||
// Test safe custom functions
|
||||
safeTemplates := []string{
|
||||
"{{get .Claims \"internal_role\"}}",
|
||||
"{{default \"guest\" .Claims.role}}",
|
||||
}
|
||||
|
||||
for _, tmplStr := range safeTemplates {
|
||||
err := validateTemplateSecure(tmplStr)
|
||||
assert.NoError(t, err, "Safe custom functions should be allowed: %s", tmplStr)
|
||||
}
|
||||
}
|
||||
|
||||
// Mock validation function for template security
|
||||
func validateTemplateSecure(templateStr string) error {
|
||||
// List of potentially dangerous template actions
|
||||
dangerousFunctions := []string{
|
||||
"call", "range", "with", "index", "printf", "println", "print",
|
||||
"js", "html", "urlquery", "base64", "exec",
|
||||
}
|
||||
|
||||
for _, dangerous := range dangerousFunctions {
|
||||
if strings.Contains(templateStr, dangerous) {
|
||||
return fmt.Errorf("dangerous template function detected: %s", dangerous)
|
||||
}
|
||||
}
|
||||
|
||||
// Define safe custom functions
|
||||
funcMap := template.FuncMap{
|
||||
"get": func(data map[string]interface{}, key string) interface{} {
|
||||
return data[key]
|
||||
},
|
||||
"default": func(defaultVal interface{}, val interface{}) interface{} {
|
||||
if val == nil || val == "" {
|
||||
return defaultVal
|
||||
}
|
||||
return val
|
||||
},
|
||||
}
|
||||
|
||||
// Try to parse the template with custom functions to check for syntax errors
|
||||
_, err := template.New("test").Funcs(funcMap).Parse(templateStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// testMissingFieldHandling tests handling of missing fields in templates
|
||||
func testMissingFieldHandling(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "missing claim field",
|
||||
templateText: "{{.Claims.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing nested field",
|
||||
templateText: "{{.Claims.user.missing}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"user": map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
expected: "<no value>",
|
||||
},
|
||||
{
|
||||
name: "missing entire path",
|
||||
templateText: "{{.Missing.Path.Field}}",
|
||||
data: map[string]interface{}{},
|
||||
expected: "<no value>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testComplexTemplateExpressions tests complex template expressions
|
||||
func testComplexTemplateExpressions(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
templateText string
|
||||
data map[string]interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "conditional template",
|
||||
templateText: "{{if .Claims.admin}}Admin User{{else}}Regular User{{end}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"admin": true,
|
||||
},
|
||||
},
|
||||
expected: "Admin User",
|
||||
},
|
||||
{
|
||||
name: "multiple claims concatenation",
|
||||
templateText: "{{.Claims.firstName}} {{.Claims.lastName}} <{{.Claims.email}}>",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"firstName": "John",
|
||||
"lastName": "Doe",
|
||||
"email": "john.doe@example.com",
|
||||
},
|
||||
},
|
||||
expected: "John Doe <john.doe@example.com>",
|
||||
},
|
||||
{
|
||||
name: "array access",
|
||||
templateText: "{{index .Claims.roles 0}}",
|
||||
data: map[string]interface{}{
|
||||
"Claims": map[string]interface{}{
|
||||
"roles": []string{"admin", "user"},
|
||||
},
|
||||
},
|
||||
expected: "admin",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := template.New("test").Parse(tc.templateText)
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, tc.data)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testTraefikConfigurationParsing tests various ways Traefik might pass configuration
|
||||
func testTraefikConfigurationParsing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
config *MockConfig
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "valid configuration with templated headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Standard configuration should work",
|
||||
},
|
||||
{
|
||||
name: "configuration with multiple headers",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{
|
||||
{Name: "Authorization", Value: "Bearer {{.AccessToken}}"},
|
||||
{Name: "X-User-Email", Value: "{{.Claims.email}}"},
|
||||
{Name: "X-User-ID", Value: "{{.Claims.sub}}"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Multiple headers should work",
|
||||
},
|
||||
{
|
||||
name: "empty headers configuration",
|
||||
config: &MockConfig{
|
||||
ProviderURL: "https://accounts.google.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
SessionEncryptionKey: "test-encryption-key-32-bytes-long",
|
||||
CallbackURL: "/oauth2/callback",
|
||||
Headers: []TemplatedHeader{},
|
||||
},
|
||||
expectError: false,
|
||||
description: "Empty headers should not cause issues",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a simple next handler
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Try to create the middleware would be done here
|
||||
ctx := context.Background()
|
||||
|
||||
// Test would create middleware handler here
|
||||
_ = ctx
|
||||
_ = next
|
||||
_ = tc.config
|
||||
|
||||
// For now, we just validate the configuration is well-formed
|
||||
if !tc.expectError {
|
||||
require.NotNil(t, tc.config, tc.description)
|
||||
require.NotEmpty(t, tc.config.ClientID, tc.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -18,5 +18,6 @@ require (
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
)
|
||||
|
||||
@@ -20,8 +20,10 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
|
||||
@@ -10,16 +10,16 @@ import (
|
||||
type GoroutineManager struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
goroutines map[string]*managedGoroutine
|
||||
logger *Logger
|
||||
wg sync.WaitGroup
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type managedGoroutine struct {
|
||||
name string
|
||||
cancel context.CancelFunc
|
||||
startTime time.Time
|
||||
cancel context.CancelFunc
|
||||
name string
|
||||
running bool
|
||||
}
|
||||
|
||||
@@ -149,10 +149,10 @@ func (m *GoroutineManager) GetStatus() map[string]GoroutineStatus {
|
||||
|
||||
// GoroutineStatus represents the status of a managed goroutine
|
||||
type GoroutineStatus struct {
|
||||
Name string
|
||||
Running bool
|
||||
StartTime time.Time
|
||||
Name string
|
||||
Runtime time.Duration
|
||||
Running bool
|
||||
}
|
||||
|
||||
// ErrShutdownTimeout is returned when shutdown times out
|
||||
|
||||
@@ -1,764 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// OAuth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestOAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthorizationRequest", func(t *testing.T) {
|
||||
// Test authorization request handling logic
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestURL string
|
||||
expectedStatus int
|
||||
checkLocation bool
|
||||
}{
|
||||
{
|
||||
name: "Valid authorization request",
|
||||
requestURL: "/auth/login",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
{
|
||||
name: "With return URL",
|
||||
requestURL: "/auth/login?return=/dashboard",
|
||||
expectedStatus: http.StatusFound,
|
||||
checkLocation: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the test case structure
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.requestURL == "" {
|
||||
t.Error("Request URL should not be empty")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
// In a real implementation, this would test the actual handler
|
||||
t.Logf("Testing %s with URL %s expecting status %d", test.name, test.requestURL, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Authorization request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleCallbackRequest", func(t *testing.T) {
|
||||
// Test callback request handling with existing mocks
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
expectedStatus int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid callback with code",
|
||||
queryParams: "code=test-code&state=test-state",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Callback with error",
|
||||
queryParams: "error=access_denied&error_description=User denied access",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing code",
|
||||
queryParams: "state=test-state",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Missing state",
|
||||
queryParams: "code=test-code",
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the callback scenarios
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Verify test case parameters
|
||||
if test.queryParams == "" && !test.expectError {
|
||||
t.Error("Query params should not be empty for successful cases")
|
||||
}
|
||||
if test.expectedStatus == 0 {
|
||||
t.Error("Expected status should be set")
|
||||
}
|
||||
|
||||
// Test session manager functionality
|
||||
if sessionManager != nil {
|
||||
t.Logf("Session manager available for test %s", test.name)
|
||||
}
|
||||
|
||||
t.Logf("Testing %s with params %s expecting status %d", test.name, test.queryParams, test.expectedStatus)
|
||||
})
|
||||
}
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Callback request test completed")
|
||||
})
|
||||
|
||||
t.Run("HandleLogout", func(t *testing.T) {
|
||||
// Test logout functionality with mock implementations
|
||||
sessionManager := NewMockSessionManager()
|
||||
logger := &MockLogger{}
|
||||
|
||||
// Test session clearing
|
||||
mockReq := &http.Request{}
|
||||
session, err := sessionManager.GetSession(mockReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set up authenticated session
|
||||
err = session.SetAuthenticated(true)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set authentication: %v", err)
|
||||
}
|
||||
session.SetIDToken("test-token")
|
||||
|
||||
// Verify session is authenticated
|
||||
if !session.GetAuthenticated() {
|
||||
t.Error("Session should be authenticated before logout")
|
||||
}
|
||||
|
||||
// Test logout by clearing session
|
||||
// session.Clear() // Method not implemented in SessionData
|
||||
// Additional logout verification would go here
|
||||
|
||||
// Verify logger doesn't cause issues
|
||||
logger.Debugf("Logout test completed")
|
||||
t.Log("Logout test completed successfully")
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Auth Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAuthHandler(t *testing.T) {
|
||||
t.Run("HandleAuthentication", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
/*
|
||||
handler := &MockAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*MockSession)
|
||||
expectedStatus int
|
||||
expectNext bool
|
||||
}{
|
||||
{
|
||||
name: "Authenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("valid-token")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectNext: true,
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated user",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(false)
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
{
|
||||
name: "Expired token",
|
||||
setupSession: func(s *MockSession) {
|
||||
s.SetAuthenticated(true)
|
||||
s.SetIDToken("expired-token")
|
||||
},
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectNext: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HandleRefreshToken", func(t *testing.T) {
|
||||
// Test authentication handling with mock types
|
||||
// validator := &MockTokenValidator{valid: true} // Currently unused
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
mockResponse *MockTokenResponse
|
||||
mockError error
|
||||
expectSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "Successful refresh",
|
||||
refreshToken: "valid-refresh-token",
|
||||
mockResponse: &MockTokenResponse{
|
||||
AccessToken: "new-access-token",
|
||||
IDToken: "new-id-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
},
|
||||
expectSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "Failed refresh",
|
||||
refreshToken: "invalid-refresh-token",
|
||||
mockError: errors.New("invalid_grant"),
|
||||
expectSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "Empty refresh token",
|
||||
refreshToken: "",
|
||||
expectSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestErrorHandler(t *testing.T) {
|
||||
t.Run("HandleHTTPErrors", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorCode int
|
||||
errorMessage string
|
||||
isAjax bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "401 Unauthorized",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Authentication required",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: "Authentication required",
|
||||
},
|
||||
{
|
||||
name: "403 Forbidden",
|
||||
errorCode: http.StatusForbidden,
|
||||
errorMessage: "Access denied",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: "Access denied",
|
||||
},
|
||||
{
|
||||
name: "500 Internal Server Error",
|
||||
errorCode: http.StatusInternalServerError,
|
||||
errorMessage: "Internal server error",
|
||||
isAjax: false,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal server error",
|
||||
},
|
||||
{
|
||||
name: "Ajax 401",
|
||||
errorCode: http.StatusUnauthorized,
|
||||
errorMessage: "Token expired",
|
||||
isAjax: true,
|
||||
expectedStatus: http.StatusUnauthorized,
|
||||
expectedBody: `{"error":"unauthorized","message":"Token expired"}`,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RecoverFromPanic", func(t *testing.T) {
|
||||
// Test with mock implementations
|
||||
/*
|
||||
handler := &MockErrorHandler{
|
||||
logger: &MockLogger{},
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
panicValue interface{}
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "String panic",
|
||||
panicValue: "something went wrong",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Error panic",
|
||||
panicValue: errors.New("critical error"),
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "Nil panic",
|
||||
panicValue: nil,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Azure OAuth Callback Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestAzureOAuthCallback(t *testing.T) {
|
||||
t.Run("AzureSpecificClaims", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
azureClaims := map[string]interface{}{
|
||||
"oid": "object-id",
|
||||
"tid": "tenant-id",
|
||||
"preferred_username": "user@example.com",
|
||||
"name": "Test User",
|
||||
"email": "user@example.com",
|
||||
"groups": []string{"group1", "group2"},
|
||||
}
|
||||
|
||||
// Test would go here when properly implemented
|
||||
_ = azureClaims
|
||||
})
|
||||
|
||||
t.Run("AzureTokenValidation", func(t *testing.T) {
|
||||
// Test with mock validator types
|
||||
/*
|
||||
validator := &MockAzureTokenValidator{
|
||||
tenantID: "test-tenant",
|
||||
clientID: "test-client",
|
||||
}
|
||||
*/
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
claims map[string]interface{}
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "Valid Azure token",
|
||||
token: "valid-azure-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: true,
|
||||
},
|
||||
{
|
||||
name: "Wrong tenant",
|
||||
token: "wrong-tenant-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "test-client",
|
||||
"tid": "wrong-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
{
|
||||
name: "Wrong audience",
|
||||
token: "wrong-audience-token",
|
||||
claims: map[string]interface{}{
|
||||
"aud": "wrong-client",
|
||||
"tid": "test-tenant",
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
expectValid: false,
|
||||
},
|
||||
}
|
||||
|
||||
// Test the authentication test cases
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
// Test with mock session
|
||||
mockSession := &MockSession{values: make(map[string]interface{})}
|
||||
// Use mock session to avoid unused variable error
|
||||
_ = mockSession
|
||||
t.Logf("Testing %s", test.name)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Concurrent Handler Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConcurrentHandlers(t *testing.T) {
|
||||
t.Run("ConcurrentCallbacks", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
successCount := int32(0)
|
||||
errorCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = successCount
|
||||
_ = errorCount
|
||||
})
|
||||
|
||||
t.Run("ConcurrentLogouts", func(t *testing.T) {
|
||||
// Test with mock configuration
|
||||
/*
|
||||
handler := &OAuthHandler{
|
||||
logger: &MockLogger{},
|
||||
sessionManager: NewMockSessionManager(),
|
||||
}
|
||||
*/
|
||||
|
||||
var wg sync.WaitGroup
|
||||
logoutCount := int32(0)
|
||||
|
||||
// Test would go here when properly implemented
|
||||
wg.Wait() // Proper usage instead of assignment
|
||||
_ = logoutCount
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Implementations
|
||||
// ============================================================================
|
||||
|
||||
type MockSessionManager struct {
|
||||
sessions map[string]*MockSession
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMockSessionManager() *MockSessionManager {
|
||||
return &MockSessionManager{
|
||||
sessions: make(map[string]*MockSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockSessionManager) GetSession(r *http.Request) (SessionData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
sessionID := "test-session"
|
||||
if session, exists := m.sessions[sessionID]; exists {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
session := &MockSession{
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
m.sessions[sessionID] = session
|
||||
return session, nil
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
values map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAuthenticated(auth bool) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["authenticated"] = auth
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAuthenticated() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
auth, ok := s.values["authenticated"].(bool)
|
||||
return ok && auth
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIDToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["id_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIDToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["id_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetAccessToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["access_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetAccessToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["access_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetRefreshToken(token string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["refresh_token"] = token
|
||||
}
|
||||
|
||||
func (s *MockSession) GetRefreshToken() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
token, _ := s.values["refresh_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func (s *MockSession) SetState(state string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["state"] = state
|
||||
}
|
||||
|
||||
func (s *MockSession) GetState() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
state, _ := s.values["state"].(string)
|
||||
return state
|
||||
}
|
||||
|
||||
func (s *MockSession) SetClaims(claims map[string]interface{}) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["claims"] = claims
|
||||
}
|
||||
|
||||
func (s *MockSession) GetClaims() map[string]interface{} {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
claims, _ := s.values["claims"].(map[string]interface{})
|
||||
return claims
|
||||
}
|
||||
|
||||
// Additional SessionData interface methods to match real interface
|
||||
func (s *MockSession) GetCSRF() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
csrf, _ := s.values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) GetNonce() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
nonce, _ := s.values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) GetCodeVerifier() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
verifier, _ := s.values["code_verifier"].(string)
|
||||
return verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) GetIncomingPath() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
path, _ := s.values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
func (s *MockSession) SetEmail(email string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["email"] = email
|
||||
}
|
||||
|
||||
func (s *MockSession) GetEmail() string {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
email, _ := s.values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCSRF(csrf string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["csrf"] = csrf
|
||||
}
|
||||
|
||||
func (s *MockSession) SetNonce(nonce string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["nonce"] = nonce
|
||||
}
|
||||
|
||||
func (s *MockSession) SetCodeVerifier(verifier string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["code_verifier"] = verifier
|
||||
}
|
||||
|
||||
func (s *MockSession) SetIncomingPath(path string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["incoming_path"] = path
|
||||
}
|
||||
|
||||
func (s *MockSession) ResetRedirectCount() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values["redirect_count"] = 0
|
||||
}
|
||||
|
||||
func (s *MockSession) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *MockSession) Clear() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.values = make(map[string]interface{})
|
||||
}
|
||||
|
||||
func (s *MockSession) returnToPoolSafely() {
|
||||
// No-op for mock
|
||||
}
|
||||
|
||||
type MockTokenValidator struct {
|
||||
valid bool
|
||||
}
|
||||
|
||||
func (v *MockTokenValidator) Validate(token string) bool {
|
||||
if token == "expired-token" {
|
||||
return false
|
||||
}
|
||||
return v.valid
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mock Handler Type Definitions (for testing)
|
||||
// ============================================================================
|
||||
|
||||
// These mock handlers are simplified versions for testing purposes
|
||||
// They don't match the actual handler implementations
|
||||
|
||||
type MockAuthHandler struct{}
|
||||
|
||||
type MockErrorHandler struct{}
|
||||
|
||||
type MockAzureTokenValidator struct {
|
||||
tenantID string
|
||||
clientID string
|
||||
}
|
||||
|
||||
func (v *MockAzureTokenValidator) ValidateAzureToken(token string, claims map[string]interface{}) bool {
|
||||
// Validate tenant ID
|
||||
if tid, ok := claims["tid"].(string); !ok || tid != v.tenantID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate audience
|
||||
if aud, ok := claims["aud"].(string); !ok || aud != v.clientID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Validate expiration
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
if time.Now().Unix() > int64(exp) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper Types and Mock Logger
|
||||
// ============================================================================
|
||||
|
||||
type MockLogger struct{}
|
||||
|
||||
func (l *MockLogger) Debugf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Errorf(format string, args ...interface{}) {}
|
||||
func (l *MockLogger) Error(msg string) {}
|
||||
|
||||
type MockTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
@@ -1,313 +0,0 @@
|
||||
// Package handlers provides HTTP request handlers for the OIDC middleware.
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OAuthHandler handles OAuth callback requests
|
||||
type OAuthHandler struct {
|
||||
logger Logger
|
||||
sessionManager SessionManager
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Error(msg string)
|
||||
}
|
||||
|
||||
// SessionManager interface for session operations
|
||||
type SessionManager interface {
|
||||
GetSession(req *http.Request) (SessionData, error)
|
||||
}
|
||||
|
||||
// SessionData interface for session data operations
|
||||
type SessionData interface {
|
||||
GetCSRF() string
|
||||
GetNonce() string
|
||||
GetCodeVerifier() string
|
||||
GetIncomingPath() string
|
||||
GetAuthenticated() bool
|
||||
GetAccessToken() string
|
||||
GetRefreshToken() string
|
||||
GetIDToken() string
|
||||
GetEmail() string
|
||||
SetAuthenticated(bool) error
|
||||
SetEmail(string)
|
||||
SetIDToken(string)
|
||||
SetAccessToken(string)
|
||||
SetRefreshToken(string)
|
||||
SetCSRF(string)
|
||||
SetNonce(string)
|
||||
SetCodeVerifier(string)
|
||||
SetIncomingPath(string)
|
||||
ResetRedirectCount()
|
||||
Save(req *http.Request, rw http.ResponseWriter) error
|
||||
returnToPoolSafely()
|
||||
}
|
||||
|
||||
// TokenExchanger interface for token operations
|
||||
type TokenExchanger interface {
|
||||
ExchangeCodeForToken(ctx context.Context, grantType string, codeOrToken string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
||||
}
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from token exchange
|
||||
type TokenResponse struct {
|
||||
IDToken string
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
}
|
||||
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleCallback handles OAuth callback requests
|
||||
func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := h.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Session error during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Session error during callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer session.returnToPoolSafely()
|
||||
|
||||
h.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Debug logging for cookie configuration
|
||||
h.logger.Debugf("Callback request headers - Host: %s, X-Forwarded-Host: %s, X-Forwarded-Proto: %s",
|
||||
req.Host, req.Header.Get("X-Forwarded-Host"), req.Header.Get("X-Forwarded-Proto"))
|
||||
|
||||
// Log all cookies in the request for debugging
|
||||
cookies := req.Cookies()
|
||||
h.logger.Debugf("Total cookies in callback request: %d", len(cookies))
|
||||
for _, cookie := range cookies {
|
||||
if strings.HasPrefix(cookie.Name, "_oidc_") {
|
||||
h.logger.Debugf("Cookie found - Name: %s, Domain: %s, Path: %s, SameSite: %v, Secure: %v, HttpOnly: %v, Value length: %d",
|
||||
cookie.Name, cookie.Domain, cookie.Path, cookie.SameSite, cookie.Secure, cookie.HttpOnly, len(cookie.Value))
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error")
|
||||
}
|
||||
h.logger.Errorf("Authentication error from provider during callback: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
h.sendErrorResponseFunc(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
h.logger.Error("No state in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log the state parameter received
|
||||
h.logger.Debugf("State parameter received in callback: %s (length: %d)", state, len(state))
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
h.logger.Errorf("CSRF token missing in session during callback. Authenticated: %v, Request URL: %s",
|
||||
session.GetAuthenticated(), req.URL.String())
|
||||
|
||||
// Enhanced debugging for missing CSRF token
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
// Log cookie names only, not values (avoid logging sensitive session data)
|
||||
cookieNames := make([]string, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
h.logger.Debugf("Available cookies (names only): %v", cookieNames)
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
cookie.Domain, cookie.Path, cookie.Secure, cookie.HttpOnly, cookie.SameSite)
|
||||
}
|
||||
|
||||
// Log session state for debugging
|
||||
h.logger.Debugf("Session state during CSRF check - Authenticated: %v, Has AccessToken: %v",
|
||||
session.GetAuthenticated(), session.GetAccessToken() != "")
|
||||
|
||||
h.sendErrorResponseFunc(rw, req, "CSRF token missing in session", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Debug log successful CSRF token retrieval
|
||||
h.logger.Debugf("CSRF token retrieved from session: %s (length: %d)", csrfToken, len(csrfToken))
|
||||
|
||||
if state != csrfToken {
|
||||
h.logger.Error("State parameter does not match CSRF token in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
h.logger.Error("No code in callback")
|
||||
h.sendErrorResponseFunc(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
codeVerifier := session.GetCodeVerifier()
|
||||
|
||||
tokenResponse, err := h.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to exchange code for token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.tokenVerifier.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
h.logger.Errorf("Failed to verify id_token during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := h.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
h.logger.Errorf("Failed to extract claims during callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
h.logger.Error("Nonce claim missing in id_token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
h.logger.Error("Nonce not found in session during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
h.logger.Error("Nonce claim does not match session nonce during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
if err := session.SetAuthenticated(true); err != nil {
|
||||
h.logger.Errorf("Failed to set authenticated state and regenerate session ID: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
session.SetCSRF("")
|
||||
session.SetNonce("")
|
||||
session.SetCodeVerifier("")
|
||||
|
||||
session.ResetRedirectCount()
|
||||
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != h.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
session.SetIncomingPath("")
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
h.logger.Errorf("Failed to save session after callback: %v", err)
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to save session after callback", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debugf("Callback successful, redirecting to %s", redirectPath)
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
// URLHelper provides utility methods for URL operations
|
||||
type URLHelper struct {
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// NewURLHelper creates a new URL helper
|
||||
func NewURLHelper(logger Logger) *URLHelper {
|
||||
return &URLHelper{logger: logger}
|
||||
}
|
||||
|
||||
// DetermineExcludedURL checks if a URL path should bypass OIDC authentication.
|
||||
// It compares the request path against configured excluded URL prefixes.
|
||||
func (h *URLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
|
||||
for excludedURL := range excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
h.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// DetermineScheme determines the URL scheme for building redirect URLs.
|
||||
// It checks X-Forwarded-Proto header first, then TLS presence.
|
||||
func (h *URLHelper) DetermineScheme(req *http.Request) string {
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
return "https"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// DetermineHost determines the host for building redirect URLs.
|
||||
// It checks X-Forwarded-Host header first, then falls back to req.Host.
|
||||
func (h *URLHelper) DetermineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
}
|
||||
return req.Host
|
||||
}
|
||||
@@ -1,899 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test mocks - implementing interfaces defined in oauth_handler.go
|
||||
type mockLogger struct {
|
||||
debugMessages []string
|
||||
errorMessages []string
|
||||
}
|
||||
|
||||
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
||||
l.debugMessages = append(l.debugMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
||||
l.errorMessages = append(l.errorMessages, format)
|
||||
}
|
||||
|
||||
func (l *mockLogger) Error(msg string) {
|
||||
l.errorMessages = append(l.errorMessages, msg)
|
||||
}
|
||||
|
||||
type mockSessionManager struct {
|
||||
sessionToReturn SessionData
|
||||
errorToReturn error
|
||||
}
|
||||
|
||||
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
|
||||
return m.sessionToReturn, m.errorToReturn
|
||||
}
|
||||
|
||||
type mockSessionData struct {
|
||||
authenticated bool
|
||||
email string
|
||||
csrf string
|
||||
nonce string
|
||||
codeVerifier string
|
||||
incomingPath string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
idToken string
|
||||
saveError error
|
||||
setAuthError error
|
||||
}
|
||||
|
||||
func (s *mockSessionData) GetCSRF() string { return s.csrf }
|
||||
func (s *mockSessionData) GetNonce() string { return s.nonce }
|
||||
func (s *mockSessionData) GetCodeVerifier() string { return s.codeVerifier }
|
||||
func (s *mockSessionData) GetIncomingPath() string { return s.incomingPath }
|
||||
func (s *mockSessionData) GetAuthenticated() bool { return s.authenticated }
|
||||
func (s *mockSessionData) GetAccessToken() string { return s.accessToken }
|
||||
func (s *mockSessionData) GetRefreshToken() string { return s.refreshToken }
|
||||
func (s *mockSessionData) GetIDToken() string { return s.idToken }
|
||||
func (s *mockSessionData) GetEmail() string { return s.email }
|
||||
|
||||
func (s *mockSessionData) SetAuthenticated(auth bool) error {
|
||||
s.authenticated = auth
|
||||
return s.setAuthError
|
||||
}
|
||||
|
||||
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
||||
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
||||
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
||||
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
||||
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
||||
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
||||
func (s *mockSessionData) SetCodeVerifier(verif string) { s.codeVerifier = verif }
|
||||
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
||||
func (s *mockSessionData) ResetRedirectCount() {}
|
||||
func (s *mockSessionData) returnToPoolSafely() {}
|
||||
|
||||
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
||||
return s.saveError
|
||||
}
|
||||
|
||||
type mockTokenExchanger struct {
|
||||
response *TokenResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *mockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
||||
return e.response, e.err
|
||||
}
|
||||
|
||||
type mockTokenVerifier struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *mockTokenVerifier) VerifyToken(token string) error {
|
||||
return v.err
|
||||
}
|
||||
|
||||
// TestOAuthHandler_NewOAuthHandler tests the constructor
|
||||
func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
|
||||
if handler.redirURLPath != "/callback" {
|
||||
t.Errorf("Expected redirURLPath '/callback', got '%s'", handler.redirURLPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionError tests session retrieval errors
|
||||
func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
sessionManager := &mockSessionManager{errorToReturn: errors.New("session error")}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Session error") {
|
||||
t.Errorf("Expected error message to contain 'Session error', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ProviderError tests OAuth provider errors
|
||||
func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Authentication error from provider") {
|
||||
t.Errorf("Expected error message to contain 'Authentication error from provider', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
|
||||
if len(logger.errorMessages) == 0 {
|
||||
t.Error("Expected error to be logged")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingState tests missing state parameter
|
||||
func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "State parameter missing") {
|
||||
t.Errorf("Expected error message to contain 'State parameter missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCSRF tests missing CSRF token in session
|
||||
func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: ""} // Empty CSRF
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF token missing") {
|
||||
t.Errorf("Expected error message to contain 'CSRF token missing', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_CSRFMismatch tests CSRF token mismatch
|
||||
func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "different-token"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "CSRF mismatch") {
|
||||
t.Errorf("Expected error message to contain 'CSRF mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingCode tests missing authorization code
|
||||
func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusBadRequest {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusBadRequest, code)
|
||||
}
|
||||
if !strings.Contains(msg, "No authorization code received") {
|
||||
t.Errorf("Expected error message to contain 'No authorization code received', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenExchangeError tests token exchange failure
|
||||
func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce", codeVerifier: "test-verifier"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenExchanger := &mockTokenExchanger{err: errors.New("token exchange failed")}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not exchange code for token") {
|
||||
t.Errorf("Expected error message to contain 'Could not exchange code for token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_TokenVerificationError tests token verification failure
|
||||
func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "invalid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{err: errors.New("token verification failed")}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) { return nil, nil }
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not verify ID token") {
|
||||
t.Errorf("Expected error message to contain 'Could not verify ID token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_ClaimsExtractionError tests claims extraction failure
|
||||
func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return nil, errors.New("claims extraction failed")
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Could not extract claims") {
|
||||
t.Errorf("Expected error message to contain 'Could not extract claims', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInToken tests missing nonce in token
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
// Claims without nonce
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingNonceInSession tests missing nonce in session
|
||||
func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: ""} // Empty nonce
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce missing in session") {
|
||||
t.Errorf("Expected error message to contain 'Nonce missing in session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_NonceMismatch tests nonce mismatch
|
||||
func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "session-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "token-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Nonce mismatch") {
|
||||
t.Errorf("Expected error message to contain 'Nonce mismatch', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_MissingEmail tests missing email in claims
|
||||
func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"nonce": "test-nonce"}, nil // No email
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_DisallowedDomain tests disallowed email domain
|
||||
func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{csrf: "test-state", nonce: "test-nonce"}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@disallowed.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return false } // Disallow all domains
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SessionSaveError tests session save failure
|
||||
func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
saveError: errors.New("save failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token", RefreshToken: "refresh-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to save session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to save session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SetAuthenticatedError tests SetAuthenticated failure
|
||||
func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
setAuthError: errors.New("set auth failed"),
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Failed to update session") {
|
||||
t.Errorf("Expected error message to contain 'Failed to update session', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if !errorSent {
|
||||
t.Error("Expected error response to be sent")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_Success tests successful callback handling
|
||||
func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/dashboard",
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{
|
||||
IDToken: "valid-id-token",
|
||||
AccessToken: "valid-access-token",
|
||||
RefreshToken: "valid-refresh-token",
|
||||
}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
errorSent := false
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
errorSent = true
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
if errorSent {
|
||||
t.Error("Unexpected error response sent")
|
||||
}
|
||||
|
||||
// Check redirect
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/dashboard" {
|
||||
t.Errorf("Expected redirect to '/dashboard', got '%s'", location)
|
||||
}
|
||||
|
||||
// Verify session data was set correctly
|
||||
if session.email != "test@example.com" {
|
||||
t.Errorf("Expected email 'test@example.com', got '%s'", session.email)
|
||||
}
|
||||
|
||||
if session.idToken != "valid-id-token" {
|
||||
t.Errorf("Expected ID token 'valid-id-token', got '%s'", session.idToken)
|
||||
}
|
||||
|
||||
if session.accessToken != "valid-access-token" {
|
||||
t.Errorf("Expected access token 'valid-access-token', got '%s'", session.accessToken)
|
||||
}
|
||||
|
||||
if session.refreshToken != "valid-refresh-token" {
|
||||
t.Errorf("Expected refresh token 'valid-refresh-token', got '%s'", session.refreshToken)
|
||||
}
|
||||
|
||||
if !session.authenticated {
|
||||
t.Error("Expected session to be authenticated")
|
||||
}
|
||||
|
||||
// Check that temporary fields are cleared
|
||||
if session.csrf != "" {
|
||||
t.Errorf("Expected CSRF to be cleared, got '%s'", session.csrf)
|
||||
}
|
||||
|
||||
if session.nonce != "" {
|
||||
t.Errorf("Expected nonce to be cleared, got '%s'", session.nonce)
|
||||
}
|
||||
|
||||
if session.codeVerifier != "" {
|
||||
t.Errorf("Expected code verifier to be cleared, got '%s'", session.codeVerifier)
|
||||
}
|
||||
|
||||
if session.incomingPath != "" {
|
||||
t.Errorf("Expected incoming path to be cleared, got '%s'", session.incomingPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_SuccessDefaultRedirect tests successful callback with default redirect
|
||||
func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "", // No incoming path, should default to "/"
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Check redirect to default path
|
||||
if rw.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
||||
}
|
||||
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOAuthHandler_HandleCallback_RedirectURLPathExcluded tests incoming path same as redirect URL
|
||||
func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
session := &mockSessionData{
|
||||
csrf: "test-state",
|
||||
nonce: "test-nonce",
|
||||
incomingPath: "/callback", // Same as redirect URL path
|
||||
}
|
||||
sessionManager := &mockSessionManager{sessionToReturn: session}
|
||||
tokenResponse := &TokenResponse{IDToken: "valid-token", AccessToken: "access-token"}
|
||||
tokenExchanger := &mockTokenExchanger{response: tokenResponse}
|
||||
tokenVerifier := &mockTokenVerifier{}
|
||||
|
||||
extractClaims := func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
isAllowed := func(email string) bool { return true }
|
||||
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {
|
||||
t.Errorf("Unexpected error sent: %s (code: %d)", msg, code)
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
handler.HandleCallback(rw, req, "http://example.com/callback")
|
||||
|
||||
// Should redirect to default path when incoming path is same as callback path
|
||||
location := rw.Header().Get("Location")
|
||||
if location != "/" {
|
||||
t.Errorf("Expected redirect to '/', got '%s'", location)
|
||||
}
|
||||
}
|
||||
@@ -1,454 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestURLHelper_NewURLHelper tests the URLHelper constructor
|
||||
func TestURLHelper_NewURLHelper(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
if helper == nil {
|
||||
t.Fatal("Expected URLHelper to be created, got nil")
|
||||
}
|
||||
|
||||
if helper.logger != logger {
|
||||
t.Error("Logger not set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineExcludedURL tests URL exclusion checking
|
||||
func TestURLHelper_DetermineExcludedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
currentURL string
|
||||
excludedURLs map[string]struct{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Exact match",
|
||||
currentURL: "/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Prefix match",
|
||||
currentURL: "/health/status",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No match",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - first match",
|
||||
currentURL: "/api/health",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple exclusions - second match",
|
||||
currentURL: "/health/check",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty excluded URLs",
|
||||
currentURL: "/api/users",
|
||||
excludedURLs: map[string]struct{}{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Root path exclusion",
|
||||
currentURL: "/anything",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Case sensitive matching",
|
||||
currentURL: "/API/users",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Partial substring but not prefix",
|
||||
currentURL: "/user/api/test",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/api": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Empty current URL",
|
||||
currentURL: "",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "URL with query parameters",
|
||||
currentURL: "/health?status=ok",
|
||||
excludedURLs: map[string]struct{}{
|
||||
"/health": {},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := helper.DetermineExcludedURL(tt.currentURL, tt.excludedURLs)
|
||||
if result != tt.expected {
|
||||
t.Errorf("DetermineExcludedURL() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
|
||||
// Verify debug logging for excluded URLs
|
||||
if result && len(logger.debugMessages) > 0 {
|
||||
// Should have logged a debug message for excluded URL
|
||||
found := false
|
||||
for _, msg := range logger.debugMessages {
|
||||
if msg == "URL is excluded - got %s / excluded hit: %s" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected debug message for excluded URL")
|
||||
}
|
||||
}
|
||||
|
||||
// Reset logger messages for next test
|
||||
logger.debugMessages = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineScheme tests scheme determination
|
||||
func TestURLHelper_DetermineScheme(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - https",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto header present - http",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "TLS connection without X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
{
|
||||
name: "No TLS and no X-Forwarded-Proto",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Proto takes precedence over TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Proto falls back to TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://example.com", nil)
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
req.Header.Set("X-Forwarded-Proto", "")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineScheme(req)
|
||||
if result != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", result, tt.expectedScheme)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineHost tests host determination
|
||||
func TestURLHelper_DetermineHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-Host header present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "No X-Forwarded-Host, use req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "direct.example.com"
|
||||
return req
|
||||
},
|
||||
expectedHost: "direct.example.com",
|
||||
},
|
||||
{
|
||||
name: "Empty X-Forwarded-Host falls back to req.Host",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "fallback.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "")
|
||||
return req
|
||||
},
|
||||
expectedHost: "fallback.example.com",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com:8080"
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com:443")
|
||||
return req
|
||||
},
|
||||
expectedHost: "public.example.com:443",
|
||||
},
|
||||
{
|
||||
name: "req.Host with port",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com:8080", nil)
|
||||
req.Host = "example.com:8080"
|
||||
return req
|
||||
},
|
||||
expectedHost: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "Multiple X-Forwarded-Host values (first one used)",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "first.example.com, second.example.com")
|
||||
return req
|
||||
},
|
||||
expectedHost: "first.example.com, second.example.com", // Header value as-is
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
result := helper.DetermineHost(req)
|
||||
if result != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", result, tt.expectedHost)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestURLHelper_DetermineSchemeAndHost_Integration tests scheme and host working together
|
||||
func TestURLHelper_DetermineSchemeAndHost_Integration(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupRequest func() *http.Request
|
||||
expectedScheme string
|
||||
expectedHost string
|
||||
}{
|
||||
{
|
||||
name: "Both headers present",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://internal.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "public.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "public.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, TLS connection",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "https://secure.example.com", nil)
|
||||
req.Host = "secure.example.com"
|
||||
req.TLS = &tls.ConnectionState{} // Simulate TLS connection
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "secure.example.com",
|
||||
},
|
||||
{
|
||||
name: "Neither header present, no TLS",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://plain.example.com", nil)
|
||||
req.Host = "plain.example.com"
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "plain.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only scheme header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "mixed.example.com"
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "https",
|
||||
expectedHost: "mixed.example.com",
|
||||
},
|
||||
{
|
||||
name: "Mixed - only host header",
|
||||
setupRequest: func() *http.Request {
|
||||
req, _ := http.NewRequest("GET", "http://mixed.example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
return req
|
||||
},
|
||||
expectedScheme: "http",
|
||||
expectedHost: "external.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupRequest()
|
||||
|
||||
scheme := helper.DetermineScheme(req)
|
||||
host := helper.DetermineHost(req)
|
||||
|
||||
if scheme != tt.expectedScheme {
|
||||
t.Errorf("DetermineScheme() = %v, expected %v", scheme, tt.expectedScheme)
|
||||
}
|
||||
|
||||
if host != tt.expectedHost {
|
||||
t.Errorf("DetermineHost() = %v, expected %v", host, tt.expectedHost)
|
||||
}
|
||||
|
||||
// Test that we can build a complete URL
|
||||
fullURL := scheme + "://" + host + "/callback"
|
||||
expectedURL := tt.expectedScheme + "://" + tt.expectedHost + "/callback"
|
||||
if fullURL != expectedURL {
|
||||
t.Errorf("Combined URL = %v, expected %v", fullURL, expectedURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests to ensure the helper methods are performant
|
||||
func BenchmarkURLHelper_DetermineExcludedURL(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
excludedURLs := map[string]struct{}{
|
||||
"/health": {},
|
||||
"/metrics": {},
|
||||
"/status": {},
|
||||
"/api/v1": {},
|
||||
"/api/v2": {},
|
||||
"/static": {},
|
||||
"/assets": {},
|
||||
"/favicon": {},
|
||||
"/robots": {},
|
||||
"/sitemap": {},
|
||||
}
|
||||
|
||||
testURL := "/api/users"
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineExcludedURL(testURL, excludedURLs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineScheme(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineScheme(req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkURLHelper_DetermineHost(b *testing.B) {
|
||||
logger := &mockLogger{}
|
||||
helper := NewURLHelper(logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "http://example.com", nil)
|
||||
req.Host = "internal.example.com"
|
||||
req.Header.Set("X-Forwarded-Host", "external.example.com")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
helper.DetermineHost(req)
|
||||
}
|
||||
}
|
||||
+4
-2
@@ -13,6 +13,8 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
)
|
||||
|
||||
// generateNonce creates a cryptographically secure random nonce for OIDC flows.
|
||||
@@ -349,8 +351,8 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
host := utils.DetermineHost(req)
|
||||
scheme := utils.DetermineScheme(req, t.forceHTTPS)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
|
||||
+12
-19
@@ -12,30 +12,23 @@ import (
|
||||
|
||||
// HTTPClientConfig provides configuration for creating HTTP clients
|
||||
type HTTPClientConfig struct {
|
||||
// Timeout for the entire request
|
||||
Timeout time.Duration
|
||||
// MaxRedirects allowed (0 means follow Go's default of 10)
|
||||
MaxRedirects int
|
||||
// UseCookieJar enables cookie jar for the client
|
||||
UseCookieJar bool
|
||||
// Connection settings
|
||||
IdleConnTimeout time.Duration
|
||||
MaxIdleConns int
|
||||
ReadBufferSize int
|
||||
DialTimeout time.Duration
|
||||
KeepAlive time.Duration
|
||||
TLSHandshakeTimeout time.Duration
|
||||
ResponseHeaderTimeout time.Duration
|
||||
ExpectContinueTimeout time.Duration
|
||||
IdleConnTimeout time.Duration
|
||||
// Connection pool settings
|
||||
MaxIdleConns int
|
||||
MaxIdleConnsPerHost int
|
||||
MaxConnsPerHost int
|
||||
// Buffer settings
|
||||
WriteBufferSize int
|
||||
ReadBufferSize int
|
||||
// Feature flags
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
MaxRedirects int
|
||||
MaxIdleConnsPerHost int
|
||||
Timeout time.Duration
|
||||
MaxConnsPerHost int
|
||||
WriteBufferSize int
|
||||
UseCookieJar bool
|
||||
ForceHTTP2 bool
|
||||
DisableKeepAlives bool
|
||||
DisableCompression bool
|
||||
}
|
||||
|
||||
// DefaultHTTPClientConfig returns the default configuration for general use
|
||||
|
||||
@@ -110,9 +110,9 @@ func TestHTTPClientFactoryValidateHTTPClientConfig(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMsg string
|
||||
config HTTPClientConfig
|
||||
wantError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid config",
|
||||
|
||||
+6
-6
@@ -12,19 +12,19 @@ import (
|
||||
|
||||
// SharedTransportPool manages a pool of shared HTTP transports to prevent connection exhaustion
|
||||
type SharedTransportPool struct {
|
||||
mu sync.RWMutex
|
||||
transports map[string]*sharedTransport
|
||||
maxConns int
|
||||
ctx context.Context
|
||||
transports map[string]*sharedTransport
|
||||
cancel context.CancelFunc
|
||||
clientCount int32 // SECURITY FIX: Track total HTTP clients
|
||||
maxClients int32 // SECURITY FIX: Limit total clients to 5
|
||||
maxConns int
|
||||
mu sync.RWMutex
|
||||
clientCount int32
|
||||
maxClients int32
|
||||
}
|
||||
|
||||
type sharedTransport struct {
|
||||
lastUsed time.Time
|
||||
transport *http.Transport
|
||||
refCount int
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ func TestInputValidator(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("Valid token validation", func(t *testing.T) {
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc"
|
||||
validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWV9.EkN-DOsnsuRjRO6BxXemmJDm3HbxrbRzXglbN2S4sOkopdU4IsDxTI8jO19W_A4K8ZPJijNLis4EZsHeY559a4DFOd50_OqgHs3UjpMC6M6FNqI2J-I2NxrragtnDxGxdJUvDERDQVHzeNlVQiuqWDEeO_O-0KptafbfyuGqfQxH_6dp2_MeFpAc" // trufflehog:ignore
|
||||
|
||||
result := validator.ValidateToken(validToken)
|
||||
if !result.IsValid {
|
||||
@@ -428,12 +428,12 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidJWTToken",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature", // trufflehog:ignore
|
||||
expectValid: true,
|
||||
description: "Valid JWT token should pass validation",
|
||||
},
|
||||
@@ -475,7 +475,7 @@ func TestInputValidatorValidateToken(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "MaliciousJWTWithExtraData",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra", // trufflehog:ignore
|
||||
expectValid: false,
|
||||
description: "JWT with extra malicious data should fail validation",
|
||||
},
|
||||
@@ -500,8 +500,8 @@ func TestInputValidatorValidateEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidEmail",
|
||||
@@ -578,8 +578,8 @@ func TestInputValidatorValidateURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHTTPSURL",
|
||||
@@ -669,8 +669,8 @@ func TestInputValidatorValidateClaim(t *testing.T) {
|
||||
name string
|
||||
claimName string
|
||||
claimValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidStringClaim",
|
||||
@@ -750,8 +750,8 @@ func TestInputValidatorValidateHeader(t *testing.T) {
|
||||
name string
|
||||
headerName string
|
||||
headerValue string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidHeader",
|
||||
@@ -830,8 +830,8 @@ func TestInputValidatorValidateUsername(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
expectValid bool
|
||||
description string
|
||||
expectValid bool
|
||||
}{
|
||||
{
|
||||
name: "ValidUsername",
|
||||
|
||||
@@ -726,20 +726,20 @@ type MockConfig struct {
|
||||
}
|
||||
|
||||
type MockSession struct {
|
||||
id string
|
||||
userID string
|
||||
created time.Time
|
||||
lastUsed time.Time
|
||||
data map[string]interface{}
|
||||
id string
|
||||
userID string
|
||||
}
|
||||
|
||||
type TestResult struct {
|
||||
UserID int
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Error error
|
||||
UserID int
|
||||
Duration time.Duration
|
||||
Success bool
|
||||
Error error
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
|
||||
Vendored
+14
-25
@@ -18,33 +18,22 @@ const (
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
L2Config *Config
|
||||
L1Config *Config
|
||||
RedisPrefix string
|
||||
Type BackendType
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
PoolSize int
|
||||
RedisDB int
|
||||
CleanupInterval time.Duration
|
||||
MaxMemoryBytes int64
|
||||
MaxSize int
|
||||
HealthCheckInterval time.Duration
|
||||
AsyncWrites bool
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
|
||||
Vendored
+18
-28
@@ -13,40 +13,30 @@ import (
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
lastL2Error atomic.Value
|
||||
secondary CacheBackend
|
||||
primary CacheBackend
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
syncWriteCacheTypes map[string]bool
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
l1Hits atomic.Int64
|
||||
errors atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
misses atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
fallbackMode atomic.Bool
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
ctx context.Context
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
@@ -82,9 +72,9 @@ func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
SyncWriteCacheTypes map[string]bool
|
||||
AsyncBufferSize int
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
|
||||
+6
-6
@@ -17,23 +17,23 @@ import (
|
||||
|
||||
// mockBackend is a simple mock implementation of CacheBackend for testing
|
||||
type mockBackend struct {
|
||||
pingError error
|
||||
data map[string]mockEntry
|
||||
stats map[string]interface{}
|
||||
mu sync.RWMutex
|
||||
getCalls atomic.Int32
|
||||
setCalls atomic.Int32
|
||||
deleteCalls atomic.Int32
|
||||
failSet bool
|
||||
failGet bool
|
||||
failDelete bool
|
||||
failClear bool
|
||||
failPing bool
|
||||
pingError error
|
||||
stats map[string]interface{}
|
||||
getCalls atomic.Int32
|
||||
setCalls atomic.Int32
|
||||
deleteCalls atomic.Int32
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
value []byte
|
||||
}
|
||||
|
||||
// mockBatchBackend extends mockBackend with batch operations
|
||||
|
||||
Vendored
+14
-45
@@ -41,53 +41,22 @@ type CacheBackend interface {
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
StartTime time.Time
|
||||
LastErrorTime time.Time
|
||||
Type BackendType
|
||||
LastError string
|
||||
Deletes int64
|
||||
Errors int64
|
||||
Evictions int64
|
||||
CurrentSize int64
|
||||
MaxSize int64
|
||||
MemoryUsage int64
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
Sets int64
|
||||
Misses int64
|
||||
Uptime time.Duration
|
||||
Hits int64
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
|
||||
Vendored
+27
-35
@@ -11,14 +11,14 @@ import (
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
value interface{}
|
||||
element *list.Element
|
||||
key string
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
@@ -31,39 +31,31 @@ func (item *memoryCacheItem) isExpired() bool {
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Statistics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
startTime time.Time
|
||||
lastErrorTime time.Time
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
cleanupDone chan bool
|
||||
cleanupTicker *time.Ticker
|
||||
evictionPolicy string
|
||||
lastError string
|
||||
currentMemory int64
|
||||
misses atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
sets atomic.Int64
|
||||
hits atomic.Int64
|
||||
maxSize int64
|
||||
currentSize int64
|
||||
maxMemory int64
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
mu sync.RWMutex
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
|
||||
Vendored
+22
-7
@@ -43,14 +43,17 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
}
|
||||
|
||||
// Create connection pool with health checks enabled
|
||||
// Timeouts are kept short to prevent request pileup when Redis is slow/stalled.
|
||||
// The UniversalCache uses 200ms context timeout, so socket timeouts should be
|
||||
// shorter to allow proper context cancellation handling.
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 500 * time.Millisecond,
|
||||
WriteTimeout: 500 * time.Millisecond,
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
@@ -73,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
@@ -260,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -340,12 +343,19 @@ func (r *RedisBackend) prefixKey(key string) string {
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
|
||||
// It checks context cancellation at multiple points to ensure fast abort when the
|
||||
// caller's context is cancelled (e.g., due to request timeout).
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 100 * time.Millisecond
|
||||
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// Check context before each attempt to fail fast
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
// If we can't get a connection and this is the last attempt, fail
|
||||
@@ -367,6 +377,11 @@ func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*Red
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// Check context after operation - if cancelled, don't bother retrying
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// If successful, return
|
||||
if err == nil {
|
||||
return nil
|
||||
|
||||
+10
-16
@@ -9,30 +9,24 @@ import (
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
|
||||
// State
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
lastCheckTime atomic.Int64 // Unix timestamp
|
||||
|
||||
// Metrics
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
lastCheckTime atomic.Int64
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration // How often to check health
|
||||
Timeout time.Duration // Timeout for health check
|
||||
UnhealthyThreshold int // Consecutive failures before marking unhealthy
|
||||
OnHealthChange func(healthy bool)
|
||||
CheckInterval time.Duration
|
||||
Timeout time.Duration
|
||||
UnhealthyThreshold int
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
|
||||
+23
-22
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
|
||||
Vendored
+5
-5
@@ -15,8 +15,8 @@ import (
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
args []string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
@@ -205,9 +205,9 @@ func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
expected interface{}
|
||||
name string
|
||||
input string
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
@@ -440,10 +440,10 @@ func TestRESPHelpers(t *testing.T) {
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command []string
|
||||
response string
|
||||
expected interface{}
|
||||
name string
|
||||
response string
|
||||
command []string
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
|
||||
Vendored
+28
-40
@@ -33,21 +33,19 @@ type Logger interface {
|
||||
|
||||
// Config provides configuration for the cache
|
||||
type Config struct {
|
||||
Logger Logger
|
||||
JWKConfig *JWKConfig
|
||||
MetadataConfig *MetadataConfig
|
||||
TokenConfig *TokenConfig
|
||||
Type Type
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableCompression bool
|
||||
MaxMemoryBytes int64
|
||||
MaxSize int
|
||||
EnableMetrics bool
|
||||
EnableAutoCleanup bool
|
||||
EnableMemoryLimit bool
|
||||
Logger Logger
|
||||
|
||||
// Type-specific configurations
|
||||
TokenConfig *TokenConfig
|
||||
MetadataConfig *MetadataConfig
|
||||
JWKConfig *JWKConfig
|
||||
EnableCompression bool
|
||||
}
|
||||
|
||||
// TokenConfig provides token-specific cache configuration
|
||||
@@ -59,11 +57,11 @@ type TokenConfig struct {
|
||||
|
||||
// MetadataConfig provides metadata-specific cache configuration
|
||||
type MetadataConfig struct {
|
||||
SecurityCriticalFields []string
|
||||
GracePeriod time.Duration
|
||||
ExtendedGracePeriod time.Duration
|
||||
MaxGracePeriod time.Duration
|
||||
SecurityCriticalMaxGracePeriod time.Duration
|
||||
SecurityCriticalFields []string
|
||||
}
|
||||
|
||||
// JWKConfig provides JWK-specific cache configuration
|
||||
@@ -75,45 +73,35 @@ type JWKConfig struct {
|
||||
|
||||
// Item represents a single cache entry
|
||||
type Item struct {
|
||||
Key string
|
||||
Value interface{}
|
||||
Size int64
|
||||
ExpiresAt time.Time
|
||||
LastAccessed time.Time
|
||||
AccessCount int64
|
||||
Value interface{}
|
||||
Metadata map[string]interface{}
|
||||
element *list.Element
|
||||
Key string
|
||||
CacheType Type
|
||||
|
||||
// Type-specific metadata
|
||||
Metadata map[string]interface{}
|
||||
|
||||
// LRU list element reference
|
||||
element *list.Element
|
||||
Size int64
|
||||
AccessCount int64
|
||||
}
|
||||
|
||||
// Cache provides a single, unified cache implementation
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*Item
|
||||
lruList *list.List
|
||||
config Config
|
||||
logger Logger
|
||||
|
||||
// Memory management
|
||||
config Config
|
||||
ctx context.Context
|
||||
logger Logger
|
||||
cancel context.CancelFunc
|
||||
lruList *list.List
|
||||
items map[string]*Item
|
||||
stopCleanup chan bool
|
||||
wg sync.WaitGroup
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Metrics
|
||||
hits int64
|
||||
misses int64
|
||||
evictions int64
|
||||
sets int64
|
||||
|
||||
// Lifecycle management
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
stopCleanup chan bool
|
||||
closed int32
|
||||
hits int64
|
||||
misses int64
|
||||
evictions int64
|
||||
sets int64
|
||||
mu sync.RWMutex
|
||||
closed int32
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default cache configuration
|
||||
|
||||
Vendored
+11
-11
@@ -1750,19 +1750,19 @@ func TestAdvancedEdgeCases(t *testing.T) {
|
||||
|
||||
// Test with various data types
|
||||
testCases := []struct {
|
||||
key string
|
||||
value interface{}
|
||||
key string
|
||||
}{
|
||||
{"string", "test string"},
|
||||
{"int", 42},
|
||||
{"float", 3.14159},
|
||||
{"bool", true},
|
||||
{"slice", []string{"a", "b", "c"}},
|
||||
{"map", map[string]int{"one": 1, "two": 2}},
|
||||
{"nil", nil},
|
||||
{"empty-string", ""},
|
||||
{"empty-slice", []string{}},
|
||||
{"empty-map", map[string]interface{}{}},
|
||||
{key: "string", value: "test string"},
|
||||
{key: "int", value: 42},
|
||||
{key: "float", value: 3.14159},
|
||||
{key: "bool", value: true},
|
||||
{key: "slice", value: []string{"a", "b", "c"}},
|
||||
{key: "map", value: map[string]int{"one": 1, "two": 2}},
|
||||
{key: "nil", value: nil},
|
||||
{key: "empty-string", value: ""},
|
||||
{key: "empty-slice", value: []string{}},
|
||||
{key: "empty-map", value: map[string]interface{}{}},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
Vendored
+2
-7
@@ -7,22 +7,17 @@ import (
|
||||
|
||||
// Manager manages multiple cache instances with singleton pattern
|
||||
type Manager struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// Core caches
|
||||
logger Logger
|
||||
tokenCache *Cache
|
||||
metadataCache *Cache
|
||||
jwkCache *Cache
|
||||
sessionCache *Cache
|
||||
generalCache *Cache
|
||||
|
||||
// Typed wrappers
|
||||
typedToken *TokenCache
|
||||
typedMetadata *MetadataCache
|
||||
typedJWK *JWKCache
|
||||
typedSession *SessionCache
|
||||
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
+26
-42
@@ -48,23 +48,12 @@ func (s State) String() string {
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
OnStateChange func(from, to State)
|
||||
MaxFailures int
|
||||
FailureThreshold float64
|
||||
Timeout time.Duration
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
ResetTimeout time.Duration
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
@@ -80,28 +69,20 @@ func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
nextRetryTime time.Time
|
||||
lastStateChange time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
config *CircuitBreakerConfig
|
||||
totalFailures atomic.Int64
|
||||
totalRequests atomic.Int64
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
stateMu sync.RWMutex
|
||||
timeMu sync.RWMutex
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
consecutiveFailures atomic.Int32
|
||||
state atomic.Int32
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
@@ -158,6 +139,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
@@ -181,6 +163,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
@@ -203,6 +186,7 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
// #nosec G115 -- MaxFailures is a small config value that fits in int32
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
@@ -310,17 +294,17 @@ func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
|
||||
@@ -28,8 +28,8 @@ type mockBackend struct {
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
value []byte
|
||||
}
|
||||
|
||||
func newMockBackend() *mockBackend {
|
||||
|
||||
+30
-49
@@ -41,26 +41,13 @@ func (h HealthStatus) String() string {
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
CheckFunc func(ctx context.Context) error
|
||||
CheckInterval time.Duration
|
||||
Timeout time.Duration
|
||||
HealthyThreshold int
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
DegradedThreshold time.Duration
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
@@ -76,31 +63,23 @@ func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
config *HealthCheckConfig
|
||||
stopChan chan struct{}
|
||||
ticker *time.Ticker
|
||||
wg sync.WaitGroup
|
||||
statusChanges atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
consecutiveSuccesses atomic.Int32
|
||||
stopped atomic.Bool
|
||||
status atomic.Int32
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
@@ -217,6 +196,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
@@ -241,6 +221,7 @@ func (hc *HealthChecker) recordFailure() {
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
@@ -340,19 +321,19 @@ func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
LastCheckTime time.Time
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
Status HealthStatus
|
||||
ConsecutiveFailures int32
|
||||
ConsecutiveSuccesses int32
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
|
||||
+8
-11
@@ -12,20 +12,16 @@ import (
|
||||
|
||||
// HealthCheckBackend wraps a cache backend with health checking
|
||||
type HealthCheckBackend struct {
|
||||
backend backends.CacheBackend
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Health tracking
|
||||
lastCheck time.Time
|
||||
backend backends.CacheBackend
|
||||
ctx context.Context
|
||||
config *HealthCheckConfig
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
checkMutex sync.RWMutex
|
||||
status atomic.Int32
|
||||
consecutiveFails atomic.Int32
|
||||
consecutiveOK atomic.Int32
|
||||
lastCheck time.Time
|
||||
checkMutex sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthCheckBackend creates a new health check wrapped backend
|
||||
@@ -150,6 +146,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
// #nosec G115 -- threshold config values are small integers that fit in int32
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
Vendored
+2
-2
@@ -292,12 +292,12 @@ type SessionCache struct {
|
||||
|
||||
// SessionData represents session information
|
||||
type SessionData struct {
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Claims map[string]interface{} `json:"claims"`
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Claims map[string]interface{} `json:"claims"`
|
||||
}
|
||||
|
||||
// NewSessionCache creates a new session cache
|
||||
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
|
||||
// Mock logger for testing
|
||||
type mockLogger struct {
|
||||
mu sync.Mutex
|
||||
logs []string
|
||||
errLogs []string
|
||||
debugLog []string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (m *mockLogger) Logf(format string, args ...interface{}) {
|
||||
|
||||
+11
-11
@@ -19,20 +19,20 @@ type Logger interface {
|
||||
|
||||
// BackgroundTask represents a recurring background task
|
||||
type BackgroundTask struct {
|
||||
name string
|
||||
interval time.Duration
|
||||
taskFunc func()
|
||||
lastRun time.Time
|
||||
logger Logger
|
||||
ctx context.Context
|
||||
ticker *time.Ticker
|
||||
stopChan chan bool
|
||||
isRunning int32
|
||||
logger Logger
|
||||
waitGroup *sync.WaitGroup
|
||||
lastRun time.Time
|
||||
taskFunc func()
|
||||
cancelFunc context.CancelFunc
|
||||
name string
|
||||
runCount int64
|
||||
errorCount int64
|
||||
interval time.Duration
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
isRunning int32
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task
|
||||
@@ -183,11 +183,11 @@ func (bt *BackgroundTask) IsRunning() bool {
|
||||
|
||||
// TaskRegistry manages all background tasks
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
logger Logger
|
||||
maxTasks int
|
||||
tasks map[string]*BackgroundTask
|
||||
circuitBreaker *TaskCircuitBreaker
|
||||
maxTasks int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// globalTaskRegistry is the singleton task registry
|
||||
|
||||
+13
-13
@@ -11,14 +11,14 @@ import (
|
||||
|
||||
// TaskCircuitBreaker prevents task creation failures from cascading
|
||||
type TaskCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
logger Logger
|
||||
taskFailures map[string]int32
|
||||
timeout time.Duration
|
||||
mu sync.RWMutex
|
||||
failureThreshold int32
|
||||
failureCount int32
|
||||
lastFailureTime time.Time
|
||||
timeout time.Duration
|
||||
state int32 // 0: closed, 1: open
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
taskFailures map[string]int32
|
||||
state int32
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents the state of the circuit breaker
|
||||
@@ -140,14 +140,14 @@ func (cb *TaskCircuitBreaker) GetState() CircuitBreakerState {
|
||||
|
||||
// TaskMemoryMonitor monitors memory usage and can trigger cleanup
|
||||
type TaskMemoryMonitor struct {
|
||||
lastCheck time.Time
|
||||
logger Logger
|
||||
registry *TaskRegistry
|
||||
stopChan chan bool
|
||||
memoryThreshold uint64
|
||||
checkInterval time.Duration
|
||||
isMonitoring int32
|
||||
stopChan chan bool
|
||||
lastCheck time.Time
|
||||
mu sync.RWMutex
|
||||
isMonitoring int32
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -310,13 +310,13 @@ func (tmm *TaskMemoryMonitor) GetStats() map[string]interface{} {
|
||||
|
||||
// WorkerPool manages a pool of worker goroutines for task execution
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
taskQueue chan func()
|
||||
workerWg sync.WaitGroup
|
||||
isRunning int32
|
||||
logger Logger
|
||||
taskQueue chan func()
|
||||
stopChan chan bool
|
||||
metrics WorkerPoolMetrics
|
||||
workerWg sync.WaitGroup
|
||||
workers int
|
||||
isRunning int32
|
||||
}
|
||||
|
||||
// WorkerPoolMetrics tracks worker pool performance
|
||||
|
||||
@@ -1,219 +0,0 @@
|
||||
// Package errors provides unified error handling for OIDC operations
|
||||
package errors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ErrorCode represents specific error types
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
// Authentication errors
|
||||
ErrCodeAuthenticationFailed ErrorCode = "AUTH_FAILED"
|
||||
ErrCodeTokenExpired ErrorCode = "TOKEN_EXPIRED"
|
||||
ErrCodeTokenInvalid ErrorCode = "TOKEN_INVALID"
|
||||
ErrCodeSessionExpired ErrorCode = "SESSION_EXPIRED"
|
||||
ErrCodeCSRFMismatch ErrorCode = "CSRF_MISMATCH"
|
||||
ErrCodeNonceMismatch ErrorCode = "NONCE_MISMATCH"
|
||||
|
||||
// Configuration errors
|
||||
ErrCodeConfigInvalid ErrorCode = "CONFIG_INVALID"
|
||||
ErrCodeProviderUnreachable ErrorCode = "PROVIDER_UNREACHABLE"
|
||||
ErrCodeMetadataFailed ErrorCode = "METADATA_FAILED"
|
||||
|
||||
// Network errors
|
||||
ErrCodeNetworkTimeout ErrorCode = "NETWORK_TIMEOUT"
|
||||
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
|
||||
ErrCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
|
||||
|
||||
// Validation errors
|
||||
ErrCodeValidationFailed ErrorCode = "VALIDATION_FAILED"
|
||||
ErrCodeDomainNotAllowed ErrorCode = "DOMAIN_NOT_ALLOWED"
|
||||
ErrCodeUserNotAllowed ErrorCode = "USER_NOT_ALLOWED"
|
||||
ErrCodeRoleNotAllowed ErrorCode = "ROLE_NOT_ALLOWED"
|
||||
)
|
||||
|
||||
// OIDCError represents a structured error with context
|
||||
type OIDCError struct {
|
||||
Code ErrorCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
HTTPStatus int `json:"http_status"`
|
||||
Internal error `json:"-"` // Internal error, not exposed
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *OIDCError) Error() string {
|
||||
if e.Details != "" {
|
||||
return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// Unwrap returns the internal error for error wrapping
|
||||
func (e *OIDCError) Unwrap() error {
|
||||
return e.Internal
|
||||
}
|
||||
|
||||
// IsRetryable indicates if the error is temporary and can be retried
|
||||
func (e *OIDCError) IsRetryable() bool {
|
||||
return e.Code == ErrCodeNetworkTimeout ||
|
||||
e.Code == ErrCodeServiceUnavailable ||
|
||||
e.Code == ErrCodeProviderUnreachable
|
||||
}
|
||||
|
||||
// IsAuthenticationError indicates if this is an authentication-related error
|
||||
func (e *OIDCError) IsAuthenticationError() bool {
|
||||
return e.Code == ErrCodeAuthenticationFailed ||
|
||||
e.Code == ErrCodeTokenExpired ||
|
||||
e.Code == ErrCodeTokenInvalid ||
|
||||
e.Code == ErrCodeSessionExpired ||
|
||||
e.Code == ErrCodeCSRFMismatch ||
|
||||
e.Code == ErrCodeNonceMismatch
|
||||
}
|
||||
|
||||
// IsAuthorizationError indicates if this is an authorization-related error
|
||||
func (e *OIDCError) IsAuthorizationError() bool {
|
||||
return e.Code == ErrCodeDomainNotAllowed ||
|
||||
e.Code == ErrCodeUserNotAllowed ||
|
||||
e.Code == ErrCodeRoleNotAllowed
|
||||
}
|
||||
|
||||
// ToJSON converts the error to a JSON response
|
||||
func (e *OIDCError) ToJSON() map[string]any {
|
||||
result := map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": string(e.Code),
|
||||
"message": e.Message,
|
||||
},
|
||||
}
|
||||
|
||||
if e.Details != "" {
|
||||
errorMap, _ := result["error"].(map[string]any) // Safe to ignore: type assertion from known type
|
||||
errorMap["details"] = e.Details
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Error constructors for common scenarios
|
||||
|
||||
// NewAuthenticationError creates an authentication-related error
|
||||
func NewAuthenticationError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
status := http.StatusUnauthorized
|
||||
if code == ErrCodeSessionExpired {
|
||||
status = http.StatusForbidden
|
||||
}
|
||||
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: status,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthorizationError creates an authorization-related error
|
||||
func NewAuthorizationError(code ErrorCode, message string, details string) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
}
|
||||
}
|
||||
|
||||
// NewConfigurationError creates a configuration-related error
|
||||
func NewConfigurationError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: http.StatusInternalServerError,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewNetworkError creates a network-related error
|
||||
func NewNetworkError(code ErrorCode, message string, internal error) *OIDCError {
|
||||
status := http.StatusServiceUnavailable
|
||||
if code == ErrCodeRateLimited {
|
||||
status = http.StatusTooManyRequests
|
||||
}
|
||||
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
HTTPStatus: status,
|
||||
Internal: internal,
|
||||
}
|
||||
}
|
||||
|
||||
// NewValidationError creates a validation-related error
|
||||
func NewValidationError(code ErrorCode, message string, details string) *OIDCError {
|
||||
return &OIDCError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience functions for common error patterns
|
||||
|
||||
// WrapAuthenticationError wraps an existing error as an authentication error
|
||||
func WrapAuthenticationError(err error, message string) *OIDCError {
|
||||
return NewAuthenticationError(ErrCodeAuthenticationFailed, message, err)
|
||||
}
|
||||
|
||||
// WrapTokenError wraps a token-related error
|
||||
func WrapTokenError(err error, tokenType string) *OIDCError {
|
||||
message := fmt.Sprintf("Token validation failed: %s", tokenType)
|
||||
return NewAuthenticationError(ErrCodeTokenInvalid, message, err)
|
||||
}
|
||||
|
||||
// WrapProviderError wraps a provider communication error
|
||||
func WrapProviderError(err error, providerURL string) *OIDCError {
|
||||
message := fmt.Sprintf("Provider communication failed: %s", providerURL)
|
||||
return NewNetworkError(ErrCodeProviderUnreachable, message, err)
|
||||
}
|
||||
|
||||
// IsOIDCError checks if an error is an OIDCError
|
||||
func IsOIDCError(err error) (*OIDCError, bool) {
|
||||
oidcErr, ok := err.(*OIDCError)
|
||||
return oidcErr, ok
|
||||
}
|
||||
|
||||
// GetHTTPStatus extracts HTTP status from error, defaulting to 500
|
||||
func GetHTTPStatus(err error) int {
|
||||
if oidcErr, ok := IsOIDCError(err); ok {
|
||||
return oidcErr.HTTPStatus
|
||||
}
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// FormatUserMessage creates a user-friendly error message
|
||||
func FormatUserMessage(err error) string {
|
||||
if oidcErr, ok := IsOIDCError(err); ok {
|
||||
switch oidcErr.Code {
|
||||
case ErrCodeDomainNotAllowed:
|
||||
return "Your email domain is not authorized for this application"
|
||||
case ErrCodeUserNotAllowed:
|
||||
return "Your account is not authorized for this application"
|
||||
case ErrCodeRoleNotAllowed:
|
||||
return "You do not have the required permissions for this application"
|
||||
case ErrCodeSessionExpired:
|
||||
return "Your session has expired. Please log in again"
|
||||
case ErrCodeTokenExpired:
|
||||
return "Your authentication has expired. Please log in again"
|
||||
case ErrCodeProviderUnreachable:
|
||||
return "Authentication service is temporarily unavailable. Please try again later"
|
||||
case ErrCodeRateLimited:
|
||||
return "Too many requests. Please wait a moment and try again"
|
||||
default:
|
||||
return "Authentication failed. Please try again"
|
||||
}
|
||||
}
|
||||
return "An unexpected error occurred. Please try again"
|
||||
}
|
||||
@@ -1,529 +0,0 @@
|
||||
package errors
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOIDCError_Error(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcErr *OIDCError
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Error with details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Details: "JWT signature invalid",
|
||||
},
|
||||
expected: "TOKEN_INVALID: Token validation failed (JWT signature invalid)",
|
||||
},
|
||||
{
|
||||
name: "Error without details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeAuthenticationFailed,
|
||||
Message: "Authentication failed",
|
||||
},
|
||||
expected: "AUTH_FAILED: Authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.oidcErr.Error()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_Unwrap(t *testing.T) {
|
||||
internalErr := errors.New("internal error")
|
||||
oidcErr := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Internal: internalErr,
|
||||
}
|
||||
|
||||
unwrapped := oidcErr.Unwrap()
|
||||
if unwrapped != internalErr {
|
||||
t.Errorf("Expected internal error, got %v", unwrapped)
|
||||
}
|
||||
|
||||
// Test with nil internal error
|
||||
oidcErrNoInternal := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
}
|
||||
|
||||
unwrappedNil := oidcErrNoInternal.Unwrap()
|
||||
if unwrappedNil != nil {
|
||||
t.Errorf("Expected nil, got %v", unwrappedNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsRetryable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Network timeout", ErrCodeNetworkTimeout, true},
|
||||
{"Service unavailable", ErrCodeServiceUnavailable, true},
|
||||
{"Provider unreachable", ErrCodeProviderUnreachable, true},
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, false},
|
||||
{"Token invalid", ErrCodeTokenInvalid, false},
|
||||
{"Rate limited", ErrCodeRateLimited, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsRetryable()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsAuthenticationError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, true},
|
||||
{"Token expired", ErrCodeTokenExpired, true},
|
||||
{"Token invalid", ErrCodeTokenInvalid, true},
|
||||
{"Session expired", ErrCodeSessionExpired, true},
|
||||
{"CSRF mismatch", ErrCodeCSRFMismatch, true},
|
||||
{"Nonce mismatch", ErrCodeNonceMismatch, true},
|
||||
{"Config invalid", ErrCodeConfigInvalid, false},
|
||||
{"Domain not allowed", ErrCodeDomainNotAllowed, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsAuthenticationError()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_IsAuthorizationError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expected bool
|
||||
}{
|
||||
{"Domain not allowed", ErrCodeDomainNotAllowed, true},
|
||||
{"User not allowed", ErrCodeUserNotAllowed, true},
|
||||
{"Role not allowed", ErrCodeRoleNotAllowed, true},
|
||||
{"Authentication failed", ErrCodeAuthenticationFailed, false},
|
||||
{"Token expired", ErrCodeTokenExpired, false},
|
||||
{"Config invalid", ErrCodeConfigInvalid, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
oidcErr := &OIDCError{Code: tt.code}
|
||||
result := oidcErr.IsAuthorizationError()
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %v, got %v for code %s", tt.expected, result, tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCError_ToJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oidcErr *OIDCError
|
||||
expected map[string]any
|
||||
}{
|
||||
{
|
||||
name: "Error with details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
Message: "Token validation failed",
|
||||
Details: "JWT signature invalid",
|
||||
},
|
||||
expected: map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "TOKEN_INVALID",
|
||||
"message": "Token validation failed",
|
||||
"details": "JWT signature invalid",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Error without details",
|
||||
oidcErr: &OIDCError{
|
||||
Code: ErrCodeAuthenticationFailed,
|
||||
Message: "Authentication failed",
|
||||
},
|
||||
expected: map[string]any{
|
||||
"error": map[string]any{
|
||||
"code": "AUTH_FAILED",
|
||||
"message": "Authentication failed",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.oidcErr.ToJSON()
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("Expected %+v, got %+v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticationError(t *testing.T) {
|
||||
internalErr := errors.New("internal error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
message string
|
||||
internal error
|
||||
expectedHTTP int
|
||||
}{
|
||||
{
|
||||
name: "Regular auth error",
|
||||
code: ErrCodeAuthenticationFailed,
|
||||
message: "Auth failed",
|
||||
internal: internalErr,
|
||||
expectedHTTP: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "Session expired error",
|
||||
code: ErrCodeSessionExpired,
|
||||
message: "Session expired",
|
||||
internal: internalErr,
|
||||
expectedHTTP: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewAuthenticationError(tt.code, tt.message, tt.internal)
|
||||
|
||||
if err.Code != tt.code {
|
||||
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
|
||||
}
|
||||
if err.Message != tt.message {
|
||||
t.Errorf("Expected message '%s', got '%s'", tt.message, err.Message)
|
||||
}
|
||||
if err.Internal != tt.internal {
|
||||
t.Errorf("Expected internal error %v, got %v", tt.internal, err.Internal)
|
||||
}
|
||||
if err.HTTPStatus != tt.expectedHTTP {
|
||||
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthorizationError(t *testing.T) {
|
||||
err := NewAuthorizationError(ErrCodeDomainNotAllowed, "Domain not allowed", "example.com not in whitelist")
|
||||
|
||||
if err.Code != ErrCodeDomainNotAllowed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeDomainNotAllowed, err.Code)
|
||||
}
|
||||
if err.Message != "Domain not allowed" {
|
||||
t.Errorf("Expected message 'Domain not allowed', got '%s'", err.Message)
|
||||
}
|
||||
if err.Details != "example.com not in whitelist" {
|
||||
t.Errorf("Expected details 'example.com not in whitelist', got '%s'", err.Details)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusForbidden {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusForbidden, err.HTTPStatus)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConfigurationError(t *testing.T) {
|
||||
internalErr := errors.New("config parse error")
|
||||
err := NewConfigurationError(ErrCodeConfigInvalid, "Invalid config", internalErr)
|
||||
|
||||
if err.Code != ErrCodeConfigInvalid {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeConfigInvalid, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusInternalServerError {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusInternalServerError, err.HTTPStatus)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewNetworkError(t *testing.T) {
|
||||
internalErr := errors.New("network error")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
code ErrorCode
|
||||
expectedHTTP int
|
||||
}{
|
||||
{
|
||||
name: "Rate limited",
|
||||
code: ErrCodeRateLimited,
|
||||
expectedHTTP: http.StatusTooManyRequests,
|
||||
},
|
||||
{
|
||||
name: "Service unavailable",
|
||||
code: ErrCodeServiceUnavailable,
|
||||
expectedHTTP: http.StatusServiceUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := NewNetworkError(tt.code, "Network error", internalErr)
|
||||
|
||||
if err.Code != tt.code {
|
||||
t.Errorf("Expected code %s, got %s", tt.code, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != tt.expectedHTTP {
|
||||
t.Errorf("Expected HTTP status %d, got %d", tt.expectedHTTP, err.HTTPStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidationError(t *testing.T) {
|
||||
err := NewValidationError(ErrCodeValidationFailed, "Validation failed", "field 'email' is required")
|
||||
|
||||
if err.Code != ErrCodeValidationFailed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeValidationFailed, err.Code)
|
||||
}
|
||||
if err.HTTPStatus != http.StatusBadRequest {
|
||||
t.Errorf("Expected HTTP status %d, got %d", http.StatusBadRequest, err.HTTPStatus)
|
||||
}
|
||||
if err.Details != "field 'email' is required" {
|
||||
t.Errorf("Expected details 'field 'email' is required', got '%s'", err.Details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapAuthenticationError(t *testing.T) {
|
||||
internalErr := errors.New("original error")
|
||||
err := WrapAuthenticationError(internalErr, "Custom auth message")
|
||||
|
||||
if err.Code != ErrCodeAuthenticationFailed {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeAuthenticationFailed, err.Code)
|
||||
}
|
||||
if err.Message != "Custom auth message" {
|
||||
t.Errorf("Expected message 'Custom auth message', got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapTokenError(t *testing.T) {
|
||||
internalErr := errors.New("token error")
|
||||
err := WrapTokenError(internalErr, "ID token")
|
||||
|
||||
if err.Code != ErrCodeTokenInvalid {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeTokenInvalid, err.Code)
|
||||
}
|
||||
if err.Message != "Token validation failed: ID token" {
|
||||
t.Errorf("Expected message 'Token validation failed: ID token', got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWrapProviderError(t *testing.T) {
|
||||
internalErr := errors.New("provider error")
|
||||
err := WrapProviderError(internalErr, "https://provider.example.com")
|
||||
|
||||
if err.Code != ErrCodeProviderUnreachable {
|
||||
t.Errorf("Expected code %s, got %s", ErrCodeProviderUnreachable, err.Code)
|
||||
}
|
||||
if err.Message != "Provider communication failed: https://provider.example.com" {
|
||||
t.Errorf("Expected specific message, got '%s'", err.Message)
|
||||
}
|
||||
if err.Internal != internalErr {
|
||||
t.Errorf("Expected internal error %v, got %v", internalErr, err.Internal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsOIDCError(t *testing.T) {
|
||||
// Test with OIDCError
|
||||
oidcErr := &OIDCError{Code: ErrCodeTokenInvalid, Message: "test"}
|
||||
result, ok := IsOIDCError(oidcErr)
|
||||
if !ok {
|
||||
t.Error("Expected IsOIDCError to return true for OIDCError")
|
||||
}
|
||||
if result != oidcErr {
|
||||
t.Error("Expected to get the same OIDCError back")
|
||||
}
|
||||
|
||||
// Test with regular error
|
||||
regularErr := errors.New("regular error")
|
||||
result, ok = IsOIDCError(regularErr)
|
||||
if ok {
|
||||
t.Error("Expected IsOIDCError to return false for regular error")
|
||||
}
|
||||
if result != nil {
|
||||
t.Error("Expected nil result for regular error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetHTTPStatus(t *testing.T) {
|
||||
// Test with OIDCError
|
||||
oidcErr := &OIDCError{
|
||||
Code: ErrCodeTokenInvalid,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
}
|
||||
status := GetHTTPStatus(oidcErr)
|
||||
if status != http.StatusUnauthorized {
|
||||
t.Errorf("Expected %d, got %d", http.StatusUnauthorized, status)
|
||||
}
|
||||
|
||||
// Test with regular error
|
||||
regularErr := errors.New("regular error")
|
||||
status = GetHTTPStatus(regularErr)
|
||||
if status != http.StatusInternalServerError {
|
||||
t.Errorf("Expected %d, got %d", http.StatusInternalServerError, status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatUserMessage(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Domain not allowed",
|
||||
err: &OIDCError{Code: ErrCodeDomainNotAllowed},
|
||||
expected: "Your email domain is not authorized for this application",
|
||||
},
|
||||
{
|
||||
name: "User not allowed",
|
||||
err: &OIDCError{Code: ErrCodeUserNotAllowed},
|
||||
expected: "Your account is not authorized for this application",
|
||||
},
|
||||
{
|
||||
name: "Role not allowed",
|
||||
err: &OIDCError{Code: ErrCodeRoleNotAllowed},
|
||||
expected: "You do not have the required permissions for this application",
|
||||
},
|
||||
{
|
||||
name: "Session expired",
|
||||
err: &OIDCError{Code: ErrCodeSessionExpired},
|
||||
expected: "Your session has expired. Please log in again",
|
||||
},
|
||||
{
|
||||
name: "Token expired",
|
||||
err: &OIDCError{Code: ErrCodeTokenExpired},
|
||||
expected: "Your authentication has expired. Please log in again",
|
||||
},
|
||||
{
|
||||
name: "Provider unreachable",
|
||||
err: &OIDCError{Code: ErrCodeProviderUnreachable},
|
||||
expected: "Authentication service is temporarily unavailable. Please try again later",
|
||||
},
|
||||
{
|
||||
name: "Rate limited",
|
||||
err: &OIDCError{Code: ErrCodeRateLimited},
|
||||
expected: "Too many requests. Please wait a moment and try again",
|
||||
},
|
||||
{
|
||||
name: "Unknown OIDC error",
|
||||
err: &OIDCError{Code: ErrCodeConfigInvalid},
|
||||
expected: "Authentication failed. Please try again",
|
||||
},
|
||||
{
|
||||
name: "Regular error",
|
||||
err: errors.New("regular error"),
|
||||
expected: "An unexpected error occurred. Please try again",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FormatUserMessage(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
// Test that all error codes are defined correctly
|
||||
codes := []ErrorCode{
|
||||
ErrCodeAuthenticationFailed,
|
||||
ErrCodeTokenExpired,
|
||||
ErrCodeTokenInvalid,
|
||||
ErrCodeSessionExpired,
|
||||
ErrCodeCSRFMismatch,
|
||||
ErrCodeNonceMismatch,
|
||||
ErrCodeConfigInvalid,
|
||||
ErrCodeProviderUnreachable,
|
||||
ErrCodeMetadataFailed,
|
||||
ErrCodeNetworkTimeout,
|
||||
ErrCodeRateLimited,
|
||||
ErrCodeServiceUnavailable,
|
||||
ErrCodeValidationFailed,
|
||||
ErrCodeDomainNotAllowed,
|
||||
ErrCodeUserNotAllowed,
|
||||
ErrCodeRoleNotAllowed,
|
||||
}
|
||||
|
||||
for _, code := range codes {
|
||||
if string(code) == "" {
|
||||
t.Errorf("Error code %v is empty", code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorConstructorCompleteness(t *testing.T) {
|
||||
// Test each constructor function to ensure they set all required fields
|
||||
internalErr := errors.New("test error")
|
||||
|
||||
// Test NewAuthenticationError
|
||||
authErr := NewAuthenticationError(ErrCodeAuthenticationFailed, "auth message", internalErr)
|
||||
if authErr.Code == "" || authErr.Message == "" || authErr.HTTPStatus == 0 {
|
||||
t.Error("NewAuthenticationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewAuthorizationError
|
||||
authzErr := NewAuthorizationError(ErrCodeDomainNotAllowed, "authz message", "details")
|
||||
if authzErr.Code == "" || authzErr.Message == "" || authzErr.HTTPStatus == 0 {
|
||||
t.Error("NewAuthorizationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewConfigurationError
|
||||
configErr := NewConfigurationError(ErrCodeConfigInvalid, "config message", internalErr)
|
||||
if configErr.Code == "" || configErr.Message == "" || configErr.HTTPStatus == 0 {
|
||||
t.Error("NewConfigurationError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewNetworkError
|
||||
netErr := NewNetworkError(ErrCodeNetworkTimeout, "network message", internalErr)
|
||||
if netErr.Code == "" || netErr.Message == "" || netErr.HTTPStatus == 0 {
|
||||
t.Error("NewNetworkError did not set all required fields")
|
||||
}
|
||||
|
||||
// Test NewValidationError
|
||||
validErr := NewValidationError(ErrCodeValidationFailed, "validation message", "details")
|
||||
if validErr.Code == "" || validErr.Message == "" || validErr.HTTPStatus == 0 {
|
||||
t.Error("NewValidationError did not set all required fields")
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user