mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9126c74723 | |||
| a750c4f5b9 | |||
| 56051779ee | |||
| 3f126d50f3 | |||
| 91f0fc9ab8 | |||
| 66b9ed0861 | |||
| e64fc7f730 | |||
| 5fcbd54955 | |||
| e70cd1907c | |||
| e45b06c86d |
@@ -1,629 +0,0 @@
|
||||
name: PR Validation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
checks: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
# Fast feedback - format and basic checks
|
||||
quick-checks:
|
||||
name: Quick Checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
# Exclude vendor directory from format checks
|
||||
UNFORMATTED=$(gofmt -s -l . | grep -v "^vendor/" || true)
|
||||
if [ -n "$UNFORMATTED" ]; then
|
||||
echo "Code is not formatted. Run: gofmt -s -w ."
|
||||
echo "Unformatted files:"
|
||||
echo "$UNFORMATTED"
|
||||
gofmt -s -d $(echo "$UNFORMATTED")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Go mod verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Go mod tidy check
|
||||
run: |
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum
|
||||
|
||||
# Static analysis with golangci-lint (advisory - will not fail the build)
|
||||
golangci-lint:
|
||||
name: golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=10m
|
||||
continue-on-error: true # Allow pipeline to continue even with linting warnings
|
||||
|
||||
# Staticcheck analysis
|
||||
staticcheck:
|
||||
name: Staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install staticcheck
|
||||
run: go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run staticcheck
|
||||
run: staticcheck ./...
|
||||
|
||||
# Security scanning with gosec
|
||||
gosec:
|
||||
name: Gosec Security Scanner
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run Gosec Security Scanner
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -no-fail -fmt sarif -out results.sarif ./... || echo "Gosec completed with warnings"
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: always() && hashFiles('results.sarif') != ''
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
continue-on-error: true
|
||||
|
||||
# Vulnerability scanning
|
||||
govulncheck:
|
||||
name: Vulnerability Scan
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Run govulncheck
|
||||
run: govulncheck ./...
|
||||
|
||||
# CodeQL analysis
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: go
|
||||
continue-on-error: true
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
continue-on-error: true
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
continue-on-error: true
|
||||
|
||||
# Unit tests with race detection
|
||||
test-race:
|
||||
name: Unit Tests (Race Detector)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with race detector
|
||||
run: go test -race -timeout=15m -count=1 -v ./...
|
||||
env:
|
||||
GOMAXPROCS: 4
|
||||
|
||||
# Coverage analysis with threshold check
|
||||
test-coverage:
|
||||
name: Test Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test -coverprofile=coverage.out -covermode=atomic -timeout=15m ./...
|
||||
go tool cover -func=coverage.out -o=coverage.txt
|
||||
|
||||
- name: Calculate coverage
|
||||
id: coverage
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//')
|
||||
echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT
|
||||
echo "Total Coverage: $COVERAGE%"
|
||||
|
||||
# Get per-package coverage
|
||||
echo "## Coverage by Package" >> coverage_report.md
|
||||
echo "" >> coverage_report.md
|
||||
go tool cover -func=coverage.out | grep -v "total:" | awk '{print "- " $1 ": " $3}' >> coverage_report.md || true
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
continue-on-error: true
|
||||
|
||||
- name: Comment coverage on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const fs = require('fs');
|
||||
const coverage = '${{ steps.coverage.outputs.coverage }}';
|
||||
let coverageReport = '';
|
||||
|
||||
try {
|
||||
coverageReport = fs.readFileSync('coverage_report.md', 'utf8');
|
||||
} catch (e) {
|
||||
coverageReport = 'Coverage details not available';
|
||||
}
|
||||
|
||||
const threshold = 70;
|
||||
const coverageNum = parseFloat(coverage);
|
||||
const emoji = coverageNum >= threshold ? '✅' : '⚠️';
|
||||
|
||||
const body = `## ${emoji} Test Coverage Report\n\n**Total Coverage:** ${coverage}%\n**Threshold:** ${threshold}%\n\n${coverageReport}`;
|
||||
|
||||
// Find existing comment
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
|
||||
const botComment = comments.find(comment =>
|
||||
comment.user.type === 'Bot' &&
|
||||
comment.body.includes('Test Coverage Report')
|
||||
);
|
||||
|
||||
if (botComment) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: botComment.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
}
|
||||
|
||||
- name: Check coverage threshold
|
||||
run: |
|
||||
COVERAGE=${{ steps.coverage.outputs.coverage }}
|
||||
THRESHOLD=70
|
||||
echo "Coverage: $COVERAGE%"
|
||||
echo "Threshold: $THRESHOLD%"
|
||||
if (( $(echo "$COVERAGE < $THRESHOLD" | bc -l) )); then
|
||||
echo "⚠️ Coverage $COVERAGE% is below threshold $THRESHOLD%"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Coverage $COVERAGE% meets threshold $THRESHOLD%"
|
||||
|
||||
# Memory leak detection
|
||||
test-memory-leaks:
|
||||
name: Memory Leak Detection
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run goroutine leak tests
|
||||
run: |
|
||||
echo "Running goroutine leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Gg]oroutine.*[Ll]eak.*' ./... || echo "No goroutine leak tests found"
|
||||
|
||||
- name: Run memory leak tests
|
||||
run: |
|
||||
echo "Running memory leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Mm]emory.*[Ll]eak.*' ./... || echo "No memory leak tests found"
|
||||
|
||||
- name: Run cleanup tests
|
||||
run: |
|
||||
echo "Running cleanup and resource management tests..."
|
||||
go test -v -timeout=20m -run='.*[Cc]leanup.*' ./... || echo "No cleanup tests found"
|
||||
|
||||
# Integration tests
|
||||
test-integration:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
if [ -d "./integration" ]; then
|
||||
go test -v -timeout=20m ./integration/...
|
||||
else
|
||||
echo "Running integration tests from all packages..."
|
||||
go test -v -timeout=20m -run='.*[Ii]ntegration.*' ./...
|
||||
fi
|
||||
|
||||
# Regression tests
|
||||
test-regression:
|
||||
name: Regression Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
echo "Running regression tests..."
|
||||
go test -v -timeout=20m -run='.*[Rr]egression.*' ./...
|
||||
|
||||
# Provider-specific tests (parallel matrix)
|
||||
test-providers:
|
||||
name: Provider Tests (${{ matrix.provider }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
provider:
|
||||
- google
|
||||
- azure
|
||||
- auth0
|
||||
- okta
|
||||
- keycloak
|
||||
- cognito
|
||||
- gitlab
|
||||
- github
|
||||
- generic
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run ${{ matrix.provider }} provider tests
|
||||
run: |
|
||||
PROVIDER_CAP=$(echo "${{ matrix.provider }}" | sed 's/.*/\u&/')
|
||||
echo "Testing $PROVIDER_CAP provider..."
|
||||
go test -v -timeout=10m -run=".*$PROVIDER_CAP.*" ./internal/providers/... || true
|
||||
go test -v -timeout=10m -run=".*${{ matrix.provider }}.*" ./... || true
|
||||
|
||||
# Benchmark tests with performance tracking
|
||||
benchmark:
|
||||
name: Benchmark Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run benchmarks
|
||||
run: |
|
||||
echo "Running benchmark tests..."
|
||||
go test -bench=. -benchmem -benchtime=1s -run=^$ ./... | tee benchmark.txt
|
||||
|
||||
- name: Upload benchmark results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results
|
||||
path: benchmark.txt
|
||||
retention-days: 30
|
||||
|
||||
- name: Compare benchmarks
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
echo "Benchmark results available in artifacts"
|
||||
echo "To compare with main branch, download previous benchmark results"
|
||||
|
||||
# Build validation across platforms
|
||||
build:
|
||||
name: Build (${{ matrix.os }}/${{ matrix.arch }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [linux, darwin]
|
||||
arch: [amd64, arm64]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Build for ${{ matrix.os }}/${{ matrix.arch }}
|
||||
env:
|
||||
GOOS: ${{ matrix.os }}
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
run: |
|
||||
echo "Building for $GOOS/$GOARCH..."
|
||||
go build -v -ldflags="-s -w" ./...
|
||||
|
||||
# Security-specific edge case tests
|
||||
test-security-edge-cases:
|
||||
name: Security Edge Cases
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run security edge case tests
|
||||
run: |
|
||||
echo "Running security edge case tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ecurity.*' ./...
|
||||
|
||||
# Session management tests
|
||||
test-session:
|
||||
name: Session Management Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run session tests
|
||||
run: |
|
||||
echo "Running session management tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ession.*' ./...
|
||||
|
||||
# Token validation tests
|
||||
test-token:
|
||||
name: Token Validation Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run token validation tests
|
||||
run: |
|
||||
echo "Running token validation tests..."
|
||||
go test -v -timeout=15m -run='.*[Tt]oken.*' ./...
|
||||
|
||||
# CSRF and security tests
|
||||
test-csrf:
|
||||
name: CSRF and Security Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run CSRF tests
|
||||
run: |
|
||||
echo "Running CSRF and security tests..."
|
||||
go test -v -timeout=15m -run='.*[Cc][Ss][Rr][Ff].*' ./...
|
||||
|
||||
# Multi-Go version compatibility
|
||||
test-go-versions:
|
||||
name: Go ${{ matrix.go-version }} Compatibility
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ['1.24']
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Run tests on Go ${{ matrix.go-version }}
|
||||
run: go test -short -timeout=10m ./...
|
||||
|
||||
# Final validation - all checks must pass (golangci-lint is advisory)
|
||||
all-checks-passed:
|
||||
name: ✅ All Checks Passed
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- quick-checks
|
||||
- golangci-lint
|
||||
- staticcheck
|
||||
- gosec
|
||||
- govulncheck
|
||||
- codeql
|
||||
- test-race
|
||||
- test-coverage
|
||||
- test-memory-leaks
|
||||
- test-integration
|
||||
- test-regression
|
||||
- test-providers
|
||||
- benchmark
|
||||
- build
|
||||
- test-security-edge-cases
|
||||
- test-session
|
||||
- test-token
|
||||
- test-csrf
|
||||
- test-go-versions
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
echo "Checking status of all jobs..."
|
||||
|
||||
# Check critical jobs (excluding golangci-lint which is advisory)
|
||||
CRITICAL_FAILURES=false
|
||||
|
||||
if [ "${{ needs.quick-checks.result }}" == "failure" ] || \
|
||||
[ "${{ needs.staticcheck.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-race.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-coverage.result }}" == "failure" ] || \
|
||||
[ "${{ needs.build.result }}" == "failure" ]; then
|
||||
CRITICAL_FAILURES=true
|
||||
fi
|
||||
|
||||
if [ "$CRITICAL_FAILURES" == "true" ]; then
|
||||
echo "❌ Critical checks failed"
|
||||
exit 1
|
||||
elif [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
|
||||
echo "⚠️ Some checks were cancelled"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All critical checks passed successfully!"
|
||||
if [ "${{ needs.golangci-lint.result }}" != "success" ]; then
|
||||
echo "ℹ️ Note: golangci-lint reported issues (advisory only)"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Post summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "# PR Validation Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "## Job Status" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Quick Checks: ${{ needs.quick-checks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Linting (advisory): ${{ needs.golangci-lint.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Static Analysis: ${{ needs.staticcheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Scan (gosec): ${{ needs.gosec.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Vulnerability Scan: ${{ needs.govulncheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CodeQL: ${{ needs.codeql.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Race Detection: ${{ needs.test-race.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Coverage: ${{ needs.test-coverage.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Memory Leaks: ${{ needs.test-memory-leaks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Integration Tests: ${{ needs.test-integration.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Regression Tests: ${{ needs.test-regression.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Provider Tests: ${{ needs.test-providers.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Benchmarks: ${{ needs.benchmark.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Build: ${{ needs.build.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Edge Cases: ${{ needs.test-security-edge-cases.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Session Tests: ${{ needs.test-session.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Token Tests: ${{ needs.test-token.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CSRF Tests: ${{ needs.test-csrf.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Go Version Compatibility: ${{ needs.test-go-versions.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- "**"
|
||||
- "!main"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,21 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.go"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,49 @@
|
||||
version: 2
|
||||
|
||||
# Traefik plugins are source-only - no binary builds
|
||||
# Traefik loads plugins via Yaegi interpreter at runtime
|
||||
builds:
|
||||
- skip: true
|
||||
|
||||
# Create source archive for GitHub releases
|
||||
archives:
|
||||
- formats: [tar.gz]
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
|
||||
files:
|
||||
- "*.go"
|
||||
- "**/*.go"
|
||||
- go.mod
|
||||
- go.sum
|
||||
- .traefik.yml
|
||||
- LICENSE*
|
||||
- README*
|
||||
# Exclude test files and vendor from release archive
|
||||
- "!**/*_test.go"
|
||||
- "!vendor/**"
|
||||
- "!docker/**"
|
||||
- "!integration/**"
|
||||
- "!regression/**"
|
||||
- "!examples/**"
|
||||
- "!docs/**"
|
||||
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
|
||||
algorithm: sha256
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
- "^Merge"
|
||||
- "^WIP"
|
||||
- "^chore:"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: lukaszraczylo
|
||||
name: traefikoidc
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
+520
-34
@@ -31,6 +31,7 @@ summary: |
|
||||
- Flexible configuration with multiple deployment scenarios
|
||||
- Memory-efficient operation with automatic cleanup
|
||||
- Extensive logging and debugging capabilities
|
||||
- Redis cache support for multi-replica deployments with automatic failover
|
||||
It supports various authentication scenarios including:
|
||||
|
||||
- Basic authentication with customizable callback and logout URLs
|
||||
@@ -73,6 +74,11 @@ testData:
|
||||
- admin
|
||||
- developer
|
||||
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
# When running behind load balancer that terminates TLS: MUST set to TRUE
|
||||
@@ -88,22 +94,24 @@ testData:
|
||||
- /metrics
|
||||
|
||||
headers: # Custom headers to set with templated values from claims and tokens
|
||||
# NOTE: If you encounter "can't evaluate field AccessToken in type bool" errors,
|
||||
# you may need to escape the templates. See the headers section in configuration below.
|
||||
# NOTE: Use double curly braces to escape template expressions in YAML
|
||||
# See the headers section in configuration below for details
|
||||
- name: "X-User-Email"
|
||||
value: "{{.Claims.email}}"
|
||||
value: "{{{{.Claims.email}}}}"
|
||||
- name: "X-User-ID"
|
||||
value: "{{.Claims.sub}}"
|
||||
value: "{{{{.Claims.sub}}}}"
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{.AccessToken}}"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles"
|
||||
value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
|
||||
# Advanced parameters (usually discovered automatically from provider metadata)
|
||||
revocationURL: https://accounts.google.com/revoke # Endpoint for revoking tokens
|
||||
oidcEndSessionURL: https://accounts.google.com/logout # Provider's end session endpoint
|
||||
enablePKCE: false # Enables PKCE (Proof Key for Code Exchange) for additional security
|
||||
cookieDomain: "" # Explicit domain for session cookies (e.g., ".example.com" for multi-subdomain setups)
|
||||
cookiePrefix: "" # Custom prefix for cookie names (e.g., "_oidc_myapp_" for session isolation between middleware instances)
|
||||
sessionMaxAge: 86400 # Maximum session age in seconds (default: 86400 = 24 hours, 0 = use default)
|
||||
overrideScopes: false # When true, replaces default scopes instead of appending (default: false)
|
||||
refreshGracePeriodSeconds: 60 # Seconds before token expiry to attempt proactive refresh (default: 60)
|
||||
|
||||
@@ -113,6 +121,8 @@ testData:
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
@@ -137,6 +147,42 @@ testData:
|
||||
X-Custom-Header: "production"
|
||||
X-API-Version: "v1"
|
||||
|
||||
# Example with Redis cache for multi-replica deployments
|
||||
testDataWithRedis:
|
||||
# Required OIDC parameters (same as standard configuration)
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
callbackURL: /oauth2/callback
|
||||
sessionEncryptionKey: your-64-character-encryption-key-at-least-32-bytes
|
||||
|
||||
# Standard optional parameters
|
||||
logLevel: info
|
||||
allowedUserDomains:
|
||||
- company.com
|
||||
|
||||
# Redis cache configuration for multi-replica support
|
||||
redis:
|
||||
enabled: true # Enable Redis caching
|
||||
address: "redis:6379" # Redis server address
|
||||
password: "redis-password" # Redis authentication password
|
||||
db: 0 # Redis database number (0-15)
|
||||
keyPrefix: "traefikoidc:" # Prefix for all Redis keys
|
||||
cacheMode: "hybrid" # Cache mode: redis, hybrid, or memory
|
||||
poolSize: 20 # Maximum number of connections
|
||||
connectTimeout: 5 # Connection timeout in seconds
|
||||
readTimeout: 3 # Read operation timeout
|
||||
writeTimeout: 3 # Write operation timeout
|
||||
enableTLS: false # Use TLS for Redis connection
|
||||
tlsSkipVerify: false # Skip TLS certificate verification
|
||||
hybridL1Size: 500 # L1 cache size for hybrid mode
|
||||
hybridL1MemoryMB: 10 # L1 memory limit for hybrid mode
|
||||
enableCircuitBreaker: true # Enable circuit breaker
|
||||
circuitBreakerThreshold: 5 # Failures before opening circuit
|
||||
circuitBreakerTimeout: 60 # Timeout before retry (seconds)
|
||||
enableHealthCheck: true # Enable periodic health checks
|
||||
healthCheckInterval: 30 # Health check interval (seconds)
|
||||
|
||||
# --- Common Configuration Examples ---
|
||||
#
|
||||
# 🔒 HIGH-SECURITY CONFIGURATION
|
||||
@@ -186,11 +232,11 @@ testData:
|
||||
# corsAllowedOrigins: ["https://app.example.com"]
|
||||
# corsAllowedMethods: ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
# corsAllowedHeaders: ["Authorization", "Content-Type", "X-API-Key"]
|
||||
# headers: # Custom headers with OIDC claims
|
||||
# headers: # Custom headers with OIDC claims (use double curly braces)
|
||||
# - name: "X-User-Email"
|
||||
# value: "{{.Claims.email}}"
|
||||
# value: "{{{{.Claims.email}}}}"
|
||||
# - name: "X-User-ID"
|
||||
# value: "{{.Claims.sub}}"
|
||||
# value: "{{{{.Claims.sub}}}}"
|
||||
|
||||
# --- Provider Specific Configuration Examples ---
|
||||
#
|
||||
@@ -223,6 +269,8 @@ testData:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -244,6 +292,26 @@ testData:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -562,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -595,28 +695,101 @@ configuration:
|
||||
cookieDomain:
|
||||
type: string
|
||||
description: |
|
||||
Explicit domain for session cookies. This is important for multi-subdomain setups
|
||||
Explicit domain for session cookies. This is important for multi-subdomain setups
|
||||
and reverse proxy deployments to ensure consistent cookie handling.
|
||||
|
||||
|
||||
When set, all session cookies will use this domain. When not set, the domain
|
||||
is auto-detected from the request headers (X-Forwarded-Host or Host).
|
||||
|
||||
|
||||
Use a leading dot for subdomain-wide cookies (e.g., ".example.com" allows
|
||||
cookies to be shared between app.example.com, api.example.com, etc.).
|
||||
|
||||
|
||||
Use a specific domain for host-only cookies (e.g., "app.example.com" restricts
|
||||
cookies to that exact domain).
|
||||
|
||||
|
||||
This setting is crucial to prevent authentication issues like "CSRF token missing
|
||||
in session" errors that can occur when cookies are created with inconsistent domains.
|
||||
|
||||
|
||||
Examples:
|
||||
- ".example.com" - Allows all subdomains to share cookies
|
||||
- "app.example.com" - Restricts cookies to this specific host
|
||||
|
||||
|
||||
Default: "" (auto-detected from request headers)
|
||||
required: false
|
||||
|
||||
cookiePrefix:
|
||||
type: string
|
||||
description: |
|
||||
Custom prefix for session cookie names. This is essential for running multiple
|
||||
middleware instances with different authorization requirements on the same domain.
|
||||
|
||||
By default, all middleware instances use the same cookie names (_oidc_raczylo_m,
|
||||
_oidc_raczylo_a, etc.), which means they share session state. When you have
|
||||
multiple instances with different access restrictions (e.g., one for general users
|
||||
and one for admins), this session sharing can lead to authorization bypass issues.
|
||||
|
||||
Setting a unique cookiePrefix for each middleware instance ensures complete
|
||||
session isolation, preventing users authenticated via one middleware from
|
||||
automatically gaining access to routes protected by a different middleware.
|
||||
|
||||
The prefix is prepended to all session cookie names:
|
||||
- Main session cookie: {prefix}m
|
||||
- Access token cookie: {prefix}a
|
||||
- Refresh token cookie: {prefix}r
|
||||
- ID token cookie: {prefix}id
|
||||
|
||||
Examples:
|
||||
- "_oidc_userauth_" - For general user authentication middleware
|
||||
- "_oidc_adminauth_" - For admin-only authentication middleware
|
||||
- "_oidc_api_" - For API-specific authentication middleware
|
||||
|
||||
Security Note: Use different cookie prefixes AND different sessionEncryptionKey
|
||||
values for each middleware instance to ensure complete isolation.
|
||||
|
||||
Default: "_oidc_raczylo_" (standard prefix for backward compatibility)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/87
|
||||
required: false
|
||||
|
||||
sessionMaxAge:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum session age in seconds before requiring re-authentication.
|
||||
|
||||
This setting controls how long a user's authentication session remains valid
|
||||
before they must authenticate again through the OIDC provider. The session
|
||||
age is tracked from the initial authentication time (created_at).
|
||||
|
||||
When a session exceeds this age:
|
||||
- The session is cleared and invalidated
|
||||
- The user is redirected to re-authenticate
|
||||
- All session cookies are removed
|
||||
|
||||
Use Cases:
|
||||
- High-security applications: Use shorter durations (e.g., 3600 = 1 hour)
|
||||
- Standard applications: Default 24 hours balances security and UX
|
||||
- Long-lived sessions: Extend for applications accessed infrequently
|
||||
(e.g., 604800 = 7 days, 2592000 = 30 days)
|
||||
|
||||
Security Considerations:
|
||||
- Shorter sessions provide better security but require more frequent logins
|
||||
- Longer sessions improve user experience but increase security risk
|
||||
- Consider your application's security requirements and user access patterns
|
||||
- This is independent of token refresh - tokens can be refreshed during the session
|
||||
|
||||
Common Values:
|
||||
- 3600 (1 hour) - High security applications
|
||||
- 28800 (8 hours) - Working day session
|
||||
- 86400 (24 hours) - Default, balances security and convenience
|
||||
- 604800 (7 days) - Weekly session for less frequently accessed apps
|
||||
- 2592000 (30 days) - Monthly session for infrequently used applications
|
||||
|
||||
Default: 86400 (24 hours)
|
||||
Minimum: 0 (uses default of 24 hours)
|
||||
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/91
|
||||
required: false
|
||||
|
||||
overrideScopes:
|
||||
type: boolean
|
||||
description: |
|
||||
@@ -787,6 +960,67 @@ configuration:
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
@@ -803,29 +1037,23 @@ configuration:
|
||||
IMPORTANT: Template Escaping
|
||||
If you encounter the error "can't evaluate field AccessToken in type bool" when
|
||||
starting Traefik, this means Traefik is trying to evaluate the template expressions
|
||||
before passing them to the plugin. To fix this, you need to escape the templates
|
||||
using one of these methods:
|
||||
before passing them to the plugin.
|
||||
|
||||
1. Use YAML literal style (recommended):
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: |
|
||||
Bearer {{.AccessToken}}
|
||||
SOLUTION: You must escape the template expressions using double curly braces:
|
||||
|
||||
2. Use single quotes:
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: 'Bearer {{.AccessToken}}'
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{{{.AccessToken}}}}"
|
||||
|
||||
3. For inline double quotes, escape the braces:
|
||||
headers:
|
||||
- name: "Authorization"
|
||||
value: "Bearer {{"{{.AccessToken}}"}}"
|
||||
This is the only reliable method that works consistently. Here's why:
|
||||
- The YAML parser converts {{{{ → {{ and }}}} → }}
|
||||
- Result: Bearer {{.AccessToken}} reaches the Go template engine correctly
|
||||
- Other methods (YAML literal style, single quotes) do NOT work reliably
|
||||
|
||||
Examples:
|
||||
- name: "X-User-Email", value: "{{.Claims.email}}"
|
||||
- name: "Authorization", value: "Bearer {{.AccessToken}}"
|
||||
- name: "X-User-Roles", value: "{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}"
|
||||
- name: "X-User-Email", value: "{{{{.Claims.email}}}}"
|
||||
- name: "Authorization", value: "Bearer {{{{.AccessToken}}}}"
|
||||
- name: "X-User-Roles", value: "{{{{range $i, $e := .Claims.roles}}}}{{{{if $i}}}},{{{{end}}}}{{{{$e}}}}{{{{end}}}}"
|
||||
required: false
|
||||
items:
|
||||
type: object
|
||||
@@ -1144,3 +1372,261 @@ configuration:
|
||||
|
||||
Prevents your resources from being embedded on other sites.
|
||||
required: false
|
||||
|
||||
redis:
|
||||
type: object
|
||||
description: |
|
||||
Optional Redis cache configuration for multi-replica deployments.
|
||||
|
||||
When running multiple Traefik instances, Redis provides shared caching to:
|
||||
- Prevent JTI replay detection false positives across replicas
|
||||
- Share token verification results between instances
|
||||
- Maintain consistent session state across the cluster
|
||||
- Improve performance by reducing redundant OIDC provider calls
|
||||
|
||||
Features:
|
||||
- Automatic failover to memory-only mode when Redis is unavailable
|
||||
- Circuit breaker pattern for resilience against Redis failures
|
||||
- Health checking with automatic recovery
|
||||
- Multiple cache modes: redis-only, hybrid (L1 memory + L2 Redis), memory-only
|
||||
- Configurable timeouts and connection pooling
|
||||
- TLS support for secure Redis connections
|
||||
|
||||
The middleware gracefully handles Redis failures by falling back to in-memory
|
||||
caching, ensuring your authentication flow continues even during Redis outages.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
cacheMode: "hybrid"
|
||||
enableCircuitBreaker: true
|
||||
```
|
||||
required: false
|
||||
properties:
|
||||
enabled:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable Redis caching for distributed session and token management.
|
||||
When enabled, the middleware will attempt to connect to Redis and use it
|
||||
for shared state across multiple Traefik instances.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
address:
|
||||
type: string
|
||||
description: |
|
||||
Redis server address in host:port format.
|
||||
|
||||
Examples:
|
||||
- "redis:6379" (Docker/Kubernetes service)
|
||||
- "localhost:6379" (local Redis)
|
||||
- "redis.example.com:6380" (custom host/port)
|
||||
- "redis-cluster.default.svc.cluster.local:6379" (Kubernetes)
|
||||
|
||||
Required when Redis is enabled.
|
||||
required: false
|
||||
|
||||
password:
|
||||
type: string
|
||||
description: |
|
||||
Password for Redis authentication.
|
||||
Leave empty if Redis doesn't require authentication.
|
||||
|
||||
For Kubernetes deployments, you can use secret references:
|
||||
urn:k8s:secret:namespace:secret-name:key
|
||||
|
||||
Default: "" (no authentication)
|
||||
required: false
|
||||
|
||||
db:
|
||||
type: integer
|
||||
description: |
|
||||
Redis database number to use (0-15).
|
||||
Different databases can be used to isolate data between environments.
|
||||
|
||||
Default: 0
|
||||
required: false
|
||||
|
||||
keyPrefix:
|
||||
type: string
|
||||
description: |
|
||||
Prefix for all Redis keys created by this middleware.
|
||||
Useful for:
|
||||
- Avoiding key collisions with other applications
|
||||
- Identifying keys for monitoring/debugging
|
||||
- Supporting multiple environments in the same Redis instance
|
||||
|
||||
Default: "traefikoidc:"
|
||||
required: false
|
||||
|
||||
cacheMode:
|
||||
type: string
|
||||
description: |
|
||||
Determines the caching strategy:
|
||||
|
||||
- "redis": Redis-only caching. All cache operations go directly to Redis.
|
||||
Best for: Consistent state across all replicas, minimal memory usage.
|
||||
|
||||
- "hybrid": Two-tier caching with in-memory L1 and Redis L2.
|
||||
Best for: High performance with shared state, reduced Redis load.
|
||||
L1 provides fast local cache, L2 provides shared state.
|
||||
|
||||
- "memory": Memory-only caching (Redis disabled even if configured).
|
||||
Best for: Single instance deployments, development/testing.
|
||||
|
||||
Default: "redis" (when Redis is enabled)
|
||||
required: false
|
||||
enum:
|
||||
- redis
|
||||
- hybrid
|
||||
- memory
|
||||
|
||||
poolSize:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of socket connections to Redis.
|
||||
Higher values allow more concurrent operations but consume more resources.
|
||||
|
||||
Recommendations:
|
||||
- Small deployments: 10-20
|
||||
- Medium deployments: 20-50
|
||||
- Large deployments: 50-100
|
||||
|
||||
Default: 10
|
||||
required: false
|
||||
|
||||
connectTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for establishing new connections to Redis.
|
||||
Should be higher than network latency but low enough to fail fast.
|
||||
|
||||
Default: 5 seconds
|
||||
required: false
|
||||
|
||||
readTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis read operations.
|
||||
Includes the time to send the command, wait for Redis to process it,
|
||||
and receive the response.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
writeTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Timeout in seconds for Redis write operations.
|
||||
Should account for network latency and Redis persistence settings.
|
||||
|
||||
Default: 3 seconds
|
||||
required: false
|
||||
|
||||
enableTLS:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable TLS encryption for Redis connections.
|
||||
Required when connecting to Redis instances that enforce TLS,
|
||||
such as AWS ElastiCache with encryption in transit.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
tlsSkipVerify:
|
||||
type: boolean
|
||||
description: |
|
||||
Skip TLS certificate verification for Redis connections.
|
||||
|
||||
⚠️ WARNING: Only use in development environments.
|
||||
This option bypasses certificate validation and should never be used
|
||||
in production as it's vulnerable to man-in-the-middle attacks.
|
||||
|
||||
Default: false
|
||||
required: false
|
||||
|
||||
hybridL1Size:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum number of items in the L1 (in-memory) cache for hybrid mode.
|
||||
Controls how many cache entries are kept in local memory before eviction.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 500
|
||||
required: false
|
||||
|
||||
hybridL1MemoryMB:
|
||||
type: integer
|
||||
description: |
|
||||
Maximum memory in megabytes for L1 cache in hybrid mode.
|
||||
The cache will start evicting items when this limit is approached.
|
||||
|
||||
Only applies when cacheMode is "hybrid".
|
||||
|
||||
Default: 10 MB
|
||||
required: false
|
||||
|
||||
enableCircuitBreaker:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable circuit breaker pattern for Redis connection failures.
|
||||
|
||||
When enabled, the middleware will:
|
||||
1. Track Redis operation failures
|
||||
2. Open the circuit after threshold failures (stop trying Redis)
|
||||
3. Fall back to in-memory caching
|
||||
4. Periodically attempt to reconnect (half-open state)
|
||||
5. Resume Redis operations when connection recovers
|
||||
|
||||
This prevents cascading failures and improves resilience.
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
circuitBreakerThreshold:
|
||||
type: integer
|
||||
description: |
|
||||
Number of consecutive Redis failures before opening the circuit.
|
||||
Lower values make the system more sensitive to Redis issues,
|
||||
higher values tolerate more failures before switching to fallback.
|
||||
|
||||
Default: 5
|
||||
required: false
|
||||
|
||||
circuitBreakerTimeout:
|
||||
type: integer
|
||||
description: |
|
||||
Time in seconds to wait before attempting to close the circuit.
|
||||
After this timeout, the circuit breaker will allow one test request
|
||||
to Redis. If successful, normal operations resume.
|
||||
|
||||
Default: 60 seconds
|
||||
required: false
|
||||
|
||||
enableHealthCheck:
|
||||
type: boolean
|
||||
description: |
|
||||
Enable periodic health checks for Redis connection.
|
||||
|
||||
Health checks:
|
||||
- Run in the background at regular intervals
|
||||
- Detect Redis availability without affecting request processing
|
||||
- Automatically reconnect when Redis becomes available
|
||||
- Update circuit breaker state based on health status
|
||||
|
||||
Default: true
|
||||
required: false
|
||||
|
||||
healthCheckInterval:
|
||||
type: integer
|
||||
description: |
|
||||
Interval in seconds between Redis health checks.
|
||||
Lower values detect issues faster but increase Redis load.
|
||||
Higher values reduce overhead but delay failure detection.
|
||||
|
||||
Default: 30 seconds
|
||||
required: false
|
||||
|
||||
@@ -8,6 +8,7 @@ The Traefik OIDC middleware provides a complete OIDC authentication solution wit
|
||||
|
||||
- **Universal provider support**: Works with 9+ OIDC providers including Google, Azure AD, Auth0, Okta, Keycloak, AWS Cognito, GitLab, and more
|
||||
- **Automatic provider detection**: Automatically detects and configures provider-specific settings
|
||||
- **Dynamic Client Registration (RFC 7591)**: Automatic client registration with OIDC providers without manual pre-registration
|
||||
- **Automatic scope filtering**: Intelligently filters OAuth scopes based on provider capabilities declared in OIDC discovery documents, preventing authentication failures with unsupported scopes
|
||||
- **Security headers**: Comprehensive security headers with CORS, CSP, HSTS, and custom profiles
|
||||
- **Domain restrictions**: Limit access to specific email domains or individual users
|
||||
@@ -76,7 +77,7 @@ experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.7.8 # Use the latest version
|
||||
version: v0.7.10 # Use the latest version
|
||||
```
|
||||
|
||||
2. Configure the middleware in your dynamic configuration (see examples below).
|
||||
@@ -117,6 +118,30 @@ The middleware supports the following configuration options:
|
||||
| `logLevel` | Sets the logging verbosity | `info` | `debug`, `info`, `error` |
|
||||
| `forceHTTPS` | Forces HTTPS scheme for redirect URIs (**REQUIRED** for TLS termination at load balancer like AWS ALB) | `false` (when not specified) | `true`, `false` |
|
||||
| `rateLimit` | Sets the maximum number of requests per second | `100` | `500` |
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `cookieDomain` | Explicit domain for session cookies (important for multi-subdomain setups) | auto-detected | `.example.com`, `app.example.com` |
|
||||
| `cookiePrefix` | Custom prefix for session cookie names (for isolating multiple middleware instances) | `_oidc_raczylo_` | `_oidc_userauth_`, `_oidc_admin_` |
|
||||
| `sessionMaxAge` | Maximum session age in seconds before requiring re-authentication | `86400` (24 hours) | `3600` (1 hour), `604800` (7 days) |
|
||||
| `audience` | Custom audience for access token validation (for Auth0 custom APIs, etc.) | `clientID` | `https://my-api.example.com` |
|
||||
| `strictAudienceValidation` | Reject sessions with access token audience mismatch (prevents token confusion attacks) | `false` | `true` |
|
||||
| `allowOpaqueTokens` | Enable opaque (non-JWT) access token support via RFC 7662 introspection | `false` | `true` |
|
||||
| `requireTokenIntrospection` | Require introspection for opaque tokens (force validation, no fallback) | `false` | `true` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
>
|
||||
@@ -131,22 +156,6 @@ The middleware supports the following configuration options:
|
||||
> - When `forceHTTPS: false` is explicitly set → scheme detection based on headers/TLS
|
||||
>
|
||||
> See [GitHub Issue #82](https://github.com/lukaszraczylo/traefikoidc/issues/82) for details.
|
||||
| `excludedURLs` | Lists paths that bypass authentication | none | `["/health", "/metrics", "/public"]` |
|
||||
| `allowedUserDomains` | Restricts access to specific email domains | none | `["company.com", "subsidiary.com"]` |
|
||||
| `allowedUsers` | A list of specific email addresses that are allowed access | none | `["user1@example.com", "user2@another.org"]` |
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
| `refreshGracePeriodSeconds` | Seconds before token expiry to attempt proactive refresh | `60` | `120` |
|
||||
| `cookieDomain` | Explicit domain for session cookies (important for multi-subdomain setups) | auto-detected | `.example.com`, `app.example.com` |
|
||||
| `audience` | Custom audience for access token validation (for Auth0 custom APIs, etc.) | `clientID` | `https://my-api.example.com` |
|
||||
| `strictAudienceValidation` | Reject sessions with access token audience mismatch (prevents token confusion attacks) | `false` | `true` |
|
||||
| `allowOpaqueTokens` | Enable opaque (non-JWT) access token support via RFC 7662 introspection | `false` | `true` |
|
||||
| `requireTokenIntrospection` | Require introspection for opaque tokens (force validation, no fallback) | `false` | `true` |
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
|
||||
## Scope Configuration
|
||||
|
||||
@@ -520,12 +529,14 @@ When running multiple Traefik replicas with the OIDC plugin, you may encounter f
|
||||
- Request → Replica B → JTI NOT in Replica B's cache ✓
|
||||
- Request → Replica A → ❌ **FALSE POSITIVE**: "token replay detected"
|
||||
|
||||
**Solution**: Disable replay detection for distributed deployments:
|
||||
**Solution 1 (Simple)**: Disable replay detection for distributed deployments:
|
||||
|
||||
```yaml
|
||||
disableReplayDetection: true # Disable JTI replay detection for multi-replica setups
|
||||
```
|
||||
|
||||
**Solution 2 (Recommended)**: Use Redis cache backend for shared state (see [Redis Cache](#redis-cache-optional) section)
|
||||
|
||||
**Security Note**: When `disableReplayDetection: true`:
|
||||
- ✅ Token signatures still validated
|
||||
- ✅ Expiration still checked
|
||||
@@ -547,10 +558,277 @@ spec:
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
disableReplayDetection: true # Required for multi-replica deployments
|
||||
disableReplayDetection: true # Required for multi-replica deployments without Redis
|
||||
```
|
||||
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, set to `true` and consider implementing a shared cache backend (Redis/Memcached) if replay detection is required.
|
||||
**Recommendation**: For single-instance deployments, leave this setting at `false` (default) to maintain replay attack protection. For multi-replica deployments, use the Redis cache backend for proper replay detection across all instances.
|
||||
|
||||
## Redis Cache (Optional)
|
||||
|
||||
The plugin supports optional Redis caching for multi-replica deployments. This solves issues with JTI replay detection and session management when running multiple Traefik instances behind a load balancer.
|
||||
|
||||
> **✨ Yaegi Compatible**: Redis support is implemented using a pure-Go RESP protocol client that works seamlessly with Traefik's Yaegi interpreter (no `unsafe` package). Full Redis functionality is available for both dynamic plugin loading and pre-compiled deployments.
|
||||
|
||||
### Why Use Redis Cache?
|
||||
|
||||
When running multiple Traefik replicas, each instance maintains its own in-memory cache for:
|
||||
- JTI (JWT Token ID) replay detection
|
||||
- Session data
|
||||
- Token metadata
|
||||
|
||||
Without a shared cache, you may experience:
|
||||
- False positive replay detection errors
|
||||
- Session inconsistencies between replicas
|
||||
- Users needing to re-authenticate when hitting different instances
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Redis is configured through Traefik's dynamic configuration (YAML, labels, etc.):
|
||||
|
||||
```yaml
|
||||
# Enable Redis cache in your middleware configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "localhost:6379"
|
||||
password: "your-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
```
|
||||
|
||||
### Configuration Priority
|
||||
|
||||
The plugin uses the following priority for Redis configuration:
|
||||
|
||||
1. **Traefik Dynamic Configuration** (PRIMARY) - Configure via YAML files or Docker/Kubernetes labels
|
||||
2. **Environment Variables** (FALLBACK) - Used only when not set in Traefik config
|
||||
|
||||
This approach allows you to manage all settings through Traefik's configuration system while maintaining backward compatibility with environment variables.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Parameter | Description | Default | Example |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `enabled` | Enable Redis caching | `false` | `true` |
|
||||
| `address` | Redis server address | - | `redis:6379` |
|
||||
| `password` | Redis password | - | `YOUR_PASSWORD` |
|
||||
| `db` | Database number | `0` | `1` |
|
||||
| `keyPrefix` | Key prefix for namespacing | `traefikoidc:` | `myapp:` |
|
||||
| `cacheMode` | Cache mode: `redis`, `hybrid`, `memory` | `redis` | `hybrid` |
|
||||
| `poolSize` | Connection pool size | `10` | `20` |
|
||||
| `connectTimeout` | Connection timeout (seconds) | `5` | `10` |
|
||||
| `readTimeout` | Read timeout (seconds) | `3` | `5` |
|
||||
| `writeTimeout` | Write timeout (seconds) | `3` | `5` |
|
||||
| `enableTLS` | Enable TLS | `false` | `true` |
|
||||
| `tlsSkipVerify` | Skip TLS verification | `false` | `true` |
|
||||
| `enableCircuitBreaker` | Circuit breaker for failures | `true` | `true` |
|
||||
| `circuitBreakerThreshold` | Failures before circuit opens | `5` | `10` |
|
||||
| `circuitBreakerTimeout` | Circuit reset timeout (seconds) | `60` | `30` |
|
||||
| `enableHealthCheck` | Periodic health checks | `true` | `true` |
|
||||
| `healthCheckInterval` | Health check interval (seconds) | `30` | `60` |
|
||||
|
||||
### Environment Variables (Fallback)
|
||||
|
||||
If not configured through Traefik, these environment variables can be used as fallback:
|
||||
|
||||
- `REDIS_ENABLED` - Enable Redis cache
|
||||
- `REDIS_ADDRESS` - Redis server address
|
||||
- `REDIS_PASSWORD` - Redis password
|
||||
- `REDIS_DB` - Database number
|
||||
- `REDIS_KEY_PREFIX` - Key prefix
|
||||
- `REDIS_CACHE_MODE` - Cache mode
|
||||
- `REDIS_POOL_SIZE` - Connection pool size
|
||||
- `REDIS_CONNECT_TIMEOUT` - Connection timeout
|
||||
- `REDIS_READ_TIMEOUT` - Read timeout
|
||||
- `REDIS_WRITE_TIMEOUT` - Write timeout
|
||||
- `REDIS_ENABLE_TLS` - Enable TLS
|
||||
- `REDIS_TLS_SKIP_VERIFY` - Skip TLS verification
|
||||
|
||||
### Cache Modes
|
||||
|
||||
The plugin supports three cache modes:
|
||||
|
||||
- **memory** (default): In-memory cache only, suitable for single-instance deployments
|
||||
- **redis**: Redis-only cache, all data stored in Redis
|
||||
- **hybrid**: Two-tier caching with local memory cache + Redis backend for optimal performance
|
||||
|
||||
### Example Configurations
|
||||
|
||||
#### Docker Compose with Redis
|
||||
|
||||
```yaml
|
||||
services:
|
||||
redis:
|
||||
image: redis:alpine
|
||||
command: redis-server --requirepass yourpassword
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
# ... rest of your Traefik configuration
|
||||
labels:
|
||||
# Configure the OIDC middleware with Redis
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# Redis configuration via labels
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.password=yourpassword"
|
||||
- "traefik.http.middlewares.oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
```
|
||||
|
||||
#### Kubernetes with Redis
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-redis
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://accounts.google.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-encryption-key
|
||||
callbackURL: /oauth2/callback
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-service.redis-namespace:6379"
|
||||
password: "urn:k8s:secret:redis-secret:password"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc"
|
||||
cacheMode: "hybrid"
|
||||
```
|
||||
|
||||
### Advanced Redis Configuration
|
||||
|
||||
See [Redis Cache Documentation](docs/REDIS_CACHE.md) for:
|
||||
- Detailed architecture overview
|
||||
- High availability setup with Redis Sentinel
|
||||
- Redis Cluster configuration
|
||||
- Performance tuning guidelines
|
||||
- Monitoring and observability
|
||||
- Troubleshooting guide
|
||||
- Migration from memory-only cache
|
||||
|
||||
## Dynamic Client Registration (RFC 7591)
|
||||
|
||||
The middleware supports **OIDC Dynamic Client Registration** (RFC 7591), allowing automatic client registration with OIDC providers without manual pre-registration. This is useful for:
|
||||
|
||||
- **Multi-tenant deployments**: Automatically register clients per tenant
|
||||
- **Development environments**: Quick setup without manual OAuth app creation
|
||||
- **Self-service integrations**: Allow applications to self-register
|
||||
|
||||
### How It Works
|
||||
|
||||
1. When enabled, the middleware discovers the `registration_endpoint` from the provider's `.well-known/openid-configuration`
|
||||
2. If no `clientID` is configured, it automatically registers a new client with the provider
|
||||
3. The registered `client_id` and `client_secret` are cached and optionally persisted to a file
|
||||
4. Subsequent requests use the registered credentials
|
||||
|
||||
### Configuration
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-dynamic-registration
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://your-oidc-provider.com
|
||||
# clientID and clientSecret are NOT required when using DCR
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
|
||||
# Optional: Initial access token for protected registration endpoints
|
||||
initialAccessToken: "your-initial-access-token"
|
||||
|
||||
# Optional: Override the registration endpoint (auto-discovered by default)
|
||||
registrationEndpoint: "https://your-provider.com/register"
|
||||
|
||||
# Optional: Persist credentials to file for reuse across restarts
|
||||
persistCredentials: true
|
||||
credentialsFile: "/tmp/oidc-client-credentials.json"
|
||||
|
||||
# Client metadata for registration
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://your-app.com/oauth2/callback"
|
||||
client_name: "My Application"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
response_types:
|
||||
- "code"
|
||||
token_endpoint_auth_method: "client_secret_basic"
|
||||
contacts:
|
||||
- "admin@your-app.com"
|
||||
```
|
||||
|
||||
### DCR Configuration Parameters
|
||||
|
||||
| Parameter | Description | Required | Default |
|
||||
|-----------|-------------|----------|---------|
|
||||
| `enabled` | Enable dynamic client registration | Yes | `false` |
|
||||
| `initialAccessToken` | Bearer token for protected registration endpoints | No | - |
|
||||
| `registrationEndpoint` | Override auto-discovered registration endpoint | No | From discovery |
|
||||
| `persistCredentials` | Save registered credentials to file | No | `false` |
|
||||
| `credentialsFile` | Path to store/load credentials | No | `/tmp/oidc-client-credentials.json` |
|
||||
| `clientMetadata.redirect_uris` | **REQUIRED** - Redirect URIs for OAuth flow | Yes | - |
|
||||
| `clientMetadata.client_name` | Human-readable client name | No | - |
|
||||
| `clientMetadata.application_type` | `web` or `native` | No | `web` |
|
||||
| `clientMetadata.grant_types` | OAuth grant types | No | `["authorization_code", "refresh_token"]` |
|
||||
| `clientMetadata.response_types` | OAuth response types | No | `["code"]` |
|
||||
| `clientMetadata.token_endpoint_auth_method` | Authentication method | No | `client_secret_basic` |
|
||||
| `clientMetadata.contacts` | Contact email addresses | No | - |
|
||||
| `clientMetadata.logo_uri` | URL to client logo | No | - |
|
||||
| `clientMetadata.client_uri` | URL to client homepage | No | - |
|
||||
| `clientMetadata.policy_uri` | URL to privacy policy | No | - |
|
||||
| `clientMetadata.tos_uri` | URL to terms of service | No | - |
|
||||
| `clientMetadata.scope` | Space-separated scopes | No | - |
|
||||
|
||||
### Provider Support
|
||||
|
||||
DCR support varies by provider:
|
||||
|
||||
| Provider | DCR Support | Notes |
|
||||
|----------|-------------|-------|
|
||||
| Keycloak | ✅ Full | Enable in realm settings |
|
||||
| Auth0 | ✅ Full | Requires Management API token |
|
||||
| Okta | ✅ Full | Enable Dynamic Client Registration |
|
||||
| Azure AD | ⚠️ Limited | App Registration API instead |
|
||||
| Google | ❌ No | Manual registration required |
|
||||
| AWS Cognito | ❌ No | Manual registration required |
|
||||
|
||||
### Security Considerations
|
||||
|
||||
1. **HTTPS Required**: Registration endpoints must use HTTPS (except localhost for development)
|
||||
2. **Initial Access Token**: Recommended for production to prevent unauthorized registrations
|
||||
3. **Credential Persistence**: If enabled, ensure the credentials file has appropriate permissions (0600)
|
||||
4. **Secret Expiration**: Monitor `client_secret_expires_at` and handle rotation if needed
|
||||
|
||||
### Example: Keycloak with DCR
|
||||
|
||||
```yaml
|
||||
dynamicClientRegistration:
|
||||
enabled: true
|
||||
clientMetadata:
|
||||
redirect_uris:
|
||||
- "https://myapp.example.com/oauth2/callback"
|
||||
client_name: "My App - Production"
|
||||
application_type: "web"
|
||||
grant_types:
|
||||
- "authorization_code"
|
||||
- "refresh_token"
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
@@ -730,6 +1008,87 @@ spec:
|
||||
|
||||
**Important**: The `cookieDomain` parameter is crucial when running behind a reverse proxy or when your application serves multiple subdomains. Without it, cookies may be created with inconsistent domains, leading to authentication issues like "CSRF token missing in session" errors.
|
||||
|
||||
### With Multiple Middleware Instances (Session Isolation)
|
||||
|
||||
When running multiple middleware instances with different authorization requirements (e.g., one for general users and one for admins), you must use different `cookiePrefix` values to prevent session sharing between instances:
|
||||
|
||||
```yaml
|
||||
# Middleware for general user authentication
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-userauth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: user-key-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
cookiePrefix: "_oidc_userauth_" # Unique prefix for this instance
|
||||
---
|
||||
# Middleware for admin authentication with stricter requirements
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-adminauth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: admin-key-at-least-32-bytes-long # Different encryption key
|
||||
callbackURL: /oauth2/admin/callback # Different callback URL
|
||||
cookiePrefix: "_oidc_adminauth_" # Different prefix for isolation
|
||||
allowedUsers: # Restricted to specific admin users
|
||||
- admin@example.com
|
||||
- superadmin@example.com
|
||||
```
|
||||
|
||||
**Security Note**: When running multiple instances, ensure you use:
|
||||
1. **Different `cookiePrefix`** values to prevent cookie name collisions
|
||||
2. **Different `sessionEncryptionKey`** values for complete session isolation
|
||||
3. **Different `callbackURL`** paths to avoid routing conflicts
|
||||
|
||||
This configuration prevents authorization bypass issues where a user authenticated via the general middleware could access admin-protected routes. See [issue #87](https://github.com/lukaszraczylo/traefikoidc/issues/87) for more details.
|
||||
|
||||
### With Extended Session Duration
|
||||
|
||||
For applications that users access infrequently (weekly or monthly), you can extend the session duration beyond the default 24 hours to reduce authentication friction:
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-long-session
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://auth.example.com
|
||||
clientID: your-client-id
|
||||
clientSecret: your-client-secret
|
||||
sessionEncryptionKey: your-key-at-least-32-bytes-long
|
||||
callbackURL: /oauth2/callback
|
||||
sessionMaxAge: 604800 # 7 days (in seconds)
|
||||
# Other common values:
|
||||
# 259200 - 3 days
|
||||
# 604800 - 7 days
|
||||
# 1209600 - 14 days
|
||||
# 2592000 - 30 days
|
||||
```
|
||||
|
||||
**Security Note**: Longer session durations improve user experience but increase security risk. Consider your application's security requirements:
|
||||
- **High-security apps**: Use shorter sessions (3600 = 1 hour)
|
||||
- **Standard apps**: Default 24 hours balances security and UX
|
||||
- **Low-frequency access apps**: Extend to 7-30 days for better UX
|
||||
|
||||
See [issue #91](https://github.com/lukaszraczylo/traefikoidc/issues/91) for more details.
|
||||
|
||||
### With Custom Logging and Rate Limiting
|
||||
|
||||
```yaml
|
||||
@@ -885,6 +1244,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -909,8 +1307,13 @@ spec:
|
||||
|
||||
scopes:
|
||||
- read:custom_data # Custom scopes as needed
|
||||
|
||||
# Custom claim names for Auth0 namespaced claims
|
||||
roleClaimName: "https://your-app.com/roles" # Auth0 requires namespaced custom claims
|
||||
groupClaimName: "https://your-app.com/groups" # Must match claims added in Auth0 Actions
|
||||
|
||||
allowedRolesAndGroups:
|
||||
- "https://your-app.com/roles:admin" # Namespaced claims from Actions
|
||||
- admin # Will match "admin" in https://your-app.com/roles claim
|
||||
- editor
|
||||
postLogoutRedirectURI: /logged-out-page # Must be in Auth0 Allowed Logout URLs
|
||||
```
|
||||
@@ -966,8 +1369,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1089,7 +1496,7 @@ services:
|
||||
image: traefik:v3.2.1
|
||||
command:
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.7.8"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.7.10"
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./traefik-config/traefik.yml:/etc/traefik/traefik.yml
|
||||
@@ -1196,58 +1603,6 @@ http:
|
||||
{{range $i, $e := .Claims.roles}}{{if $i}},{{end}}{{$e}}{{end}}
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Session Management
|
||||
|
||||
The middleware uses encrypted cookies to manage user sessions. The `sessionEncryptionKey` must be at least 32 bytes long and should be kept secret.
|
||||
|
||||
### PKCE Support
|
||||
|
||||
The middleware supports PKCE (Proof Key for Code Exchange), which is an extension to the authorization code flow to prevent authorization code interception attacks. When enabled via the `enablePKCE` option, the middleware will generate a code verifier for each authentication request and derive a code challenge from it. The code verifier is stored in the user's session and sent during the token exchange process.
|
||||
|
||||
PKCE is recommended when:
|
||||
- Your OIDC provider supports it (most modern providers do)
|
||||
- You need an additional layer of security for the authorization code flow
|
||||
- You're concerned about potential authorization code interception attacks
|
||||
|
||||
Note that not all OIDC providers support PKCE, so check your provider's documentation before enabling this feature.
|
||||
|
||||
### Session Duration and Token Refresh
|
||||
|
||||
This middleware aims to provide long-lived user sessions, typically up to 24 hours, by utilizing OIDC refresh tokens.
|
||||
|
||||
**How it works:**
|
||||
- When a user authenticates, the middleware requests an access token and, if available, a refresh token from the OIDC provider.
|
||||
- The access token usually has a short lifespan (e.g., 1 hour).
|
||||
- Before the access token expires (controlled by `refreshGracePeriodSeconds`), the middleware uses the refresh token to obtain a new access token from the provider without requiring the user to log in again.
|
||||
- This process repeats, allowing the session to remain valid for as long as the refresh token is valid (often 24 hours or more, depending on the provider).
|
||||
|
||||
**Provider-Specific Considerations (e.g., Google):**
|
||||
- Some providers, like Google, issue short-lived access tokens (e.g., 1 hour) and require specific configurations for long-term sessions.
|
||||
- To enable session extension beyond the initial token expiry with Google and similar providers, the middleware automatically includes the `offline_access` scope in the authentication request. This scope is necessary to obtain a refresh token.
|
||||
- For Google specifically, the middleware also adds the `prompt=consent` parameter to the initial authorization request. This ensures Google issues a refresh token, which is crucial for extending the session.
|
||||
- If a refresh attempt fails (e.g., the refresh token is revoked or expired), the user will be required to re-authenticate. The middleware includes enhanced error handling and logging for these scenarios.
|
||||
- Ensure your OIDC provider is configured to issue refresh tokens and allows their use for extending sessions. Check your provider's documentation for details on refresh token validity periods.
|
||||
|
||||
### Google OAuth Compatibility Fix
|
||||
|
||||
The middleware includes a specific fix for Google's OAuth implementation, which differs from the standard OIDC specification in how it handles refresh tokens:
|
||||
|
||||
- **Issue**: Google does not support the standard `offline_access` scope for requesting refresh tokens and instead requires special parameters.
|
||||
|
||||
- **Automatic Solution**: The middleware detects Google as the provider based on the issuer URL and:
|
||||
- Uses `access_type=offline` query parameter instead of the `offline_access` scope
|
||||
- Adds `prompt=consent` to ensure refresh tokens are consistently issued
|
||||
- Properly handles token refresh with Google's implementation
|
||||
|
||||
You do not need any special configuration to use Google OAuth - just set `providerURL` to `https://accounts.google.com` and the middleware will automatically apply the proper parameters.
|
||||
|
||||
For detailed information on the Google OAuth fix, see the [dedicated documentation](docs/google-oauth-fix.md).
|
||||
|
||||
### Token Caching and Blacklisting
|
||||
|
||||
The middleware automatically caches validated tokens to improve performance and maintains a blacklist of revoked tokens.
|
||||
### Templated Headers
|
||||
|
||||
The middleware supports setting custom HTTP headers with values templated from OIDC claims and tokens. This allows you to pass authentication information to downstream services in a flexible, customized format.
|
||||
@@ -1320,12 +1675,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1449,32 +1831,6 @@ GitLab supports OIDC for both GitLab.com and self-hosted instances.
|
||||
* **Scopes**: Use `user:email`, `read:user` for basic profile access
|
||||
* **Detection**: Auto-detected from `github.com` in issuer URL
|
||||
|
||||
### Azure AD (Microsoft Entra ID)
|
||||
|
||||
Azure AD generally works well with standard OIDC configurations.
|
||||
|
||||
* **ID Token Claims**: Azure AD typically includes standard claims like `email`, `name`, `preferred_username`, and `oid` (Object ID) in the ID Token by default when `openid profile email` scopes are requested.
|
||||
* **Group Claims**: To include group claims in the ID Token, you need to configure this in the Azure AD application registration:
|
||||
* Go to your App Registration -> Token configuration -> Add groups claim.
|
||||
* You can choose which types of groups (Security groups, Directory roles, All groups) to include.
|
||||
* Be aware of the "overage" issue: If a user is a member of too many groups, Azure AD will send a link to fetch groups instead of embedding them. This plugin currently expects group claims to be directly in the ID token. For users with many groups, consider alternative role/permission management strategies.
|
||||
* The claim name for groups is typically `groups`.
|
||||
* **Optional Claims**: You can add other optional claims via the "Token configuration" section of your App Registration. Ensure these are configured for the ID token.
|
||||
* **Endpoints**: The `providerURL` should be `https://login.microsoftonline.com/{your-tenant-id}/v2.0`. The plugin will auto-discover the necessary endpoints.
|
||||
* **Optimization**: Ensure your application manifest in Azure AD is configured for the desired token version (v1.0 or v2.0). This plugin works with v2.0 endpoints.
|
||||
|
||||
### Google Workspace / Google Cloud Identity
|
||||
|
||||
Google's OIDC implementation is well-supported.
|
||||
|
||||
* **Optimal Configuration**: The plugin automatically handles Google-specific requirements, such as using `access_type=offline` and `prompt=consent` to ensure refresh tokens are issued for long-lived sessions. You do not need to add `offline_access` to scopes.
|
||||
* **ID Token Claims**: Google includes standard claims like `email`, `sub`, `name`, `given_name`, `family_name`, `picture` in the ID Token by default with `openid profile email` scopes.
|
||||
* **Hosted Domain (hd claim)**: If you are using Google Workspace and want to restrict access to users within your organization's domain, Google includes an `hd` (hosted domain) claim in the ID Token. You can use this with the `allowedUserDomains` setting or for custom header logic.
|
||||
* **Best Practices**:
|
||||
* Use the `providerURL`: `https://accounts.google.com`.
|
||||
* Ensure your OAuth consent screen in Google Cloud Console is configured correctly and published. For production, it should be "External" and in "Production" status. "Testing" status limits refresh token lifetime.
|
||||
* Refer to the [Google OAuth Compatibility Fix](#google-oauth-compatibility-fix) section for more details on how the plugin handles Google's specifics.
|
||||
|
||||
### Auth0
|
||||
|
||||
Auth0 is generally OIDC compliant and works well.
|
||||
@@ -1579,6 +1935,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
+22
-21
@@ -838,7 +838,7 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
}
|
||||
|
||||
logger := NewLogger("debug")
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", logger)
|
||||
sm, err := NewSessionManager(strings.Repeat("a", MinSessionEncryptionKeyLength), false, "", "", 0, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create session manager: %v", err)
|
||||
}
|
||||
@@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
userIdentifierClaim: "email", // Required for user identification
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
+29
-25
@@ -18,17 +18,18 @@ type ScopeFilter interface {
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -40,19 +41,20 @@ type Logger interface {
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
scopeFilter ScopeFilter, scopesSupported []string, allowPrivateIPAddresses bool) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter,
|
||||
scopesSupported: scopesSupported,
|
||||
allowPrivateIPAddresses: allowPrivateIPAddresses,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +349,7 @@ func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
@@ -361,7 +364,7 @@ func (h *Handler) validateHost(host string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
// Check for localhost variations (always blocked, even with allowPrivateIPAddresses)
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
@@ -376,7 +379,8 @@ func (h *Handler) validateHost(host string) error {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
// Skip private IP check if allowPrivateIPAddresses is enabled
|
||||
if !h.allowPrivateIPAddresses && ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
|
||||
+25
-25
@@ -86,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
||||
"test-client-id", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -125,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{redirectCount: 5} // At the limit
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -160,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -191,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -222,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -253,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{saveError: &testError{"save failed"}}
|
||||
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
||||
@@ -297,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
||||
@@ -400,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -468,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -493,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -565,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
tt.scopes, tt.overrideScopes, nil, nil)
|
||||
tt.scopes, tt.overrideScopes, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -676,7 +676,7 @@ func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -714,7 +714,7 @@ func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
|
||||
"https://gitlab.example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -756,7 +756,7 @@ func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://accounts.google.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -797,7 +797,7 @@ func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -831,7 +831,7 @@ func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"generic-client", "https://auth.provider.com/authorize",
|
||||
"https://auth.provider.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -870,7 +870,7 @@ func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, true, scopeFilter, scopesSupported)
|
||||
scopes, true, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -955,7 +955,7 @@ func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, scopesSupported) // scopeFilter is nil
|
||||
scopes, false, nil, scopesSupported, false) // scopeFilter is nil
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -988,7 +988,7 @@ func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -1021,7 +1021,7 @@ func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -1064,7 +1064,7 @@ func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
|
||||
|
||||
@@ -1130,7 +1130,7 @@ func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
|
||||
+103
-5
@@ -10,7 +10,7 @@ import (
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -560,3 +560,101 @@ func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost_AllowPrivateIPAddresses tests the allowPrivateIPAddresses flag
|
||||
func TestAuthHandler_validateHost_AllowPrivateIPAddresses(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Test with allowPrivateIPAddresses = false (default)
|
||||
t.Run("Private IPs blocked by default", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err == nil {
|
||||
t.Errorf("Expected private IP %s to be blocked, but it was allowed", ip)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "private IP not allowed") {
|
||||
t.Errorf("Expected 'private IP not allowed' error for %s, got: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test with allowPrivateIPAddresses = true
|
||||
t.Run("Private IPs allowed when flag enabled", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected private IP %s to be allowed with flag enabled, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that loopback is still blocked even with flag enabled
|
||||
t.Run("Loopback always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
loopbackAddresses := []string{
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
}
|
||||
|
||||
for _, addr := range loopbackAddresses {
|
||||
err := handler.validateHost(addr)
|
||||
if err == nil {
|
||||
t.Errorf("Expected loopback address %s to be blocked even with allowPrivateIPAddresses=true", addr)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that link-local is still blocked even with flag enabled
|
||||
t.Run("Link-local always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
err := handler.validateHost("169.254.1.1")
|
||||
if err == nil {
|
||||
t.Error("Expected link-local address to be blocked even with allowPrivateIPAddresses=true")
|
||||
}
|
||||
})
|
||||
|
||||
// Test that public IPs work with flag enabled
|
||||
t.Run("Public IPs allowed", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"142.250.185.68",
|
||||
}
|
||||
|
||||
for _, ip := range publicIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected public IP %s to be allowed, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+19
-9
@@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
+1
-1
@@ -79,7 +79,7 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
tOidc := &mockTraefikOidc{TraefikOidc: baseOidc}
|
||||
|
||||
// Initialize session manager
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", mockLogger)
|
||||
sessionManager, _ := NewSessionManager("test-encryption-key-32-bytes-long", false, "", "", 0, mockLogger)
|
||||
tOidc.sessionManager = sessionManager
|
||||
|
||||
// Mock the JWT verification to avoid JWKS lookup issues
|
||||
|
||||
+28
-1
@@ -21,10 +21,37 @@ var (
|
||||
)
|
||||
|
||||
// GetGlobalCacheManager returns a singleton CacheManager instance
|
||||
// Deprecated: Use GetGlobalCacheManagerWithConfig instead
|
||||
func GetGlobalCacheManager(wg *sync.WaitGroup) *CacheManager {
|
||||
return GetGlobalCacheManagerWithConfig(wg, nil)
|
||||
}
|
||||
|
||||
// GetGlobalCacheManagerWithConfig returns a singleton CacheManager instance with optional Redis configuration
|
||||
func GetGlobalCacheManagerWithConfig(wg *sync.WaitGroup, config *Config) *CacheManager {
|
||||
cacheManagerInitOnce.Do(func() {
|
||||
var redisConfig *RedisConfig
|
||||
var logger *Logger
|
||||
|
||||
if config != nil {
|
||||
logger = NewLogger(config.LogLevel)
|
||||
|
||||
// Initialize Redis config if not present
|
||||
if config.Redis == nil {
|
||||
config.Redis = &RedisConfig{}
|
||||
}
|
||||
|
||||
// Apply environment variable fallbacks for fields not set in config
|
||||
// This allows env vars to be used as optional overrides
|
||||
config.Redis.ApplyEnvFallbacks()
|
||||
|
||||
// Apply defaults after env fallbacks
|
||||
config.Redis.ApplyDefaults()
|
||||
|
||||
redisConfig = config.Redis
|
||||
}
|
||||
|
||||
globalCacheManagerInstance = &CacheManager{
|
||||
manager: GetUniversalCacheManager(nil),
|
||||
manager: GetUniversalCacheManagerWithConfig(logger, redisConfig),
|
||||
}
|
||||
})
|
||||
return globalCacheManagerInstance
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// Package config provides backward compatibility for legacy configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// LegacyAdapter provides backward compatibility for old Config struct
|
||||
type LegacyAdapter struct {
|
||||
unified *UnifiedConfig
|
||||
adapter *compat.ConfigAdapter
|
||||
}
|
||||
|
||||
// NewLegacyAdapter creates a new legacy adapter from unified config
|
||||
func NewLegacyAdapter(unified *UnifiedConfig) *LegacyAdapter {
|
||||
adapter := compat.NewConfigAdapter(unified)
|
||||
|
||||
// Register getters for commonly used fields
|
||||
adapter.RegisterGetter("ProviderURL", func() interface{} {
|
||||
return unified.Provider.IssuerURL
|
||||
})
|
||||
adapter.RegisterGetter("ClientID", func() interface{} {
|
||||
return unified.Provider.ClientID
|
||||
})
|
||||
adapter.RegisterGetter("ClientSecret", func() interface{} {
|
||||
return unified.Provider.ClientSecret
|
||||
})
|
||||
adapter.RegisterGetter("CallbackURL", func() interface{} {
|
||||
return unified.Provider.RedirectURL
|
||||
})
|
||||
adapter.RegisterGetter("LogoutURL", func() interface{} {
|
||||
return unified.Provider.LogoutURL
|
||||
})
|
||||
adapter.RegisterGetter("PostLogoutRedirectURI", func() interface{} {
|
||||
return unified.Provider.PostLogoutRedirectURI
|
||||
})
|
||||
adapter.RegisterGetter("SessionEncryptionKey", func() interface{} {
|
||||
return unified.Session.EncryptionKey
|
||||
})
|
||||
adapter.RegisterGetter("ForceHTTPS", func() interface{} {
|
||||
return unified.Security.ForceHTTPS
|
||||
})
|
||||
adapter.RegisterGetter("LogLevel", func() interface{} {
|
||||
return unified.Logging.Level
|
||||
})
|
||||
adapter.RegisterGetter("Scopes", func() interface{} {
|
||||
return unified.Provider.Scopes
|
||||
})
|
||||
adapter.RegisterGetter("OverrideScopes", func() interface{} {
|
||||
return unified.Provider.OverrideScopes
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUsers", func() interface{} {
|
||||
return unified.Security.AllowedUsers
|
||||
})
|
||||
adapter.RegisterGetter("AllowedUserDomains", func() interface{} {
|
||||
return unified.Security.AllowedUserDomains
|
||||
})
|
||||
adapter.RegisterGetter("AllowedRolesAndGroups", func() interface{} {
|
||||
return unified.Security.AllowedRolesAndGroups
|
||||
})
|
||||
adapter.RegisterGetter("ExcludedURLs", func() interface{} {
|
||||
return unified.Security.ExcludedURLs
|
||||
})
|
||||
adapter.RegisterGetter("EnablePKCE", func() interface{} {
|
||||
return unified.Security.EnablePKCE
|
||||
})
|
||||
adapter.RegisterGetter("RateLimit", func() interface{} {
|
||||
return unified.RateLimit.RequestsPerSecond
|
||||
})
|
||||
adapter.RegisterGetter("RefreshGracePeriodSeconds", func() interface{} {
|
||||
return int(unified.Token.RefreshGracePeriod.Seconds())
|
||||
})
|
||||
adapter.RegisterGetter("CookieDomain", func() interface{} {
|
||||
return unified.Session.Domain
|
||||
})
|
||||
adapter.RegisterGetter("SecurityHeaders", func() interface{} {
|
||||
return unified.Security.Headers
|
||||
})
|
||||
|
||||
return &LegacyAdapter{
|
||||
unified: unified,
|
||||
adapter: adapter,
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig converts unified config to old Config struct format
|
||||
func (la *LegacyAdapter) ToOldConfig() *Config {
|
||||
// Use feature flags to determine behavior
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Return existing Config if unified config not enabled
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
ProviderURL: la.unified.Provider.IssuerURL,
|
||||
ClientID: la.unified.Provider.ClientID,
|
||||
ClientSecret: la.unified.Provider.ClientSecret,
|
||||
CallbackURL: la.unified.Provider.RedirectURL,
|
||||
LogoutURL: la.unified.Provider.LogoutURL,
|
||||
PostLogoutRedirectURI: la.unified.Provider.PostLogoutRedirectURI,
|
||||
SessionEncryptionKey: la.unified.Session.EncryptionKey,
|
||||
ForceHTTPS: la.unified.Security.ForceHTTPS,
|
||||
LogLevel: la.unified.Logging.Level,
|
||||
Scopes: la.unified.Provider.Scopes,
|
||||
OverrideScopes: la.unified.Provider.OverrideScopes,
|
||||
AllowedUsers: la.unified.Security.AllowedUsers,
|
||||
AllowedUserDomains: la.unified.Security.AllowedUserDomains,
|
||||
AllowedRolesAndGroups: la.unified.Security.AllowedRolesAndGroups,
|
||||
ExcludedURLs: la.unified.Security.ExcludedURLs,
|
||||
EnablePKCE: la.unified.Security.EnablePKCE,
|
||||
RateLimit: la.unified.RateLimit.RequestsPerSecond,
|
||||
RefreshGracePeriodSeconds: int(la.unified.Token.RefreshGracePeriod.Seconds()),
|
||||
Headers: la.convertHeaders(),
|
||||
CookieDomain: la.unified.Session.Domain,
|
||||
SecurityHeaders: la.unified.Security.Headers,
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// convertHeaders converts unified header config to old format
|
||||
func (la *LegacyAdapter) convertHeaders() []HeaderConfig {
|
||||
headers := make([]HeaderConfig, 0)
|
||||
|
||||
for name, value := range la.unified.Middleware.CustomHeaders {
|
||||
headers = append(headers, HeaderConfig{
|
||||
Name: name,
|
||||
Value: value,
|
||||
})
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// FromOldConfig creates unified config from old Config struct
|
||||
func FromOldConfig(old *Config) *UnifiedConfig {
|
||||
unified := NewUnifiedConfig()
|
||||
|
||||
// Map provider settings
|
||||
unified.Provider.IssuerURL = old.ProviderURL
|
||||
unified.Provider.ClientID = old.ClientID
|
||||
unified.Provider.ClientSecret = old.ClientSecret
|
||||
unified.Provider.RedirectURL = old.CallbackURL
|
||||
unified.Provider.LogoutURL = old.LogoutURL
|
||||
unified.Provider.PostLogoutRedirectURI = old.PostLogoutRedirectURI
|
||||
unified.Provider.Scopes = old.Scopes
|
||||
unified.Provider.OverrideScopes = old.OverrideScopes
|
||||
|
||||
// Map session settings
|
||||
unified.Session.EncryptionKey = old.SessionEncryptionKey
|
||||
unified.Session.Domain = old.CookieDomain
|
||||
|
||||
// Map security settings
|
||||
unified.Security.ForceHTTPS = old.ForceHTTPS
|
||||
unified.Security.EnablePKCE = old.EnablePKCE
|
||||
unified.Security.AllowedUsers = old.AllowedUsers
|
||||
unified.Security.AllowedUserDomains = old.AllowedUserDomains
|
||||
unified.Security.AllowedRolesAndGroups = old.AllowedRolesAndGroups
|
||||
unified.Security.ExcludedURLs = old.ExcludedURLs
|
||||
unified.Security.Headers = old.SecurityHeaders
|
||||
|
||||
// Map rate limiting
|
||||
unified.RateLimit.RequestsPerSecond = old.RateLimit
|
||||
unified.RateLimit.Enabled = old.RateLimit > 0
|
||||
|
||||
// Map token settings
|
||||
unified.Token.RefreshGracePeriod = timeSecondsToDuration(old.RefreshGracePeriodSeconds)
|
||||
|
||||
// Map logging
|
||||
unified.Logging.Level = old.LogLevel
|
||||
|
||||
// Map custom headers
|
||||
if len(old.Headers) > 0 {
|
||||
unified.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, header := range old.Headers {
|
||||
unified.Middleware.CustomHeaders[header.Name] = header.Value
|
||||
}
|
||||
}
|
||||
|
||||
// Store original config in legacy field for reference
|
||||
unified.Legacy["original"] = old
|
||||
|
||||
return unified
|
||||
}
|
||||
|
||||
// timeSecondsToDuration converts seconds to time.Duration
|
||||
func timeSecondsToDuration(seconds int) time.Duration {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// GetConfigInterface returns appropriate config based on feature flag
|
||||
func GetConfigInterface() interface{} {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
return CreateConfig()
|
||||
}
|
||||
|
||||
// ValidateConfig validates config based on feature flag
|
||||
func ValidateConfig(cfg interface{}) error {
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if unified, ok := cfg.(*UnifiedConfig); ok {
|
||||
return unified.Validate()
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to old validation if available
|
||||
if old, ok := cfg.(*Config); ok {
|
||||
return old.Validate()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add Validate method to old Config for compatibility
|
||||
func (c *Config) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Basic validation for old config
|
||||
if c.ProviderURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ProviderURL",
|
||||
Message: "provider URL is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
if c.ClientSecret == "" && !c.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE)",
|
||||
})
|
||||
}
|
||||
|
||||
if c.SessionEncryptionKey != "" && len(c.SessionEncryptionKey) < minEncryptionKeyLength {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "SessionEncryptionKey",
|
||||
Message: fmt.Sprintf("encryption key must be at least %d characters", minEncryptionKeyLength),
|
||||
Value: len(c.SessionEncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
)
|
||||
|
||||
// NewLegacyAdapter Tests
|
||||
func TestNewLegacyAdapter(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "test-client"
|
||||
unified.Provider.ClientSecret = "test-secret"
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("Expected NewLegacyAdapter to return non-nil")
|
||||
}
|
||||
|
||||
if adapter.unified != unified {
|
||||
t.Error("Expected adapter to reference the unified config")
|
||||
}
|
||||
|
||||
if adapter.adapter == nil {
|
||||
t.Error("Expected internal adapter to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// ToOldConfig Tests
|
||||
func TestLegacyAdapter_ToOldConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://issuer.example.com"
|
||||
unified.Provider.ClientID = "client-123"
|
||||
unified.Provider.ClientSecret = "secret-456"
|
||||
unified.Provider.RedirectURL = "https://app.example.com/callback"
|
||||
unified.Provider.LogoutURL = "/logout"
|
||||
unified.Provider.PostLogoutRedirectURI = "https://app.example.com"
|
||||
unified.Provider.Scopes = []string{"openid", "profile"}
|
||||
unified.Provider.OverrideScopes = true
|
||||
unified.Session.EncryptionKey = "test-encryption-key-32-chars!!"
|
||||
unified.Session.Domain = "example.com"
|
||||
unified.Security.ForceHTTPS = true
|
||||
unified.Security.EnablePKCE = true
|
||||
unified.Security.AllowedUsers = []string{"user@example.com"}
|
||||
unified.Security.AllowedUserDomains = []string{"example.com"}
|
||||
unified.Security.AllowedRolesAndGroups = []string{"admin"}
|
||||
unified.Security.ExcludedURLs = []string{"/health"}
|
||||
unified.RateLimit.RequestsPerSecond = 100
|
||||
unified.Logging.Level = "debug"
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Header-1": "value1",
|
||||
"X-Header-2": "value2",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
oldConfig := adapter.ToOldConfig()
|
||||
|
||||
if oldConfig == nil {
|
||||
t.Fatal("Expected ToOldConfig to return non-nil")
|
||||
}
|
||||
|
||||
// ToOldConfig behavior depends on feature flag
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// When feature is disabled, returns default config
|
||||
if oldConfig.ProviderURL == "" {
|
||||
t.Log("Feature flag disabled - ToOldConfig returns default config")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// When feature is enabled, verify all fields were correctly mapped
|
||||
if oldConfig.ProviderURL != unified.Provider.IssuerURL {
|
||||
t.Errorf("Expected ProviderURL '%s', got '%s'", unified.Provider.IssuerURL, oldConfig.ProviderURL)
|
||||
}
|
||||
|
||||
if oldConfig.ClientID != unified.Provider.ClientID {
|
||||
t.Errorf("Expected ClientID '%s', got '%s'", unified.Provider.ClientID, oldConfig.ClientID)
|
||||
}
|
||||
|
||||
if oldConfig.ClientSecret != unified.Provider.ClientSecret {
|
||||
t.Errorf("Expected ClientSecret '%s', got '%s'", unified.Provider.ClientSecret, oldConfig.ClientSecret)
|
||||
}
|
||||
|
||||
if oldConfig.CallbackURL != unified.Provider.RedirectURL {
|
||||
t.Error("Expected CallbackURL to match RedirectURL")
|
||||
}
|
||||
|
||||
if oldConfig.LogoutURL != unified.Provider.LogoutURL {
|
||||
t.Error("Expected LogoutURL to match")
|
||||
}
|
||||
|
||||
if oldConfig.ForceHTTPS != unified.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to match")
|
||||
}
|
||||
|
||||
if oldConfig.EnablePKCE != unified.Security.EnablePKCE {
|
||||
t.Error("Expected EnablePKCE to match")
|
||||
}
|
||||
|
||||
if oldConfig.RateLimit != unified.RateLimit.RequestsPerSecond {
|
||||
t.Errorf("Expected RateLimit %d, got %d", unified.RateLimit.RequestsPerSecond, oldConfig.RateLimit)
|
||||
}
|
||||
|
||||
if len(oldConfig.Headers) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(oldConfig.Headers))
|
||||
}
|
||||
}
|
||||
|
||||
// convertHeaders Tests
|
||||
func TestLegacyAdapter_convertHeaders(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Middleware.CustomHeaders = map[string]string{
|
||||
"X-Custom-Header-1": "value1",
|
||||
"X-Custom-Header-2": "value2",
|
||||
"X-Custom-Header-3": "value3",
|
||||
}
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 3 {
|
||||
t.Errorf("Expected 3 headers, got %d", len(headers))
|
||||
}
|
||||
|
||||
// Check that headers were converted
|
||||
headerMap := make(map[string]string)
|
||||
for _, h := range headers {
|
||||
headerMap[h.Name] = h.Value
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-1"] != "value1" {
|
||||
t.Error("Expected X-Custom-Header-1 to have value 'value1'")
|
||||
}
|
||||
|
||||
if headerMap["X-Custom-Header-2"] != "value2" {
|
||||
t.Error("Expected X-Custom-Header-2 to have value 'value2'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegacyAdapter_convertHeaders_Empty(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
// No custom headers
|
||||
|
||||
adapter := NewLegacyAdapter(unified)
|
||||
headers := adapter.convertHeaders()
|
||||
|
||||
if len(headers) != 0 {
|
||||
t.Errorf("Expected 0 headers, got %d", len(headers))
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfigInterface Tests
|
||||
func TestGetConfigInterface(t *testing.T) {
|
||||
cfg := GetConfigInterface()
|
||||
|
||||
if cfg == nil {
|
||||
t.Fatal("Expected GetConfigInterface to return non-nil")
|
||||
}
|
||||
|
||||
// Should return either UnifiedConfig or Config depending on feature flag
|
||||
_, isUnified := cfg.(*UnifiedConfig)
|
||||
_, isOld := cfg.(*Config)
|
||||
|
||||
if !isUnified && !isOld {
|
||||
t.Error("Expected either *UnifiedConfig or *Config")
|
||||
}
|
||||
|
||||
// Verify consistency with feature flag
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
if !isUnified {
|
||||
t.Error("Expected *UnifiedConfig when unified config is enabled")
|
||||
}
|
||||
} else {
|
||||
if !isOld {
|
||||
t.Error("Expected *Config when unified config is disabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateConfig Tests
|
||||
func TestValidateConfig_UnifiedConfig(t *testing.T) {
|
||||
unified := NewUnifiedConfig()
|
||||
unified.Provider.IssuerURL = "https://provider.example.com"
|
||||
unified.Provider.ClientID = "client-id"
|
||||
unified.Provider.ClientSecret = "client-secret"
|
||||
unified.Session.EncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(unified)
|
||||
// Should succeed regardless of feature flag since we're passing the right type
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid unified config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_OldConfig(t *testing.T) {
|
||||
old := CreateConfig()
|
||||
old.ProviderURL = "https://provider.example.com"
|
||||
old.ClientID = "client-id"
|
||||
old.ClientSecret = "client-secret"
|
||||
old.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := ValidateConfig(old)
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid old config to pass validation, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateConfig_InvalidType(t *testing.T) {
|
||||
// Pass something that's not a config
|
||||
err := ValidateConfig("not a config")
|
||||
if err != nil {
|
||||
t.Errorf("Expected nil for unknown type, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Config.Validate Tests
|
||||
func TestConfig_Validate_Valid(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "encryption-key-32-characters!!"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid config to pass, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingProviderURL(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ProviderURL")
|
||||
}
|
||||
|
||||
// Check if it's a ValidationErrors type
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ProviderURL" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ProviderURL validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientID(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientID")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientID" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientID validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_NoPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = false
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for missing ClientSecret without PKCE")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "ClientSecret" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected ClientSecret validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MissingClientSecret_WithPKCE(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.EnablePKCE = true // PKCE enabled, so ClientSecret not required
|
||||
|
||||
err := cfg.Validate()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error with PKCE enabled and no ClientSecret, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_ShortEncryptionKey(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
cfg.ProviderURL = "https://provider.example.com"
|
||||
cfg.ClientID = "client-id"
|
||||
cfg.ClientSecret = "client-secret"
|
||||
cfg.SessionEncryptionKey = "short" // Too short
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Error("Expected error for short encryption key")
|
||||
}
|
||||
|
||||
if verrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, verr := range verrs {
|
||||
if verr.Field == "SessionEncryptionKey" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected SessionEncryptionKey validation error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate_MultipleErrors(t *testing.T) {
|
||||
cfg := CreateConfig()
|
||||
// Missing ProviderURL, ClientID, and ClientSecret
|
||||
|
||||
err := cfg.Validate()
|
||||
if err == nil {
|
||||
t.Fatal("Expected validation errors")
|
||||
}
|
||||
|
||||
verrs, ok := err.(ValidationErrors)
|
||||
if !ok {
|
||||
t.Fatal("Expected ValidationErrors type")
|
||||
}
|
||||
|
||||
if len(verrs) < 2 {
|
||||
t.Errorf("Expected at least 2 validation errors, got %d", len(verrs))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,276 @@
|
||||
// Package config provides default values and initialization for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewUnifiedConfig creates a new unified configuration with sensible defaults
|
||||
func NewUnifiedConfig() *UnifiedConfig {
|
||||
return &UnifiedConfig{
|
||||
Provider: DefaultProviderConfig(),
|
||||
Session: DefaultSessionConfig(),
|
||||
Token: DefaultTokenConfig(),
|
||||
Redis: *DefaultRedisConfig(), // Using existing DefaultRedisConfig
|
||||
Security: DefaultSecurityConfig(),
|
||||
Middleware: DefaultMiddlewareConfig(),
|
||||
Cache: DefaultCacheConfig(),
|
||||
RateLimit: DefaultRateLimitConfig(),
|
||||
Logging: DefaultLoggingConfig(),
|
||||
Metrics: DefaultMetricsConfig(),
|
||||
Health: DefaultHealthConfig(),
|
||||
Transport: DefaultTransportConfig(),
|
||||
Pool: DefaultPoolConfig(),
|
||||
Circuit: DefaultCircuitConfig(),
|
||||
Legacy: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultProviderConfig returns default provider configuration
|
||||
func DefaultProviderConfig() ProviderConfig {
|
||||
return ProviderConfig{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
OverrideScopes: false,
|
||||
CustomClaims: make(map[string]string),
|
||||
JWKCachePeriod: 24 * time.Hour,
|
||||
MetadataCacheTTL: 24 * time.Hour,
|
||||
Discovery: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSessionConfig returns default session configuration
|
||||
func DefaultSessionConfig() SessionConfig {
|
||||
return SessionConfig{
|
||||
Name: "oidc_session",
|
||||
MaxAge: 86400, // 24 hours
|
||||
ChunkSize: 4000, // Safe size for cookies
|
||||
MaxChunks: 5,
|
||||
Path: "/",
|
||||
Secure: true,
|
||||
HttpOnly: true,
|
||||
SameSite: "Lax",
|
||||
StorageType: "cookie",
|
||||
CleanupInterval: 1 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTokenConfig returns default token configuration
|
||||
func DefaultTokenConfig() TokenConfig {
|
||||
return TokenConfig{
|
||||
AccessTokenTTL: 1 * time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
RefreshGracePeriod: 60 * time.Second,
|
||||
ValidationMode: "jwt",
|
||||
CacheEnabled: true,
|
||||
CacheTTL: 5 * time.Minute,
|
||||
CacheNegativeTTL: 30 * time.Second,
|
||||
ValidateSignature: true,
|
||||
ValidateExpiry: true,
|
||||
ValidateAudience: true,
|
||||
ValidateIssuer: true,
|
||||
RequiredClaims: []string{"sub", "iat", "exp"},
|
||||
ClockSkew: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultSecurityConfig returns default security configuration
|
||||
func DefaultSecurityConfig() SecurityConfig {
|
||||
return SecurityConfig{
|
||||
ForceHTTPS: true,
|
||||
EnablePKCE: true,
|
||||
AllowedUsers: []string{},
|
||||
AllowedUserDomains: []string{},
|
||||
AllowedRolesAndGroups: []string{},
|
||||
ExcludedURLs: []string{
|
||||
"/favicon.ico",
|
||||
"/robots.txt",
|
||||
"/health",
|
||||
"/.well-known/",
|
||||
"/metrics",
|
||||
"/ping",
|
||||
"/static/",
|
||||
"/assets/",
|
||||
"/js/",
|
||||
"/css/",
|
||||
"/images/",
|
||||
"/fonts/",
|
||||
},
|
||||
Headers: createDefaultSecurityConfig(),
|
||||
CSRFProtection: true,
|
||||
CSRFTokenName: "csrf_token",
|
||||
CSRFTokenTTL: 1 * time.Hour,
|
||||
MaxLoginAttempts: 5,
|
||||
LockoutDuration: 15 * time.Minute,
|
||||
RequireMFA: false,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMiddlewareConfig returns default middleware configuration
|
||||
func DefaultMiddlewareConfig() MiddlewareConfig {
|
||||
return MiddlewareConfig{
|
||||
Priority: 1000,
|
||||
SkipPaths: []string{},
|
||||
RequirePaths: []string{},
|
||||
PassthroughMode: false,
|
||||
MaxRequestSize: 10 * 1024 * 1024, // 10MB
|
||||
RequestTimeout: 30 * time.Second,
|
||||
IdleTimeout: 90 * time.Second,
|
||||
CustomHeaders: make(map[string]string),
|
||||
RemoveHeaders: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCacheConfig returns default cache configuration
|
||||
func DefaultCacheConfig() CacheConfig {
|
||||
return CacheConfig{
|
||||
Enabled: true,
|
||||
Type: "memory",
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxEntries: 10000,
|
||||
MaxEntrySize: 1024 * 1024, // 1MB
|
||||
EvictionPolicy: "lru",
|
||||
CleanupInterval: 10 * time.Minute,
|
||||
Namespace: "traefikoidc",
|
||||
Compression: false,
|
||||
Serialization: "json",
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRateLimitConfig returns default rate limiting configuration
|
||||
func DefaultRateLimitConfig() RateLimitConfig {
|
||||
return RateLimitConfig{
|
||||
Enabled: false,
|
||||
RequestsPerSecond: 10,
|
||||
Burst: 20,
|
||||
StorageType: "memory",
|
||||
WindowDuration: 1 * time.Minute,
|
||||
KeyType: "ip",
|
||||
CustomKeyFunc: "",
|
||||
WhitelistIPs: []string{},
|
||||
WhitelistUsers: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLoggingConfig returns default logging configuration
|
||||
func DefaultLoggingConfig() LoggingConfig {
|
||||
return LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
FilePath: "",
|
||||
FilterSensitive: true,
|
||||
MaskFields: []string{
|
||||
"password",
|
||||
"secret",
|
||||
"token",
|
||||
"key",
|
||||
"authorization",
|
||||
"cookie",
|
||||
},
|
||||
BufferSize: 8192,
|
||||
FlushInterval: 5 * time.Second,
|
||||
AuditEnabled: false,
|
||||
AuditEvents: []string{
|
||||
"login",
|
||||
"logout",
|
||||
"token_refresh",
|
||||
"auth_failure",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultMetricsConfig returns default metrics configuration
|
||||
func DefaultMetricsConfig() MetricsConfig {
|
||||
return MetricsConfig{
|
||||
Enabled: false,
|
||||
Provider: "prometheus",
|
||||
Endpoint: "/metrics",
|
||||
Namespace: "traefikoidc",
|
||||
Subsystem: "middleware",
|
||||
CollectInterval: 10 * time.Second,
|
||||
Histograms: true,
|
||||
Labels: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHealthConfig returns default health check configuration
|
||||
func DefaultHealthConfig() HealthConfig {
|
||||
return HealthConfig{
|
||||
Enabled: true,
|
||||
Path: "/health",
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
CheckProvider: true,
|
||||
CheckRedis: true,
|
||||
CheckCache: true,
|
||||
MaxLatency: 1 * time.Second,
|
||||
MinMemory: 100 * 1024 * 1024, // 100MB
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultTransportConfig returns default HTTP transport configuration
|
||||
func DefaultTransportConfig() TransportConfig {
|
||||
return TransportConfig{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
MaxConnsPerHost: 0, // No limit
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ResponseHeaderTimeout: 10 * time.Second,
|
||||
DisableKeepAlives: false,
|
||||
DisableCompression: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
TLSMinVersion: "TLS1.2",
|
||||
TLSCipherSuites: []string{},
|
||||
ProxyURL: "",
|
||||
NoProxy: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultPoolConfig returns default connection pool configuration
|
||||
func DefaultPoolConfig() PoolConfig {
|
||||
return PoolConfig{
|
||||
Enabled: true,
|
||||
Size: 10,
|
||||
MinSize: 2,
|
||||
MaxSize: 50,
|
||||
MaxAge: 30 * time.Minute,
|
||||
IdleTimeout: 5 * time.Minute,
|
||||
WaitTimeout: 5 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultCircuitConfig returns default circuit breaker configuration
|
||||
func DefaultCircuitConfig() CircuitConfig {
|
||||
return CircuitConfig{
|
||||
Enabled: true,
|
||||
MaxRequests: 100,
|
||||
Interval: 10 * time.Second,
|
||||
Timeout: 60 * time.Second,
|
||||
ConsecutiveFailures: 5,
|
||||
FailureRatio: 0.5,
|
||||
OnOpen: "reject",
|
||||
OnHalfOpen: "passthrough",
|
||||
MetricsEnabled: true,
|
||||
LogStateChanges: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeWithDefaults merges a partial configuration with defaults
|
||||
func MergeWithDefaults(partial *UnifiedConfig) *UnifiedConfig {
|
||||
if partial == nil {
|
||||
return NewUnifiedConfig()
|
||||
}
|
||||
|
||||
// Ensure Legacy field is initialized
|
||||
if partial.Legacy == nil {
|
||||
partial.Legacy = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// TODO: Implement deep merge logic with defaults
|
||||
// For now, just return the partial config
|
||||
return partial
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
// Package config provides configuration loading and merging logic
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigLoader handles loading configuration from various sources
|
||||
type ConfigLoader struct {
|
||||
migrator *ConfigMigrator
|
||||
envPrefix string
|
||||
configPaths []string
|
||||
}
|
||||
|
||||
// NewConfigLoader creates a new configuration loader
|
||||
func NewConfigLoader() *ConfigLoader {
|
||||
return &ConfigLoader{
|
||||
migrator: NewConfigMigrator(),
|
||||
envPrefix: "TRAEFIKOIDC_",
|
||||
configPaths: getDefaultConfigPaths(),
|
||||
}
|
||||
}
|
||||
|
||||
// getDefaultConfigPaths returns default configuration file paths to check
|
||||
func getDefaultConfigPaths() []string {
|
||||
return []string{
|
||||
"traefik-oidc.yaml",
|
||||
"traefik-oidc.yml",
|
||||
"traefik-oidc.json",
|
||||
"config.yaml",
|
||||
"config.yml",
|
||||
"config.json",
|
||||
"/etc/traefik-oidc/config.yaml",
|
||||
"/etc/traefik-oidc/config.json",
|
||||
}
|
||||
}
|
||||
|
||||
// Load loads configuration from all available sources
|
||||
func (l *ConfigLoader) Load() (*UnifiedConfig, error) {
|
||||
// Start with defaults
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Try to load from file
|
||||
if fileConfig, err := l.LoadFromFile(); err == nil && fileConfig != nil {
|
||||
config = l.mergeConfigs(config, fileConfig)
|
||||
}
|
||||
|
||||
// Load from environment variables
|
||||
l.LoadFromEnv(config)
|
||||
|
||||
// Validate the final configuration
|
||||
if err := config.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("configuration validation failed: %w", err)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// LoadFromFile loads configuration from a file
|
||||
func (l *ConfigLoader) LoadFromFile(paths ...string) (*UnifiedConfig, error) {
|
||||
// Use provided paths or default paths
|
||||
searchPaths := paths
|
||||
if len(searchPaths) == 0 {
|
||||
searchPaths = l.configPaths
|
||||
}
|
||||
|
||||
// Check for config file in environment variable
|
||||
if envPath := os.Getenv(l.envPrefix + "CONFIG_FILE"); envPath != "" {
|
||||
searchPaths = append([]string{envPath}, searchPaths...)
|
||||
}
|
||||
|
||||
// Try each path
|
||||
for _, path := range searchPaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
return l.loadFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
// No config file found, not an error (use defaults)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// loadFile loads a specific configuration file
|
||||
func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories (current dir or subdirs)
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
// Check if unified config is enabled
|
||||
if features.IsUnifiedConfigEnabled() {
|
||||
// Use migrator to handle any version
|
||||
config, warnings, err := l.migrator.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to migrate config from %s: %w", path, err)
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, use proper logging
|
||||
fmt.Printf("Config Warning (%s): %s\n", path, warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Legacy path: load old config and convert
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
var oldConfig Config
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
if err := json.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON config: %w", err)
|
||||
}
|
||||
case ".yaml", ".yml":
|
||||
if err := yaml.Unmarshal(data, &oldConfig); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse YAML config: %w", err)
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config file extension: %s", ext)
|
||||
}
|
||||
|
||||
return FromOldConfig(&oldConfig), nil
|
||||
}
|
||||
|
||||
// LoadFromEnv loads configuration from environment variables
|
||||
func (l *ConfigLoader) LoadFromEnv(config *UnifiedConfig) {
|
||||
// Provider configuration
|
||||
l.loadEnvString(&config.Provider.IssuerURL, "PROVIDER_ISSUER_URL", "PROVIDER_URL")
|
||||
l.loadEnvString(&config.Provider.ClientID, "PROVIDER_CLIENT_ID", "CLIENT_ID")
|
||||
l.loadEnvString(&config.Provider.ClientSecret, "PROVIDER_CLIENT_SECRET", "CLIENT_SECRET")
|
||||
l.loadEnvString(&config.Provider.RedirectURL, "PROVIDER_REDIRECT_URL", "CALLBACK_URL")
|
||||
l.loadEnvString(&config.Provider.LogoutURL, "PROVIDER_LOGOUT_URL", "LOGOUT_URL")
|
||||
l.loadEnvString(&config.Provider.PostLogoutRedirectURI, "PROVIDER_POST_LOGOUT_URI", "POST_LOGOUT_REDIRECT_URI")
|
||||
l.loadEnvStringSlice(&config.Provider.Scopes, "PROVIDER_SCOPES", "SCOPES")
|
||||
l.loadEnvBool(&config.Provider.OverrideScopes, "PROVIDER_OVERRIDE_SCOPES", "OVERRIDE_SCOPES")
|
||||
|
||||
// Session configuration
|
||||
l.loadEnvString(&config.Session.Name, "SESSION_NAME")
|
||||
l.loadEnvInt(&config.Session.MaxAge, "SESSION_MAX_AGE")
|
||||
l.loadEnvString(&config.Session.Secret, "SESSION_SECRET")
|
||||
l.loadEnvString(&config.Session.EncryptionKey, "SESSION_ENCRYPTION_KEY")
|
||||
l.loadEnvString(&config.Session.Domain, "SESSION_DOMAIN", "COOKIE_DOMAIN")
|
||||
l.loadEnvBool(&config.Session.Secure, "SESSION_SECURE")
|
||||
l.loadEnvBool(&config.Session.HttpOnly, "SESSION_HTTP_ONLY")
|
||||
l.loadEnvString(&config.Session.SameSite, "SESSION_SAME_SITE")
|
||||
|
||||
// Security configuration
|
||||
l.loadEnvBool(&config.Security.ForceHTTPS, "SECURITY_FORCE_HTTPS", "FORCE_HTTPS")
|
||||
l.loadEnvBool(&config.Security.EnablePKCE, "SECURITY_ENABLE_PKCE", "ENABLE_PKCE")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUsers, "SECURITY_ALLOWED_USERS", "ALLOWED_USERS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedUserDomains, "SECURITY_ALLOWED_DOMAINS", "ALLOWED_USER_DOMAINS")
|
||||
l.loadEnvStringSlice(&config.Security.AllowedRolesAndGroups, "SECURITY_ALLOWED_ROLES", "ALLOWED_ROLES_AND_GROUPS")
|
||||
l.loadEnvStringSlice(&config.Security.ExcludedURLs, "SECURITY_EXCLUDED_URLS", "EXCLUDED_URLS")
|
||||
|
||||
// Cache configuration
|
||||
l.loadEnvBool(&config.Cache.Enabled, "CACHE_ENABLED")
|
||||
l.loadEnvString(&config.Cache.Type, "CACHE_TYPE")
|
||||
l.loadEnvInt(&config.Cache.MaxEntries, "CACHE_MAX_ENTRIES")
|
||||
// MaxEntrySize is int64, skip for now
|
||||
|
||||
// Rate limiting
|
||||
l.loadEnvBool(&config.RateLimit.Enabled, "RATELIMIT_ENABLED")
|
||||
l.loadEnvInt(&config.RateLimit.RequestsPerSecond, "RATELIMIT_RPS", "RATE_LIMIT")
|
||||
l.loadEnvInt(&config.RateLimit.Burst, "RATELIMIT_BURST")
|
||||
|
||||
// Logging
|
||||
l.loadEnvString(&config.Logging.Level, "LOGGING_LEVEL", "LOG_LEVEL")
|
||||
l.loadEnvString(&config.Logging.Format, "LOGGING_FORMAT")
|
||||
l.loadEnvString(&config.Logging.Output, "LOGGING_OUTPUT")
|
||||
|
||||
// Redis configuration (already handled by its own LoadFromEnv)
|
||||
config.Redis.LoadFromEnv()
|
||||
|
||||
// Feature flags
|
||||
features.GetManager().LoadFromEnv()
|
||||
}
|
||||
|
||||
// Helper methods for environment variable loading
|
||||
|
||||
func (l *ConfigLoader) loadEnvString(target *string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = value
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvBool(target *bool, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = strings.ToLower(value) == "true" || value == "1"
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvInt(target *int, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
|
||||
*target = i
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ConfigLoader) loadEnvStringSlice(target *[]string, keys ...string) {
|
||||
for _, key := range keys {
|
||||
if value := os.Getenv(l.envPrefix + key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
// Try without prefix
|
||||
if value := os.Getenv(key); value != "" {
|
||||
*target = splitAndTrim(value)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitAndTrim(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// mergeConfigs merges two configurations, with source overriding target
|
||||
func (l *ConfigLoader) mergeConfigs(target, source *UnifiedConfig) *UnifiedConfig {
|
||||
if source == nil {
|
||||
return target
|
||||
}
|
||||
if target == nil {
|
||||
return source
|
||||
}
|
||||
|
||||
// Use reflection for deep merge
|
||||
l.mergeStructs(reflect.ValueOf(target).Elem(), reflect.ValueOf(source).Elem())
|
||||
|
||||
return target
|
||||
}
|
||||
|
||||
// mergeStructs recursively merges two structs
|
||||
func (l *ConfigLoader) mergeStructs(target, source reflect.Value) {
|
||||
for i := 0; i < source.NumField(); i++ {
|
||||
sourceField := source.Field(i)
|
||||
targetField := target.Field(i)
|
||||
|
||||
// Skip if source field is zero value
|
||||
if isZeroValue(sourceField) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sourceField.Kind() {
|
||||
case reflect.Struct:
|
||||
// Recursively merge structs
|
||||
l.mergeStructs(targetField, sourceField)
|
||||
case reflect.Slice:
|
||||
// Replace slice if source has values
|
||||
if sourceField.Len() > 0 {
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
case reflect.Map:
|
||||
// Merge maps
|
||||
if !sourceField.IsNil() {
|
||||
if targetField.IsNil() {
|
||||
targetField.Set(reflect.MakeMap(sourceField.Type()))
|
||||
}
|
||||
for _, key := range sourceField.MapKeys() {
|
||||
targetField.SetMapIndex(key, sourceField.MapIndex(key))
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Replace value
|
||||
targetField.Set(sourceField)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isZeroValue checks if a reflect.Value is a zero value
|
||||
func isZeroValue(v reflect.Value) bool {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
return v.IsNil()
|
||||
case reflect.Slice, reflect.Map:
|
||||
return v.IsNil() || v.Len() == 0
|
||||
case reflect.Struct:
|
||||
// Check if all fields are zero
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !isZeroValue(v.Field(i)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
default:
|
||||
zero := reflect.Zero(v.Type())
|
||||
return reflect.DeepEqual(v.Interface(), zero.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
// SaveToFile saves the configuration to a file
|
||||
func (l *ConfigLoader) SaveToFile(config *UnifiedConfig, path string) error {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(path)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return fmt.Errorf("invalid config path: potential path traversal detected in %s", path)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve absolute path for %s: %w", path, err)
|
||||
}
|
||||
|
||||
ext := strings.ToLower(filepath.Ext(absPath))
|
||||
|
||||
var data []byte
|
||||
|
||||
switch ext {
|
||||
case ".json":
|
||||
data, err = json.MarshalIndent(config, "", " ")
|
||||
case ".yaml", ".yml":
|
||||
data, err = yaml.Marshal(config)
|
||||
default:
|
||||
return fmt.Errorf("unsupported file extension: %s", ext)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// Create directory if it doesn't exist with secure permissions
|
||||
dir := filepath.Dir(absPath)
|
||||
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
// Write file with secure permissions (owner read/write only)
|
||||
if err := os.WriteFile(absPath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write config file %s: %w", absPath, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,832 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestConfigLoader tests the config loader functionality
|
||||
func TestConfigLoader(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
if loader == nil {
|
||||
t.Fatal("NewConfigLoader should not return nil")
|
||||
}
|
||||
|
||||
if loader.migrator == nil {
|
||||
t.Error("ConfigLoader should have a migrator")
|
||||
}
|
||||
|
||||
if loader.envPrefix != "TRAEFIKOIDC_" {
|
||||
t.Errorf("Expected envPrefix to be 'TRAEFIKOIDC_', got %s", loader.envPrefix)
|
||||
}
|
||||
|
||||
if len(loader.configPaths) == 0 {
|
||||
t.Error("ConfigLoader should have default config paths")
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFromEnv tests loading configuration from environment variables
|
||||
func TestLoadFromEnv(t *testing.T) {
|
||||
// Set up test environment variables
|
||||
testEnvVars := map[string]string{
|
||||
"TRAEFIKOIDC_PROVIDER_ISSUER_URL": "https://test.example.com",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_ID": "test-client-id",
|
||||
"TRAEFIKOIDC_PROVIDER_CLIENT_SECRET": "test-secret",
|
||||
"TRAEFIKOIDC_SESSION_ENCRYPTION_KEY": "32-character-encryption-key-12345",
|
||||
"TRAEFIKOIDC_SESSION_CHUNKED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ENABLED": "true",
|
||||
"TRAEFIKOIDC_REDIS_ADDR": "redis.example.com:6379",
|
||||
"TRAEFIKOIDC_SECURITY_FORCE_HTTPS": "true",
|
||||
"TRAEFIKOIDC_CACHE_ENABLED": "true",
|
||||
"TRAEFIKOIDC_CACHE_TYPE": "redis",
|
||||
"TRAEFIKOIDC_RATELIMIT_ENABLED": "true",
|
||||
"TRAEFIKOIDC_RATELIMIT_RPS": "100",
|
||||
}
|
||||
|
||||
// Set environment variables
|
||||
for key, value := range testEnvVars {
|
||||
os.Setenv(key, value)
|
||||
defer os.Unsetenv(key)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{}
|
||||
loader.LoadFromEnv(config)
|
||||
|
||||
// Verify values were loaded
|
||||
if config.Provider.IssuerURL != "https://test.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://test.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client-id" {
|
||||
t.Errorf("Expected ClientID to be 'test-client-id', got %s", config.Provider.ClientID)
|
||||
}
|
||||
if config.Provider.ClientSecret != "test-secret" {
|
||||
t.Errorf("Expected ClientSecret to be 'test-secret', got %s", config.Provider.ClientSecret)
|
||||
}
|
||||
if config.Session.EncryptionKey != "32-character-encryption-key-12345" {
|
||||
t.Errorf("Expected EncryptionKey to be set, got %s", config.Session.EncryptionKey)
|
||||
}
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS to be true")
|
||||
}
|
||||
if !config.Cache.Enabled {
|
||||
t.Error("Expected Cache to be enabled")
|
||||
}
|
||||
if config.Cache.Type != "redis" {
|
||||
t.Errorf("Expected Cache.Type to be 'redis', got %s", config.Cache.Type)
|
||||
}
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("Expected RateLimit to be enabled")
|
||||
}
|
||||
if config.RateLimit.RequestsPerSecond != 100 {
|
||||
t.Errorf("Expected RequestsPerSecond to be 100, got %d", config.RateLimit.RequestsPerSecond)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveToFile tests saving configuration to files
|
||||
func TestSaveToFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "32-character-encryption-key-12345",
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "save as JSON",
|
||||
filename: "config.json",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YAML",
|
||||
filename: "config.yaml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "save as YML",
|
||||
filename: "config.yml",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unsupported extension",
|
||||
filename: "config.txt",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/config.json",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filePath := filepath.Join(tmpDir, tt.filename)
|
||||
err := loader.SaveToFile(config, filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify file was created with correct permissions
|
||||
info, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to stat saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check file permissions (should be 0600)
|
||||
mode := info.Mode().Perm()
|
||||
if mode != 0600 {
|
||||
t.Errorf("Expected file permissions 0600, got %o", mode)
|
||||
}
|
||||
|
||||
// Verify content can be read back
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to read saved file: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify secrets are redacted
|
||||
content := string(data)
|
||||
if strings.Contains(content, "secret") && !strings.Contains(content, "[REDACTED]") {
|
||||
t.Error("Secrets should be redacted in saved file")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLoadFile tests loading configuration from files
|
||||
func TestLoadFile(t *testing.T) {
|
||||
// Create a temporary directory for test files
|
||||
tmpDir, err := os.MkdirTemp("", "config-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Test data - using old config format since unified config is not enabled by default
|
||||
jsonConfig := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
|
||||
yamlConfig := `
|
||||
providerurl: https://auth.example.com
|
||||
clientid: test-client
|
||||
clientsecret: secret
|
||||
sessionencryptionkey: 32-character-encryption-key-12345
|
||||
`
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filename string
|
||||
content string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "load JSON config",
|
||||
filename: "config.json",
|
||||
content: jsonConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "load YAML config",
|
||||
filename: "config.yaml",
|
||||
content: yamlConfig,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "path traversal attempt",
|
||||
filename: "../../../etc/passwd",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-existent file",
|
||||
filename: "does-not-exist.json",
|
||||
content: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var filePath string
|
||||
if tt.content != "" {
|
||||
filePath = filepath.Join(tmpDir, tt.filename)
|
||||
err := os.WriteFile(filePath, []byte(tt.content), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
filePath = tt.filename
|
||||
}
|
||||
|
||||
config, err := loader.loadFile(filePath)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !os.IsNotExist(err) && !strings.Contains(err.Error(), "no such file") {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Verify loaded config
|
||||
if config == nil {
|
||||
t.Error("Expected config to be loaded")
|
||||
return
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be 'https://auth.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
if config.Provider.ClientID != "test-client" {
|
||||
t.Errorf("Expected ClientID to be 'test-client', got %s", config.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================================
|
||||
// Tests for untested functions (0% coverage)
|
||||
// ====================================================================================
|
||||
|
||||
// TestConfigLoader_Load tests the full Load pipeline
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-load-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create a test config file
|
||||
configPath := filepath.Join(tmpDir, "traefik-oidc.json")
|
||||
configData := `{
|
||||
"providerURL": "https://auth.example.com",
|
||||
"clientID": "test-client",
|
||||
"clientSecret": "test-secret",
|
||||
"sessionEncryptionKey": "32-character-encryption-key-12345"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config file: %v", err)
|
||||
}
|
||||
|
||||
// Change to temp directory so loader can find the config
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
// Set some environment variables to test merging
|
||||
os.Setenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS", "true")
|
||||
defer os.Unsetenv("TRAEFIKOIDC_SECURITY_FORCE_HTTPS")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.Load()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Load() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("Load() returned nil config")
|
||||
}
|
||||
|
||||
// Verify file was loaded
|
||||
if config.Provider.IssuerURL != "https://auth.example.com" {
|
||||
t.Errorf("Expected IssuerURL from file, got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Verify env vars were loaded
|
||||
if !config.Security.ForceHTTPS {
|
||||
t.Error("Expected ForceHTTPS from env var to be true")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_LoadFromFile tests the LoadFromFile function
|
||||
func TestConfigLoader_LoadFromFile(t *testing.T) {
|
||||
t.Run("NoConfigFile", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-nofile-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
oldDir, _ := os.Getwd()
|
||||
os.Chdir(tmpDir)
|
||||
defer os.Chdir(oldDir)
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
// Should not error when no config file found
|
||||
if err != nil {
|
||||
t.Errorf("LoadFromFile() should not error when no file found: %v", err)
|
||||
}
|
||||
|
||||
// Should return nil config
|
||||
if config != nil {
|
||||
t.Error("LoadFromFile() should return nil config when no file found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadFromEnvPath", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-envpath-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "custom-config.json")
|
||||
configData := `{
|
||||
"providerURL": "https://custom.example.com",
|
||||
"clientID": "custom-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
// Set env variable pointing to config
|
||||
os.Setenv("TRAEFIKOIDC_CONFIG_FILE", configPath)
|
||||
defer os.Unsetenv("TRAEFIKOIDC_CONFIG_FILE")
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile()
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://custom.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://custom.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("LoadWithProvidedPaths", func(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "config-provided-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp directory: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create config file
|
||||
configPath := filepath.Join(tmpDir, "specific.json")
|
||||
configData := `{
|
||||
"providerURL": "https://specific.example.com",
|
||||
"clientID": "specific-client"
|
||||
}`
|
||||
err = os.WriteFile(configPath, []byte(configData), 0600)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to write test config: %v", err)
|
||||
}
|
||||
|
||||
loader := NewConfigLoader()
|
||||
config, err := loader.LoadFromFile(configPath)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("LoadFromFile() with path failed: %v", err)
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
t.Fatal("LoadFromFile() returned nil config")
|
||||
}
|
||||
|
||||
if config.Provider.IssuerURL != "https://specific.example.com" {
|
||||
t.Errorf("Expected IssuerURL 'https://specific.example.com', got %s", config.Provider.IssuerURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestSplitAndTrim tests the splitAndTrim helper function
|
||||
func TestSplitAndTrim(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Simple comma-separated",
|
||||
input: "a,b,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "With spaces",
|
||||
input: "a, b , c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Empty strings filtered out",
|
||||
input: "a,,b, ,c",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Leading and trailing spaces",
|
||||
input: " a , b , c ",
|
||||
expected: []string{"a", "b", "c"},
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "single",
|
||||
expected: []string{"single"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Only commas and spaces",
|
||||
input: " , , , ",
|
||||
expected: []string{},
|
||||
},
|
||||
{
|
||||
name: "Complex real-world example",
|
||||
input: "openid, profile, email, groups",
|
||||
expected: []string{"openid", "profile", "email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := splitAndTrim(tt.input)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d items, got %d: %v", len(tt.expected), len(result), result)
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expected {
|
||||
if result[i] != expected {
|
||||
t.Errorf("At index %d: expected %q, got %q", i, expected, result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeConfigs tests the mergeConfigs function
|
||||
func TestConfigLoader_MergeConfigs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("MergeNilSource", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, nil)
|
||||
|
||||
if result != target {
|
||||
t.Error("mergeConfigs should return target when source is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeNilTarget", func(t *testing.T) {
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(nil, source)
|
||||
|
||||
if result != source {
|
||||
t.Error("mergeConfigs should return source when target is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSimpleFields", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "",
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://source.example.com",
|
||||
ClientID: "source-client",
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if result.Provider.IssuerURL != "https://source.example.com" {
|
||||
t.Errorf("Expected IssuerURL to be overridden, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be set, got %s", result.Provider.ClientID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeSlices", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"openid", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
Scopes: []string{"email", "groups"},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Source slice should replace target slice
|
||||
if len(result.Provider.Scopes) != 2 {
|
||||
t.Errorf("Expected 2 scopes, got %d", len(result.Provider.Scopes))
|
||||
}
|
||||
|
||||
if result.Provider.Scopes[0] != "email" {
|
||||
t.Errorf("Expected first scope 'email', got %s", result.Provider.Scopes[0])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("MergeMaps", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Target-Header": "target-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Middleware: MiddlewareConfig{
|
||||
CustomHeaders: map[string]string{
|
||||
"X-Source-Header": "source-value",
|
||||
"X-Target-Header": "overridden-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
if len(result.Middleware.CustomHeaders) != 2 {
|
||||
t.Errorf("Expected 2 headers, got %d", len(result.Middleware.CustomHeaders))
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Target-Header"] != "overridden-value" {
|
||||
t.Errorf("Expected X-Target-Header to be overridden")
|
||||
}
|
||||
|
||||
if result.Middleware.CustomHeaders["X-Source-Header"] != "source-value" {
|
||||
t.Errorf("Expected X-Source-Header to be added")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestConfigLoader_MergeStructs tests the mergeStructs function indirectly
|
||||
func TestConfigLoader_MergeStructs(t *testing.T) {
|
||||
loader := NewConfigLoader()
|
||||
|
||||
t.Run("NestedStructMerge", func(t *testing.T) {
|
||||
target := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://target.example.com",
|
||||
ClientID: "target-client",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "target-session",
|
||||
MaxAge: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
source := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "source-client",
|
||||
ClientSecret: "source-secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
MaxAge: 7200,
|
||||
},
|
||||
}
|
||||
|
||||
result := loader.mergeConfigs(target, source)
|
||||
|
||||
// Provider.IssuerURL should remain (zero value in source)
|
||||
if result.Provider.IssuerURL != "https://target.example.com" {
|
||||
t.Errorf("Expected IssuerURL to remain, got %s", result.Provider.IssuerURL)
|
||||
}
|
||||
|
||||
// Provider.ClientID should be overridden
|
||||
if result.Provider.ClientID != "source-client" {
|
||||
t.Errorf("Expected ClientID to be overridden, got %s", result.Provider.ClientID)
|
||||
}
|
||||
|
||||
// Provider.ClientSecret should be added
|
||||
if result.Provider.ClientSecret != "source-secret" {
|
||||
t.Errorf("Expected ClientSecret to be added, got %s", result.Provider.ClientSecret)
|
||||
}
|
||||
|
||||
// Session.Name should remain (zero value in source)
|
||||
if result.Session.Name != "target-session" {
|
||||
t.Errorf("Expected Session.Name to remain, got %s", result.Session.Name)
|
||||
}
|
||||
|
||||
// Session.MaxAge should be overridden
|
||||
if result.Session.MaxAge != 7200 {
|
||||
t.Errorf("Expected Session.MaxAge to be overridden, got %d", result.Session.MaxAge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestIsZeroValue tests the isZeroValue helper function
|
||||
func TestIsZeroValue(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Zero string",
|
||||
value: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero string",
|
||||
value: "hello",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero int",
|
||||
value: 0,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero int",
|
||||
value: 42,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Zero bool",
|
||||
value: false,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-zero bool",
|
||||
value: true,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil pointer",
|
||||
value: (*string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-nil pointer",
|
||||
value: stringPtr("test"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil slice",
|
||||
value: ([]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty slice",
|
||||
value: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty slice",
|
||||
value: []string{"a"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Nil map",
|
||||
value: (map[string]string)(nil),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Empty map",
|
||||
value: map[string]string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Non-empty map",
|
||||
value: map[string]string{"key": "value"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
v := reflect.ValueOf(tt.value)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected isZeroValue to be %v, got %v", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsZeroValue_Struct tests isZeroValue with struct types
|
||||
func TestIsZeroValue_Struct(t *testing.T) {
|
||||
type TestStruct struct {
|
||||
Field1 string
|
||||
Field2 int
|
||||
}
|
||||
|
||||
t.Run("Zero struct", func(t *testing.T) {
|
||||
s := TestStruct{}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if !result {
|
||||
t.Error("Expected zero struct to return true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field1 set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test"}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Field2 set", func(t *testing.T) {
|
||||
s := TestStruct{Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-zero struct - Both fields set", func(t *testing.T) {
|
||||
s := TestStruct{Field1: "test", Field2: 42}
|
||||
v := reflect.ValueOf(s)
|
||||
result := isZeroValue(v)
|
||||
|
||||
if result {
|
||||
t.Error("Expected non-zero struct to return false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function for pointer tests
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
@@ -0,0 +1,169 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalJSON() ([]byte, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return json.Marshal(copy)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
func (c UnifiedConfig) MarshalYAML() (interface{}, error) {
|
||||
// Create an alias to avoid recursion
|
||||
type Alias UnifiedConfig
|
||||
|
||||
// Create a copy with redacted sensitive fields
|
||||
copy := (Alias)(c)
|
||||
|
||||
// Redact provider secrets
|
||||
if copy.Provider.ClientSecret != "" {
|
||||
copy.Provider.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
// Redact session secrets
|
||||
if copy.Session.Secret != "" {
|
||||
copy.Session.Secret = REDACTED
|
||||
}
|
||||
if copy.Session.EncryptionKey != "" {
|
||||
copy.Session.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.Session.SigningKey != "" {
|
||||
copy.Session.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
// Redact Redis passwords
|
||||
if copy.Redis.Password != "" {
|
||||
copy.Redis.Password = REDACTED
|
||||
}
|
||||
if copy.Redis.SentinelPassword != "" {
|
||||
copy.Redis.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for ProviderConfig to redact sensitive fields
|
||||
func (p ProviderConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias ProviderConfig
|
||||
copy := (Alias)(p)
|
||||
|
||||
if copy.ClientSecret != "" {
|
||||
copy.ClientSecret = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for SessionConfig to redact sensitive fields
|
||||
func (s SessionConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias SessionConfig
|
||||
copy := (Alias)(s)
|
||||
|
||||
if copy.Secret != "" {
|
||||
copy.Secret = REDACTED
|
||||
}
|
||||
if copy.EncryptionKey != "" {
|
||||
copy.EncryptionKey = REDACTED
|
||||
}
|
||||
if copy.SigningKey != "" {
|
||||
copy.SigningKey = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
type Alias RedisConfig
|
||||
copy := (Alias)(r)
|
||||
|
||||
if copy.Password != "" {
|
||||
copy.Password = REDACTED
|
||||
}
|
||||
if copy.SentinelPassword != "" {
|
||||
copy.SentinelPassword = REDACTED
|
||||
}
|
||||
|
||||
return copy, nil
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
// Package config provides configuration migration from old to new format
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/compat"
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/features"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ConfigVersion represents the version of a configuration format
|
||||
type ConfigVersion string
|
||||
|
||||
const (
|
||||
// VersionLegacy represents the original config format
|
||||
VersionLegacy ConfigVersion = "legacy"
|
||||
|
||||
// VersionUnified represents the new unified config format
|
||||
VersionUnified ConfigVersion = "unified"
|
||||
|
||||
// CurrentVersion is the current config version
|
||||
CurrentVersion ConfigVersion = VersionUnified
|
||||
)
|
||||
|
||||
// ConfigMigrator handles migration between config versions
|
||||
type ConfigMigrator struct {
|
||||
compatLayer *compat.CompatibilityLayer
|
||||
migrations map[ConfigVersion]MigrationFunc
|
||||
}
|
||||
|
||||
// MigrationFunc defines a function that migrates configuration
|
||||
type MigrationFunc func(data map[string]interface{}) (*UnifiedConfig, error)
|
||||
|
||||
// NewConfigMigrator creates a new configuration migrator
|
||||
func NewConfigMigrator() *ConfigMigrator {
|
||||
m := &ConfigMigrator{
|
||||
compatLayer: compat.GetLayer(),
|
||||
migrations: make(map[ConfigVersion]MigrationFunc),
|
||||
}
|
||||
|
||||
// Register migration functions
|
||||
m.migrations[VersionLegacy] = m.migrateLegacyToUnified
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// DetectVersion detects the version of a configuration
|
||||
func (m *ConfigMigrator) DetectVersion(data []byte) ConfigVersion {
|
||||
var testMap map[string]interface{}
|
||||
|
||||
// Try JSON first
|
||||
if err := json.Unmarshal(data, &testMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &testMap); err != nil {
|
||||
return VersionLegacy // Default to legacy if can't parse
|
||||
}
|
||||
}
|
||||
|
||||
// Check for unified config markers
|
||||
if _, hasProvider := testMap["provider"]; hasProvider {
|
||||
if _, hasSession := testMap["session"]; hasSession {
|
||||
return VersionUnified
|
||||
}
|
||||
}
|
||||
|
||||
// Check for legacy config markers
|
||||
if _, hasProviderURL := testMap["providerUrl"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
if _, hasProviderURL := testMap["ProviderURL"]; hasProviderURL {
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
return VersionLegacy
|
||||
}
|
||||
|
||||
// Migrate migrates configuration data to the current version
|
||||
func (m *ConfigMigrator) Migrate(data []byte) (*UnifiedConfig, []string, error) {
|
||||
warnings := []string{}
|
||||
|
||||
// Detect version
|
||||
version := m.DetectVersion(data)
|
||||
|
||||
// If already current version, just unmarshal
|
||||
if version == CurrentVersion {
|
||||
var config UnifiedConfig
|
||||
if err := json.Unmarshal(data, &config); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &config); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal unified config: %w", err)
|
||||
}
|
||||
}
|
||||
return &config, warnings, nil
|
||||
}
|
||||
|
||||
// Parse to generic map
|
||||
var configMap map[string]interface{}
|
||||
if err := json.Unmarshal(data, &configMap); err != nil {
|
||||
// Try YAML
|
||||
if err := yaml.Unmarshal(data, &configMap); err != nil {
|
||||
return nil, warnings, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply migration
|
||||
migrationFunc, exists := m.migrations[version]
|
||||
if !exists {
|
||||
return nil, warnings, fmt.Errorf("no migration path from version %s", version)
|
||||
}
|
||||
|
||||
config, err := migrationFunc(configMap)
|
||||
if err != nil {
|
||||
return nil, warnings, fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Collect any deprecation warnings
|
||||
for key := range configMap {
|
||||
if warning, deprecated := m.compatLayer.CheckDeprecation(key); deprecated {
|
||||
warnings = append(warnings, warning)
|
||||
}
|
||||
}
|
||||
|
||||
return config, warnings, nil
|
||||
}
|
||||
|
||||
// migrateLegacyToUnified migrates legacy config to unified format
|
||||
func (m *ConfigMigrator) migrateLegacyToUnified(data map[string]interface{}) (*UnifiedConfig, error) {
|
||||
config := NewUnifiedConfig()
|
||||
|
||||
// Use compatibility layer for field mapping
|
||||
migratedMap, warnings := m.compatLayer.MigrateMap(data)
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
// In production, these would be logged
|
||||
_ = warning
|
||||
}
|
||||
|
||||
// Map provider configuration
|
||||
if provider, ok := getNestedMap(migratedMap, "Provider"); ok {
|
||||
_ = mapToStruct(provider, &config.Provider)
|
||||
} else {
|
||||
// Direct field mapping for legacy format
|
||||
config.Provider.IssuerURL = getStringValue(data, "providerUrl", "ProviderURL")
|
||||
config.Provider.ClientID = getStringValue(data, "clientId", "ClientID")
|
||||
config.Provider.ClientSecret = getStringValue(data, "clientSecret", "ClientSecret")
|
||||
config.Provider.RedirectURL = getStringValue(data, "callbackUrl", "CallbackURL")
|
||||
config.Provider.LogoutURL = getStringValue(data, "logoutUrl", "LogoutURL")
|
||||
config.Provider.PostLogoutRedirectURI = getStringValue(data, "postLogoutRedirectUri", "PostLogoutRedirectURI")
|
||||
|
||||
if scopes := getArrayValue(data, "scopes", "Scopes"); scopes != nil {
|
||||
config.Provider.Scopes = scopes
|
||||
}
|
||||
config.Provider.OverrideScopes = getBoolValue(data, "overrideScopes", "OverrideScopes")
|
||||
}
|
||||
|
||||
// Map session configuration
|
||||
if session, ok := getNestedMap(migratedMap, "Session"); ok {
|
||||
_ = mapToStruct(session, &config.Session)
|
||||
} else {
|
||||
config.Session.EncryptionKey = getStringValue(data, "sessionEncryptionKey", "SessionEncryptionKey")
|
||||
config.Session.Domain = getStringValue(data, "cookieDomain", "CookieDomain")
|
||||
}
|
||||
|
||||
// Map security configuration
|
||||
if security, ok := getNestedMap(migratedMap, "Security"); ok {
|
||||
_ = mapToStruct(security, &config.Security)
|
||||
} else {
|
||||
config.Security.ForceHTTPS = getBoolValue(data, "forceHttps", "ForceHTTPS")
|
||||
config.Security.EnablePKCE = getBoolValue(data, "enablePkce", "EnablePKCE")
|
||||
|
||||
if users := getArrayValue(data, "allowedUsers", "AllowedUsers"); users != nil {
|
||||
config.Security.AllowedUsers = users
|
||||
}
|
||||
if domains := getArrayValue(data, "allowedUserDomains", "AllowedUserDomains"); domains != nil {
|
||||
config.Security.AllowedUserDomains = domains
|
||||
}
|
||||
if roles := getArrayValue(data, "allowedRolesAndGroups", "AllowedRolesAndGroups"); roles != nil {
|
||||
config.Security.AllowedRolesAndGroups = roles
|
||||
}
|
||||
if excluded := getArrayValue(data, "excludedUrls", "ExcludedURLs"); excluded != nil {
|
||||
config.Security.ExcludedURLs = excluded
|
||||
}
|
||||
|
||||
// Handle security headers
|
||||
if headers := data["securityHeaders"]; headers != nil {
|
||||
// Security headers might be in old format
|
||||
_ = mapToStruct(headers, &config.Security.Headers)
|
||||
}
|
||||
}
|
||||
|
||||
// Map rate limiting
|
||||
if rateLimit := getIntValue(data, "rateLimit", "RateLimit"); rateLimit > 0 {
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = rateLimit
|
||||
config.RateLimit.Burst = rateLimit * 2 // Default burst to 2x rate
|
||||
}
|
||||
|
||||
// Map token configuration
|
||||
if refreshGrace := getIntValue(data, "refreshGracePeriodSeconds", "RefreshGracePeriodSeconds"); refreshGrace > 0 {
|
||||
config.Token.RefreshGracePeriod = time.Duration(refreshGrace) * time.Second
|
||||
}
|
||||
|
||||
// Map logging
|
||||
config.Logging.Level = strings.ToLower(getStringValue(data, "logLevel", "LogLevel"))
|
||||
if config.Logging.Level == "" {
|
||||
config.Logging.Level = "info"
|
||||
}
|
||||
|
||||
// Map custom headers
|
||||
if headers := data["headers"]; headers != nil {
|
||||
if headerList, ok := headers.([]interface{}); ok {
|
||||
config.Middleware.CustomHeaders = make(map[string]string)
|
||||
for _, h := range headerList {
|
||||
if headerMap, ok := h.(map[string]interface{}); ok {
|
||||
name := getStringFromInterface(headerMap["name"])
|
||||
value := getStringFromInterface(headerMap["value"])
|
||||
if name != "" {
|
||||
config.Middleware.CustomHeaders[name] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store original data for reference
|
||||
config.Legacy = data
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// MigrateFile migrates a configuration file
|
||||
func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
// Clean and validate path to prevent traversal attacks
|
||||
cleanPath := filepath.Clean(filePath)
|
||||
|
||||
// Check for path traversal attempts
|
||||
if strings.Contains(cleanPath, "..") {
|
||||
return nil, fmt.Errorf("invalid config path: potential path traversal detected in %s", filePath)
|
||||
}
|
||||
|
||||
// Ensure the path is within expected directories
|
||||
absPath, err := filepath.Abs(cleanPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve absolute path for %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
config, warnings, err := m.Migrate(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Log warnings
|
||||
for _, warning := range warnings {
|
||||
fmt.Printf("Migration Warning: %s\n", warning)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// AutoMigrate automatically migrates config based on feature flags
|
||||
func AutoMigrate(data interface{}) (*UnifiedConfig, error) {
|
||||
if !features.IsUnifiedConfigEnabled() {
|
||||
// Feature not enabled, return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
migrator := NewConfigMigrator()
|
||||
|
||||
// Handle different input types
|
||||
switch v := data.(type) {
|
||||
case []byte:
|
||||
config, _, err := migrator.Migrate(v)
|
||||
return config, err
|
||||
case string:
|
||||
config, _, err := migrator.Migrate([]byte(v))
|
||||
return config, err
|
||||
case *Config:
|
||||
// Convert old config to unified
|
||||
return FromOldConfig(v), nil
|
||||
case *UnifiedConfig:
|
||||
// Already unified
|
||||
return v, nil
|
||||
case map[string]interface{}:
|
||||
// Convert map to JSON then migrate
|
||||
jsonData, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config, _, err := migrator.Migrate(jsonData)
|
||||
return config, err
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported config type: %T", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func getNestedMap(m map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||
if val, exists := m[key]; exists {
|
||||
if mapped, ok := val.(map[string]interface{}); ok {
|
||||
return mapped, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func getStringValue(m map[string]interface{}, keys ...string) string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
return getStringFromInterface(val)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getStringFromInterface(val interface{}) string {
|
||||
if val == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func getBoolValue(m map[string]interface{}, keys ...string) bool {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if b, ok := val.(bool); ok {
|
||||
return b
|
||||
}
|
||||
// Try string conversion
|
||||
if s, ok := val.(string); ok {
|
||||
return strings.ToLower(s) == "true"
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getIntValue(m map[string]interface{}, keys ...string) int {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
switch v := val.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case string:
|
||||
// Try to parse
|
||||
var i int
|
||||
if _, err := fmt.Sscanf(v, "%d", &i); err != nil {
|
||||
// If parsing fails, return default
|
||||
return 0
|
||||
}
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func getArrayValue(m map[string]interface{}, keys ...string) []string {
|
||||
for _, key := range keys {
|
||||
if val, exists := m[key]; exists {
|
||||
if arr, ok := val.([]interface{}); ok {
|
||||
result := make([]string, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
result = append(result, getStringFromInterface(item))
|
||||
}
|
||||
return result
|
||||
}
|
||||
if strArr, ok := val.([]string); ok {
|
||||
return strArr
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func mapToStruct(m interface{}, target interface{}) error {
|
||||
// Simple mapping using JSON as intermediate
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(data, target)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,297 @@
|
||||
// Package config provides configuration structures for the Traefik OIDC plugin.
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedisMode represents the Redis deployment mode
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
// RedisModeStandalone represents a single Redis instance
|
||||
RedisModeStandalone RedisMode = "standalone"
|
||||
|
||||
// RedisModeCluster represents Redis cluster mode
|
||||
RedisModeCluster RedisMode = "cluster"
|
||||
|
||||
// RedisModeSentinel represents Redis sentinel mode
|
||||
RedisModeSentinel RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
// RedisConfig holds Redis cache backend configuration
|
||||
type RedisConfig struct {
|
||||
// Enabled indicates if Redis backend should be used
|
||||
Enabled bool `json:"enabled,omitempty" yaml:"enabled,omitempty"`
|
||||
|
||||
// Mode specifies the Redis deployment mode
|
||||
Mode RedisMode `json:"mode,omitempty" yaml:"mode,omitempty"`
|
||||
|
||||
// === Standalone Configuration ===
|
||||
// Addr is the Redis server address (host:port)
|
||||
Addr string `json:"addr,omitempty" yaml:"addr,omitempty"`
|
||||
|
||||
// Password for Redis authentication
|
||||
Password string `json:"password,omitempty" yaml:"password,omitempty"`
|
||||
|
||||
// DB is the database number (0-15)
|
||||
DB int `json:"db,omitempty" yaml:"db,omitempty"`
|
||||
|
||||
// === Cluster Configuration ===
|
||||
// ClusterAddrs is the list of cluster node addresses
|
||||
ClusterAddrs []string `json:"clusterAddrs,omitempty" yaml:"clusterAddrs,omitempty"`
|
||||
|
||||
// === Sentinel Configuration ===
|
||||
// MasterName is the name of the master instance
|
||||
MasterName string `json:"masterName,omitempty" yaml:"masterName,omitempty"`
|
||||
|
||||
// SentinelAddrs is the list of sentinel addresses
|
||||
SentinelAddrs []string `json:"sentinelAddrs,omitempty" yaml:"sentinelAddrs,omitempty"`
|
||||
|
||||
// SentinelPassword is the password for sentinel authentication
|
||||
SentinelPassword string `json:"sentinelPassword,omitempty" yaml:"sentinelPassword,omitempty"`
|
||||
|
||||
// === Connection Pool Settings ===
|
||||
// PoolSize is the maximum number of socket connections
|
||||
PoolSize int `json:"poolSize,omitempty" yaml:"poolSize,omitempty"`
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections
|
||||
MinIdleConns int `json:"minIdleConns,omitempty" yaml:"minIdleConns,omitempty"`
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up
|
||||
MaxRetries int `json:"maxRetries,omitempty" yaml:"maxRetries,omitempty"`
|
||||
|
||||
// === Timeouts ===
|
||||
// DialTimeout is the timeout for establishing new connections
|
||||
DialTimeout time.Duration `json:"dialTimeout,omitempty" yaml:"dialTimeout,omitempty"`
|
||||
|
||||
// ReadTimeout is the timeout for socket reads
|
||||
ReadTimeout time.Duration `json:"readTimeout,omitempty" yaml:"readTimeout,omitempty"`
|
||||
|
||||
// WriteTimeout is the timeout for socket writes
|
||||
WriteTimeout time.Duration `json:"writeTimeout,omitempty" yaml:"writeTimeout,omitempty"`
|
||||
|
||||
// PoolTimeout is the timeout for connection pool
|
||||
PoolTimeout time.Duration `json:"poolTimeout,omitempty" yaml:"poolTimeout,omitempty"`
|
||||
|
||||
// ConnMaxIdleTime is the maximum amount of time a connection may be idle
|
||||
ConnMaxIdleTime time.Duration `json:"connMaxIdleTime,omitempty" yaml:"connMaxIdleTime,omitempty"`
|
||||
|
||||
// ConnMaxLifetime is the maximum lifetime of a connection
|
||||
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty" yaml:"connMaxLifetime,omitempty"`
|
||||
|
||||
// === Key Management ===
|
||||
// KeyPrefix is the prefix for all Redis keys
|
||||
KeyPrefix string `json:"keyPrefix,omitempty" yaml:"keyPrefix,omitempty"`
|
||||
|
||||
// === TLS Configuration ===
|
||||
// TLSEnabled enables TLS for Redis connections
|
||||
TLSEnabled bool `json:"tlsEnabled,omitempty" yaml:"tlsEnabled,omitempty"`
|
||||
|
||||
// TLSInsecureSkipVerify skips TLS certificate verification
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify,omitempty" yaml:"tlsInsecureSkipVerify,omitempty"`
|
||||
|
||||
// === Resilience Settings ===
|
||||
// EnableCircuitBreaker enables circuit breaker for Redis operations
|
||||
EnableCircuitBreaker bool `json:"enableCircuitBreaker,omitempty" yaml:"enableCircuitBreaker,omitempty"`
|
||||
|
||||
// CircuitBreakerMaxFailures is the number of failures before opening circuit
|
||||
CircuitBreakerMaxFailures int `json:"circuitBreakerMaxFailures,omitempty" yaml:"circuitBreakerMaxFailures,omitempty"`
|
||||
|
||||
// CircuitBreakerTimeout is how long the circuit stays open
|
||||
CircuitBreakerTimeout time.Duration `json:"circuitBreakerTimeout,omitempty" yaml:"circuitBreakerTimeout,omitempty"`
|
||||
|
||||
// EnableHealthCheck enables periodic health checks
|
||||
EnableHealthCheck bool `json:"enableHealthCheck,omitempty" yaml:"enableHealthCheck,omitempty"`
|
||||
|
||||
// HealthCheckInterval is how often to check Redis health
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval,omitempty" yaml:"healthCheckInterval,omitempty"`
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns default Redis configuration
|
||||
func DefaultRedisConfig() *RedisConfig {
|
||||
return &RedisConfig{
|
||||
Enabled: false,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
DB: 0,
|
||||
PoolSize: 10,
|
||||
MinIdleConns: 2,
|
||||
MaxRetries: 3,
|
||||
DialTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
PoolTimeout: 4 * time.Second,
|
||||
ConnMaxIdleTime: 5 * time.Minute,
|
||||
ConnMaxLifetime: 30 * time.Minute,
|
||||
KeyPrefix: "traefikoidc:",
|
||||
TLSEnabled: false,
|
||||
TLSInsecureSkipVerify: false,
|
||||
EnableCircuitBreaker: true,
|
||||
CircuitBreakerMaxFailures: 5,
|
||||
CircuitBreakerTimeout: 30 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads Redis configuration from environment variables
|
||||
func (c *RedisConfig) LoadFromEnv() {
|
||||
// Enable Redis if environment variable is set
|
||||
if enabled := os.Getenv("REDIS_ENABLED"); enabled != "" {
|
||||
c.Enabled = strings.ToLower(enabled) == "true"
|
||||
}
|
||||
|
||||
// Mode
|
||||
if mode := os.Getenv("REDIS_MODE"); mode != "" {
|
||||
c.Mode = RedisMode(strings.ToLower(mode))
|
||||
}
|
||||
|
||||
// Standalone configuration
|
||||
if addr := os.Getenv("REDIS_ADDR"); addr != "" {
|
||||
c.Addr = addr
|
||||
}
|
||||
if password := os.Getenv("REDIS_PASSWORD"); password != "" {
|
||||
c.Password = password
|
||||
}
|
||||
if db := os.Getenv("REDIS_DB"); db != "" {
|
||||
if dbNum, err := strconv.Atoi(db); err == nil {
|
||||
c.DB = dbNum
|
||||
}
|
||||
}
|
||||
|
||||
// Cluster configuration
|
||||
if clusterAddrs := os.Getenv("REDIS_CLUSTER_ADDRS"); clusterAddrs != "" {
|
||||
c.ClusterAddrs = strings.Split(clusterAddrs, ",")
|
||||
for i := range c.ClusterAddrs {
|
||||
c.ClusterAddrs[i] = strings.TrimSpace(c.ClusterAddrs[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Sentinel configuration
|
||||
if masterName := os.Getenv("REDIS_MASTER_NAME"); masterName != "" {
|
||||
c.MasterName = masterName
|
||||
}
|
||||
if sentinelAddrs := os.Getenv("REDIS_SENTINEL_ADDRS"); sentinelAddrs != "" {
|
||||
c.SentinelAddrs = strings.Split(sentinelAddrs, ",")
|
||||
for i := range c.SentinelAddrs {
|
||||
c.SentinelAddrs[i] = strings.TrimSpace(c.SentinelAddrs[i])
|
||||
}
|
||||
}
|
||||
if sentinelPassword := os.Getenv("REDIS_SENTINEL_PASSWORD"); sentinelPassword != "" {
|
||||
c.SentinelPassword = sentinelPassword
|
||||
}
|
||||
|
||||
// Connection pool settings
|
||||
if poolSize := os.Getenv("REDIS_POOL_SIZE"); poolSize != "" {
|
||||
if size, err := strconv.Atoi(poolSize); err == nil {
|
||||
c.PoolSize = size
|
||||
}
|
||||
}
|
||||
if minIdleConns := os.Getenv("REDIS_MIN_IDLE_CONNS"); minIdleConns != "" {
|
||||
if conns, err := strconv.Atoi(minIdleConns); err == nil {
|
||||
c.MinIdleConns = conns
|
||||
}
|
||||
}
|
||||
if maxRetries := os.Getenv("REDIS_MAX_RETRIES"); maxRetries != "" {
|
||||
if retries, err := strconv.Atoi(maxRetries); err == nil {
|
||||
c.MaxRetries = retries
|
||||
}
|
||||
}
|
||||
|
||||
// Timeouts
|
||||
if dialTimeout := os.Getenv("REDIS_DIAL_TIMEOUT"); dialTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(dialTimeout); err == nil {
|
||||
c.DialTimeout = timeout
|
||||
}
|
||||
}
|
||||
if readTimeout := os.Getenv("REDIS_READ_TIMEOUT"); readTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(readTimeout); err == nil {
|
||||
c.ReadTimeout = timeout
|
||||
}
|
||||
}
|
||||
if writeTimeout := os.Getenv("REDIS_WRITE_TIMEOUT"); writeTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(writeTimeout); err == nil {
|
||||
c.WriteTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
// Key prefix
|
||||
if keyPrefix := os.Getenv("REDIS_KEY_PREFIX"); keyPrefix != "" {
|
||||
c.KeyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// TLS settings
|
||||
if tlsEnabled := os.Getenv("REDIS_TLS_ENABLED"); tlsEnabled != "" {
|
||||
c.TLSEnabled = strings.ToLower(tlsEnabled) == "true"
|
||||
}
|
||||
if tlsInsecure := os.Getenv("REDIS_TLS_INSECURE_SKIP_VERIFY"); tlsInsecure != "" {
|
||||
c.TLSInsecureSkipVerify = strings.ToLower(tlsInsecure) == "true"
|
||||
}
|
||||
|
||||
// Resilience settings
|
||||
if enableCB := os.Getenv("REDIS_ENABLE_CIRCUIT_BREAKER"); enableCB != "" {
|
||||
c.EnableCircuitBreaker = strings.ToLower(enableCB) == "true"
|
||||
}
|
||||
if cbMaxFailures := os.Getenv("REDIS_CIRCUIT_BREAKER_MAX_FAILURES"); cbMaxFailures != "" {
|
||||
if failures, err := strconv.Atoi(cbMaxFailures); err == nil {
|
||||
c.CircuitBreakerMaxFailures = failures
|
||||
}
|
||||
}
|
||||
if cbTimeout := os.Getenv("REDIS_CIRCUIT_BREAKER_TIMEOUT"); cbTimeout != "" {
|
||||
if timeout, err := time.ParseDuration(cbTimeout); err == nil {
|
||||
c.CircuitBreakerTimeout = timeout
|
||||
}
|
||||
}
|
||||
if enableHC := os.Getenv("REDIS_ENABLE_HEALTH_CHECK"); enableHC != "" {
|
||||
c.EnableHealthCheck = strings.ToLower(enableHC) == "true"
|
||||
}
|
||||
if hcInterval := os.Getenv("REDIS_HEALTH_CHECK_INTERVAL"); hcInterval != "" {
|
||||
if interval, err := time.ParseDuration(hcInterval); err == nil {
|
||||
c.HealthCheckInterval = interval
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid
|
||||
func (c *RedisConfig) Validate() error {
|
||||
if !c.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case RedisModeStandalone:
|
||||
if c.Addr == "" {
|
||||
return &ConfigError{Field: "addr", Message: "Redis address is required for standalone mode"}
|
||||
}
|
||||
case RedisModeCluster:
|
||||
if len(c.ClusterAddrs) == 0 {
|
||||
return &ConfigError{Field: "clusterAddrs", Message: "At least one cluster address is required"}
|
||||
}
|
||||
case RedisModeSentinel:
|
||||
if c.MasterName == "" {
|
||||
return &ConfigError{Field: "masterName", Message: "Master name is required for sentinel mode"}
|
||||
}
|
||||
if len(c.SentinelAddrs) == 0 {
|
||||
return &ConfigError{Field: "sentinelAddrs", Message: "At least one sentinel address is required"}
|
||||
}
|
||||
default:
|
||||
return &ConfigError{Field: "mode", Message: "Invalid Redis mode"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigError represents a configuration validation error
|
||||
type ConfigError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ConfigError) Error() string {
|
||||
return "redis config error: " + e.Field + ": " + e.Message
|
||||
}
|
||||
@@ -69,6 +69,89 @@ type Config struct {
|
||||
HTTPClient *http.Client `json:"-"`
|
||||
CookieDomain string `json:"cookieDomain"`
|
||||
SecurityHeaders *SecurityHeadersConfig `json:"securityHeaders,omitempty"`
|
||||
|
||||
// Dynamic Client Registration (RFC 7591) configuration
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrationConfig configures OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrationConfig struct {
|
||||
// Enabled enables automatic client registration with the OIDC provider
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// InitialAccessToken is an optional bearer token for protected registration endpoints
|
||||
// Some providers require this token to authorize new client registrations
|
||||
InitialAccessToken string `json:"initialAccessToken,omitempty"`
|
||||
|
||||
// RegistrationEndpoint overrides the endpoint discovered from provider metadata
|
||||
// If empty, uses the registration_endpoint from .well-known/openid-configuration
|
||||
RegistrationEndpoint string `json:"registrationEndpoint,omitempty"`
|
||||
|
||||
// ClientMetadata contains the client metadata to register
|
||||
ClientMetadata *ClientRegistrationMetadata `json:"clientMetadata,omitempty"`
|
||||
|
||||
// PersistCredentials determines whether to save registered credentials to a file
|
||||
// This allows reusing the same client_id/client_secret across restarts
|
||||
PersistCredentials bool `json:"persistCredentials"`
|
||||
|
||||
// CredentialsFile is the path to store/load registered client credentials
|
||||
// Defaults to "/tmp/oidc-client-credentials.json" if not specified
|
||||
CredentialsFile string `json:"credentialsFile,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationMetadata contains client metadata for dynamic registration (RFC 7591)
|
||||
type ClientRegistrationMetadata struct {
|
||||
// RedirectURIs is REQUIRED - array of redirect URIs for authorization
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
|
||||
// ResponseTypes specifies OAuth 2.0 response types (default: ["code"])
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
|
||||
// GrantTypes specifies OAuth 2.0 grant types (default: ["authorization_code"])
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
|
||||
// ApplicationType is either "web" (default) or "native"
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
|
||||
// Contacts is an array of email addresses for responsible parties
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
|
||||
// ClientName is a human-readable name for the client
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
|
||||
// LogoURI is a URL pointing to a logo for the client
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
|
||||
// ClientURI is a URL of the home page of the client
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
|
||||
// PolicyURI is a URL pointing to the client's privacy policy
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
|
||||
// TOSURI is a URL pointing to the client's terms of service
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
|
||||
// JWKSURI is a URL for the client's JSON Web Key Set
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
|
||||
// SubjectType is "pairwise" or "public" (provider-specific)
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
|
||||
// TokenEndpointAuthMethod specifies how the client authenticates at token endpoint
|
||||
// Values: "client_secret_basic", "client_secret_post", "client_secret_jwt", "private_key_jwt", "none"
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
|
||||
// DefaultMaxAge is the default maximum authentication age in seconds
|
||||
DefaultMaxAge int `json:"default_max_age,omitempty"`
|
||||
|
||||
// RequireAuthTime specifies whether auth_time claim is required in ID token
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
|
||||
// DefaultACRValues specifies default ACR values
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
|
||||
// Scope is a space-separated list of scopes (alternative to config.Scopes)
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// HeaderConfig represents header template configuration
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
// Package config provides unified configuration management for the OIDC middleware
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UnifiedConfig is the master configuration structure consolidating all config aspects
|
||||
// This replaces 45 duplicate config structs across the codebase
|
||||
type UnifiedConfig struct {
|
||||
// Core Configuration
|
||||
Provider ProviderConfig `json:"provider" yaml:"provider"`
|
||||
Session SessionConfig `json:"session" yaml:"session"`
|
||||
Token TokenConfig `json:"token" yaml:"token"`
|
||||
Redis RedisConfig `json:"redis" yaml:"redis"`
|
||||
Security SecurityConfig `json:"security" yaml:"security"`
|
||||
|
||||
// Middleware Configuration
|
||||
Middleware MiddlewareConfig `json:"middleware" yaml:"middleware"`
|
||||
Cache CacheConfig `json:"cache" yaml:"cache"`
|
||||
RateLimit RateLimitConfig `json:"rateLimit" yaml:"rateLimit"`
|
||||
|
||||
// Operational Configuration
|
||||
Logging LoggingConfig `json:"logging" yaml:"logging"`
|
||||
Metrics MetricsConfig `json:"metrics" yaml:"metrics"`
|
||||
Health HealthConfig `json:"health" yaml:"health"`
|
||||
|
||||
// Advanced Configuration
|
||||
Transport TransportConfig `json:"transport" yaml:"transport"`
|
||||
Pool PoolConfig `json:"pool" yaml:"pool"`
|
||||
Circuit CircuitConfig `json:"circuit" yaml:"circuit"`
|
||||
|
||||
// Compatibility field for migration
|
||||
Legacy map[string]interface{} `json:"-" yaml:"-"`
|
||||
}
|
||||
|
||||
// ProviderConfig contains OIDC provider settings
|
||||
type ProviderConfig struct {
|
||||
IssuerURL string `json:"issuerURL" yaml:"issuerURL"`
|
||||
ClientID string `json:"clientID" yaml:"clientID"`
|
||||
ClientSecret string `json:"clientSecret" yaml:"clientSecret"`
|
||||
RedirectURL string `json:"redirectURL" yaml:"redirectURL"`
|
||||
LogoutURL string `json:"logoutURL" yaml:"logoutURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI" yaml:"postLogoutRedirectURI"`
|
||||
Scopes []string `json:"scopes" yaml:"scopes"`
|
||||
OverrideScopes bool `json:"overrideScopes" yaml:"overrideScopes"`
|
||||
CustomClaims map[string]string `json:"customClaims" yaml:"customClaims"`
|
||||
JWKCachePeriod time.Duration `json:"jwkCachePeriod" yaml:"jwkCachePeriod"`
|
||||
MetadataCacheTTL time.Duration `json:"metadataCacheTTL" yaml:"metadataCacheTTL"`
|
||||
Discovery bool `json:"discovery" yaml:"discovery"`
|
||||
|
||||
// Provider-specific endpoints
|
||||
AuthorizationEndpoint string `json:"authorizationEndpoint,omitempty" yaml:"authorizationEndpoint,omitempty"`
|
||||
TokenEndpoint string `json:"tokenEndpoint,omitempty" yaml:"tokenEndpoint,omitempty"`
|
||||
UserInfoEndpoint string `json:"userInfoEndpoint,omitempty" yaml:"userInfoEndpoint,omitempty"`
|
||||
JWKSEndpoint string `json:"jwksEndpoint,omitempty" yaml:"jwksEndpoint,omitempty"`
|
||||
IntrospectEndpoint string `json:"introspectEndpoint,omitempty" yaml:"introspectEndpoint,omitempty"`
|
||||
RevocationEndpoint string `json:"revocationEndpoint,omitempty" yaml:"revocationEndpoint,omitempty"`
|
||||
}
|
||||
|
||||
// SessionConfig contains session management settings
|
||||
type SessionConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
MaxAge int `json:"maxAge" yaml:"maxAge"`
|
||||
Secret string `json:"secret" yaml:"secret"`
|
||||
EncryptionKey string `json:"encryptionKey" yaml:"encryptionKey"`
|
||||
SigningKey string `json:"signingKey" yaml:"signingKey"`
|
||||
ChunkSize int `json:"chunkSize" yaml:"chunkSize"`
|
||||
MaxChunks int `json:"maxChunks" yaml:"maxChunks"`
|
||||
|
||||
// Cookie settings
|
||||
Domain string `json:"domain" yaml:"domain"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
Secure bool `json:"secure" yaml:"secure"`
|
||||
HttpOnly bool `json:"httpOnly" yaml:"httpOnly"`
|
||||
SameSite string `json:"sameSite" yaml:"sameSite"`
|
||||
CookiePrefix string `json:"cookiePrefix" yaml:"cookiePrefix"` // Prefix for cookie names (e.g., "_oidc_myapp_")
|
||||
|
||||
// Storage settings
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis", "cookie"
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
}
|
||||
|
||||
// TokenConfig contains token handling settings
|
||||
type TokenConfig struct {
|
||||
AccessTokenTTL time.Duration `json:"accessTokenTTL" yaml:"accessTokenTTL"`
|
||||
RefreshTokenTTL time.Duration `json:"refreshTokenTTL" yaml:"refreshTokenTTL"`
|
||||
RefreshGracePeriod time.Duration `json:"refreshGracePeriod" yaml:"refreshGracePeriod"`
|
||||
ValidationMode string `json:"validationMode" yaml:"validationMode"` // "jwt", "introspect", "hybrid"
|
||||
IntrospectURL string `json:"introspectURL" yaml:"introspectURL"`
|
||||
|
||||
// Token caching
|
||||
CacheEnabled bool `json:"cacheEnabled" yaml:"cacheEnabled"`
|
||||
CacheTTL time.Duration `json:"cacheTTL" yaml:"cacheTTL"`
|
||||
CacheNegativeTTL time.Duration `json:"cacheNegativeTTL" yaml:"cacheNegativeTTL"`
|
||||
|
||||
// Token validation
|
||||
ValidateSignature bool `json:"validateSignature" yaml:"validateSignature"`
|
||||
ValidateExpiry bool `json:"validateExpiry" yaml:"validateExpiry"`
|
||||
ValidateAudience bool `json:"validateAudience" yaml:"validateAudience"`
|
||||
ValidateIssuer bool `json:"validateIssuer" yaml:"validateIssuer"`
|
||||
RequiredClaims []string `json:"requiredClaims" yaml:"requiredClaims"`
|
||||
ClockSkew time.Duration `json:"clockSkew" yaml:"clockSkew"`
|
||||
}
|
||||
|
||||
// SecurityConfig contains security-related settings
|
||||
type SecurityConfig struct {
|
||||
ForceHTTPS bool `json:"forceHTTPS" yaml:"forceHTTPS"`
|
||||
EnablePKCE bool `json:"enablePKCE" yaml:"enablePKCE"`
|
||||
AllowedUsers []string `json:"allowedUsers" yaml:"allowedUsers"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains" yaml:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups" yaml:"allowedRolesAndGroups"`
|
||||
ExcludedURLs []string `json:"excludedURLs" yaml:"excludedURLs"`
|
||||
Headers *SecurityHeadersConfig `json:"headers" yaml:"headers"`
|
||||
|
||||
// CSRF protection
|
||||
CSRFProtection bool `json:"csrfProtection" yaml:"csrfProtection"`
|
||||
CSRFTokenName string `json:"csrfTokenName" yaml:"csrfTokenName"`
|
||||
CSRFTokenTTL time.Duration `json:"csrfTokenTTL" yaml:"csrfTokenTTL"`
|
||||
|
||||
// Additional security
|
||||
MaxLoginAttempts int `json:"maxLoginAttempts" yaml:"maxLoginAttempts"`
|
||||
LockoutDuration time.Duration `json:"lockoutDuration" yaml:"lockoutDuration"`
|
||||
RequireMFA bool `json:"requireMFA" yaml:"requireMFA"`
|
||||
}
|
||||
|
||||
// MiddlewareConfig contains middleware-specific settings
|
||||
type MiddlewareConfig struct {
|
||||
Priority int `json:"priority" yaml:"priority"`
|
||||
SkipPaths []string `json:"skipPaths" yaml:"skipPaths"`
|
||||
RequirePaths []string `json:"requirePaths" yaml:"requirePaths"`
|
||||
PassthroughMode bool `json:"passthroughMode" yaml:"passthroughMode"`
|
||||
|
||||
// Request handling
|
||||
MaxRequestSize int64 `json:"maxRequestSize" yaml:"maxRequestSize"`
|
||||
RequestTimeout time.Duration `json:"requestTimeout" yaml:"requestTimeout"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
|
||||
// Response handling
|
||||
CustomHeaders map[string]string `json:"customHeaders" yaml:"customHeaders"`
|
||||
RemoveHeaders []string `json:"removeHeaders" yaml:"removeHeaders"`
|
||||
}
|
||||
|
||||
// CacheConfig contains cache configuration
|
||||
type CacheConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Type string `json:"type" yaml:"type"` // "memory", "redis", "hybrid"
|
||||
DefaultTTL time.Duration `json:"defaultTTL" yaml:"defaultTTL"`
|
||||
MaxEntries int `json:"maxEntries" yaml:"maxEntries"`
|
||||
MaxEntrySize int64 `json:"maxEntrySize" yaml:"maxEntrySize"`
|
||||
EvictionPolicy string `json:"evictionPolicy" yaml:"evictionPolicy"` // "lru", "lfu", "fifo"
|
||||
|
||||
// Memory cache settings
|
||||
CleanupInterval time.Duration `json:"cleanupInterval" yaml:"cleanupInterval"`
|
||||
|
||||
// Distributed cache settings
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Compression bool `json:"compression" yaml:"compression"`
|
||||
Serialization string `json:"serialization" yaml:"serialization"` // "json", "msgpack", "protobuf"
|
||||
}
|
||||
|
||||
// RateLimitConfig contains rate limiting configuration
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
RequestsPerSecond int `json:"requestsPerSecond" yaml:"requestsPerSecond"`
|
||||
Burst int `json:"burst" yaml:"burst"`
|
||||
|
||||
// Rate limit storage
|
||||
StorageType string `json:"storageType" yaml:"storageType"` // "memory", "redis"
|
||||
WindowDuration time.Duration `json:"windowDuration" yaml:"windowDuration"`
|
||||
|
||||
// Rate limit keys
|
||||
KeyType string `json:"keyType" yaml:"keyType"` // "ip", "user", "token", "custom"
|
||||
CustomKeyFunc string `json:"customKeyFunc" yaml:"customKeyFunc"`
|
||||
|
||||
// Whitelisting
|
||||
WhitelistIPs []string `json:"whitelistIPs" yaml:"whitelistIPs"`
|
||||
WhitelistUsers []string `json:"whitelistUsers" yaml:"whitelistUsers"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `json:"level" yaml:"level"` // "debug", "info", "warn", "error"
|
||||
Format string `json:"format" yaml:"format"` // "json", "text", "structured"
|
||||
Output string `json:"output" yaml:"output"` // "stdout", "stderr", "file"
|
||||
FilePath string `json:"filePath" yaml:"filePath"`
|
||||
|
||||
// Log filtering
|
||||
FilterSensitive bool `json:"filterSensitive" yaml:"filterSensitive"`
|
||||
MaskFields []string `json:"maskFields" yaml:"maskFields"`
|
||||
|
||||
// Performance
|
||||
BufferSize int `json:"bufferSize" yaml:"bufferSize"`
|
||||
FlushInterval time.Duration `json:"flushInterval" yaml:"flushInterval"`
|
||||
|
||||
// Audit logging
|
||||
AuditEnabled bool `json:"auditEnabled" yaml:"auditEnabled"`
|
||||
AuditEvents []string `json:"auditEvents" yaml:"auditEvents"`
|
||||
}
|
||||
|
||||
// MetricsConfig contains metrics collection configuration
|
||||
type MetricsConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Provider string `json:"provider" yaml:"provider"` // "prometheus", "statsd", "otlp"
|
||||
Endpoint string `json:"endpoint" yaml:"endpoint"`
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Subsystem string `json:"subsystem" yaml:"subsystem"`
|
||||
|
||||
// Collection settings
|
||||
CollectInterval time.Duration `json:"collectInterval" yaml:"collectInterval"`
|
||||
Histograms bool `json:"histograms" yaml:"histograms"`
|
||||
|
||||
// Custom labels
|
||||
Labels map[string]string `json:"labels" yaml:"labels"`
|
||||
}
|
||||
|
||||
// HealthConfig contains health check configuration
|
||||
type HealthConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Path string `json:"path" yaml:"path"`
|
||||
CheckInterval time.Duration `json:"checkInterval" yaml:"checkInterval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
|
||||
// Checks to perform
|
||||
CheckProvider bool `json:"checkProvider" yaml:"checkProvider"`
|
||||
CheckRedis bool `json:"checkRedis" yaml:"checkRedis"`
|
||||
CheckCache bool `json:"checkCache" yaml:"checkCache"`
|
||||
|
||||
// Thresholds
|
||||
MaxLatency time.Duration `json:"maxLatency" yaml:"maxLatency"`
|
||||
MinMemory int64 `json:"minMemory" yaml:"minMemory"`
|
||||
}
|
||||
|
||||
// TransportConfig contains HTTP transport configuration
|
||||
type TransportConfig struct {
|
||||
MaxIdleConns int `json:"maxIdleConns" yaml:"maxIdleConns"`
|
||||
MaxIdleConnsPerHost int `json:"maxIdleConnsPerHost" yaml:"maxIdleConnsPerHost"`
|
||||
MaxConnsPerHost int `json:"maxConnsPerHost" yaml:"maxConnsPerHost"`
|
||||
IdleConnTimeout time.Duration `json:"idleConnTimeout" yaml:"idleConnTimeout"`
|
||||
TLSHandshakeTimeout time.Duration `json:"tlsHandshakeTimeout" yaml:"tlsHandshakeTimeout"`
|
||||
ExpectContinueTimeout time.Duration `json:"expectContinueTimeout" yaml:"expectContinueTimeout"`
|
||||
ResponseHeaderTimeout time.Duration `json:"responseHeaderTimeout" yaml:"responseHeaderTimeout"`
|
||||
DisableKeepAlives bool `json:"disableKeepAlives" yaml:"disableKeepAlives"`
|
||||
DisableCompression bool `json:"disableCompression" yaml:"disableCompression"`
|
||||
|
||||
// TLS configuration
|
||||
TLSInsecureSkipVerify bool `json:"tlsInsecureSkipVerify" yaml:"tlsInsecureSkipVerify"`
|
||||
TLSMinVersion string `json:"tlsMinVersion" yaml:"tlsMinVersion"`
|
||||
TLSCipherSuites []string `json:"tlsCipherSuites" yaml:"tlsCipherSuites"`
|
||||
|
||||
// Proxy settings
|
||||
ProxyURL string `json:"proxyURL" yaml:"proxyURL"`
|
||||
NoProxy []string `json:"noProxy" yaml:"noProxy"`
|
||||
}
|
||||
|
||||
// PoolConfig contains connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
Size int `json:"size" yaml:"size"`
|
||||
MinSize int `json:"minSize" yaml:"minSize"`
|
||||
MaxSize int `json:"maxSize" yaml:"maxSize"`
|
||||
MaxAge time.Duration `json:"maxAge" yaml:"maxAge"`
|
||||
IdleTimeout time.Duration `json:"idleTimeout" yaml:"idleTimeout"`
|
||||
WaitTimeout time.Duration `json:"waitTimeout" yaml:"waitTimeout"`
|
||||
|
||||
// Health checking
|
||||
HealthCheckInterval time.Duration `json:"healthCheckInterval" yaml:"healthCheckInterval"`
|
||||
MaxRetries int `json:"maxRetries" yaml:"maxRetries"`
|
||||
}
|
||||
|
||||
// CircuitConfig contains circuit breaker configuration
|
||||
type CircuitConfig struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
MaxRequests uint32 `json:"maxRequests" yaml:"maxRequests"`
|
||||
Interval time.Duration `json:"interval" yaml:"interval"`
|
||||
Timeout time.Duration `json:"timeout" yaml:"timeout"`
|
||||
ConsecutiveFailures uint32 `json:"consecutiveFailures" yaml:"consecutiveFailures"`
|
||||
FailureRatio float64 `json:"failureRatio" yaml:"failureRatio"`
|
||||
|
||||
// Circuit states
|
||||
OnOpen string `json:"onOpen" yaml:"onOpen"` // "reject", "fallback", "passthrough"
|
||||
OnHalfOpen string `json:"onHalfOpen" yaml:"onHalfOpen"`
|
||||
|
||||
// Monitoring
|
||||
MetricsEnabled bool `json:"metricsEnabled" yaml:"metricsEnabled"`
|
||||
LogStateChanges bool `json:"logStateChanges" yaml:"logStateChanges"`
|
||||
}
|
||||
@@ -0,0 +1,263 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// TestUnifiedConfigJSONMarshalling tests JSON marshalling with secret redaction
|
||||
func TestUnifiedConfigJSONMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Session.Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("Session.EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("Session.SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Redis.Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("Redis.SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(jsonStr, `"issuerURL":"https://auth.example.com"`) {
|
||||
t.Error("IssuerURL should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnifiedConfigYAMLMarshalling tests YAML marshalling with secret redaction
|
||||
func TestUnifiedConfigYAMLMarshalling(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to YAML
|
||||
yamlBytes, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
|
||||
// Verify secrets are redacted
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "secret: '[REDACTED]'") {
|
||||
t.Error("Session.Secret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "encryptionKey: '[REDACTED]'") {
|
||||
t.Error("Session.EncryptionKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "signingKey: '[REDACTED]'") {
|
||||
t.Error("Session.SigningKey should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "password: '[REDACTED]'") {
|
||||
t.Error("Redis.Password should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "sentinelPassword: '[REDACTED]'") {
|
||||
t.Error("Redis.SentinelPassword should be redacted in YAML output")
|
||||
}
|
||||
|
||||
// Verify non-secret fields are preserved
|
||||
if !contains(yamlStr, "issuerURL: https://auth.example.com") {
|
||||
t.Error("IssuerURL should be preserved in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProviderConfigMarshalling tests individual struct marshalling
|
||||
func TestProviderConfigMarshalling(t *testing.T) {
|
||||
provider := ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "super-secret-value",
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"clientSecret":"[REDACTED]"`) {
|
||||
t.Error("ClientSecret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"clientID":"test-client"`) {
|
||||
t.Error("ClientID should be preserved in JSON output")
|
||||
}
|
||||
|
||||
// Test YAML marshalling
|
||||
yamlBytes, err := yaml.Marshal(provider)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal ProviderConfig to YAML: %v", err)
|
||||
}
|
||||
|
||||
yamlStr := string(yamlBytes)
|
||||
if !contains(yamlStr, "clientSecret: '[REDACTED]'") {
|
||||
t.Error("ClientSecret should be redacted in YAML output")
|
||||
}
|
||||
if !contains(yamlStr, "clientID: test-client") {
|
||||
t.Error("ClientID should be preserved in YAML output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSessionConfigMarshalling tests session config marshalling
|
||||
func TestSessionConfigMarshalling(t *testing.T) {
|
||||
session := SessionConfig{
|
||||
Name: "session-cookie",
|
||||
Secret: "session-secret",
|
||||
EncryptionKey: "32-character-encryption-key-here",
|
||||
SigningKey: "signing-key-secret",
|
||||
Domain: "example.com",
|
||||
Secure: true,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal SessionConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"secret":"[REDACTED]"`) {
|
||||
t.Error("Secret should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"encryptionKey":"[REDACTED]"`) {
|
||||
t.Error("EncryptionKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"signingKey":"[REDACTED]"`) {
|
||||
t.Error("SigningKey should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"name":"session-cookie"`) {
|
||||
t.Error("Name should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"domain":"example.com"`) {
|
||||
t.Error("Domain should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRedisConfigMarshalling tests Redis config marshalling
|
||||
func TestRedisConfigMarshalling(t *testing.T) {
|
||||
redis := RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
Password: "redis-password",
|
||||
SentinelPassword: "sentinel-password",
|
||||
Addr: "localhost:6379",
|
||||
DB: 1,
|
||||
}
|
||||
|
||||
// Test JSON marshalling
|
||||
jsonBytes, err := json.Marshal(redis)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal RedisConfig to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
if !contains(jsonStr, `"password":"[REDACTED]"`) {
|
||||
t.Error("Password should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"sentinelPassword":"[REDACTED]"`) {
|
||||
t.Error("SentinelPassword should be redacted in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"addr":"localhost:6379"`) {
|
||||
t.Error("Addr should be preserved in JSON output")
|
||||
}
|
||||
if !contains(jsonStr, `"db":1`) {
|
||||
t.Error("DB should be preserved in JSON output")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEmptySecretsNotRedacted tests that empty secrets are not shown as redacted
|
||||
func TestEmptySecretsNotRedacted(t *testing.T) {
|
||||
config := &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "", // Empty secret
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Secret: "", // Empty secret
|
||||
EncryptionKey: "", // Empty secret
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Password: "", // Empty secret
|
||||
},
|
||||
}
|
||||
|
||||
// Marshal to JSON
|
||||
jsonBytes, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal config to JSON: %v", err)
|
||||
}
|
||||
|
||||
jsonStr := string(jsonBytes)
|
||||
|
||||
// Verify empty secrets are not shown as redacted
|
||||
if contains(jsonStr, "[REDACTED]") {
|
||||
t.Error("Empty secrets should not be shown as [REDACTED]")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if string contains substring
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(s, substr)
|
||||
}
|
||||
@@ -0,0 +1,652 @@
|
||||
// Package config provides validation for unified configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidationError represents a configuration validation error
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *ValidationError) Error() string {
|
||||
if e.Value != nil {
|
||||
return fmt.Sprintf("config validation error: %s: %s (value: %v)", e.Field, e.Message, e.Value)
|
||||
}
|
||||
return fmt.Sprintf("config validation error: %s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidationErrors represents multiple validation errors
|
||||
type ValidationErrors []ValidationError
|
||||
|
||||
// Error implements the error interface
|
||||
func (e ValidationErrors) Error() string {
|
||||
if len(e) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var messages []string
|
||||
for _, err := range e {
|
||||
messages = append(messages, err.Error())
|
||||
}
|
||||
return strings.Join(messages, "; ")
|
||||
}
|
||||
|
||||
// Validate performs comprehensive validation on the unified configuration
|
||||
func (c *UnifiedConfig) Validate() error {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate Provider configuration
|
||||
if err := c.validateProvider(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Session configuration
|
||||
if err := c.validateSession(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Token configuration
|
||||
if err := c.validateToken(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Redis configuration (uses existing validation)
|
||||
if err := c.Redis.Validate(); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Redis",
|
||||
Message: err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validate Security configuration
|
||||
if err := c.validateSecurity(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Middleware configuration
|
||||
if err := c.validateMiddleware(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Cache configuration
|
||||
if err := c.validateCache(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate RateLimit configuration
|
||||
if err := c.validateRateLimit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Logging configuration
|
||||
if err := c.validateLogging(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Metrics configuration
|
||||
if err := c.validateMetrics(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Transport configuration
|
||||
if err := c.validateTransport(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
// Validate Circuit configuration
|
||||
if err := c.validateCircuit(); err != nil {
|
||||
errors = append(errors, err...)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
return errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProvider validates provider configuration
|
||||
func (c *UnifiedConfig) validateProvider() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// IssuerURL is required and must be a valid URL
|
||||
if c.Provider.IssuerURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "issuer URL is required",
|
||||
})
|
||||
} else if _, err := url.Parse(c.Provider.IssuerURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "invalid issuer URL",
|
||||
Value: c.Provider.IssuerURL,
|
||||
})
|
||||
}
|
||||
|
||||
// ClientID is required
|
||||
if c.Provider.ClientID == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientID",
|
||||
Message: "client ID is required",
|
||||
})
|
||||
}
|
||||
|
||||
// ClientSecret is required (except for public clients with PKCE)
|
||||
if c.Provider.ClientSecret == "" && !c.Security.EnablePKCE {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.ClientSecret",
|
||||
Message: "client secret is required (or enable PKCE for public clients)",
|
||||
})
|
||||
}
|
||||
|
||||
// RedirectURL must be valid if provided
|
||||
if c.Provider.RedirectURL != "" {
|
||||
if _, err := url.Parse(c.Provider.RedirectURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.RedirectURL",
|
||||
Message: "invalid redirect URL",
|
||||
Value: c.Provider.RedirectURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Scopes must include 'openid' for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range c.Provider.Scopes {
|
||||
if scope == "openid" {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID && !c.Provider.OverrideScopes {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.Scopes",
|
||||
Message: "scopes must include 'openid' for OIDC",
|
||||
Value: c.Provider.Scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// JWK cache period must be positive
|
||||
if c.Provider.JWKCachePeriod < 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Provider.JWKCachePeriod",
|
||||
Message: "JWK cache period must be positive",
|
||||
Value: c.Provider.JWKCachePeriod,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSession validates session configuration
|
||||
func (c *UnifiedConfig) validateSession() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Session name must not be empty
|
||||
if c.Session.Name == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.Name",
|
||||
Message: "session name is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Session secret or encryption key is required
|
||||
if c.Session.Secret == "" && c.Session.EncryptionKey == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session",
|
||||
Message: "either session secret or encryption key is required",
|
||||
})
|
||||
}
|
||||
|
||||
// Encryption key must be at least 32 bytes for security
|
||||
if c.Session.EncryptionKey != "" && len(c.Session.EncryptionKey) < 32 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "encryption key must be at least 32 characters for proper security",
|
||||
Value: len(c.Session.EncryptionKey),
|
||||
})
|
||||
}
|
||||
|
||||
// ChunkSize must be reasonable (between 1KB and 10KB)
|
||||
if c.Session.ChunkSize < 1000 || c.Session.ChunkSize > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.ChunkSize",
|
||||
Message: "chunk size must be between 1000 and 10000 bytes",
|
||||
Value: c.Session.ChunkSize,
|
||||
})
|
||||
}
|
||||
|
||||
// MaxChunks must be reasonable (between 1 and 100)
|
||||
if c.Session.MaxChunks < 1 || c.Session.MaxChunks > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.MaxChunks",
|
||||
Message: "max chunks must be between 1 and 100",
|
||||
Value: c.Session.MaxChunks,
|
||||
})
|
||||
}
|
||||
|
||||
// SameSite must be valid
|
||||
validSameSite := map[string]bool{
|
||||
"": true,
|
||||
"Lax": true,
|
||||
"Strict": true,
|
||||
"None": true,
|
||||
}
|
||||
if !validSameSite[c.Session.SameSite] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.SameSite",
|
||||
Message: "invalid SameSite value (must be Lax, Strict, or None)",
|
||||
Value: c.Session.SameSite,
|
||||
})
|
||||
}
|
||||
|
||||
// StorageType must be valid
|
||||
validStorage := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"cookie": true,
|
||||
}
|
||||
if !validStorage[c.Session.StorageType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Session.StorageType",
|
||||
Message: "invalid storage type (must be memory, redis, or cookie)",
|
||||
Value: c.Session.StorageType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateToken validates token configuration
|
||||
func (c *UnifiedConfig) validateToken() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Token TTLs must be positive
|
||||
if c.Token.AccessTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.AccessTokenTTL",
|
||||
Message: "access token TTL must be positive",
|
||||
Value: c.Token.AccessTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Token.RefreshTokenTTL <= 0 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.RefreshTokenTTL",
|
||||
Message: "refresh token TTL must be positive",
|
||||
Value: c.Token.RefreshTokenTTL,
|
||||
})
|
||||
}
|
||||
|
||||
// Validation mode must be valid
|
||||
validModes := map[string]bool{
|
||||
"jwt": true,
|
||||
"introspect": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validModes[c.Token.ValidationMode] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ValidationMode",
|
||||
Message: "invalid validation mode (must be jwt, introspect, or hybrid)",
|
||||
Value: c.Token.ValidationMode,
|
||||
})
|
||||
}
|
||||
|
||||
// Introspect URL required for introspect or hybrid mode
|
||||
if (c.Token.ValidationMode == "introspect" || c.Token.ValidationMode == "hybrid") && c.Token.IntrospectURL == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.IntrospectURL",
|
||||
Message: "introspect URL is required for introspect or hybrid validation mode",
|
||||
})
|
||||
}
|
||||
|
||||
// Clock skew must be reasonable (0 to 10 minutes)
|
||||
if c.Token.ClockSkew < 0 || c.Token.ClockSkew > 10*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Token.ClockSkew",
|
||||
Message: "clock skew must be between 0 and 10 minutes",
|
||||
Value: c.Token.ClockSkew,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateSecurity validates security configuration
|
||||
func (c *UnifiedConfig) validateSecurity() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Validate allowed user domains are valid domains
|
||||
domainRegex := regexp.MustCompile(`^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$`)
|
||||
for _, domain := range c.Security.AllowedUserDomains {
|
||||
if !domainRegex.MatchString(domain) {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.AllowedUserDomains",
|
||||
Message: "invalid domain format",
|
||||
Value: domain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Max login attempts must be reasonable
|
||||
if c.Security.MaxLoginAttempts < 0 || c.Security.MaxLoginAttempts > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.MaxLoginAttempts",
|
||||
Message: "max login attempts must be between 0 and 100",
|
||||
Value: c.Security.MaxLoginAttempts,
|
||||
})
|
||||
}
|
||||
|
||||
// Lockout duration must be reasonable
|
||||
if c.Security.LockoutDuration < 0 || c.Security.LockoutDuration > 24*time.Hour {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Security.LockoutDuration",
|
||||
Message: "lockout duration must be between 0 and 24 hours",
|
||||
Value: c.Security.LockoutDuration,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMiddleware validates middleware configuration
|
||||
func (c *UnifiedConfig) validateMiddleware() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max request size must be reasonable (1KB to 100MB)
|
||||
if c.Middleware.MaxRequestSize < 1024 || c.Middleware.MaxRequestSize > 100*1024*1024 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.MaxRequestSize",
|
||||
Message: "max request size must be between 1KB and 100MB",
|
||||
Value: c.Middleware.MaxRequestSize,
|
||||
})
|
||||
}
|
||||
|
||||
// Request timeout must be reasonable
|
||||
if c.Middleware.RequestTimeout < time.Second || c.Middleware.RequestTimeout > 5*time.Minute {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Middleware.RequestTimeout",
|
||||
Message: "request timeout must be between 1 second and 5 minutes",
|
||||
Value: c.Middleware.RequestTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCache validates cache configuration
|
||||
func (c *UnifiedConfig) validateCache() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Cache.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Cache type must be valid
|
||||
validTypes := map[string]bool{
|
||||
"memory": true,
|
||||
"redis": true,
|
||||
"hybrid": true,
|
||||
}
|
||||
if !validTypes[c.Cache.Type] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.Type",
|
||||
Message: "invalid cache type (must be memory, redis, or hybrid)",
|
||||
Value: c.Cache.Type,
|
||||
})
|
||||
}
|
||||
|
||||
// Max entries must be reasonable
|
||||
if c.Cache.MaxEntries < 10 || c.Cache.MaxEntries > 1000000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.MaxEntries",
|
||||
Message: "max entries must be between 10 and 1000000",
|
||||
Value: c.Cache.MaxEntries,
|
||||
})
|
||||
}
|
||||
|
||||
// Eviction policy must be valid
|
||||
validEviction := map[string]bool{
|
||||
"lru": true,
|
||||
"lfu": true,
|
||||
"fifo": true,
|
||||
}
|
||||
if !validEviction[c.Cache.EvictionPolicy] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Cache.EvictionPolicy",
|
||||
Message: "invalid eviction policy (must be lru, lfu, or fifo)",
|
||||
Value: c.Cache.EvictionPolicy,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateRateLimit validates rate limiting configuration
|
||||
func (c *UnifiedConfig) validateRateLimit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.RateLimit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Requests per second must be reasonable
|
||||
if c.RateLimit.RequestsPerSecond < 1 || c.RateLimit.RequestsPerSecond > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.RequestsPerSecond",
|
||||
Message: "requests per second must be between 1 and 10000",
|
||||
Value: c.RateLimit.RequestsPerSecond,
|
||||
})
|
||||
}
|
||||
|
||||
// Burst must be at least as large as requests per second
|
||||
if c.RateLimit.Burst < c.RateLimit.RequestsPerSecond {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.Burst",
|
||||
Message: "burst must be at least as large as requests per second",
|
||||
Value: c.RateLimit.Burst,
|
||||
})
|
||||
}
|
||||
|
||||
// Key type must be valid
|
||||
validKeyTypes := map[string]bool{
|
||||
"ip": true,
|
||||
"user": true,
|
||||
"token": true,
|
||||
"custom": true,
|
||||
}
|
||||
if !validKeyTypes[c.RateLimit.KeyType] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "RateLimit.KeyType",
|
||||
Message: "invalid key type (must be ip, user, token, or custom)",
|
||||
Value: c.RateLimit.KeyType,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateLogging validates logging configuration
|
||||
func (c *UnifiedConfig) validateLogging() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Log level must be valid
|
||||
validLevels := map[string]bool{
|
||||
"debug": true,
|
||||
"info": true,
|
||||
"warn": true,
|
||||
"error": true,
|
||||
}
|
||||
if !validLevels[c.Logging.Level] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Level",
|
||||
Message: "invalid log level (must be debug, info, warn, or error)",
|
||||
Value: c.Logging.Level,
|
||||
})
|
||||
}
|
||||
|
||||
// Format must be valid
|
||||
validFormats := map[string]bool{
|
||||
"json": true,
|
||||
"text": true,
|
||||
"structured": true,
|
||||
}
|
||||
if !validFormats[c.Logging.Format] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Format",
|
||||
Message: "invalid log format (must be json, text, or structured)",
|
||||
Value: c.Logging.Format,
|
||||
})
|
||||
}
|
||||
|
||||
// Output must be valid
|
||||
validOutputs := map[string]bool{
|
||||
"stdout": true,
|
||||
"stderr": true,
|
||||
"file": true,
|
||||
}
|
||||
if !validOutputs[c.Logging.Output] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.Output",
|
||||
Message: "invalid log output (must be stdout, stderr, or file)",
|
||||
Value: c.Logging.Output,
|
||||
})
|
||||
}
|
||||
|
||||
// File path required if output is file
|
||||
if c.Logging.Output == "file" && c.Logging.FilePath == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Logging.FilePath",
|
||||
Message: "file path is required when output is 'file'",
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateMetrics validates metrics configuration
|
||||
func (c *UnifiedConfig) validateMetrics() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Metrics.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Provider must be valid
|
||||
validProviders := map[string]bool{
|
||||
"prometheus": true,
|
||||
"statsd": true,
|
||||
"otlp": true,
|
||||
}
|
||||
if !validProviders[c.Metrics.Provider] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Provider",
|
||||
Message: "invalid metrics provider (must be prometheus, statsd, or otlp)",
|
||||
Value: c.Metrics.Provider,
|
||||
})
|
||||
}
|
||||
|
||||
// Endpoint required for some providers
|
||||
if (c.Metrics.Provider == "statsd" || c.Metrics.Provider == "otlp") && c.Metrics.Endpoint == "" {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Metrics.Endpoint",
|
||||
Message: fmt.Sprintf("endpoint is required for %s provider", c.Metrics.Provider),
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateTransport validates transport configuration
|
||||
func (c *UnifiedConfig) validateTransport() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
// Max connections must be reasonable
|
||||
if c.Transport.MaxIdleConns < 0 || c.Transport.MaxIdleConns > 10000 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.MaxIdleConns",
|
||||
Message: "max idle connections must be between 0 and 10000",
|
||||
Value: c.Transport.MaxIdleConns,
|
||||
})
|
||||
}
|
||||
|
||||
// TLS min version must be valid
|
||||
validTLSVersions := map[string]bool{
|
||||
"TLS1.0": true,
|
||||
"TLS1.1": true,
|
||||
"TLS1.2": true,
|
||||
"TLS1.3": true,
|
||||
}
|
||||
if c.Transport.TLSMinVersion != "" && !validTLSVersions[c.Transport.TLSMinVersion] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.TLSMinVersion",
|
||||
Message: "invalid TLS min version (must be TLS1.0, TLS1.1, TLS1.2, or TLS1.3)",
|
||||
Value: c.Transport.TLSMinVersion,
|
||||
})
|
||||
}
|
||||
|
||||
// Proxy URL must be valid if provided
|
||||
if c.Transport.ProxyURL != "" {
|
||||
if _, err := url.Parse(c.Transport.ProxyURL); err != nil {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Transport.ProxyURL",
|
||||
Message: "invalid proxy URL",
|
||||
Value: c.Transport.ProxyURL,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// validateCircuit validates circuit breaker configuration
|
||||
func (c *UnifiedConfig) validateCircuit() ValidationErrors {
|
||||
var errors ValidationErrors
|
||||
|
||||
if !c.Circuit.Enabled {
|
||||
return errors
|
||||
}
|
||||
|
||||
// Consecutive failures must be reasonable
|
||||
if c.Circuit.ConsecutiveFailures < 1 || c.Circuit.ConsecutiveFailures > 100 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.ConsecutiveFailures",
|
||||
Message: "consecutive failures must be between 1 and 100",
|
||||
Value: c.Circuit.ConsecutiveFailures,
|
||||
})
|
||||
}
|
||||
|
||||
// Failure ratio must be between 0 and 1
|
||||
if c.Circuit.FailureRatio < 0 || c.Circuit.FailureRatio > 1 {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.FailureRatio",
|
||||
Message: "failure ratio must be between 0 and 1",
|
||||
Value: c.Circuit.FailureRatio,
|
||||
})
|
||||
}
|
||||
|
||||
// OnOpen action must be valid
|
||||
validActions := map[string]bool{
|
||||
"reject": true,
|
||||
"fallback": true,
|
||||
"passthrough": true,
|
||||
}
|
||||
if !validActions[c.Circuit.OnOpen] {
|
||||
errors = append(errors, ValidationError{
|
||||
Field: "Circuit.OnOpen",
|
||||
Message: "invalid OnOpen action (must be reject, fallback, or passthrough)",
|
||||
Value: c.Circuit.OnOpen,
|
||||
})
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestValidateUnifiedConfig tests the validation of UnifiedConfig
|
||||
func TestValidateUnifiedConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *UnifiedConfig
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "valid config with minimum requirements",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
},
|
||||
Session: SessionConfig{
|
||||
Name: "oidc_session",
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 5,
|
||||
StorageType: "cookie",
|
||||
},
|
||||
Token: TokenConfig{
|
||||
AccessTokenTTL: time.Hour,
|
||||
RefreshTokenTTL: 24 * time.Hour,
|
||||
ValidationMode: "jwt",
|
||||
},
|
||||
Middleware: MiddlewareConfig{
|
||||
MaxRequestSize: 10 * 1024 * 1024,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
},
|
||||
Logging: LoggingConfig{
|
||||
Level: "info",
|
||||
Format: "json",
|
||||
Output: "stdout",
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing provider URL",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.IssuerURL",
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Provider.ClientID",
|
||||
},
|
||||
{
|
||||
name: "encryption key too short",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "too-short",
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.EncryptionKey",
|
||||
},
|
||||
{
|
||||
name: "invalid chunk size",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 500, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.ChunkSize",
|
||||
},
|
||||
{
|
||||
name: "invalid max chunks",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
ChunkSize: 4000,
|
||||
MaxChunks: 0, // Too small
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Session.MaxChunks",
|
||||
},
|
||||
{
|
||||
name: "invalid TLS min version",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Transport: TransportConfig{
|
||||
TLSMinVersion: "1.0", // Too old
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Transport.TLSMinVersion",
|
||||
},
|
||||
{
|
||||
name: "invalid circuit breaker failure ratio",
|
||||
config: &UnifiedConfig{
|
||||
Provider: ProviderConfig{
|
||||
IssuerURL: "https://auth.example.com",
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
Session: SessionConfig{
|
||||
EncryptionKey: "this-is-a-32-character-key-12345",
|
||||
},
|
||||
Circuit: CircuitConfig{
|
||||
Enabled: true,
|
||||
FailureRatio: 1.5, // Too high
|
||||
},
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "Circuit.FailureRatio",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error for field %s, but got none", tt.errorField)
|
||||
} else if validationErrs, ok := err.(ValidationErrors); ok {
|
||||
found := false
|
||||
for _, e := range validationErrs {
|
||||
if e.Field == tt.errorField {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected validation error for field %s, but got errors for: %v",
|
||||
tt.errorField, validationErrs)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidationErrorMessage tests validation error formatting
|
||||
func TestValidationErrorMessage(t *testing.T) {
|
||||
errs := ValidationErrors{
|
||||
{
|
||||
Field: "Provider.IssuerURL",
|
||||
Message: "is required",
|
||||
Value: nil,
|
||||
},
|
||||
{
|
||||
Field: "Session.EncryptionKey",
|
||||
Message: "must be at least 32 characters",
|
||||
Value: 16,
|
||||
},
|
||||
}
|
||||
|
||||
errMsg := errs.Error()
|
||||
|
||||
if !strings.Contains(errMsg, "Provider.IssuerURL") {
|
||||
t.Error("Error message should contain field name Provider.IssuerURL")
|
||||
}
|
||||
if !strings.Contains(errMsg, "is required") {
|
||||
t.Error("Error message should contain 'is required'")
|
||||
}
|
||||
if !strings.Contains(errMsg, "Session.EncryptionKey") {
|
||||
t.Error("Error message should contain field name Session.EncryptionKey")
|
||||
}
|
||||
if !strings.Contains(errMsg, "must be at least 32 characters") {
|
||||
t.Error("Error message should contain 'must be at least 32 characters'")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateRedisConfig tests Redis configuration validation
|
||||
func TestValidateRedisConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *RedisConfig
|
||||
expectError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "valid standalone config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "localhost:6379",
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing address for standalone",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeStandalone,
|
||||
Addr: "",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Redis address is required",
|
||||
},
|
||||
{
|
||||
name: "valid cluster config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{"localhost:7000", "localhost:7001"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing cluster addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeCluster,
|
||||
ClusterAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "cluster address is required",
|
||||
},
|
||||
{
|
||||
name: "valid sentinel config",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "missing master name for sentinel",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "",
|
||||
SentinelAddrs: []string{"localhost:26379"},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Master name is required",
|
||||
},
|
||||
{
|
||||
name: "missing sentinel addresses",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: RedisModeSentinel,
|
||||
MasterName: "mymaster",
|
||||
SentinelAddrs: []string{},
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "sentinel address is required",
|
||||
},
|
||||
{
|
||||
name: "disabled redis needs no validation",
|
||||
config: &RedisConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid redis mode",
|
||||
config: &RedisConfig{
|
||||
Enabled: true,
|
||||
Mode: "invalid-mode",
|
||||
},
|
||||
expectError: true,
|
||||
errorMsg: "Invalid Redis mode",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected validation error containing '%s', but got none", tt.errorMsg)
|
||||
} else if !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error message to contain '%s', but got: %v", tt.errorMsg, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no validation error, but got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateRateLimit Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateRateLimit_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = false
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when rate limiting is disabled")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidConfig(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid rate limit config")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooLow(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0
|
||||
config.RateLimit.Burst = 100
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_RequestsPerSecondTooHigh(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 15000
|
||||
config.RateLimit.Burst = 20000
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.RequestsPerSecond", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "between 1 and 10000")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_BurstTooSmall(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 50 // Less than RequestsPerSecond
|
||||
config.RateLimit.KeyType = "ip"
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.Burst", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "at least as large as requests per second")
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_InvalidKeyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
keyType string
|
||||
}{
|
||||
{"empty key type", ""},
|
||||
{"invalid key type", "invalid"},
|
||||
{"random string", "foobar"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = tt.keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "RateLimit.KeyType", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid key type")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_ValidKeyTypes(t *testing.T) {
|
||||
validKeyTypes := []string{"ip", "user", "token", "custom"}
|
||||
|
||||
for _, keyType := range validKeyTypes {
|
||||
t.Run(keyType, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 100
|
||||
config.RateLimit.Burst = 200
|
||||
config.RateLimit.KeyType = keyType
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid key type: %s", keyType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRateLimit_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.RateLimit.Enabled = true
|
||||
config.RateLimit.RequestsPerSecond = 0 // Too low
|
||||
config.RateLimit.Burst = 50 // Will pass (0 < 50)
|
||||
config.RateLimit.KeyType = "invalid" // Invalid
|
||||
|
||||
errors := config.validateRateLimit()
|
||||
|
||||
// Should have 2 errors (rps and keyType)
|
||||
assert.Len(t, errors, 2)
|
||||
|
||||
// Check each error is present
|
||||
fields := make(map[string]bool)
|
||||
for _, err := range errors {
|
||||
fields[err.Field] = true
|
||||
}
|
||||
assert.True(t, fields["RateLimit.RequestsPerSecond"])
|
||||
assert.True(t, fields["RateLimit.KeyType"])
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// validateMetrics Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestValidateMetrics_Disabled(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = false
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors when metrics are disabled")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidPrometheus(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "prometheus"
|
||||
config.Metrics.Endpoint = "" // Prometheus doesn't require endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid prometheus config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidStatsd(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "localhost:8125"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid statsd config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_ValidOTLP(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "localhost:4317"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
assert.Empty(t, errors, "Should have no errors for valid otlp config")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_InvalidProvider(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider string
|
||||
}{
|
||||
{"empty provider", ""},
|
||||
{"invalid provider", "invalid"},
|
||||
{"datadog", "datadog"},
|
||||
{"influx", "influx"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = tt.provider
|
||||
config.Metrics.Endpoint = "localhost:8080"
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "invalid metrics provider")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateMetrics_StatsdMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "statsd"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for statsd provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_OTLPMissingEndpoint(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "otlp"
|
||||
config.Metrics.Endpoint = "" // Missing required endpoint
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
require.Len(t, errors, 1)
|
||||
assert.Equal(t, "Metrics.Endpoint", errors[0].Field)
|
||||
assert.Contains(t, errors[0].Message, "endpoint is required for otlp provider")
|
||||
}
|
||||
|
||||
func TestValidateMetrics_MultipleErrors(t *testing.T) {
|
||||
config := NewUnifiedConfig()
|
||||
config.Metrics.Enabled = true
|
||||
config.Metrics.Provider = "invalid" // Invalid provider
|
||||
config.Metrics.Endpoint = "" // Would be missing if provider was statsd/otlp
|
||||
|
||||
errors := config.validateMetrics()
|
||||
|
||||
// Should have at least 1 error for invalid provider
|
||||
assert.NotEmpty(t, errors)
|
||||
assert.Equal(t, "Metrics.Provider", errors[0].Field)
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// REDACTED is the placeholder value for sensitive information
|
||||
const REDACTED = "[REDACTED]"
|
||||
|
||||
// MarshalJSON implements custom JSON marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalJSON() ([]byte, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML implements custom YAML marshalling to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (c Config) MarshalYAML() (interface{}, error) {
|
||||
// Build a map manually to avoid type alias issues with yaegi
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Copy public fields
|
||||
result["providerURL"] = c.ProviderURL
|
||||
result["clientID"] = c.ClientID
|
||||
result["callbackURL"] = c.CallbackURL
|
||||
result["logoutURL"] = c.LogoutURL
|
||||
result["postLogoutRedirectURI"] = c.PostLogoutRedirectURI
|
||||
result["scopes"] = c.Scopes
|
||||
result["forceHTTPS"] = c.ForceHTTPS
|
||||
result["logLevel"] = c.LogLevel
|
||||
result["rateLimit"] = c.RateLimit
|
||||
result["excludedURLs"] = c.ExcludedURLs
|
||||
result["allowedUserDomains"] = c.AllowedUserDomains
|
||||
result["allowedUsers"] = c.AllowedUsers
|
||||
result["allowedRolesAndGroups"] = c.AllowedRolesAndGroups
|
||||
|
||||
// Redact sensitive fields
|
||||
result["clientSecret"] = REDACTED
|
||||
result["sessionEncryptionKey"] = REDACTED
|
||||
|
||||
// Handle Redis config
|
||||
if c.Redis != nil {
|
||||
redisMap := make(map[string]interface{})
|
||||
redisMap["enabled"] = c.Redis.Enabled
|
||||
redisMap["address"] = c.Redis.Address
|
||||
redisMap["password"] = REDACTED
|
||||
redisMap["db"] = c.Redis.DB
|
||||
redisMap["poolSize"] = c.Redis.PoolSize
|
||||
redisMap["cacheMode"] = c.Redis.CacheMode
|
||||
result["redis"] = redisMap
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarshalJSON for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalJSON() ([]byte, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return json.Marshal(result)
|
||||
}
|
||||
|
||||
// MarshalYAML for RedisConfig to redact sensitive fields
|
||||
// Rewritten without type aliases for yaegi compatibility
|
||||
func (r RedisConfig) MarshalYAML() (interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
result["enabled"] = r.Enabled
|
||||
result["address"] = r.Address
|
||||
result["password"] = REDACTED
|
||||
result["db"] = r.DB
|
||||
result["poolSize"] = r.PoolSize
|
||||
result["cacheMode"] = r.CacheMode
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -18,7 +18,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
// Test that CSRF tokens persist through the authentication flow
|
||||
t.Run("CSRF_Token_Persists_After_Selective_Clear", func(t *testing.T) {
|
||||
// Create a session manager
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create initial request
|
||||
@@ -90,7 +90,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test that marking session as dirty forces save
|
||||
t.Run("Mark_Dirty_Forces_Session_Save", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -126,7 +126,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test Azure-specific session handling
|
||||
t.Run("Azure_Session_Cookie_Configuration", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate Azure callback scenario
|
||||
@@ -158,7 +158,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test session continuity through auth flow
|
||||
t.Run("Session_Continuity_Through_Auth_Flow", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Step 1: Initial request
|
||||
@@ -199,7 +199,7 @@ func TestCSRFTokenSessionManagement(t *testing.T) {
|
||||
|
||||
// Test large token handling doesn't affect CSRF
|
||||
t.Run("Large_Tokens_Dont_Affect_CSRF", func(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest("GET", "http://example.com/test", nil)
|
||||
@@ -262,7 +262,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
|
||||
// We can't fully initialize TraefikOidc without network access,
|
||||
// but we can test the session management directly
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", NewLogger(plugin.LogLevel))
|
||||
sessionManager, err := NewSessionManager(plugin.SessionEncryptionKey, plugin.ForceHTTPS, "", "", 0, NewLogger(plugin.LogLevel))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Session_Created_On_Protected_Request", func(t *testing.T) {
|
||||
@@ -291,7 +291,7 @@ func TestAuthFlowWithoutExternalDependencies(t *testing.T) {
|
||||
// TestRegressionLoginLoop specifically tests the fix for issue #53
|
||||
func TestRegressionLoginLoop(t *testing.T) {
|
||||
// This test verifies that the specific changes made to fix the login loop work correctly
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simulate the exact flow that was causing the login loop
|
||||
@@ -392,7 +392,7 @@ func TestRegressionLoginLoop(t *testing.T) {
|
||||
|
||||
// TestCSRFValidationTiming tests timing-sensitive CSRF validation scenarios
|
||||
func TestCSRFValidationTiming(t *testing.T) {
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", NewLogger("debug"))
|
||||
sessionManager, err := NewSessionManager("test-encryption-key-32-characters", false, "", "", 0, NewLogger("debug"))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Rapid_Redirect_Maintains_CSRF", func(t *testing.T) {
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCustomClaimNames_DefaultBehavior tests backward compatibility with default claim names
|
||||
func TestCustomClaimNames_DefaultBehavior(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Explicitly set defaults to test backward compatibility
|
||||
ts.tOidc.roleClaimName = "roles"
|
||||
ts.tOidc.groupClaimName = "groups"
|
||||
|
||||
// Test that when no custom claim names are configured, it uses defaults "roles" and "groups"
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin", "users"},
|
||||
"roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_Auth0Namespaced tests Auth0-style namespaced claims
|
||||
func TestCustomClaimNames_Auth0Namespaced(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names for Auth0
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with Auth0-style namespaced claims
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{"admin", "users"},
|
||||
"https://myapp.com/roles": []interface{}{"editor", "viewer"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin", "users"}) {
|
||||
t.Errorf("Expected groups [admin users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor", "viewer"}) {
|
||||
t.Errorf("Expected roles [editor viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_CustomSimpleNames tests custom simple claim names
|
||||
func TestCustomClaimNames_CustomSimpleNames(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom simple claim names
|
||||
ts.tOidc.roleClaimName = "user_roles"
|
||||
ts.tOidc.groupClaimName = "user_groups"
|
||||
|
||||
// Create token with custom claim names
|
||||
claims := map[string]interface{}{
|
||||
"user_groups": []interface{}{"engineering", "product"},
|
||||
"user_roles": []interface{}{"developer", "manager"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"engineering", "product"}) {
|
||||
t.Errorf("Expected groups [engineering product], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"developer", "manager"}) {
|
||||
t.Errorf("Expected roles [developer manager], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MissingClaims tests behavior when custom claims are missing
|
||||
func TestCustomClaimNames_MissingClaims(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token WITHOUT the custom claims
|
||||
claims := map[string]interface{}{
|
||||
"sub": "user123",
|
||||
"email": "user@example.com",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should return empty slices, not error
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedClaims tests error handling for malformed claims
|
||||
func TestCustomClaimNames_MalformedRoleClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with malformed role claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": "this-should-be-an-array",
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed role claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_roles claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_MalformedGroupClaim tests error handling for malformed group claims
|
||||
func TestCustomClaimNames_MalformedGroupClaim(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with malformed group claim (not an array)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": 12345, // Not an array
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, _, err = ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err == nil {
|
||||
t.Error("Expected error for malformed group claim, got nil")
|
||||
}
|
||||
|
||||
// Check error message contains the custom claim name
|
||||
expectedError := "custom_groups claim is not an array"
|
||||
if err.Error() != expectedError {
|
||||
t.Errorf("Expected error '%s', got '%s'", expectedError, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_PartialConfiguration tests when only one claim name is customized
|
||||
func TestCustomClaimNames_OnlyRoleCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only role claim name (group uses default)
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "groups" // default
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"groups": []interface{}{"admin"},
|
||||
"https://myapp.com/roles": []interface{}{"editor"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"admin"}) {
|
||||
t.Errorf("Expected groups [admin], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"editor"}) {
|
||||
t.Errorf("Expected roles [editor], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_OnlyGroupCustomized tests when only group claim name is customized
|
||||
func TestCustomClaimNames_OnlyGroupCustomized(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure only group claim name (role uses default)
|
||||
ts.tOidc.roleClaimName = "roles" // default
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with mixed claim names
|
||||
claims := map[string]interface{}{
|
||||
"roles": []interface{}{"viewer"},
|
||||
"https://myapp.com/groups": []interface{}{"users"},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(groups, []string{"users"}) {
|
||||
t.Errorf("Expected groups [users], got %v", groups)
|
||||
}
|
||||
|
||||
if !stringSliceEqual(roles, []string{"viewer"}) {
|
||||
t.Errorf("Expected roles [viewer], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_EmptyArrays tests extraction with empty claim arrays
|
||||
func TestCustomClaimNames_EmptyArrays(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "https://myapp.com/roles"
|
||||
ts.tOidc.groupClaimName = "https://myapp.com/groups"
|
||||
|
||||
// Create token with empty arrays
|
||||
claims := map[string]interface{}{
|
||||
"https://myapp.com/groups": []interface{}{},
|
||||
"https://myapp.com/roles": []interface{}{},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(groups) != 0 {
|
||||
t.Errorf("Expected empty groups, got %v", groups)
|
||||
}
|
||||
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("Expected empty roles, got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringElements tests handling of non-string elements in claim arrays
|
||||
func TestCustomClaimNames_NonStringInRoleArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.roleClaimName = "custom_roles"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_roles": []interface{}{"role1", 12345, "role2", true},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
_, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(roles, []string{"role1", "role2"}) {
|
||||
t.Errorf("Expected roles [role1 role2], got %v", roles)
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomClaimNames_NonStringInGroupArray tests handling of non-string elements in group arrays
|
||||
func TestCustomClaimNames_NonStringInGroupArray(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Configure custom claim names
|
||||
ts.tOidc.groupClaimName = "custom_groups"
|
||||
|
||||
// Create token with mixed-type array (should skip non-string elements)
|
||||
claims := map[string]interface{}{
|
||||
"custom_groups": []interface{}{"group1", nil, "group2", 3.14},
|
||||
}
|
||||
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, _, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Should only extract string elements
|
||||
if !stringSliceEqual(groups, []string{"group1", "group2"}) {
|
||||
t.Errorf("Expected groups [group1 group2], got %v", groups)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
traefikoidc.raczylo.com
|
||||
@@ -437,6 +437,21 @@ http:
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
If your Keycloak instance runs on an internal network with private IP addresses (e.g., Docker networks, Kubernetes internal services), set `allowPrivateIPAddresses: true`:
|
||||
|
||||
```yaml
|
||||
traefikoidc:
|
||||
providerUrl: "https://192.168.1.100:8443/auth/realms/your-realm" # Private IP
|
||||
allowPrivateIPAddresses: true # Required for private IP addresses
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
# ... other config
|
||||
```
|
||||
|
||||
> **Security Warning**: Only enable `allowPrivateIPAddresses` in trusted network environments where you control the OIDC provider. This setting reduces SSRF protection.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
+1125
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,413 @@
|
||||
# Redis Cache Backend Test Suite
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the comprehensive test suite created for the Redis cache backend feature in the Traefik OIDC plugin. The test suite ensures reliability, performance, and correctness of the caching infrastructure.
|
||||
|
||||
## Test Structure
|
||||
|
||||
### Directory Organization
|
||||
|
||||
```
|
||||
internal/cache/
|
||||
├── backend/
|
||||
│ ├── interface.go # CacheBackend interface definition
|
||||
│ ├── interface_test.go # Contract tests for all backends
|
||||
│ ├── memory.go # In-memory backend implementation
|
||||
│ ├── memory_test.go # Memory backend unit tests
|
||||
│ ├── redis.go # Redis backend implementation
|
||||
│ ├── redis_test.go # Redis backend unit tests
|
||||
│ ├── errors.go # Error definitions
|
||||
│ └── test_helpers_test.go # Test infrastructure and helpers
|
||||
│
|
||||
└── resilience/
|
||||
├── circuit_breaker.go # Circuit breaker implementation
|
||||
├── circuit_breaker_test.go # Circuit breaker tests
|
||||
├── health_check.go # Health checker implementation
|
||||
└── health_check_test.go # Health check tests
|
||||
|
||||
redis_integration_test.go # End-to-end integration tests
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### 1. Interface Contract Tests (`interface_test.go`)
|
||||
|
||||
**Purpose:** Ensure all backend implementations (Memory, Redis, Hybrid) comply with the CacheBackend interface contract.
|
||||
|
||||
**Test Cases:**
|
||||
- `TestCacheBackendContract` - Runs all contract tests against each backend type
|
||||
- `testBasicSetGet` - Verifies basic set/get operations
|
||||
- `testGetNonExistent` - Tests behavior for non-existent keys
|
||||
- `testUpdateExisting` - Validates updating existing keys
|
||||
- `testDelete` - Tests delete operations
|
||||
- `testDeleteNonExistent` - Delete non-existent keys
|
||||
- `testExists` - Key existence checking
|
||||
- `testTTLExpiration` - TTL and expiration behavior
|
||||
- `testClear` - Clear all keys operation
|
||||
- `testPing` - Health check functionality
|
||||
- `testStats` - Statistics tracking
|
||||
- `testConcurrentAccess` - Thread safety with 10+ goroutines
|
||||
- `testLargeValues` - Handling of 1MB+ values
|
||||
- `testEmptyValues` - Empty byte array handling
|
||||
- `testSpecialCharactersInKeys` - Special characters in key names
|
||||
|
||||
**Coverage:** ~95% of interface methods
|
||||
|
||||
### 2. Memory Backend Tests (`memory_test.go`)
|
||||
|
||||
**Purpose:** Test the in-memory LRU cache backend with comprehensive edge cases.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (6 tests)
|
||||
- `TestMemoryBackend_BasicOperations` - CRUD operations
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- DeleteNonExistent
|
||||
- Exists
|
||||
- Clear
|
||||
|
||||
#### TTL and Expiration (3 tests)
|
||||
- `TestMemoryBackend_TTLExpiration`
|
||||
- ShortTTL (100ms)
|
||||
- TTLDecrement over time
|
||||
- CleanupExpiredItems
|
||||
|
||||
#### LRU Eviction (2 tests)
|
||||
- `TestMemoryBackend_LRUEviction` - Verifies LRU algorithm
|
||||
- `TestMemoryBackend_MemoryLimit` - Memory-based eviction
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestMemoryBackend_ConcurrentAccess` - 20 goroutines, 50 iterations each
|
||||
|
||||
#### Edge Cases (6 tests)
|
||||
- `TestMemoryBackend_UpdateExisting` - Overwriting values
|
||||
- `TestMemoryBackend_Stats` - Metrics tracking (hits, misses, hit rate)
|
||||
- `TestMemoryBackend_EmptyValues` - Zero-length byte arrays
|
||||
- `TestMemoryBackend_LargeValues` - 1MB values
|
||||
- `TestMemoryBackend_Close` - Proper cleanup
|
||||
- `TestMemoryBackend_Ping` - Health checks
|
||||
- `TestMemoryBackend_ValueIsolation` - Returns copies, not references
|
||||
|
||||
**Coverage:** ~92% of memory backend code
|
||||
|
||||
### 3. Redis Backend Tests (`redis_test.go`)
|
||||
|
||||
**Purpose:** Test Redis backend using miniredis (in-memory Redis mock).
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Basic Operations (4 tests)
|
||||
- `TestRedisBackend_BasicOperations`
|
||||
- SetAndGet
|
||||
- GetNonExistent
|
||||
- Delete
|
||||
- Exists
|
||||
|
||||
#### Redis-Specific Features (6 tests)
|
||||
- `TestRedisBackend_KeyPrefixing` - Namespace isolation
|
||||
- `TestRedisBackend_TTLExpiration` - Redis TTL handling
|
||||
- `TestRedisBackend_Clear` - Bulk delete with SCAN
|
||||
- `TestRedisBackend_NoPrefix` - Operation without prefix
|
||||
|
||||
#### Error Handling (2 tests)
|
||||
- `TestRedisBackend_ConnectionFailure` - Connection errors
|
||||
- `TestRedisBackend_RedisErrors` - Simulated Redis failures
|
||||
|
||||
#### Concurrency (1 test)
|
||||
- `TestRedisBackend_ConcurrentAccess` - 20 goroutines, 50 operations
|
||||
|
||||
#### Advanced Features (3 tests)
|
||||
- `TestRedisBackend_PipelineOperations`
|
||||
- SetMany (batch writes)
|
||||
- GetMany (batch reads)
|
||||
- GetManyWithNonExistent
|
||||
|
||||
#### Edge Cases (5 tests)
|
||||
- `TestRedisBackend_Stats` - Statistics tracking
|
||||
- `TestRedisBackend_Ping` - Connection health
|
||||
- `TestRedisBackend_Close` - Resource cleanup
|
||||
- `TestRedisBackend_UpdateExisting` - Overwrite handling
|
||||
- `TestRedisBackend_LargeValues` - 1MB values
|
||||
- `TestRedisBackend_EmptyValues` - Empty arrays
|
||||
|
||||
**Coverage:** ~88% of Redis backend code
|
||||
|
||||
**Key Testing Tool:** `miniredis` - In-memory Redis mock that supports:
|
||||
- All basic Redis commands
|
||||
- TTL and expiration
|
||||
- Time manipulation (FastForward)
|
||||
- Error simulation
|
||||
- No external Redis server required
|
||||
|
||||
### 4. Circuit Breaker Tests (`circuit_breaker_test.go`)
|
||||
|
||||
**Purpose:** Verify circuit breaker pattern implementation for fault tolerance.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### State Transitions (5 tests)
|
||||
- `TestCircuitBreaker_StateTransitions`
|
||||
- Initial state (Closed)
|
||||
- Closed → Open (after max failures)
|
||||
- Open → HalfOpen (after timeout)
|
||||
- HalfOpen → Closed (after successful requests)
|
||||
- HalfOpen → Open (on failure)
|
||||
|
||||
#### Behavior Tests (5 tests)
|
||||
- `TestCircuitBreaker_OpenCircuitBlocks` - Blocks requests when open
|
||||
- `TestCircuitBreaker_HalfOpenMaxRequests` - Limits requests in half-open
|
||||
- `TestCircuitBreaker_SuccessResetsFailures` - Failure counter reset
|
||||
- `TestCircuitBreaker_ConcurrentAccess` - Thread safety
|
||||
- `TestCircuitBreaker_Stats` - Statistics tracking
|
||||
|
||||
#### Advanced Tests (7 tests)
|
||||
- `TestCircuitBreaker_Reset` - Manual reset
|
||||
- `TestCircuitBreaker_StateChangeCallback` - Notifications
|
||||
- `TestCircuitBreaker_IsAvailable` - Availability check
|
||||
- `TestCircuitBreaker_RapidFailures` - Fast consecutive failures
|
||||
- `TestCircuitBreaker_TimeoutAccuracy` - Timeout precision
|
||||
- `TestCircuitBreaker_DefaultConfig` - Default configuration
|
||||
- `TestCircuitBreaker_StateString` - String representation
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkCircuitBreaker_Execute` - Successful operations
|
||||
- `BenchmarkCircuitBreaker_ExecuteWithFailures` - Mixed success/failure
|
||||
|
||||
**Coverage:** ~95% of circuit breaker code
|
||||
|
||||
### 5. Health Check Tests (`health_check_test.go`)
|
||||
|
||||
**Purpose:** Validate periodic health checking and status management.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Status Transitions (4 tests)
|
||||
- `TestHealthChecker_StatusTransitions` - Healthy → Degraded → Unhealthy → Healthy
|
||||
- `TestHealthChecker_InitialState` - Default healthy state
|
||||
- `TestHealthChecker_ForceCheck` - Manual health check trigger
|
||||
- `TestHealthChecker_StatusChangeCallback` - Change notifications
|
||||
|
||||
#### Behavior Tests (6 tests)
|
||||
- `TestHealthChecker_Stats` - Statistics tracking
|
||||
- `TestHealthChecker_Timeout` - Check timeout handling
|
||||
- `TestHealthChecker_ConcurrentAccess` - Thread safety
|
||||
- `TestHealthChecker_StopAndStart` - Lifecycle management
|
||||
- `TestHealthChecker_DegradedState` - Degraded status detection
|
||||
- `TestHealthChecker_DefaultConfig` - Default settings
|
||||
|
||||
#### Advanced Tests (2 tests)
|
||||
- `TestHealthChecker_StatusString` - String representation
|
||||
- `TestHealthChecker_RecoveryPattern` - Typical failure/recovery cycle
|
||||
|
||||
**Benchmarks:**
|
||||
- `BenchmarkHealthChecker_ForceCheck` - Check performance
|
||||
- `BenchmarkHealthChecker_Status` - Status read performance
|
||||
|
||||
**Coverage:** ~90% of health checker code
|
||||
|
||||
### 6. Integration Tests (`redis_integration_test.go`)
|
||||
|
||||
**Purpose:** End-to-end testing of real-world scenarios.
|
||||
|
||||
**Test Cases:**
|
||||
|
||||
#### Multi-Instance Tests (3 tests)
|
||||
- `TestRedisIntegration_MultipleInstances`
|
||||
- ShareTokenBlacklist - JTI sharing across Traefik replicas
|
||||
- ShareTokenCache - Token cache sharing
|
||||
- ShareMetadataCache - Provider metadata sharing
|
||||
|
||||
#### Replay Detection (2 tests)
|
||||
- `TestRedisIntegration_JTIReplayDetection`
|
||||
- PreventReplayAcrossInstances - Block used JTIs
|
||||
- ConcurrentJTIChecks - Race condition handling
|
||||
|
||||
#### Resilience (1 test)
|
||||
- `TestRedisIntegration_Failover`
|
||||
- RedisTemporaryFailure - Recovery from temporary failures
|
||||
|
||||
#### Performance (1 test)
|
||||
- `TestRedisIntegration_HighLoad`
|
||||
- HighConcurrency - 50 goroutines × 100 operations
|
||||
|
||||
#### Consistency (2 tests)
|
||||
- `TestRedisIntegration_TTLConsistency` - TTL accuracy
|
||||
- `TestRedisIntegration_MemoryUsage` - 10,000 item dataset
|
||||
- `TestRedisIntegration_Cleanup` - Bulk cleanup operations
|
||||
|
||||
**Coverage:** Integration scenarios covering 80%+ of realistic use cases
|
||||
|
||||
## Test Helpers and Infrastructure
|
||||
|
||||
### Test Helpers (`test_helpers_test.go`)
|
||||
|
||||
**Utilities:**
|
||||
- `TestLogger` - Logging for tests
|
||||
- `MiniredisServer` - Miniredis setup/teardown
|
||||
- `TestConfig` - Default test configurations
|
||||
- `GenerateTestData` - Test data generation
|
||||
- `GenerateLargeValue` - Large value creation
|
||||
- `AssertCacheStats` - Statistics validation
|
||||
- `WaitForCondition` - Async condition waiting
|
||||
- `AssertEventuallyExpires` - TTL expiration verification
|
||||
|
||||
## Running the Tests
|
||||
|
||||
### Run All Tests
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -v
|
||||
go test ./internal/cache/resilience/... -v
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run Specific Test Suites
|
||||
```bash
|
||||
# Memory backend only
|
||||
go test ./internal/cache/backend -run TestMemoryBackend -v
|
||||
|
||||
# Redis backend only
|
||||
go test ./internal/cache/backend -run TestRedisBackend -v
|
||||
|
||||
# Circuit breaker only
|
||||
go test ./internal/cache/resilience -run TestCircuitBreaker -v
|
||||
|
||||
# Integration tests only
|
||||
go test -run TestRedisIntegration -v
|
||||
```
|
||||
|
||||
### Run with Coverage
|
||||
```bash
|
||||
go test ./internal/cache/backend/... -coverprofile=coverage.out
|
||||
go test ./internal/cache/resilience/... -coverprofile=coverage_resilience.out
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
### Run Benchmarks
|
||||
```bash
|
||||
go test ./internal/cache/backend -bench=. -benchmem
|
||||
go test ./internal/cache/resilience -bench=. -benchmem
|
||||
```
|
||||
|
||||
### Run with Race Detector
|
||||
```bash
|
||||
go test ./internal/cache/... -race -v
|
||||
```
|
||||
|
||||
## Test Patterns Used
|
||||
|
||||
### 1. Table-Driven Tests
|
||||
Used for testing multiple scenarios with similar structure.
|
||||
|
||||
### 2. Subtests (t.Run)
|
||||
Organized test cases into logical groups with clear names.
|
||||
|
||||
### 3. Parallel Tests
|
||||
Tests marked with `t.Parallel()` for faster execution.
|
||||
|
||||
### 4. Test Fixtures
|
||||
Reusable setup functions for common test data.
|
||||
|
||||
### 5. Mocking
|
||||
- `miniredis` for Redis operations
|
||||
- Mock functions for callbacks and health checks
|
||||
|
||||
### 6. Assertion Helpers
|
||||
Using `testify/assert` and `testify/require` for clear assertions.
|
||||
|
||||
## Test Coverage Summary
|
||||
|
||||
| Component | Coverage | Tests | Lines of Code |
|
||||
|-----------|----------|-------|---------------|
|
||||
| Interface Contract | 95% | 14 | ~200 |
|
||||
| Memory Backend | 92% | 18 | ~350 |
|
||||
| Redis Backend | 88% | 21 | ~400 |
|
||||
| Circuit Breaker | 95% | 17 | ~250 |
|
||||
| Health Checker | 90% | 12 | ~200 |
|
||||
| Integration Tests | 80% | 9 | ~300 |
|
||||
| **Total** | **90%** | **91** | **~1,700** |
|
||||
|
||||
## Edge Cases Tested
|
||||
|
||||
1. **Empty values** - Zero-length byte arrays
|
||||
2. **Large values** - 1MB+ data
|
||||
3. **Special characters** - Keys with :, /, -, _, ., |
|
||||
4. **Concurrent access** - 10-50 goroutines
|
||||
5. **TTL edge cases** - Very short (<100ms) and long (24h+) TTLs
|
||||
6. **Connection failures** - Network errors, timeouts
|
||||
7. **Redis errors** - Simulated Redis failures
|
||||
8. **Memory limits** - Eviction under memory pressure
|
||||
9. **Race conditions** - Concurrent JTI checks
|
||||
10. **State transitions** - All circuit breaker and health check states
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
Benchmarks included for:
|
||||
- Cache operations (Set, Get, Delete)
|
||||
- Circuit breaker execution
|
||||
- Health check operations
|
||||
- Concurrent access patterns
|
||||
- Large datasets (10,000+ items)
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Testing Libraries
|
||||
- `github.com/stretchr/testify` - Assertions and test utilities
|
||||
- `github.com/alicebob/miniredis/v2` - In-memory Redis mock
|
||||
- `github.com/redis/go-redis/v9` - Redis client
|
||||
|
||||
### Why Miniredis?
|
||||
- **No external dependencies** - No Redis server required
|
||||
- **Fast** - In-memory, perfect for unit tests
|
||||
- **Full Redis API** - Supports all operations we need
|
||||
- **Time manipulation** - FastForward for TTL testing
|
||||
- **Error simulation** - Test failure scenarios
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Tests
|
||||
1. Hybrid backend tests (L1/L2 cache)
|
||||
2. Network partition scenarios
|
||||
3. Redis cluster support
|
||||
4. Persistence and recovery tests
|
||||
5. Metrics and monitoring integration
|
||||
|
||||
### Test Infrastructure Improvements
|
||||
1. Test containers for real Redis integration
|
||||
2. Performance regression tracking
|
||||
3. Chaos engineering tests
|
||||
4. Load testing framework
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### Recommended CI Configuration
|
||||
|
||||
```yaml
|
||||
test:
|
||||
script:
|
||||
- go test ./internal/cache/... -race -cover -v
|
||||
- go test -run TestRedisIntegration -v
|
||||
- go test ./internal/cache/... -bench=. -benchmem
|
||||
```
|
||||
|
||||
## Maintenance Guidelines
|
||||
|
||||
1. **Add tests for new features** - Maintain >85% coverage
|
||||
2. **Update contract tests** - When interface changes
|
||||
3. **Test edge cases** - Always test error paths
|
||||
4. **Document test purpose** - Clear comments explaining what each test validates
|
||||
5. **Keep tests fast** - Use t.Parallel() where possible
|
||||
6. **Mock external dependencies** - Use miniredis, not real Redis
|
||||
|
||||
## Conclusion
|
||||
|
||||
This comprehensive test suite provides:
|
||||
- **High confidence** in cache backend correctness
|
||||
- **Fast feedback** - Tests run in seconds
|
||||
- **Good coverage** - 90% overall
|
||||
- **Clear documentation** - Each test is well-documented
|
||||
- **Maintainability** - Clear structure and patterns
|
||||
|
||||
The test suite ensures that the Redis cache backend feature is production-ready and reliable for multi-replica Traefik deployments with shared caching requirements.
|
||||
+1373
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,551 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientRegistrationResponse represents the response from a successful client registration (RFC 7591)
|
||||
type ClientRegistrationResponse struct {
|
||||
// Required fields
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
// Conditional - only for confidential clients
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
|
||||
// Optional - for managing registration
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
|
||||
// Expiration
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"`
|
||||
|
||||
// Echo back of registered metadata
|
||||
RedirectURIs []string `json:"redirect_uris,omitempty"`
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TOSURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
// ClientRegistrationError represents an error response from client registration (RFC 7591)
|
||||
type ClientRegistrationError struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// DynamicClientRegistrar handles OIDC Dynamic Client Registration (RFC 7591)
|
||||
type DynamicClientRegistrar struct {
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
config *DynamicClientRegistrationConfig
|
||||
providerURL string
|
||||
|
||||
// Cached registration response
|
||||
mu sync.RWMutex
|
||||
registrationResponse *ClientRegistrationResponse
|
||||
}
|
||||
|
||||
// NewDynamicClientRegistrar creates a new dynamic client registrar
|
||||
func NewDynamicClientRegistrar(
|
||||
httpClient *http.Client,
|
||||
logger *Logger,
|
||||
dcrConfig *DynamicClientRegistrationConfig,
|
||||
providerURL string,
|
||||
) *DynamicClientRegistrar {
|
||||
if logger == nil {
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
return &DynamicClientRegistrar{
|
||||
httpClient: httpClient,
|
||||
logger: logger,
|
||||
config: dcrConfig,
|
||||
providerURL: providerURL,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClient performs dynamic client registration with the OIDC provider
|
||||
// It first attempts to load existing credentials from a file if persistence is enabled,
|
||||
// then registers a new client if no valid credentials exist.
|
||||
func (r *DynamicClientRegistrar) RegisterClient(ctx context.Context, registrationEndpoint string) (*ClientRegistrationResponse, error) {
|
||||
if r.config == nil || !r.config.Enabled {
|
||||
return nil, fmt.Errorf("dynamic client registration is not enabled")
|
||||
}
|
||||
|
||||
// Try to load existing credentials if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
if resp, err := r.loadCredentials(); err == nil && resp != nil {
|
||||
// Check if credentials are still valid (not expired)
|
||||
if r.areCredentialsValid(resp) {
|
||||
r.logger.Info("Loaded existing client credentials from file")
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = resp
|
||||
r.mu.Unlock()
|
||||
return resp, nil
|
||||
}
|
||||
r.logger.Info("Existing credentials expired or invalid, registering new client")
|
||||
}
|
||||
}
|
||||
|
||||
// Determine registration endpoint
|
||||
endpoint := registrationEndpoint
|
||||
if r.config.RegistrationEndpoint != "" {
|
||||
endpoint = r.config.RegistrationEndpoint
|
||||
}
|
||||
|
||||
if endpoint == "" {
|
||||
return nil, fmt.Errorf("no registration endpoint available: provider does not support dynamic client registration or endpoint not configured")
|
||||
}
|
||||
|
||||
// Validate the endpoint URL
|
||||
if !strings.HasPrefix(endpoint, "https://") {
|
||||
// Allow http only for localhost/development
|
||||
if !strings.HasPrefix(endpoint, "http://localhost") && !strings.HasPrefix(endpoint, "http://127.0.0.1") {
|
||||
return nil, fmt.Errorf("registration endpoint must use HTTPS for security")
|
||||
}
|
||||
r.logger.Infof("Warning: using insecure HTTP for registration endpoint (development only): %s", endpoint)
|
||||
}
|
||||
|
||||
// Build registration request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build registration request: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Registering client at endpoint: %s", endpoint)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create registration request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
// Add Initial Access Token if provided
|
||||
if r.config.InitialAccessToken != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+r.config.InitialAccessToken)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("registration request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read registration response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("registration failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("registration failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registration response: %w", err)
|
||||
}
|
||||
|
||||
// Validate response
|
||||
if regResp.ClientID == "" {
|
||||
return nil, fmt.Errorf("registration response missing client_id")
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully registered client with ID: %s", regResp.ClientID)
|
||||
|
||||
// Cache the response
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist client credentials: %v", err)
|
||||
// Don't fail registration if persistence fails
|
||||
}
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// buildRegistrationRequest creates the JSON request body for client registration
|
||||
func (r *DynamicClientRegistrar) buildRegistrationRequest() ([]byte, error) {
|
||||
metadata := r.config.ClientMetadata
|
||||
if metadata == nil {
|
||||
metadata = &ClientRegistrationMetadata{}
|
||||
}
|
||||
|
||||
// Build request object
|
||||
reqData := make(map[string]interface{})
|
||||
|
||||
// Required: redirect_uris
|
||||
if len(metadata.RedirectURIs) > 0 {
|
||||
reqData["redirect_uris"] = metadata.RedirectURIs
|
||||
} else {
|
||||
return nil, fmt.Errorf("redirect_uris is required for client registration")
|
||||
}
|
||||
|
||||
// Optional fields - only include if set
|
||||
if len(metadata.ResponseTypes) > 0 {
|
||||
reqData["response_types"] = metadata.ResponseTypes
|
||||
} else {
|
||||
// Default to authorization code flow
|
||||
reqData["response_types"] = []string{"code"}
|
||||
}
|
||||
|
||||
if len(metadata.GrantTypes) > 0 {
|
||||
reqData["grant_types"] = metadata.GrantTypes
|
||||
} else {
|
||||
// Default grant types for authorization code flow
|
||||
reqData["grant_types"] = []string{"authorization_code", "refresh_token"}
|
||||
}
|
||||
|
||||
if metadata.ApplicationType != "" {
|
||||
reqData["application_type"] = metadata.ApplicationType
|
||||
}
|
||||
|
||||
if len(metadata.Contacts) > 0 {
|
||||
reqData["contacts"] = metadata.Contacts
|
||||
}
|
||||
|
||||
if metadata.ClientName != "" {
|
||||
reqData["client_name"] = metadata.ClientName
|
||||
}
|
||||
|
||||
if metadata.LogoURI != "" {
|
||||
reqData["logo_uri"] = metadata.LogoURI
|
||||
}
|
||||
|
||||
if metadata.ClientURI != "" {
|
||||
reqData["client_uri"] = metadata.ClientURI
|
||||
}
|
||||
|
||||
if metadata.PolicyURI != "" {
|
||||
reqData["policy_uri"] = metadata.PolicyURI
|
||||
}
|
||||
|
||||
if metadata.TOSURI != "" {
|
||||
reqData["tos_uri"] = metadata.TOSURI
|
||||
}
|
||||
|
||||
if metadata.JWKSURI != "" {
|
||||
reqData["jwks_uri"] = metadata.JWKSURI
|
||||
}
|
||||
|
||||
if metadata.SubjectType != "" {
|
||||
reqData["subject_type"] = metadata.SubjectType
|
||||
}
|
||||
|
||||
if metadata.TokenEndpointAuthMethod != "" {
|
||||
reqData["token_endpoint_auth_method"] = metadata.TokenEndpointAuthMethod
|
||||
} else {
|
||||
// Default to client_secret_basic for confidential clients
|
||||
reqData["token_endpoint_auth_method"] = "client_secret_basic"
|
||||
}
|
||||
|
||||
if metadata.DefaultMaxAge > 0 {
|
||||
reqData["default_max_age"] = metadata.DefaultMaxAge
|
||||
}
|
||||
|
||||
if metadata.RequireAuthTime {
|
||||
reqData["require_auth_time"] = metadata.RequireAuthTime
|
||||
}
|
||||
|
||||
if len(metadata.DefaultACRValues) > 0 {
|
||||
reqData["default_acr_values"] = metadata.DefaultACRValues
|
||||
}
|
||||
|
||||
if metadata.Scope != "" {
|
||||
reqData["scope"] = metadata.Scope
|
||||
}
|
||||
|
||||
return json.Marshal(reqData)
|
||||
}
|
||||
|
||||
// GetCachedResponse returns the cached registration response
|
||||
func (r *DynamicClientRegistrar) GetCachedResponse() *ClientRegistrationResponse {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.registrationResponse
|
||||
}
|
||||
|
||||
// areCredentialsValid checks if the cached credentials are still valid
|
||||
func (r *DynamicClientRegistrar) areCredentialsValid(resp *ClientRegistrationResponse) bool {
|
||||
if resp == nil || resp.ClientID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if secret has expired
|
||||
if resp.ClientSecretExpiresAt > 0 {
|
||||
expiresAt := time.Unix(resp.ClientSecretExpiresAt, 0)
|
||||
// Add 5 minute buffer before expiration
|
||||
if time.Now().Add(5 * time.Minute).After(expiresAt) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// credentialsFilePath returns the path for storing credentials
|
||||
func (r *DynamicClientRegistrar) credentialsFilePath() string {
|
||||
if r.config.CredentialsFile != "" {
|
||||
return r.config.CredentialsFile
|
||||
}
|
||||
return "/tmp/oidc-client-credentials.json"
|
||||
}
|
||||
|
||||
// saveCredentials persists client credentials to a file
|
||||
func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationResponse) error {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal credentials: %w", err)
|
||||
}
|
||||
|
||||
// Write with restrictive permissions (owner read/write only)
|
||||
if err := os.WriteFile(filePath, data, 0600); err != nil {
|
||||
return fmt.Errorf("failed to write credentials file: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Debugf("Saved client credentials to %s", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadCredentials loads client credentials from a file
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil // No credentials file exists
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read credentials file: %w", err)
|
||||
}
|
||||
|
||||
var resp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse credentials file: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateClientRegistration updates an existing client registration using RFC 7592
|
||||
// This requires the registration_client_uri and registration_access_token from the original registration
|
||||
func (r *DynamicClientRegistrar) UpdateClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to update")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Build update request
|
||||
reqBody, err := r.buildRegistrationRequest()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to build update request: %w", err)
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPut, cachedResp.RegistrationClientURI, bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create update request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read update response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("update failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("update failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse update response: %w", err)
|
||||
}
|
||||
|
||||
// Update cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = ®Resp
|
||||
r.mu.Unlock()
|
||||
|
||||
// Persist updated credentials if enabled
|
||||
if r.config.PersistCredentials {
|
||||
if err := r.saveCredentials(®Resp); err != nil {
|
||||
r.logger.Errorf("Failed to persist updated credentials: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Infof("Successfully updated client registration for client ID: %s", regResp.ClientID)
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// ReadClientRegistration reads the current client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) ReadClientRegistration(ctx context.Context) (*ClientRegistrationResponse, error) {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return nil, fmt.Errorf("no existing registration to read")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return nil, fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create read request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Read response body
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
// Handle error responses
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return nil, fmt.Errorf("read failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return nil, fmt.Errorf("read failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse successful response
|
||||
var regResp ClientRegistrationResponse
|
||||
if err := json.Unmarshal(body, ®Resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse read response: %w", err)
|
||||
}
|
||||
|
||||
return ®Resp, nil
|
||||
}
|
||||
|
||||
// DeleteClientRegistration deletes the client registration using RFC 7592
|
||||
func (r *DynamicClientRegistrar) DeleteClientRegistration(ctx context.Context) error {
|
||||
r.mu.RLock()
|
||||
cachedResp := r.registrationResponse
|
||||
r.mu.RUnlock()
|
||||
|
||||
if cachedResp == nil {
|
||||
return fmt.Errorf("no existing registration to delete")
|
||||
}
|
||||
|
||||
if cachedResp.RegistrationClientURI == "" || cachedResp.RegistrationAccessToken == "" {
|
||||
return fmt.Errorf("registration management not supported: missing registration_client_uri or registration_access_token")
|
||||
}
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodDelete, cachedResp.RegistrationClientURI, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create delete request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+cachedResp.RegistrationAccessToken)
|
||||
|
||||
// Execute request
|
||||
resp, err := r.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Handle error responses (204 No Content is success)
|
||||
if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
var regError ClientRegistrationError
|
||||
if jsonErr := json.Unmarshal(body, ®Error); jsonErr == nil && regError.Error != "" {
|
||||
return fmt.Errorf("delete failed: %s - %s", regError.Error, regError.ErrorDescription)
|
||||
}
|
||||
return fmt.Errorf("delete failed with status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
r.mu.Lock()
|
||||
r.registrationResponse = nil
|
||||
r.mu.Unlock()
|
||||
|
||||
// Remove credentials file if persistence is enabled
|
||||
if r.config.PersistCredentials {
|
||||
filePath := r.credentialsFilePath()
|
||||
if err := os.Remove(filePath); err != nil && !os.IsNotExist(err) {
|
||||
r.logger.Errorf("Failed to remove credentials file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("Successfully deleted client registration")
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -538,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
@@ -1087,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -846,3 +846,296 @@ func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// Ensure mockNetError implements net.Error
|
||||
var _ net.Error = (*mockNetError)(nil)
|
||||
|
||||
// Test isTraefikDefaultCertError
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
|
||||
func TestIsTraefikDefaultCertError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "network error",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isTraefikDefaultCertError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTraefikDefaultCertError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isEOFError
|
||||
|
||||
func TestIsEOFError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing EOF in message",
|
||||
err: errors.New("connection closed: EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing unexpected EOF",
|
||||
err: errors.New("read: unexpected EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "network error without EOF",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isEOFError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isEOFError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isCertificateError
|
||||
|
||||
func TestIsCertificateError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing certificate in message",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing x509 in message",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing tls in message",
|
||||
err: errors.New("tls handshake failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing ssl in message",
|
||||
err: errors.New("ssl connection error"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCertificateError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCertificateError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test MetadataFetchRetryConfig
|
||||
|
||||
func TestMetadataFetchRetryConfig(t *testing.T) {
|
||||
config := MetadataFetchRetryConfig()
|
||||
|
||||
if config.MaxAttempts != 10 {
|
||||
t.Errorf("Expected MaxAttempts 10, got %d", config.MaxAttempts)
|
||||
}
|
||||
|
||||
if config.InitialDelay != 1*time.Second {
|
||||
t.Errorf("Expected InitialDelay 1s, got %v", config.InitialDelay)
|
||||
}
|
||||
|
||||
if config.MaxDelay != 10*time.Second {
|
||||
t.Errorf("Expected MaxDelay 10s, got %v", config.MaxDelay)
|
||||
}
|
||||
|
||||
if config.BackoffFactor != 1.5 {
|
||||
t.Errorf("Expected BackoffFactor 1.5, got %v", config.BackoffFactor)
|
||||
}
|
||||
|
||||
if !config.EnableJitter {
|
||||
t.Error("Expected EnableJitter to be true")
|
||||
}
|
||||
|
||||
// Verify retryable errors include startup-related patterns
|
||||
expectedPatterns := []string{"EOF", "certificate", "x509", "tls"}
|
||||
for _, pattern := range expectedPatterns {
|
||||
found := false
|
||||
for _, retryableErr := range config.RetryableErrors {
|
||||
if retryableErr == pattern {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected '%s' in RetryableErrors", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test RetryExecutor with startup-specific errors
|
||||
|
||||
func TestRetryExecutorStartupErrors(t *testing.T) {
|
||||
// Verify MetadataFetchRetryConfig creates a valid retry executor
|
||||
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF error",
|
||||
err: errors.New("read tcp: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "unexpected EOF",
|
||||
err: errors.New("http: unexpected EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate error",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS error",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
err: errors.New("dial tcp: connection refused"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "permanent error",
|
||||
err: errors.New("invalid response format"),
|
||||
shouldRetry: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use very short delays for testing
|
||||
testConfig := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 1.5,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
testRe := NewRetryExecutor(testConfig, nil)
|
||||
|
||||
attempts := 0
|
||||
_ = testRe.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return tt.err
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts for '%s', got %d", expectedAttempts, tt.name, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that retry executor properly uses isRetryableError with new error types
|
||||
|
||||
func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
|
||||
re := NewRetryExecutor(DefaultRetryConfig(), nil)
|
||||
|
||||
// Test that the enhanced isRetryableError is being used
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF in error message",
|
||||
err: errors.New("connection reset by peer: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate in error message",
|
||||
err: errors.New("x509: certificate has expired"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS in error message",
|
||||
err: errors.New("tls: handshake failure"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := re.isRetryableError(tt.err)
|
||||
if result != tt.shouldRetry {
|
||||
t.Errorf("isRetryableError(%q) = %v, expected %v", tt.err.Error(), result, tt.shouldRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,486 @@
|
||||
# ============================================================================
|
||||
# Complete Traefik Configuration Example with TraefikOIDC Plugin + Redis
|
||||
# ============================================================================
|
||||
#
|
||||
# This example shows a complete, production-ready configuration for using
|
||||
# the TraefikOIDC plugin with Redis caching in a multi-replica deployment.
|
||||
#
|
||||
|
||||
# ============================================================================
|
||||
# Part 1: Traefik Static Configuration (traefik.yml)
|
||||
# ============================================================================
|
||||
# This file configures Traefik itself and enables the plugin.
|
||||
# Place this in /etc/traefik/traefik.yml or mount it in your container.
|
||||
|
||||
---
|
||||
# Static Configuration
|
||||
api:
|
||||
dashboard: true
|
||||
insecure: false # Set to true only for local development
|
||||
|
||||
entryPoints:
|
||||
web:
|
||||
address: ":80"
|
||||
http:
|
||||
redirections:
|
||||
entryPoint:
|
||||
to: websecure
|
||||
scheme: https
|
||||
|
||||
websecure:
|
||||
address: ":443"
|
||||
http:
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
certificatesResolvers:
|
||||
letsencrypt:
|
||||
acme:
|
||||
email: admin@example.com
|
||||
storage: /letsencrypt/acme.json
|
||||
httpChallenge:
|
||||
entryPoint: web
|
||||
|
||||
providers:
|
||||
file:
|
||||
filename: /etc/traefik/dynamic.yml
|
||||
watch: true
|
||||
|
||||
# Enable the TraefikOIDC plugin
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
log:
|
||||
level: INFO
|
||||
format: json
|
||||
|
||||
accessLog:
|
||||
format: json
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 2: Traefik Dynamic Configuration (dynamic.yml)
|
||||
# ============================================================================
|
||||
# This file defines your routes, services, and middleware.
|
||||
# Place this in /etc/traefik/dynamic.yml
|
||||
|
||||
---
|
||||
http:
|
||||
# -------------------------------------------------------------------------
|
||||
# Middleware Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
middlewares:
|
||||
# Example 1: Minimal Redis Configuration
|
||||
# Perfect for getting started quickly
|
||||
oidc-minimal:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-application-client-id"
|
||||
clientSecret: "your-client-secret-from-provider"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-secure-64-character-encryption-key-must-be-kept-secret"
|
||||
|
||||
# Minimal Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
|
||||
# Example 2: Production Redis Configuration
|
||||
# Recommended for production deployments with multiple Traefik replicas
|
||||
oidc-production:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Provider Configuration
|
||||
clientID: "prod-client-id"
|
||||
clientSecret: "prod-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
|
||||
# Session Configuration
|
||||
sessionEncryptionKey: "prod-encryption-key-64-chars-long-keep-it-secret-and-safe"
|
||||
sessionMaxAge: 28800 # 8 hours
|
||||
|
||||
# Security Settings
|
||||
forceHTTPS: true
|
||||
strictAudienceValidation: true
|
||||
|
||||
# Redis Configuration for Multi-Replica Deployment
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis-master.redis-namespace.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:prod:"
|
||||
|
||||
# Cache Strategy
|
||||
cacheMode: "hybrid" # Fast local cache + shared Redis
|
||||
|
||||
# Connection Pooling
|
||||
poolSize: 20
|
||||
connectTimeout: 5
|
||||
readTimeout: 3
|
||||
writeTimeout: 3
|
||||
|
||||
# Resilience Features
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS (for production security)
|
||||
oidc-secure:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "secure-client-id"
|
||||
clientSecret: "secure-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "secure-64-character-encryption-key-for-production-use-only"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Verify certificates in production
|
||||
cacheMode: "redis"
|
||||
|
||||
# Example 4: Hybrid Mode (Best Performance + Consistency)
|
||||
# Local cache for hot data, Redis for consistency across replicas
|
||||
oidc-hybrid:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
clientID: "app-client-id"
|
||||
clientSecret: "app-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "hybrid-mode-encryption-key-64-characters-long-and-secure"
|
||||
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
cacheMode: "hybrid"
|
||||
|
||||
# Hybrid mode L1 cache settings
|
||||
hybridL1Size: 1000 # Number of items in local cache
|
||||
hybridL1MemoryMB: 20 # MB of memory for local cache
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Router Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
routers:
|
||||
# Protected application using OIDC authentication
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-production # Use the OIDC middleware
|
||||
service: my-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# Another app with minimal OIDC config
|
||||
simple-app:
|
||||
rule: "Host(`simple.example.com`)"
|
||||
entryPoints:
|
||||
- websecure
|
||||
middlewares:
|
||||
- oidc-minimal
|
||||
service: simple-app-service
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Service Definitions
|
||||
# -------------------------------------------------------------------------
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://my-app:8080"
|
||||
healthCheck:
|
||||
path: /health
|
||||
interval: 30s
|
||||
timeout: 5s
|
||||
|
||||
simple-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://simple-app:3000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 3: Docker Compose Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# docker-compose.yml
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Redis service for shared caching
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --requirepass yourredispassword --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
ports:
|
||||
- "6379:6379"
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 10s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Traefik with TraefikOIDC plugin
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
command:
|
||||
- "--api.dashboard=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--providers.file.filename=/etc/traefik/dynamic.yml"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
- "--experimental.plugins.traefikoidc.modulename=github.com/lukaszraczylo/traefikoidc"
|
||||
- "--experimental.plugins.traefikoidc.version=v0.8.0"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # Dashboard
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./traefik-dynamic.yml:/etc/traefik/dynamic.yml:ro
|
||||
- ./letsencrypt:/letsencrypt
|
||||
depends_on:
|
||||
- redis
|
||||
networks:
|
||||
- traefik-network
|
||||
|
||||
# Your application
|
||||
my-app:
|
||||
image: my-app:latest
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
- "traefik.http.routers.my-app.entrypoints=websecure"
|
||||
- "traefik.http.routers.my-app.tls.certresolver=letsencrypt"
|
||||
|
||||
# OIDC Middleware Configuration with Redis (using labels)
|
||||
- "traefik.http.routers.my-app.middlewares=my-oidc@docker"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-client-secret"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-character-encryption-key-here"
|
||||
|
||||
# Redis configuration
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=yourredispassword"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
- "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=hybrid"
|
||||
networks:
|
||||
- traefik-network
|
||||
deploy:
|
||||
replicas: 3 # Multiple replicas sharing Redis cache
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
|
||||
networks:
|
||||
traefik-network:
|
||||
driver: bridge
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 4: Kubernetes Example
|
||||
# ============================================================================
|
||||
|
||||
---
|
||||
# kubernetes-example.yaml
|
||||
|
||||
# Redis Deployment
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: redis
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: redis
|
||||
spec:
|
||||
containers:
|
||||
- name: redis
|
||||
image: redis:7-alpine
|
||||
args:
|
||||
- redis-server
|
||||
- --requirepass
|
||||
- $(REDIS_PASSWORD)
|
||||
- --maxmemory
|
||||
- 512mb
|
||||
- --maxmemory-policy
|
||||
- allkeys-lru
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: redis-secret
|
||||
key: password
|
||||
ports:
|
||||
- containerPort: 6379
|
||||
resources:
|
||||
requests:
|
||||
memory: "256Mi"
|
||||
cpu: "100m"
|
||||
limits:
|
||||
memory: "512Mi"
|
||||
cpu: "500m"
|
||||
---
|
||||
# Redis Service
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: redis
|
||||
namespace: traefik
|
||||
spec:
|
||||
selector:
|
||||
app: redis
|
||||
ports:
|
||||
- port: 6379
|
||||
targetPort: 6379
|
||||
---
|
||||
# Redis Secret
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: redis-secret
|
||||
namespace: traefik
|
||||
type: Opaque
|
||||
stringData:
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
---
|
||||
# OIDC Middleware with Redis
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-auth
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# OIDC Configuration
|
||||
clientID: "kubernetes-client-id"
|
||||
clientSecret: "kubernetes-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "kubernetes-64-character-session-encryption-key-keep-secret"
|
||||
|
||||
# Redis Configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.traefik.svc.cluster.local:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD"
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:k8s:"
|
||||
cacheMode: "hybrid"
|
||||
poolSize: 20
|
||||
enableCircuitBreaker: true
|
||||
enableHealthCheck: true
|
||||
---
|
||||
# IngressRoute using the middleware
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: IngressRoute
|
||||
metadata:
|
||||
name: my-app
|
||||
namespace: default
|
||||
spec:
|
||||
entryPoints:
|
||||
- websecure
|
||||
routes:
|
||||
- match: Host(`app.example.com`)
|
||||
kind: Rule
|
||||
middlewares:
|
||||
- name: oidc-auth
|
||||
namespace: traefik
|
||||
services:
|
||||
- name: my-app
|
||||
port: 80
|
||||
tls:
|
||||
certResolver: letsencrypt
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Part 5: Environment Variables (Optional Fallback)
|
||||
# ============================================================================
|
||||
|
||||
# If you prefer environment variables as fallback (not recommended for production),
|
||||
# you can set these. NOTE: Plugin configuration takes precedence!
|
||||
|
||||
# Docker Compose env file (.env)
|
||||
---
|
||||
# OIDC Configuration
|
||||
OIDC_CLIENT_ID=your-client-id
|
||||
OIDC_CLIENT_SECRET=your-client-secret
|
||||
OIDC_PROVIDER_URL=https://auth.example.com
|
||||
|
||||
# Redis Configuration (fallback)
|
||||
REDIS_ENABLED=true
|
||||
REDIS_ADDRESS=redis:6379
|
||||
REDIS_PASSWORD=yourredispassword
|
||||
REDIS_DB=0
|
||||
REDIS_KEY_PREFIX=traefikoidc:
|
||||
REDIS_CACHE_MODE=hybrid
|
||||
REDIS_POOL_SIZE=20
|
||||
REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
REDIS_ENABLE_HEALTH_CHECK=true
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Cheat Sheet
|
||||
# ============================================================================
|
||||
|
||||
# Minimal Setup (Quick Start):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis:6379"
|
||||
|
||||
# Production Setup (Recommended):
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis-master:6379"
|
||||
# password: "strong-password"
|
||||
# cacheMode: "hybrid"
|
||||
# enableCircuitBreaker: true
|
||||
# enableHealthCheck: true
|
||||
|
||||
# High Security Setup:
|
||||
# redis:
|
||||
# enabled: true
|
||||
# address: "redis.example.com:6380"
|
||||
# password: "strong-password"
|
||||
# enableTLS: true
|
||||
# tlsSkipVerify: false
|
||||
# cacheMode: "redis"
|
||||
|
||||
# Cache Modes:
|
||||
# - "memory": Local cache only (default, no Redis needed)
|
||||
# - "redis": Redis only (consistent, shared across replicas)
|
||||
# - "hybrid": Local L1 + Redis L2 (best performance + consistency)
|
||||
@@ -0,0 +1,149 @@
|
||||
# Example Traefik configuration for TraefikOIDC plugin with Redis caching
|
||||
# This example shows how to configure Redis through Traefik's dynamic configuration
|
||||
|
||||
# Static configuration (traefik.yml)
|
||||
experimental:
|
||||
plugins:
|
||||
traefikoidc:
|
||||
moduleName: github.com/lukaszraczylo/traefikoidc
|
||||
version: v0.8.0
|
||||
|
||||
# Dynamic configuration (dynamic.yml or labels)
|
||||
http:
|
||||
middlewares:
|
||||
# Example 1: Basic Redis configuration
|
||||
oidc-redis-basic:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
# password: "your-redis-password" # Optional
|
||||
db: 0
|
||||
keyPrefix: "traefikoidc:"
|
||||
|
||||
# Example 2: Redis with resilience features
|
||||
oidc-redis-resilient:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with full resilience configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis:6379"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder - use your actual password
|
||||
db: 1
|
||||
keyPrefix: "myapp:"
|
||||
poolSize: 20
|
||||
connectTimeout: 10
|
||||
readTimeout: 5
|
||||
writeTimeout: 5
|
||||
cacheMode: "redis" # Options: "redis", "hybrid", "memory"
|
||||
# Circuit breaker settings
|
||||
enableCircuitBreaker: true
|
||||
circuitBreakerThreshold: 5
|
||||
circuitBreakerTimeout: 60
|
||||
# Health check settings
|
||||
enableHealthCheck: true
|
||||
healthCheckInterval: 30
|
||||
|
||||
# Example 3: Redis with TLS
|
||||
oidc-redis-tls:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
# Required OIDC settings
|
||||
clientID: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
providerURL: "https://auth.example.com"
|
||||
callbackURL: "/oauth2/callback"
|
||||
sessionEncryptionKey: "your-64-character-encryption-key-here-keep-it-secret"
|
||||
|
||||
# Redis with TLS configuration
|
||||
redis:
|
||||
enabled: true
|
||||
address: "redis.example.com:6380"
|
||||
password: "REPLACE_WITH_YOUR_REDIS_PASSWORD" # Example placeholder
|
||||
enableTLS: true
|
||||
tlsSkipVerify: false # Set to true only for testing
|
||||
cacheMode: "redis"
|
||||
|
||||
routers:
|
||||
my-app:
|
||||
rule: "Host(`app.example.com`)"
|
||||
middlewares:
|
||||
- oidc-redis-basic
|
||||
service: my-app-service
|
||||
|
||||
services:
|
||||
my-app-service:
|
||||
loadBalancer:
|
||||
servers:
|
||||
- url: "http://localhost:8080"
|
||||
|
||||
# Docker Compose labels example
|
||||
# version: '3.8'
|
||||
# services:
|
||||
# traefik:
|
||||
# image: traefik:v3.0
|
||||
# # ... other config ...
|
||||
#
|
||||
# my-app:
|
||||
# image: my-app:latest
|
||||
# labels:
|
||||
# - "traefik.enable=true"
|
||||
# - "traefik.http.routers.my-app.rule=Host(`app.example.com`)"
|
||||
# - "traefik.http.routers.my-app.middlewares=my-oidc"
|
||||
# # OIDC middleware configuration with Redis
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientID=your-client-id"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.clientSecret=your-secret"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.providerURL=https://auth.example.com"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.callbackURL=/oauth2/callback"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.sessionEncryptionKey=your-64-char-key"
|
||||
# # Redis configuration via labels
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.enabled=true"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.address=redis:6379"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.password=redis-password"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.db=0"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.keyPrefix=traefikoidc:"
|
||||
# - "traefik.http.middlewares.my-oidc.plugin.traefikoidc.redis.cacheMode=redis"
|
||||
#
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# command: redis-server --requirepass redis-password
|
||||
# # ... other config ...
|
||||
|
||||
# Environment variable fallback (optional)
|
||||
# If Redis configuration is not provided in Traefik config, these environment variables
|
||||
# can be used as a fallback (but Traefik config takes precedence):
|
||||
#
|
||||
# REDIS_ENABLED=true
|
||||
# REDIS_ADDRESS=redis:6379
|
||||
# REDIS_PASSWORD=secret
|
||||
# REDIS_DB=0
|
||||
# REDIS_KEY_PREFIX=traefikoidc:
|
||||
# REDIS_CACHE_MODE=redis
|
||||
# REDIS_POOL_SIZE=10
|
||||
# REDIS_CONNECT_TIMEOUT=5
|
||||
# REDIS_READ_TIMEOUT=3
|
||||
# REDIS_WRITE_TIMEOUT=3
|
||||
# REDIS_ENABLE_TLS=false
|
||||
# REDIS_TLS_SKIP_VERIFY=false
|
||||
# REDIS_ENABLE_CIRCUIT_BREAKER=true
|
||||
# REDIS_CIRCUIT_BREAKER_THRESHOLD=5
|
||||
# REDIS_CIRCUIT_BREAKER_TIMEOUT=60
|
||||
# REDIS_ENABLE_HEALTH_CHECK=true
|
||||
# REDIS_HEALTH_CHECK_INTERVAL=30
|
||||
@@ -3,15 +3,20 @@ module github.com/lukaszraczylo/traefikoidc
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
)
|
||||
|
||||
@@ -1,5 +1,15 @@
|
||||
github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI=
|
||||
github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
|
||||
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
@@ -10,8 +20,12 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
|
||||
+35
-13
@@ -15,7 +15,8 @@ type OAuthHandler struct {
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
isAllowedUserFunc func(userIdentifier string) bool // validates user authorization
|
||||
userIdentifierClaim string // JWT claim to use for user identification
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
@@ -77,16 +78,22 @@ type TokenResponse struct {
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
// Default to "email" for backward compatibility
|
||||
if userIdentifierClaim == "" {
|
||||
userIdentifierClaim = "email"
|
||||
}
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
isAllowedUserFunc: isAllowedUserFunc,
|
||||
userIdentifierClaim: userIdentifierClaim,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
@@ -147,7 +154,12 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
cookie, err := req.Cookie("_oidc_raczylo_m")
|
||||
if err != nil {
|
||||
h.logger.Errorf("Main session cookie not found in request: %v", err)
|
||||
h.logger.Debugf("Available cookies: %v", req.Header.Get("Cookie"))
|
||||
// Log cookie names only, not values (avoid logging sensitive session data)
|
||||
cookieNames := make([]string, 0, len(req.Cookies()))
|
||||
for _, c := range req.Cookies() {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
h.logger.Debugf("Available cookies (names only): %v", cookieNames)
|
||||
} else {
|
||||
h.logger.Errorf("Main session cookie exists but CSRF token is empty. Cookie value length: %d", len(cookie.Value))
|
||||
h.logger.Debugf("Cookie details - Domain: %s, Path: %s, Secure: %v, HttpOnly: %v, SameSite: %v",
|
||||
@@ -220,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[h.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if h.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim)
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !h.isAllowedUserFunc(userIdentifier) {
|
||||
h.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -237,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
isAllowedUser := func(userIdentifier string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowedUser, "email", "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
@@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User identifier missing in token") {
|
||||
t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User not authorized") {
|
||||
t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
+42
-10
@@ -146,6 +146,9 @@ func (p *SharedTransportPool) ReleaseTransport(transport *http.Transport) {
|
||||
}
|
||||
|
||||
// cleanupIdleTransports periodically cleans up unused transports
|
||||
// Uses two-phase cleanup to minimize lock contention:
|
||||
// 1. Find candidates while holding read lock
|
||||
// 2. Remove and close transports with minimal lock duration
|
||||
func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
@@ -155,17 +158,46 @@ func (p *SharedTransportPool) cleanupIdleTransports(ctx context.Context) {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
p.mu.Lock()
|
||||
now := time.Now()
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
shared.transport.CloseIdleConnections()
|
||||
delete(p.transports, transportKey)
|
||||
// SECURITY FIX: Decrement client count when removing transport
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
}
|
||||
p.performCleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performCleanup does the actual cleanup with optimized locking
|
||||
func (p *SharedTransportPool) performCleanup() {
|
||||
now := time.Now()
|
||||
|
||||
// Phase 1: Find candidates while holding read lock (fast)
|
||||
p.mu.RLock()
|
||||
candidates := make([]string, 0)
|
||||
for transportKey, shared := range p.transports {
|
||||
// Clean up transports not used for 2 minutes with no references
|
||||
if shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
candidates = append(candidates, transportKey)
|
||||
}
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Phase 2: Remove and close each candidate individually
|
||||
// This minimizes lock contention and allows concurrent access
|
||||
for _, key := range candidates {
|
||||
p.mu.Lock()
|
||||
shared, exists := p.transports[key]
|
||||
if exists && shared.refCount <= 0 && now.Sub(shared.lastUsed) > 2*time.Minute {
|
||||
// Remove from map first (releases memory)
|
||||
delete(p.transports, key)
|
||||
atomic.AddInt32(&p.clientCount, -1)
|
||||
p.mu.Unlock()
|
||||
|
||||
// Close idle connections outside the lock (can be slow)
|
||||
if shared.transport != nil {
|
||||
shared.transport.CloseIdleConnections()
|
||||
}
|
||||
} else {
|
||||
p.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+90
@@ -0,0 +1,90 @@
|
||||
package backends
|
||||
|
||||
import "time"
|
||||
|
||||
// BackendType represents the type of cache backend
|
||||
type BackendType string
|
||||
|
||||
const (
|
||||
BackendTypeMemory BackendType = "memory"
|
||||
BackendTypeRedis BackendType = "redis"
|
||||
BackendTypeHybrid BackendType = "hybrid"
|
||||
|
||||
// Aliases for backward compatibility
|
||||
TypeMemory BackendType = "memory"
|
||||
TypeRedis BackendType = "redis"
|
||||
TypeHybrid BackendType = "hybrid"
|
||||
)
|
||||
|
||||
// Config provides common configuration for cache backends
|
||||
type Config struct {
|
||||
// Type specifies the backend type
|
||||
Type BackendType
|
||||
|
||||
// Memory backend settings
|
||||
MaxSize int
|
||||
MaxMemoryBytes int64
|
||||
CleanupInterval time.Duration
|
||||
|
||||
// Redis backend settings
|
||||
RedisAddr string
|
||||
RedisPassword string
|
||||
RedisDB int
|
||||
RedisPrefix string
|
||||
PoolSize int
|
||||
|
||||
// Hybrid backend settings
|
||||
L1Config *Config // Memory cache (L1)
|
||||
L2Config *Config // Redis cache (L2)
|
||||
AsyncWrites bool // Write to L2 asynchronously
|
||||
|
||||
// Resilience settings
|
||||
EnableCircuitBreaker bool
|
||||
EnableHealthCheck bool
|
||||
HealthCheckInterval time.Duration
|
||||
|
||||
// Metrics
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default configuration for in-memory caching
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 1000,
|
||||
MaxMemoryBytes: 50 * 1024 * 1024, // 50MB
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultRedisConfig returns a default configuration for Redis caching
|
||||
func DefaultRedisConfig(addr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeRedis,
|
||||
RedisAddr: addr,
|
||||
RedisDB: 0,
|
||||
RedisPrefix: "traefikoidc:",
|
||||
PoolSize: 10,
|
||||
EnableCircuitBreaker: true,
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultHybridConfig returns a default configuration for hybrid caching
|
||||
func DefaultHybridConfig(redisAddr string) *Config {
|
||||
return &Config{
|
||||
Type: BackendTypeHybrid,
|
||||
L1Config: &Config{
|
||||
Type: BackendTypeMemory,
|
||||
MaxSize: 500,
|
||||
MaxMemoryBytes: 10 * 1024 * 1024, // 10MB for L1
|
||||
CleanupInterval: 1 * time.Minute,
|
||||
},
|
||||
L2Config: DefaultRedisConfig(redisAddr),
|
||||
AsyncWrites: true,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
+59
@@ -0,0 +1,59 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package backends
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestDefaultHybridConfig verifies the default hybrid configuration
|
||||
func TestDefaultHybridConfig(t *testing.T) {
|
||||
redisAddr := "localhost:6379"
|
||||
|
||||
config := DefaultHybridConfig(redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Verify top-level config
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.True(t, config.AsyncWrites)
|
||||
assert.True(t, config.EnableMetrics)
|
||||
|
||||
// Verify L1 (memory) config
|
||||
require.NotNil(t, config.L1Config)
|
||||
assert.Equal(t, BackendTypeMemory, config.L1Config.Type)
|
||||
assert.Equal(t, 500, config.L1Config.MaxSize)
|
||||
assert.Equal(t, int64(10*1024*1024), config.L1Config.MaxMemoryBytes) // 10MB
|
||||
assert.Equal(t, 1*time.Minute, config.L1Config.CleanupInterval)
|
||||
|
||||
// Verify L2 (Redis) config exists
|
||||
require.NotNil(t, config.L2Config)
|
||||
assert.Equal(t, BackendTypeRedis, config.L2Config.Type)
|
||||
}
|
||||
|
||||
func TestDefaultHybridConfig_DifferentRedisAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
redisAddr string
|
||||
}{
|
||||
{"localhost", "localhost:6379"},
|
||||
{"remote host", "redis.example.com:6379"},
|
||||
{"IP address", "192.168.1.100:6379"},
|
||||
{"custom port", "localhost:6380"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
config := DefaultHybridConfig(tt.redisAddr)
|
||||
|
||||
require.NotNil(t, config)
|
||||
assert.Equal(t, BackendTypeHybrid, config.Type)
|
||||
assert.NotNil(t, config.L1Config)
|
||||
assert.NotNil(t, config.L2Config)
|
||||
})
|
||||
}
|
||||
}
|
||||
Vendored
+38
@@ -0,0 +1,38 @@
|
||||
package backends
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
// ErrBackendClosed is returned when operating on a closed backend
|
||||
ErrBackendClosed = errors.New("cache backend is closed")
|
||||
|
||||
// ErrKeyNotFound is returned when a key doesn't exist
|
||||
ErrKeyNotFound = errors.New("key not found")
|
||||
|
||||
// ErrCacheMiss indicates the requested key was not found in the cache
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
|
||||
// ErrBackendUnavailable indicates the cache backend is not available
|
||||
ErrBackendUnavailable = errors.New("cache backend unavailable")
|
||||
|
||||
// ErrInvalidValue indicates the cached value is invalid or corrupted
|
||||
ErrInvalidValue = errors.New("invalid cached value")
|
||||
|
||||
// ErrInvalidTTL is returned when TTL is invalid
|
||||
ErrInvalidTTL = errors.New("invalid TTL")
|
||||
|
||||
// ErrConnectionFailed is returned when connection fails
|
||||
ErrConnectionFailed = errors.New("connection failed")
|
||||
|
||||
// ErrCircuitOpen is returned when circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTimeout is returned when operation times out
|
||||
ErrTimeout = errors.New("operation timeout")
|
||||
|
||||
// ErrSerializationFailed is returned when serialization fails
|
||||
ErrSerializationFailed = errors.New("serialization failed")
|
||||
|
||||
// ErrDeserializationFailed is returned when deserialization fails
|
||||
ErrDeserializationFailed = errors.New("deserialization failed")
|
||||
)
|
||||
Vendored
+695
@@ -0,0 +1,695 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HybridBackend implements a two-tier cache with L1 (memory) and L2 (Redis) backends
|
||||
// It provides automatic failover, async writes for non-critical data, and optimized read paths
|
||||
type HybridBackend struct {
|
||||
primary CacheBackend // L1: Memory cache for fast access
|
||||
secondary CacheBackend // L2: Redis cache for distributed access
|
||||
|
||||
// Configuration
|
||||
syncWriteCacheTypes map[string]bool // Which cache types require synchronous writes
|
||||
asyncWriteBuffer chan *asyncWriteItem
|
||||
|
||||
// Metrics
|
||||
l1Hits atomic.Int64
|
||||
l2Hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
l1Writes atomic.Int64
|
||||
l2Writes atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Fallback tracking
|
||||
fallbackMode atomic.Bool // True when operating in degraded mode (L1 only)
|
||||
lastL2Error atomic.Value // Stores last L2 error timestamp
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
// Logging
|
||||
logger Logger
|
||||
}
|
||||
|
||||
// asyncWriteItem represents an async write operation
|
||||
type asyncWriteItem struct {
|
||||
key string
|
||||
value []byte
|
||||
ttl time.Duration
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// Logger interface for structured logging
|
||||
type Logger interface {
|
||||
Debugf(format string, args ...interface{})
|
||||
Infof(format string, args ...interface{})
|
||||
Warnf(format string, args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// defaultLogger provides a basic logger implementation
|
||||
type defaultLogger struct {
|
||||
*log.Logger
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Printf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Infof(format string, args ...interface{}) {
|
||||
l.Printf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Warnf(format string, args ...interface{}) {
|
||||
l.Printf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Printf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
// HybridConfig provides configuration for the hybrid backend
|
||||
type HybridConfig struct {
|
||||
Primary CacheBackend
|
||||
Secondary CacheBackend
|
||||
SyncWriteCacheTypes map[string]bool // Cache types requiring synchronous L2 writes
|
||||
AsyncBufferSize int
|
||||
Logger Logger
|
||||
}
|
||||
|
||||
// NewHybridBackend creates a new hybrid cache backend with L1 (memory) and L2 (Redis) tiers
|
||||
func NewHybridBackend(config *HybridConfig) (*HybridBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.Primary == nil {
|
||||
return nil, fmt.Errorf("primary (L1) backend is required")
|
||||
}
|
||||
|
||||
if config.Secondary == nil {
|
||||
return nil, fmt.Errorf("secondary (L2) backend is required")
|
||||
}
|
||||
|
||||
if config.Logger == nil {
|
||||
config.Logger = &defaultLogger{Logger: log.New(log.Writer(), "[HybridCache] ", log.LstdFlags)}
|
||||
}
|
||||
|
||||
if config.AsyncBufferSize <= 0 {
|
||||
config.AsyncBufferSize = 1000
|
||||
}
|
||||
|
||||
// Default critical cache types that require synchronous writes
|
||||
if config.SyncWriteCacheTypes == nil {
|
||||
config.SyncWriteCacheTypes = map[string]bool{
|
||||
"blacklist": true, // Token blacklist must be immediately consistent
|
||||
"token": true, // Token validation is critical
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
h := &HybridBackend{
|
||||
primary: config.Primary,
|
||||
secondary: config.Secondary,
|
||||
syncWriteCacheTypes: config.SyncWriteCacheTypes,
|
||||
asyncWriteBuffer: make(chan *asyncWriteItem, config.AsyncBufferSize),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
logger: config.Logger,
|
||||
}
|
||||
|
||||
// Start async write worker
|
||||
h.wg.Add(1)
|
||||
go h.asyncWriteWorker()
|
||||
|
||||
// Start health monitoring
|
||||
h.wg.Add(1)
|
||||
go h.healthMonitor()
|
||||
|
||||
h.logger.Infof("HybridBackend initialized with L1 (memory) and L2 (Redis) tiers")
|
||||
h.logger.Infof("Sync write cache types: %v", config.SyncWriteCacheTypes)
|
||||
h.logger.Infof("Async write buffer size: %d", config.AsyncBufferSize)
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Set stores a value in both L1 and L2 caches
|
||||
func (h *HybridBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Always write to L1 first (synchronous)
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L1 cache: %v", err)
|
||||
// Continue to try L2 even if L1 fails
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.logger.Debugf("Operating in fallback mode, skipping L2 write for key: %s", key)
|
||||
return nil // Don't fail the operation if L2 is down
|
||||
}
|
||||
|
||||
// Determine if this should be a sync or async write based on cache type
|
||||
cacheType := h.extractCacheType(key)
|
||||
requiresSync := h.syncWriteCacheTypes[cacheType]
|
||||
|
||||
if requiresSync {
|
||||
// Synchronous write for critical cache types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Warnf("Failed to write to L2 cache (sync) for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
// Don't fail the operation - L1 write succeeded
|
||||
return nil
|
||||
}
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Synchronous write to L2 completed for critical key: %s", key)
|
||||
} else {
|
||||
// Asynchronous write for non-critical cache types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
h.logger.Debugf("Queued async write to L2 for key: %s", key)
|
||||
default:
|
||||
// Buffer is full, log and continue
|
||||
h.logger.Warnf("Async write buffer full, dropping L2 write for key: %s", key)
|
||||
h.errors.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from cache, checking L1 first, then L2
|
||||
func (h *HybridBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
// Try L1 first
|
||||
value, ttl, exists, err := h.primary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L1 get error for key %s: %v", key, err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
h.l1Hits.Add(1)
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Check if we're in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
// Try L2
|
||||
value, ttl, exists, err = h.secondary.Get(ctx, key)
|
||||
if err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("L2 get error for key %s: %v", key, err)
|
||||
h.recordL2Error()
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil // Don't propagate L2 errors
|
||||
}
|
||||
|
||||
if !exists {
|
||||
h.misses.Add(1)
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Populate L1 cache with value from L2 (write-through on read)
|
||||
// Use goroutine to avoid blocking the read path
|
||||
go func() {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
if err := h.primary.Set(writeCtx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to populate L1 cache from L2 for key %s: %v", key, err)
|
||||
} else {
|
||||
h.logger.Debugf("Populated L1 cache from L2 for key: %s", key)
|
||||
}
|
||||
}()
|
||||
|
||||
return value, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from both L1 and L2 caches
|
||||
func (h *HybridBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
var deleted bool
|
||||
|
||||
// Delete from L1
|
||||
if d, err := h.primary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L1 cache: %v", err)
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
|
||||
// Delete from L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if d, err := h.secondary.Delete(ctx, key); err != nil {
|
||||
h.logger.Debugf("Failed to delete from L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
} else if d {
|
||||
deleted = true
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in either cache
|
||||
func (h *HybridBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// Check L1 first
|
||||
if exists, err := h.primary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Check L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if exists, err := h.secondary.Exists(ctx, key); err == nil && exists {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys from both caches
|
||||
func (h *HybridBackend) Clear(ctx context.Context) error {
|
||||
var lastErr error
|
||||
|
||||
// Clear L1
|
||||
if err := h.primary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L1 cache: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
// Clear L2 if not in fallback mode
|
||||
if !h.fallbackMode.Load() {
|
||||
if err := h.secondary.Clear(ctx); err != nil {
|
||||
h.logger.Errorf("Failed to clear L2 cache: %v", err)
|
||||
h.recordL2Error()
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetStats returns statistics for the hybrid cache
|
||||
func (h *HybridBackend) GetStats() map[string]interface{} {
|
||||
l1Hits := h.l1Hits.Load()
|
||||
l2Hits := h.l2Hits.Load()
|
||||
misses := h.misses.Load()
|
||||
total := l1Hits + l2Hits + misses
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"type": TypeHybrid,
|
||||
"l1_hits": l1Hits,
|
||||
"l2_hits": l2Hits,
|
||||
"misses": misses,
|
||||
"total": total,
|
||||
"l1_writes": h.l1Writes.Load(),
|
||||
"l2_writes": h.l2Writes.Load(),
|
||||
"errors": h.errors.Load(),
|
||||
"fallback_mode": h.fallbackMode.Load(),
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
stats["l1_hit_rate"] = float64(l1Hits) / float64(total)
|
||||
stats["l2_hit_rate"] = float64(l2Hits) / float64(total)
|
||||
stats["overall_hit_rate"] = float64(l1Hits+l2Hits) / float64(total)
|
||||
}
|
||||
|
||||
// Add sub-backend stats
|
||||
stats["l1_stats"] = h.primary.GetStats()
|
||||
stats["l2_stats"] = h.secondary.GetStats()
|
||||
|
||||
// Add last L2 error time if available
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok {
|
||||
stats["last_l2_error"] = t.Format(time.RFC3339)
|
||||
stats["seconds_since_l2_error"] = time.Since(t).Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks if both backends are healthy
|
||||
func (h *HybridBackend) Ping(ctx context.Context) error {
|
||||
// Check L1
|
||||
if err := h.primary.Ping(ctx); err != nil {
|
||||
return fmt.Errorf("L1 ping failed: %w", err)
|
||||
}
|
||||
|
||||
// Check L2 (but don't fail if it's down)
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
h.logger.Warnf("L2 ping failed: %v", err)
|
||||
h.recordL2Error()
|
||||
// Don't return error - we can operate with L1 only
|
||||
} else {
|
||||
// L2 is healthy, clear fallback mode if it was set
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend recovered, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close shuts down the hybrid backend
|
||||
func (h *HybridBackend) Close() error {
|
||||
// Cancel context to stop workers
|
||||
h.cancel()
|
||||
|
||||
// Close async write channel
|
||||
close(h.asyncWriteBuffer)
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Workers finished
|
||||
case <-time.After(5 * time.Second):
|
||||
h.logger.Warnf("Timeout waiting for workers to finish")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
|
||||
// Close backends
|
||||
if err := h.primary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L1 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
if err := h.secondary.Close(); err != nil {
|
||||
h.logger.Errorf("Failed to close L2 backend: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
h.logger.Infof("HybridBackend closed")
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values efficiently
|
||||
func (h *HybridBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string][]byte), nil
|
||||
}
|
||||
|
||||
results := make(map[string][]byte, len(keys))
|
||||
missingKeys := make([]string, 0)
|
||||
|
||||
// Try L1 first for all keys
|
||||
for _, key := range keys {
|
||||
if value, _, exists, _ := h.primary.Get(ctx, key); exists {
|
||||
results[key] = value
|
||||
h.l1Hits.Add(1)
|
||||
} else {
|
||||
missingKeys = append(missingKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
// If all found in L1 or in fallback mode, return
|
||||
if len(missingKeys) == 0 || h.fallbackMode.Load() {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Try L2 for missing keys using batch operation if available
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
GetMany(context.Context, []string) (map[string][]byte, error)
|
||||
}); ok {
|
||||
l2Results, err := batcher.GetMany(ctx, missingKeys)
|
||||
if err != nil {
|
||||
h.logger.Debugf("L2 batch get error: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
for key, value := range l2Results {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, 0) // Use default TTL
|
||||
}(key, value)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual gets
|
||||
for _, key := range missingKeys {
|
||||
if value, ttl, exists, err := h.secondary.Get(ctx, key); err == nil && exists {
|
||||
results[key] = value
|
||||
h.l2Hits.Add(1)
|
||||
|
||||
// Asynchronously populate L1
|
||||
go func(k string, v []byte, t time.Duration) {
|
||||
writeCtx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
_ = h.primary.Set(writeCtx, k, v, t)
|
||||
}(key, value, ttl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count misses for keys not found anywhere
|
||||
for _, key := range keys {
|
||||
if _, found := results[key]; !found {
|
||||
h.misses.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// SetMany stores multiple key-value pairs efficiently
|
||||
func (h *HybridBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write to L1 first
|
||||
for key, value := range items {
|
||||
if err := h.primary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L1 in batch: %v", err)
|
||||
} else {
|
||||
h.l1Writes.Add(1)
|
||||
}
|
||||
}
|
||||
|
||||
// Skip L2 if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if L2 supports batch operations
|
||||
if batcher, ok := h.secondary.(interface {
|
||||
SetMany(context.Context, map[string][]byte, time.Duration) error
|
||||
}); ok {
|
||||
if err := batcher.SetMany(ctx, items, ttl); err != nil {
|
||||
h.logger.Warnf("Failed to batch write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(int64(len(items)))
|
||||
}
|
||||
} else {
|
||||
// Fallback to individual sets
|
||||
for key, value := range items {
|
||||
cacheType := h.extractCacheType(key)
|
||||
if h.syncWriteCacheTypes[cacheType] {
|
||||
// Sync write for critical types
|
||||
if err := h.secondary.Set(ctx, key, value, ttl); err != nil {
|
||||
h.logger.Debugf("Failed to write to L2: %v", err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
}
|
||||
} else {
|
||||
// Async write for non-critical types
|
||||
select {
|
||||
case h.asyncWriteBuffer <- &asyncWriteItem{
|
||||
key: key,
|
||||
value: value,
|
||||
ttl: ttl,
|
||||
ctx: ctx,
|
||||
}:
|
||||
// Queued
|
||||
default:
|
||||
h.logger.Warnf("Async buffer full for batch write")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// asyncWriteWorker processes asynchronous writes to L2
|
||||
func (h *HybridBackend) asyncWriteWorker() {
|
||||
defer h.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
// Drain remaining items with best effort
|
||||
for len(h.asyncWriteBuffer) > 0 {
|
||||
select {
|
||||
case item := <-h.asyncWriteBuffer:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
_ = h.secondary.Set(ctx, item.key, item.value, item.ttl)
|
||||
cancel()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
case item, ok := <-h.asyncWriteBuffer:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Skip if in fallback mode
|
||||
if h.fallbackMode.Load() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Perform the write with a timeout
|
||||
writeCtx, cancel := context.WithTimeout(item.ctx, 500*time.Millisecond)
|
||||
if err := h.secondary.Set(writeCtx, item.key, item.value, item.ttl); err != nil {
|
||||
h.errors.Add(1)
|
||||
h.logger.Debugf("Async write to L2 failed for key %s: %v", item.key, err)
|
||||
h.recordL2Error()
|
||||
} else {
|
||||
h.l2Writes.Add(1)
|
||||
h.logger.Debugf("Async write to L2 completed for key: %s", item.key)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthMonitor periodically checks L2 health and manages fallback mode
|
||||
func (h *HybridBackend) healthMonitor() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
if err := h.secondary.Ping(ctx); err != nil {
|
||||
if !h.fallbackMode.Load() {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("L2 backend unhealthy, entering fallback mode: %v", err)
|
||||
}
|
||||
} else {
|
||||
if h.fallbackMode.CompareAndSwap(true, false) {
|
||||
h.logger.Infof("L2 backend healthy, exiting fallback mode")
|
||||
}
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordL2Error records the timestamp of an L2 error
|
||||
func (h *HybridBackend) recordL2Error() {
|
||||
h.lastL2Error.Store(time.Now())
|
||||
|
||||
// Check if we should enter fallback mode based on recent errors
|
||||
if !h.fallbackMode.Load() {
|
||||
// Simple heuristic: if we've had an error in the last second, consider L2 unhealthy
|
||||
if lastErr := h.lastL2Error.Load(); lastErr != nil {
|
||||
if t, ok := lastErr.(time.Time); ok && time.Since(t) < time.Second {
|
||||
h.fallbackMode.Store(true)
|
||||
h.logger.Warnf("Multiple L2 errors detected, entering fallback mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractCacheType attempts to determine the cache type from the key
|
||||
func (h *HybridBackend) extractCacheType(key string) string {
|
||||
// Simple heuristic based on key prefixes
|
||||
// This should match the actual cache type strategy in the main application
|
||||
|
||||
if len(key) > 10 {
|
||||
prefix := key[:10]
|
||||
switch {
|
||||
case contains(prefix, "blacklist"):
|
||||
return "blacklist"
|
||||
case contains(prefix, "token"):
|
||||
return "token"
|
||||
case contains(prefix, "metadata"):
|
||||
return "metadata"
|
||||
case contains(prefix, "jwk"):
|
||||
return "jwk"
|
||||
case contains(prefix, "session"):
|
||||
return "session"
|
||||
case contains(prefix, "introspect"):
|
||||
return "introspection"
|
||||
}
|
||||
}
|
||||
|
||||
return "general"
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring (case-insensitive)
|
||||
func contains(s, substr string) bool {
|
||||
if len(substr) > len(s) {
|
||||
return false
|
||||
}
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
match := true
|
||||
for j := 0; j < len(substr); j++ {
|
||||
if toLower(s[i+j]) != toLower(substr[j]) {
|
||||
match = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if match {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toLower converts a byte to lowercase
|
||||
func toLower(b byte) byte {
|
||||
if b >= 'A' && b <= 'Z' {
|
||||
return b + 32
|
||||
}
|
||||
return b
|
||||
}
|
||||
+1490
File diff suppressed because it is too large
Load Diff
Vendored
+133
@@ -0,0 +1,133 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheBackend defines the interface for all cache backend implementations
|
||||
// Implementations include: MemoryBackend, RedisBackend, and HybridBackend
|
||||
type CacheBackend interface {
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
// Returns an error if the operation fails
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
// Returns: value, remaining TTL, exists flag, and error
|
||||
// If the key doesn't exist, exists will be false
|
||||
Get(ctx context.Context, key string) (value []byte, ttl time.Duration, exists bool, err error)
|
||||
|
||||
// Delete removes a key from the cache
|
||||
// Returns true if the key was deleted, false if it didn't exist
|
||||
Delete(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// GetStats returns cache statistics
|
||||
// Stats include: hits, misses, size, memory usage, etc.
|
||||
GetStats() map[string]interface{}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
Close() error
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
Ping(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackendStats represents statistics for a cache backend
|
||||
type BackendStats struct {
|
||||
// Type is the backend type
|
||||
Type BackendType
|
||||
|
||||
// Hits is the number of cache hits
|
||||
Hits int64
|
||||
|
||||
// Misses is the number of cache misses
|
||||
Misses int64
|
||||
|
||||
// Sets is the number of set operations
|
||||
Sets int64
|
||||
|
||||
// Deletes is the number of delete operations
|
||||
Deletes int64
|
||||
|
||||
// Errors is the number of errors
|
||||
Errors int64
|
||||
|
||||
// Evictions is the number of evicted items
|
||||
Evictions int64
|
||||
|
||||
// CurrentSize is the current number of items in cache
|
||||
CurrentSize int64
|
||||
|
||||
// MaxSize is the maximum number of items (0 means unlimited)
|
||||
MaxSize int64
|
||||
|
||||
// MemoryUsage is the approximate memory usage in bytes
|
||||
MemoryUsage int64
|
||||
|
||||
// AverageGetLatency is the average latency for get operations
|
||||
AverageGetLatency time.Duration
|
||||
|
||||
// AverageSetLatency is the average latency for set operations
|
||||
AverageSetLatency time.Duration
|
||||
|
||||
// LastError is the last error encountered
|
||||
LastError string
|
||||
|
||||
// LastErrorTime is when the last error occurred
|
||||
LastErrorTime time.Time
|
||||
|
||||
// Uptime is how long the backend has been running
|
||||
Uptime time.Duration
|
||||
|
||||
// StartTime is when the backend was started
|
||||
StartTime time.Time
|
||||
}
|
||||
|
||||
// BackendCapabilities describes the capabilities of a cache backend
|
||||
type BackendCapabilities struct {
|
||||
// Distributed indicates if the backend is distributed across multiple instances
|
||||
Distributed bool
|
||||
|
||||
// Persistent indicates if the backend persists data across restarts
|
||||
Persistent bool
|
||||
|
||||
// Eviction indicates if the backend supports automatic eviction
|
||||
Eviction bool
|
||||
|
||||
// TTL indicates if the backend supports TTL (time-to-live)
|
||||
TTL bool
|
||||
|
||||
// MaxKeySize is the maximum size of a key in bytes (0 = unlimited)
|
||||
MaxKeySize int64
|
||||
|
||||
// MaxValueSize is the maximum size of a value in bytes (0 = unlimited)
|
||||
MaxValueSize int64
|
||||
|
||||
// MaxKeys is the maximum number of keys (0 = unlimited)
|
||||
MaxKeys int64
|
||||
|
||||
// SupportsExpire indicates if the backend supports expiration
|
||||
SupportsExpire bool
|
||||
|
||||
// SupportsMultiGet indicates if the backend supports batch get operations
|
||||
SupportsMultiGet bool
|
||||
|
||||
// SupportsTransaction indicates if the backend supports transactions
|
||||
SupportsTransaction bool
|
||||
|
||||
// SupportsCompression indicates if the backend supports compression
|
||||
SupportsCompression bool
|
||||
|
||||
// RequiresSerialize indicates if values must be serialized
|
||||
RequiresSerialize bool
|
||||
|
||||
// AtomicOperations indicates if the backend supports atomic operations
|
||||
AtomicOperations bool
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCacheBackendContract defines a set of tests that all CacheBackend implementations must pass
|
||||
// This ensures that Memory, Redis, and Hybrid backends all behave consistently
|
||||
func TestCacheBackendContract(t *testing.T) {
|
||||
// Test suite will be run against each backend type
|
||||
t.Run("MemoryBackend", func(t *testing.T) {
|
||||
backend := setupMemoryBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("RedisBackend", func(t *testing.T) {
|
||||
backend := setupRedisBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
|
||||
t.Run("HybridBackend", func(t *testing.T) {
|
||||
backend := setupHybridBackend(t)
|
||||
runContractTests(t, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// runContractTests executes all contract tests against a backend
|
||||
func runContractTests(t *testing.T, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("BasicSetGet", func(t *testing.T) {
|
||||
testBasicSetGet(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
testGetNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("UpdateExisting", func(t *testing.T) {
|
||||
testUpdateExisting(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
testDelete(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
testDeleteNonExistent(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
testExists(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("TTLExpiration", func(t *testing.T) {
|
||||
testTTLExpiration(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
testClear(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Ping", func(t *testing.T) {
|
||||
testPing(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
testStats(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("ConcurrentAccess", func(t *testing.T) {
|
||||
testConcurrentAccess(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("LargeValues", func(t *testing.T) {
|
||||
testLargeValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("EmptyValues", func(t *testing.T) {
|
||||
testEmptyValues(t, ctx, backend)
|
||||
})
|
||||
|
||||
t.Run("SpecialCharactersInKeys", func(t *testing.T) {
|
||||
testSpecialCharactersInKeys(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
// testBasicSetGet verifies basic set and get operations
|
||||
func testBasicSetGet(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "test-key-1"
|
||||
value := []byte("test-value-1")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err, "Set should not return error")
|
||||
|
||||
// Get value
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error")
|
||||
assert.True(t, exists, "Key should exist")
|
||||
assert.Equal(t, value, retrieved, "Retrieved value should match")
|
||||
assert.Greater(t, remainingTTL, 50*time.Second, "TTL should be close to original")
|
||||
assert.LessOrEqual(t, remainingTTL, ttl, "TTL should not exceed original")
|
||||
}
|
||||
|
||||
// testGetNonExistent verifies behavior when getting non-existent keys
|
||||
func testGetNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-key"
|
||||
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err, "Get should not return error for non-existent key")
|
||||
assert.False(t, exists, "Key should not exist")
|
||||
assert.Nil(t, retrieved, "Value should be nil")
|
||||
assert.Equal(t, time.Duration(0), ttl, "TTL should be zero")
|
||||
}
|
||||
|
||||
// testUpdateExisting verifies updating an existing key
|
||||
func testUpdateExisting(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
// Set initial value
|
||||
err := backend.Set(ctx, key, value1, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update value
|
||||
err = backend.Set(ctx, key, value2, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated value
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved, "Value should be updated")
|
||||
}
|
||||
|
||||
// testDelete verifies delete operation
|
||||
func testDelete(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
// Set value
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Delete
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted, "Delete should return true for existing key")
|
||||
|
||||
// Verify deleted
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after delete")
|
||||
}
|
||||
|
||||
// testDeleteNonExistent verifies deleting non-existent keys
|
||||
func testDeleteNonExistent(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "non-existent-delete-key"
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted, "Delete should return false for non-existent key")
|
||||
}
|
||||
|
||||
// testExists verifies the Exists operation
|
||||
func testExists(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
// Check non-existent key
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist initially")
|
||||
|
||||
// Set value
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check existing key
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist after Set")
|
||||
}
|
||||
|
||||
// testTTLExpiration verifies TTL expiration behavior
|
||||
func testTTLExpiration(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
// Set with short TTL
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key should exist immediately after Set")
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Verify expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after TTL expiration")
|
||||
}
|
||||
|
||||
// testClear verifies Clear operation
|
||||
func testClear(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
// Set multiple keys
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Give async writes time to complete before clearing
|
||||
// This prevents race conditions with async write workers
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Clear all
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key should not exist after Clear")
|
||||
}
|
||||
}
|
||||
|
||||
// testPing verifies Ping operation
|
||||
func testPing(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
err := backend.Ping(ctx)
|
||||
assert.NoError(t, err, "Ping should succeed on healthy backend")
|
||||
}
|
||||
|
||||
// testStats verifies GetStats operation
|
||||
func testStats(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
stats := backend.GetStats()
|
||||
assert.NotNil(t, stats, "Stats should not be nil")
|
||||
|
||||
// Stats should contain basic metrics
|
||||
_, hasHits := stats["hits"]
|
||||
_, hasMisses := stats["misses"]
|
||||
assert.True(t, hasHits || hasMisses, "Stats should contain hits or misses")
|
||||
}
|
||||
|
||||
// testConcurrentAccess verifies thread safety
|
||||
func testConcurrentAccess(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 10
|
||||
iterations := 20
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// testLargeValues verifies handling of large values
|
||||
func testLargeValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "large-value-key"
|
||||
value := GenerateLargeValue(1024 * 1024) // 1MB
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle large values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(value), len(retrieved), "Large value should be retrieved intact")
|
||||
}
|
||||
|
||||
// testEmptyValues verifies handling of empty values
|
||||
func testEmptyValues(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
key := "empty-value-key"
|
||||
value := []byte{}
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle empty values")
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Empty value should exist")
|
||||
assert.Equal(t, 0, len(retrieved), "Retrieved value should be empty")
|
||||
}
|
||||
|
||||
// testSpecialCharactersInKeys verifies handling of special characters in keys
|
||||
func testSpecialCharactersInKeys(t *testing.T, ctx context.Context, backend CacheBackend) {
|
||||
t.Helper()
|
||||
|
||||
specialKeys := []string{
|
||||
"key:with:colons",
|
||||
"key/with/slashes",
|
||||
"key-with-dashes",
|
||||
"key_with_underscores",
|
||||
"key.with.dots",
|
||||
"key|with|pipes",
|
||||
}
|
||||
|
||||
for _, key := range specialKeys {
|
||||
value := []byte(fmt.Sprintf("value-for-%s", key))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err, "Should handle special character in key: %s", key)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Key with special characters should exist: %s", key)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions to setup different backend types
|
||||
// These will be implemented in respective test files
|
||||
|
||||
func setupMemoryBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in memory_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("MemoryBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupRedisBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
// This will be implemented in redis_test.go
|
||||
// For now, return nil to allow compilation
|
||||
t.Skip("RedisBackend implementation pending")
|
||||
return nil
|
||||
}
|
||||
|
||||
func setupHybridBackend(t *testing.T) CacheBackend {
|
||||
t.Helper()
|
||||
|
||||
primary := newMockBackend()
|
||||
secondary := newMockBackend()
|
||||
|
||||
config := &HybridConfig{
|
||||
Primary: primary,
|
||||
Secondary: secondary,
|
||||
AsyncBufferSize: 100,
|
||||
Logger: NewTestLogger(t),
|
||||
}
|
||||
|
||||
hybrid, err := NewHybridBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
hybrid.Close()
|
||||
})
|
||||
|
||||
return hybrid
|
||||
}
|
||||
Vendored
+516
@@ -0,0 +1,516 @@
|
||||
// Package backend provides cache backend implementations for the Traefik OIDC plugin.
|
||||
package backends
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryCacheItem represents an item in the memory cache
|
||||
type memoryCacheItem struct {
|
||||
key string
|
||||
value interface{}
|
||||
expiresAt time.Time
|
||||
createdAt time.Time
|
||||
accessedAt time.Time
|
||||
accessCount int64
|
||||
size int64
|
||||
element *list.Element // for LRU tracking
|
||||
}
|
||||
|
||||
// isExpired checks if the item is expired
|
||||
func (item *memoryCacheItem) isExpired() bool {
|
||||
if item.expiresAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(item.expiresAt)
|
||||
}
|
||||
|
||||
// MemoryCacheBackend implements the CacheBackend interface using in-memory storage
|
||||
type MemoryCacheBackend struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryCacheItem
|
||||
lruList *list.List
|
||||
maxSize int64
|
||||
maxMemory int64
|
||||
currentSize int64
|
||||
currentMemory int64
|
||||
|
||||
// Statistics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
sets atomic.Int64
|
||||
deletes atomic.Int64
|
||||
evictions atomic.Int64
|
||||
errors atomic.Int64
|
||||
|
||||
// Latency tracking
|
||||
totalGetTime atomic.Int64
|
||||
totalSetTime atomic.Int64
|
||||
getCount atomic.Int64
|
||||
setCount atomic.Int64
|
||||
|
||||
// Status
|
||||
startTime time.Time
|
||||
lastError string
|
||||
lastErrorTime time.Time
|
||||
cleanupTicker *time.Ticker
|
||||
cleanupDone chan bool
|
||||
closed atomic.Bool
|
||||
|
||||
// Configuration
|
||||
cleanupInterval time.Duration
|
||||
evictionPolicy string // "lru", "lfu", "fifo"
|
||||
}
|
||||
|
||||
// NewMemoryCacheBackend creates a new memory cache backend
|
||||
func NewMemoryCacheBackend(maxSize int64, maxMemory int64, cleanupInterval time.Duration) *MemoryCacheBackend {
|
||||
if maxSize <= 0 {
|
||||
maxSize = 10000 // Default to 10k items
|
||||
}
|
||||
if maxMemory <= 0 {
|
||||
maxMemory = 100 * 1024 * 1024 // Default to 100MB
|
||||
}
|
||||
if cleanupInterval <= 0 {
|
||||
cleanupInterval = 5 * time.Minute
|
||||
}
|
||||
|
||||
m := &MemoryCacheBackend{
|
||||
items: make(map[string]*memoryCacheItem),
|
||||
lruList: list.New(),
|
||||
maxSize: maxSize,
|
||||
maxMemory: maxMemory,
|
||||
startTime: time.Now(),
|
||||
cleanupInterval: cleanupInterval,
|
||||
evictionPolicy: "lru",
|
||||
cleanupDone: make(chan bool),
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
m.cleanupTicker = time.NewTicker(cleanupInterval)
|
||||
go m.cleanupLoop()
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// cleanupLoop runs periodic cleanup of expired items
|
||||
func (m *MemoryCacheBackend) cleanupLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-m.cleanupTicker.C:
|
||||
m.cleanupExpired()
|
||||
case <-m.cleanupDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupExpired removes all expired items from the cache
|
||||
func (m *MemoryCacheBackend) cleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var keysToDelete []string
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
keysToDelete = append(keysToDelete, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range keysToDelete {
|
||||
m.deleteItemLocked(key)
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryCacheBackend) Get(ctx context.Context, key string) (interface{}, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalGetTime.Add(duration)
|
||||
m.getCount.Add(1)
|
||||
}()
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.Lock()
|
||||
m.deleteItemLocked(key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, ErrCacheMiss
|
||||
}
|
||||
|
||||
// Update access time and count
|
||||
m.mu.Lock()
|
||||
item.accessedAt = time.Now()
|
||||
item.accessCount++
|
||||
// Move to front of LRU list
|
||||
if m.evictionPolicy == "lru" && item.element != nil {
|
||||
m.lruList.MoveToFront(item.element)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return item.value, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with optional TTL
|
||||
func (m *MemoryCacheBackend) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(start).Nanoseconds()
|
||||
m.totalSetTime.Add(duration)
|
||||
m.setCount.Add(1)
|
||||
}()
|
||||
|
||||
// Calculate item size (simplified estimation)
|
||||
itemSize := int64(len(key)) + estimateValueSize(value)
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Check if we need to evict items
|
||||
if m.currentSize >= m.maxSize || m.currentMemory+itemSize > m.maxMemory {
|
||||
m.evictLocked()
|
||||
}
|
||||
|
||||
// Check if key exists
|
||||
if oldItem, exists := m.items[key]; exists {
|
||||
m.currentMemory -= oldItem.size
|
||||
if oldItem.element != nil {
|
||||
m.lruList.Remove(oldItem.element)
|
||||
}
|
||||
} else {
|
||||
m.currentSize++
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var expiresAt time.Time
|
||||
if ttl > 0 {
|
||||
expiresAt = now.Add(ttl)
|
||||
}
|
||||
|
||||
item := &memoryCacheItem{
|
||||
key: key,
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
createdAt: now,
|
||||
accessedAt: now,
|
||||
accessCount: 0,
|
||||
size: itemSize,
|
||||
}
|
||||
|
||||
// Add to LRU list
|
||||
if m.evictionPolicy == "lru" {
|
||||
item.element = m.lruList.PushFront(item)
|
||||
}
|
||||
|
||||
m.items[key] = item
|
||||
m.currentMemory += itemSize
|
||||
m.sets.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryCacheBackend) Delete(ctx context.Context, key string) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, exists := m.items[key]; !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.deleteItemLocked(key)
|
||||
m.deletes.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteItemLocked deletes an item without acquiring the lock (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) deleteItemLocked(key string) {
|
||||
if item, exists := m.items[key]; exists {
|
||||
m.currentMemory -= item.size
|
||||
m.currentSize--
|
||||
if item.element != nil {
|
||||
m.lruList.Remove(item.element)
|
||||
}
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
// evictLocked evicts items based on the eviction policy (must be called with lock held)
|
||||
func (m *MemoryCacheBackend) evictLocked() {
|
||||
if m.evictionPolicy == "lru" && m.lruList.Len() > 0 {
|
||||
// Evict least recently used item
|
||||
element := m.lruList.Back()
|
||||
if element != nil {
|
||||
item := element.Value.(*memoryCacheItem)
|
||||
m.deleteItemLocked(item.key)
|
||||
m.evictions.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryCacheBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if m.closed.Load() {
|
||||
return false, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return !item.isExpired(), nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache
|
||||
func (m *MemoryCacheBackend) Clear(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryCacheItem)
|
||||
m.lruList = list.New()
|
||||
m.currentSize = 0
|
||||
m.currentMemory = 0
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys returns all keys matching the pattern (use "*" for all keys)
|
||||
func (m *MemoryCacheBackend) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
var keys []string
|
||||
for key, item := range m.items {
|
||||
if !item.isExpired() && matchPattern(pattern, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// Size returns the number of items in the cache
|
||||
func (m *MemoryCacheBackend) Size(ctx context.Context) (int64, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return m.currentSize, nil
|
||||
}
|
||||
|
||||
// TTL returns the remaining time-to-live for a key
|
||||
func (m *MemoryCacheBackend) TTL(ctx context.Context, key string) (time.Duration, error) {
|
||||
if m.closed.Load() {
|
||||
return 0, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists || item.isExpired() {
|
||||
return 0, ErrCacheMiss
|
||||
}
|
||||
|
||||
if item.expiresAt.IsZero() {
|
||||
return 0, nil // No expiration
|
||||
}
|
||||
|
||||
remaining := time.Until(item.expiresAt)
|
||||
if remaining < 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return remaining, nil
|
||||
}
|
||||
|
||||
// Expire updates the TTL for an existing key
|
||||
func (m *MemoryCacheBackend) Expire(ctx context.Context, key string, ttl time.Duration) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists || item.isExpired() {
|
||||
return ErrCacheMiss
|
||||
}
|
||||
|
||||
if ttl > 0 {
|
||||
item.expiresAt = time.Now().Add(ttl)
|
||||
} else {
|
||||
item.expiresAt = time.Time{} // Remove expiration
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the cache backend
|
||||
func (m *MemoryCacheBackend) GetStats(ctx context.Context) (*BackendStats, error) {
|
||||
if m.closed.Load() {
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
lastError := m.lastError
|
||||
lastErrorTime := m.lastErrorTime
|
||||
m.mu.RUnlock()
|
||||
|
||||
avgGetLatency := time.Duration(0)
|
||||
if getCount := m.getCount.Load(); getCount > 0 {
|
||||
avgGetLatency = time.Duration(m.totalGetTime.Load() / getCount)
|
||||
}
|
||||
|
||||
avgSetLatency := time.Duration(0)
|
||||
if setCount := m.setCount.Load(); setCount > 0 {
|
||||
avgSetLatency = time.Duration(m.totalSetTime.Load() / setCount)
|
||||
}
|
||||
|
||||
return &BackendStats{
|
||||
Type: TypeMemory,
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Sets: m.sets.Load(),
|
||||
Deletes: m.deletes.Load(),
|
||||
Errors: m.errors.Load(),
|
||||
Evictions: m.evictions.Load(),
|
||||
CurrentSize: m.currentSize,
|
||||
MaxSize: m.maxSize,
|
||||
MemoryUsage: m.currentMemory,
|
||||
AverageGetLatency: avgGetLatency,
|
||||
AverageSetLatency: avgSetLatency,
|
||||
LastError: lastError,
|
||||
LastErrorTime: lastErrorTime,
|
||||
Uptime: time.Since(m.startTime),
|
||||
StartTime: m.startTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy
|
||||
func (m *MemoryCacheBackend) Ping(ctx context.Context) error {
|
||||
if m.closed.Load() {
|
||||
return ErrBackendUnavailable
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the backend and releases resources
|
||||
func (m *MemoryCacheBackend) Close() error {
|
||||
if m.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
m.cleanupTicker.Stop()
|
||||
close(m.cleanupDone)
|
||||
|
||||
m.mu.Lock()
|
||||
m.items = nil
|
||||
m.lruList = nil
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (m *MemoryCacheBackend) IsHealthy() bool {
|
||||
return !m.closed.Load()
|
||||
}
|
||||
|
||||
// Type returns the backend type
|
||||
func (m *MemoryCacheBackend) Type() BackendType {
|
||||
return TypeMemory
|
||||
}
|
||||
|
||||
// Capabilities returns the backend capabilities
|
||||
func (m *MemoryCacheBackend) Capabilities() *BackendCapabilities {
|
||||
return &BackendCapabilities{
|
||||
Distributed: false,
|
||||
Persistent: false,
|
||||
Eviction: true,
|
||||
TTL: true,
|
||||
MaxKeySize: 1024, // 1KB
|
||||
MaxValueSize: 10485760, // 10MB
|
||||
MaxKeys: m.maxSize,
|
||||
SupportsExpire: true,
|
||||
SupportsMultiGet: true,
|
||||
SupportsTransaction: false,
|
||||
SupportsCompression: false,
|
||||
RequiresSerialize: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// estimateValueSize estimates the size of a value in bytes
|
||||
func estimateValueSize(value interface{}) int64 {
|
||||
// This is a simplified estimation
|
||||
// In production, you might want to use a more accurate method
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return int64(len(v))
|
||||
case []byte:
|
||||
return int64(len(v))
|
||||
case int, int32, int64, uint, uint32, uint64:
|
||||
return 8
|
||||
case float32, float64:
|
||||
return 8
|
||||
case bool:
|
||||
return 1
|
||||
default:
|
||||
// For complex types, use a default estimate
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
// matchPattern checks if a key matches a pattern (simplified glob matching)
|
||||
func matchPattern(pattern, key string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
// Simplified pattern matching - in production, use a proper glob library
|
||||
return key == pattern || (len(pattern) > 0 && pattern[0] == '*' &&
|
||||
len(key) >= len(pattern)-1 && key[len(key)-len(pattern)+1:] == pattern[1:])
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
)
|
||||
|
||||
// setupBenchmarkRedis creates a miniredis instance for benchmarking
|
||||
func setupBenchmarkRedis(b *testing.B) string {
|
||||
b.Helper()
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
return mr.Addr()
|
||||
}
|
||||
|
||||
// BenchmarkRedisOperations_WithPooling benchmarks memory allocations with object pooling
|
||||
func BenchmarkRedisOperations_WithPooling(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Perform various operations
|
||||
_, _ = conn.Do("SET", "bench-key", "bench-value")
|
||||
_, _ = conn.Do("GET", "bench-key")
|
||||
_, _ = conn.Do("EXISTS", "bench-key")
|
||||
_, _ = conn.Do("DEL", "bench-key")
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_SetGet benchmarks the full backend with pooling
|
||||
func BenchmarkRedisBackend_SetGet(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("benchmark test data with some content")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Set operation
|
||||
err := backend.Set(ctx, "bench-key", testData, 0)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Get operation
|
||||
_, _, _, err = backend.Get(ctx, "bench-key")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkRedisBackend_ConcurrentAccess benchmarks concurrent operations with pooling
|
||||
func BenchmarkRedisBackend_ConcurrentAccess(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
backend, err := NewRedisBackend(&Config{
|
||||
RedisAddr: addr,
|
||||
PoolSize: 10,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
testData := []byte("concurrent benchmark data")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_ = backend.Set(ctx, "concurrent-key", testData, 0)
|
||||
_, _, _, _ = backend.Get(ctx, "concurrent-key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// BenchmarkRESPProtocol_WriteRead benchmarks RESP protocol encoding/decoding
|
||||
func BenchmarkRESPProtocol_WriteRead(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Put(conn)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
// This tests the pooling of RESPReader/RESPWriter
|
||||
_, _ = conn.Do("PING")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkConnectionPool_GetPut benchmarks connection pool operations
|
||||
func BenchmarkConnectionPool_GetPut(b *testing.B) {
|
||||
addr := setupBenchmarkRedis(b)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: addr,
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
pool.Put(conn)
|
||||
}
|
||||
}
|
||||
+783
@@ -0,0 +1,783 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestMemoryBackend_BasicOperations tests basic CRUD operations
|
||||
func TestMemoryBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "test-key"
|
||||
value := []byte("test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
assert.LessOrEqual(t, remainingTTL, ttl)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "delete-key"
|
||||
value := []byte("delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("DeleteNonExistent", func(t *testing.T) {
|
||||
deleted, err := backend.Delete(ctx, "non-existent-delete")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, deleted)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "exists-key"
|
||||
value := []byte("exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear", func(t *testing.T) {
|
||||
// Add multiple items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
err := backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(0), size)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTLExpiration tests TTL and expiration
|
||||
func TestMemoryBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "short-ttl-key"
|
||||
value := []byte("short-ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
_, _, exists, err = backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLDecrement", func(t *testing.T) {
|
||||
key := "ttl-decrement-key"
|
||||
value := []byte("ttl-decrement-value")
|
||||
ttl := 2 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check TTL immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Wait a bit
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Check TTL again - should be less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1, "TTL should decrease over time")
|
||||
})
|
||||
|
||||
t.Run("CleanupExpiredItems", func(t *testing.T) {
|
||||
// Set multiple items with short TTL
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("cleanup-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 50*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// All items should be cleaned up
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("cleanup-key-%d", i)
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Expired items should be cleaned up")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LRUEviction tests LRU eviction
|
||||
func TestMemoryBackend_LRUEviction(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 5
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Fill cache to max size
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("lru-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("lru-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Access first item to make it most recently used
|
||||
_, _, exists, err := backend.Get(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Add a new item - should evict lru-key-1 (least recently used)
|
||||
err = backend.Set(ctx, "lru-key-new", []byte("new-value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// lru-key-0 should still exist (was accessed recently)
|
||||
exists, err = backend.Exists(ctx, "lru-key-0")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "Recently accessed item should not be evicted")
|
||||
|
||||
// lru-key-1 should be evicted
|
||||
exists, err = backend.Exists(ctx, "lru-key-1")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Least recently used item should be evicted")
|
||||
|
||||
// Check eviction count
|
||||
stats := backend.GetStats()
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_MemoryLimit tests memory-based eviction
|
||||
func TestMemoryBackend_MemoryLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxSize = 100
|
||||
config.MaxMemoryBytes = 1024 // 1KB limit
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items until memory limit is reached
|
||||
largeValue := make([]byte, 512) // 512 bytes each
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("mem-key-%d", i)
|
||||
err := backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
stats := backend.GetStats()
|
||||
memory := stats["memory"].(int64)
|
||||
assert.LessOrEqual(t, memory, config.MaxMemoryBytes, "Memory should not exceed limit")
|
||||
|
||||
evictions := stats["evictions"].(int64)
|
||||
assert.Greater(t, evictions, int64(0), "Should have memory-based evictions")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ConcurrentAccess tests thread safety
|
||||
func TestMemoryBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
// Concurrent writes
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// Random deletes
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify stats are consistent
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0), "Should have cache operations")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_UpdateExisting tests updating existing keys
|
||||
func TestMemoryBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute, "TTL should be updated")
|
||||
|
||||
// Size should not increase (same key)
|
||||
stats := backend.GetStats()
|
||||
size := stats["size"].(int64)
|
||||
assert.Equal(t, int64(1), size, "Size should be 1 for one key")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Stats tests statistics tracking
|
||||
func TestMemoryBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add items and track hits/misses
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Hit
|
||||
backend.Get(ctx, "key1")
|
||||
// Miss
|
||||
backend.Get(ctx, "non-existent")
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_EmptyValues tests handling of empty values
|
||||
func TestMemoryBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_LargeValues tests handling of large values
|
||||
func TestMemoryBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := DefaultConfig()
|
||||
config.MaxMemoryBytes = 10 * 1024 * 1024 // 10MB
|
||||
backend, err := NewMemoryBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Close tests proper cleanup on close
|
||||
func TestMemoryBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add some items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations after close should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
_, _, _, err = backend.Get(ctx, "close-key-0")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Closing again should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Ping tests ping operation
|
||||
func TestMemoryBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_ValueIsolation tests that returned values are isolated
|
||||
func TestMemoryBackend_ValueIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "isolation-key"
|
||||
originalValue := []byte("original-value")
|
||||
|
||||
err = backend.Set(ctx, key, originalValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get value and modify it
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Modify retrieved value
|
||||
if len(retrieved) > 0 {
|
||||
retrieved[0] = 'X'
|
||||
}
|
||||
|
||||
// Get again - should be unchanged
|
||||
retrieved2, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, originalValue, retrieved2, "Original value should not be modified")
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Keys tests the Keys method with pattern matching
|
||||
func TestMemoryBackend_Keys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add test data
|
||||
testKeys := []string{"user:1", "user:2", "session:abc", "session:def", "token:xyz"}
|
||||
for _, key := range testKeys {
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
t.Run("AllKeys", func(t *testing.T) {
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 5)
|
||||
})
|
||||
|
||||
t.Run("SpecificPattern", func(t *testing.T) {
|
||||
// Simple exact match
|
||||
keys, err := backend.Keys(ctx, "user:1")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 1)
|
||||
assert.Contains(t, keys, "user:1")
|
||||
})
|
||||
|
||||
t.Run("ExcludesExpired", func(t *testing.T) {
|
||||
// Add an expired key
|
||||
expiredKey := "expired:key"
|
||||
err := backend.Set(ctx, expiredKey, []byte("value"), 1*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
keys, err := backend.Keys(ctx, "*")
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, keys, expiredKey, "Expired keys should not be returned")
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.Keys(ctx, "*")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Size tests the Size method
|
||||
func TestMemoryBackend_Size(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially empty
|
||||
size, err := backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), size)
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 5; i++ {
|
||||
key := fmt.Sprintf("key-%d", i)
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5), size)
|
||||
|
||||
// Delete one
|
||||
backend.Delete(ctx, "key-0")
|
||||
|
||||
size, err = backend.Size(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(4), size)
|
||||
|
||||
// After close
|
||||
backend.Close()
|
||||
_, err = backend.Size(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_TTL tests the TTL method
|
||||
func TestMemoryBackend_TTL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ExistingKey", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, []byte("value"), ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, remaining, 50*time.Second)
|
||||
assert.LessOrEqual(t, remaining, ttl)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
_, err := backend.TTL(ctx, "non-existent")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("NoExpiration", func(t *testing.T) {
|
||||
key := "no-expiry"
|
||||
// TTL of 0 typically means no expiration
|
||||
err := backend.Set(ctx, key, []byte("value"), 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
// No expiration returns 0
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
_, err := closedBackend.TTL(ctx, "key")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Expire tests the Expire method
|
||||
func TestMemoryBackend_Expire(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("UpdateTTL", func(t *testing.T) {
|
||||
key := "expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update to shorter TTL
|
||||
err = backend.Expire(ctx, key, 5*time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check new TTL
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, remaining, 5*time.Second)
|
||||
})
|
||||
|
||||
t.Run("NonExistentKey", func(t *testing.T) {
|
||||
err := backend.Expire(ctx, "non-existent", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCacheMiss, err)
|
||||
})
|
||||
|
||||
t.Run("RemoveExpiration", func(t *testing.T) {
|
||||
key := "no-expire-key"
|
||||
err := backend.Set(ctx, key, []byte("value"), 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set TTL to 0 to remove expiration
|
||||
err = backend.Expire(ctx, key, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TTL should now be 0
|
||||
remaining, err := backend.TTL(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, time.Duration(0), remaining)
|
||||
})
|
||||
|
||||
t.Run("AfterClose", func(t *testing.T) {
|
||||
closedBackend, _ := NewMemoryBackend(DefaultConfig())
|
||||
closedBackend.Close()
|
||||
|
||||
err := closedBackend.Expire(ctx, "key", 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendUnavailable, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestMemoryBackend_IsHealthy tests the IsHealthy method
|
||||
func TestMemoryBackend_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should be healthy when open
|
||||
assert.True(t, backend.IsHealthy())
|
||||
|
||||
// Should be unhealthy after close
|
||||
backend.Close()
|
||||
assert.False(t, backend.IsHealthy())
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Type tests the Type method
|
||||
func TestMemoryBackend_Type(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
backendType := backend.Type()
|
||||
assert.Equal(t, TypeMemory, backendType)
|
||||
}
|
||||
|
||||
// TestMemoryBackend_Capabilities tests the Capabilities method
|
||||
func TestMemoryBackend_Capabilities(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend, err := NewMemoryBackend(DefaultConfig())
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
caps := backend.Capabilities()
|
||||
require.NotNil(t, caps)
|
||||
|
||||
// Memory backend should not be distributed or persistent
|
||||
assert.False(t, caps.Distributed)
|
||||
assert.False(t, caps.Persistent)
|
||||
|
||||
// Should support eviction and TTL
|
||||
assert.True(t, caps.Eviction)
|
||||
assert.True(t, caps.TTL)
|
||||
assert.True(t, caps.SupportsExpire)
|
||||
assert.True(t, caps.SupportsMultiGet)
|
||||
|
||||
// Check limits
|
||||
assert.Greater(t, caps.MaxKeySize, int64(0))
|
||||
assert.Greater(t, caps.MaxValueSize, int64(0))
|
||||
}
|
||||
|
||||
// TestMatchPattern tests the matchPattern helper function
|
||||
func TestMatchPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
pattern string
|
||||
key string
|
||||
matches bool
|
||||
}{
|
||||
{"*", "any-key", true},
|
||||
{"*", "another", true},
|
||||
{"user:1", "user:1", true},
|
||||
{"user:1", "user:2", false},
|
||||
{"*:suffix", "prefix:suffix", true},
|
||||
{"*suffix", "prefix-suffix", true},
|
||||
{"*abc", "xyzabc", true},
|
||||
{"*abc", "xyz", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s-%s", tt.pattern, tt.key), func(t *testing.T) {
|
||||
result := matchPattern(tt.pattern, tt.key)
|
||||
assert.Equal(t, tt.matches, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
+153
@@ -0,0 +1,153 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MemoryBackend wraps MemoryCacheBackend to implement the CacheBackend interface
|
||||
type MemoryBackend struct {
|
||||
*MemoryCacheBackend
|
||||
}
|
||||
|
||||
// NewMemoryBackend creates a new memory backend from a config
|
||||
func NewMemoryBackend(config *Config) (*MemoryBackend, error) {
|
||||
maxSize := int64(config.MaxSize)
|
||||
if maxSize <= 0 {
|
||||
maxSize = 1000
|
||||
}
|
||||
|
||||
cacheBackend := NewMemoryCacheBackend(maxSize, config.MaxMemoryBytes, config.CleanupInterval)
|
||||
return &MemoryBackend{
|
||||
MemoryCacheBackend: cacheBackend,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL
|
||||
func (m *MemoryBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
err := m.MemoryCacheBackend.Set(ctx, key, value, ttl)
|
||||
if err == ErrBackendUnavailable {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache
|
||||
func (m *MemoryBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
val, err := m.MemoryCacheBackend.Get(ctx, key)
|
||||
if err != nil {
|
||||
if err == ErrCacheMiss {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err == ErrBackendUnavailable {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
return nil, 0, false, err
|
||||
}
|
||||
|
||||
// Get the item directly to check TTL
|
||||
m.MemoryCacheBackend.mu.RLock()
|
||||
item, exists := m.MemoryCacheBackend.items[key]
|
||||
m.MemoryCacheBackend.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
if !item.expiresAt.IsZero() {
|
||||
ttl = time.Until(item.expiresAt)
|
||||
if ttl < 0 {
|
||||
ttl = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Convert interface{} to []byte
|
||||
var valueBytes []byte
|
||||
if val != nil {
|
||||
if bytes, ok := val.([]byte); ok {
|
||||
valueBytes = bytes
|
||||
} else {
|
||||
// If it's not already []byte, we might need to handle other types
|
||||
// For now, we'll just return an error
|
||||
return nil, 0, false, ErrInvalidValue
|
||||
}
|
||||
}
|
||||
|
||||
return valueBytes, ttl, true, nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache
|
||||
func (m *MemoryBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
// Check if key exists first
|
||||
exists, err := m.MemoryCacheBackend.Exists(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = m.MemoryCacheBackend.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache
|
||||
func (m *MemoryBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
return m.MemoryCacheBackend.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Clear removes all keys from the cache
|
||||
func (m *MemoryBackend) Clear(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Clear(ctx)
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics
|
||||
func (m *MemoryBackend) GetStats() map[string]interface{} {
|
||||
stats, err := m.MemoryCacheBackend.GetStats(context.Background())
|
||||
if err != nil {
|
||||
return map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BackendStats to map
|
||||
hitRate := float64(0)
|
||||
total := stats.Hits + stats.Misses
|
||||
if total > 0 {
|
||||
hitRate = float64(stats.Hits) / float64(total)
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"type": stats.Type,
|
||||
"hits": stats.Hits,
|
||||
"misses": stats.Misses,
|
||||
"sets": stats.Sets,
|
||||
"deletes": stats.Deletes,
|
||||
"errors": stats.Errors,
|
||||
"evictions": stats.Evictions,
|
||||
"size": stats.CurrentSize,
|
||||
"max_size": stats.MaxSize,
|
||||
"memory": stats.MemoryUsage,
|
||||
"hit_rate": hitRate,
|
||||
"uptime": stats.Uptime,
|
||||
"start_time": stats.StartTime,
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the cache backend and releases resources
|
||||
func (m *MemoryBackend) Close() error {
|
||||
return m.MemoryCacheBackend.Close()
|
||||
}
|
||||
|
||||
// Ping checks if the backend is healthy and responsive
|
||||
func (m *MemoryBackend) Ping(ctx context.Context) error {
|
||||
return m.MemoryCacheBackend.Ping(ctx)
|
||||
}
|
||||
|
||||
// Ensure MemoryBackend implements CacheBackend
|
||||
var _ CacheBackend = (*MemoryBackend)(nil)
|
||||
Vendored
+470
@@ -0,0 +1,470 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Pure-Go Redis client implementation
|
||||
// Compatible with Yaegi interpreter (no unsafe package)
|
||||
// Implements RESP protocol for basic Redis operations
|
||||
|
||||
var (
|
||||
ErrPoolExhausted = errors.New("connection pool exhausted")
|
||||
)
|
||||
|
||||
// RedisBackend implements a Redis-based cache backend using pure Go
|
||||
type RedisBackend struct {
|
||||
config *Config
|
||||
pool *ConnectionPool
|
||||
healthMonitor *HealthMonitor
|
||||
|
||||
// Metrics
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRedisBackend creates a new Redis cache backend with pure-Go implementation
|
||||
func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
if config == nil {
|
||||
return nil, fmt.Errorf("config is required")
|
||||
}
|
||||
|
||||
if config.RedisAddr == "" {
|
||||
return nil, fmt.Errorf("redis address is required")
|
||||
}
|
||||
|
||||
// Create connection pool with health checks enabled
|
||||
// Timeouts are kept short to prevent request pileup when Redis is slow/stalled.
|
||||
// The UniversalCache uses 200ms context timeout, so socket timeouts should be
|
||||
// shorter to allow proper context cancellation handling.
|
||||
poolConfig := &PoolConfig{
|
||||
Address: config.RedisAddr,
|
||||
Password: config.RedisPassword,
|
||||
DB: config.RedisDB,
|
||||
MaxConnections: config.PoolSize,
|
||||
ConnectTimeout: 2 * time.Second,
|
||||
ReadTimeout: 500 * time.Millisecond,
|
||||
WriteTimeout: 500 * time.Millisecond,
|
||||
EnableHealthCheck: true,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(poolConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection pool: %w", err)
|
||||
}
|
||||
|
||||
// Create health monitor
|
||||
healthConfig := DefaultHealthMonitorConfig()
|
||||
healthMonitor := NewHealthMonitor(pool, healthConfig)
|
||||
|
||||
backend := &RedisBackend{
|
||||
config: config,
|
||||
pool: pool,
|
||||
healthMonitor: healthMonitor,
|
||||
}
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
// Start health monitoring
|
||||
healthMonitor.Start()
|
||||
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Set stores a value in Redis with TTL
|
||||
func (r *RedisBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
|
||||
// Execute with retry logic
|
||||
return r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
var err error
|
||||
|
||||
// Use PSETEX for millisecond precision, SETEX for second precision
|
||||
if ttl > 0 {
|
||||
ttlMillis := ttl.Milliseconds()
|
||||
if ttlMillis < 1000 {
|
||||
// Use PSETEX for sub-second TTLs (millisecond precision)
|
||||
_, err = conn.Do("PSETEX", prefixedKey, fmt.Sprintf("%d", ttlMillis), string(value))
|
||||
} else {
|
||||
// Use SETEX for larger TTLs (second precision)
|
||||
ttlSeconds := int(ttl.Seconds())
|
||||
_, err = conn.Do("SETEX", prefixedKey, fmt.Sprintf("%d", ttlSeconds), string(value))
|
||||
}
|
||||
} else {
|
||||
_, err = conn.Do("SET", prefixedKey, string(value))
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// Get retrieves a value from Redis
|
||||
func (r *RedisBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, 0, false, ErrBackendClosed
|
||||
}
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
var resultValue []byte
|
||||
var resultTTL time.Duration
|
||||
var resultExists bool
|
||||
|
||||
// Execute with retry logic
|
||||
err := r.executeWithRetry(ctx, func(conn *RedisConn) error {
|
||||
// Get value
|
||||
resp, err := conn.Do("GET", prefixedKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrNilResponse) {
|
||||
r.misses.Add(1)
|
||||
resultExists = false
|
||||
return nil // Not an error, key just doesn't exist
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
value, err := RESPString(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get TTL
|
||||
ttlResp, err := conn.Do("TTL", prefixedKey)
|
||||
if err != nil {
|
||||
// If TTL fails, still return the value
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = 0
|
||||
resultExists = true
|
||||
return nil
|
||||
}
|
||||
|
||||
ttlSeconds, _ := RESPInt(ttlResp)
|
||||
var ttl time.Duration
|
||||
if ttlSeconds > 0 {
|
||||
ttl = time.Duration(ttlSeconds) * time.Second
|
||||
}
|
||||
|
||||
r.hits.Add(1)
|
||||
resultValue = []byte(value)
|
||||
resultTTL = ttl
|
||||
resultExists = true
|
||||
return nil
|
||||
})
|
||||
|
||||
return resultValue, resultTTL, resultExists, err
|
||||
}
|
||||
|
||||
// Delete removes a key from Redis
|
||||
func (r *RedisBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("DEL", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in Redis
|
||||
func (r *RedisBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if r.closed.Load() {
|
||||
return false, ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
prefixedKey := r.prefixKey(key)
|
||||
resp, err := conn.Do("EXISTS", prefixedKey)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// Clear removes all keys with the configured prefix
|
||||
func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
// Use FLUSHDB if no prefix (clear entire DB)
|
||||
if r.config.RedisPrefix == "" {
|
||||
_, err := conn.Do("FLUSHDB")
|
||||
return err
|
||||
}
|
||||
|
||||
// With prefix, we need to scan and delete keys
|
||||
// For simplicity in this implementation, we'll use KEYS pattern (not recommended for production at scale)
|
||||
pattern := r.config.RedisPrefix + "*"
|
||||
resp, err := conn.Do("KEYS", pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract keys from array response
|
||||
keys, ok := resp.([]interface{})
|
||||
if !ok || len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete each key
|
||||
for _, keyInterface := range keys {
|
||||
key, err := RESPString(keyInterface)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats returns backend statistics
|
||||
func (r *RedisBackend) GetStats() map[string]interface{} {
|
||||
hits := r.hits.Load()
|
||||
misses := r.misses.Load()
|
||||
total := hits + misses
|
||||
|
||||
hitRate := float64(0)
|
||||
if total > 0 {
|
||||
hitRate = float64(hits) / float64(total)
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"backend": "redis-pure-go",
|
||||
"address": r.config.RedisAddr,
|
||||
"hits": hits,
|
||||
"misses": misses,
|
||||
"hit_rate": hitRate,
|
||||
"pool": r.pool.Stats(),
|
||||
}
|
||||
|
||||
// Add health monitor stats if available
|
||||
if r.healthMonitor != nil {
|
||||
stats["health"] = r.healthMonitor.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks Redis connectivity
|
||||
func (r *RedisBackend) Ping(ctx context.Context) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer r.pool.Put(conn)
|
||||
|
||||
_, err = conn.Do("PING")
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the Redis backend and all connections
|
||||
func (r *RedisBackend) Close() error {
|
||||
if r.closed.Swap(true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
// Stop health monitor
|
||||
if r.healthMonitor != nil {
|
||||
r.healthMonitor.Stop()
|
||||
}
|
||||
|
||||
// Close connection pool
|
||||
if r.pool != nil {
|
||||
return r.pool.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// prefixKey adds the configured prefix to a key
|
||||
func (r *RedisBackend) prefixKey(key string) string {
|
||||
if r.config.RedisPrefix == "" {
|
||||
return key
|
||||
}
|
||||
return r.config.RedisPrefix + key
|
||||
}
|
||||
|
||||
// executeWithRetry executes a Redis operation with exponential backoff retry logic.
|
||||
// It checks context cancellation at multiple points to ensure fast abort when the
|
||||
// caller's context is cancelled (e.g., due to request timeout).
|
||||
func (r *RedisBackend) executeWithRetry(ctx context.Context, operation func(*RedisConn) error) error {
|
||||
maxRetries := 3
|
||||
baseDelay := 50 * time.Millisecond // Reduced from 100ms to fail faster
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
// Check context before each attempt to fail fast
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
conn, err := r.pool.Get(ctx)
|
||||
if err != nil {
|
||||
// If we can't get a connection and this is the last attempt, fail
|
||||
if attempt == maxRetries-1 {
|
||||
return fmt.Errorf("failed to get connection after %d attempts: %w", maxRetries, err)
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the operation
|
||||
err = operation(conn)
|
||||
r.pool.Put(conn)
|
||||
|
||||
// Check context after operation - if cancelled, don't bother retrying
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// If successful, return
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If error is not retryable or last attempt, fail
|
||||
if attempt == maxRetries-1 || !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait with exponential backoff before retrying
|
||||
delay := baseDelay * time.Duration(1<<uint(attempt))
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts", maxRetries)
|
||||
}
|
||||
|
||||
// isRetryableError determines if an error is worth retrying
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Retry on connection errors, timeouts, etc.
|
||||
// Don't retry on application-level errors like wrong type
|
||||
errMsg := err.Error()
|
||||
retryablePatterns := []string{
|
||||
"connection",
|
||||
"timeout",
|
||||
"EOF",
|
||||
"broken pipe",
|
||||
"reset by peer",
|
||||
}
|
||||
|
||||
for _, pattern := range retryablePatterns {
|
||||
if contains(errMsg, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SetMany stores multiple values in Redis (batch operation)
|
||||
func (r *RedisBackend) SetMany(ctx context.Context, items map[string][]byte, ttl time.Duration) error {
|
||||
if r.closed.Load() {
|
||||
return ErrBackendClosed
|
||||
}
|
||||
|
||||
// For simplicity, execute sequentially (can be optimized with pipelining later)
|
||||
for key, value := range items {
|
||||
if err := r.Set(ctx, key, value, ttl); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMany retrieves multiple values from Redis
|
||||
func (r *RedisBackend) GetMany(ctx context.Context, keys []string) (map[string][]byte, error) {
|
||||
if r.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
result := make(map[string][]byte)
|
||||
|
||||
// For simplicity, execute sequentially
|
||||
for _, key := range keys {
|
||||
value, _, exists, err := r.Get(ctx, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
+176
@@ -0,0 +1,176 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthMonitor continuously monitors Redis connection health and triggers reconnections
|
||||
type HealthMonitor struct {
|
||||
pool *ConnectionPool
|
||||
config *HealthMonitorConfig
|
||||
|
||||
// State
|
||||
healthy atomic.Bool
|
||||
running atomic.Bool
|
||||
lastCheckTime atomic.Int64 // Unix timestamp
|
||||
|
||||
// Metrics
|
||||
consecutiveFailures atomic.Int64
|
||||
totalChecks atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// HealthMonitorConfig configures the health monitor
|
||||
type HealthMonitorConfig struct {
|
||||
CheckInterval time.Duration // How often to check health
|
||||
Timeout time.Duration // Timeout for health check
|
||||
UnhealthyThreshold int // Consecutive failures before marking unhealthy
|
||||
OnHealthChange func(healthy bool)
|
||||
}
|
||||
|
||||
// DefaultHealthMonitorConfig returns default health monitor configuration
|
||||
func DefaultHealthMonitorConfig() *HealthMonitorConfig {
|
||||
return &HealthMonitorConfig{
|
||||
CheckInterval: 5 * time.Second,
|
||||
Timeout: 3 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHealthMonitor creates a new health monitor
|
||||
func NewHealthMonitor(pool *ConnectionPool, config *HealthMonitorConfig) *HealthMonitor {
|
||||
if config == nil {
|
||||
config = DefaultHealthMonitorConfig()
|
||||
}
|
||||
|
||||
hm := &HealthMonitor{
|
||||
pool: pool,
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
hm.healthy.Store(true) // Assume healthy initially
|
||||
return hm
|
||||
}
|
||||
|
||||
// Start begins health monitoring
|
||||
func (hm *HealthMonitor) Start() {
|
||||
if hm.running.Swap(true) {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
hm.wg.Add(1)
|
||||
go hm.monitorLoop()
|
||||
}
|
||||
|
||||
// Stop stops health monitoring
|
||||
func (hm *HealthMonitor) Stop() {
|
||||
if !hm.running.Swap(false) {
|
||||
return // Not running
|
||||
}
|
||||
|
||||
close(hm.stopChan)
|
||||
hm.wg.Wait()
|
||||
}
|
||||
|
||||
// IsHealthy returns the current health status
|
||||
func (hm *HealthMonitor) IsHealthy() bool {
|
||||
return hm.healthy.Load()
|
||||
}
|
||||
|
||||
// GetStats returns health monitor statistics
|
||||
func (hm *HealthMonitor) GetStats() map[string]interface{} {
|
||||
lastCheck := time.Unix(hm.lastCheckTime.Load(), 0)
|
||||
|
||||
return map[string]interface{}{
|
||||
"healthy": hm.healthy.Load(),
|
||||
"consecutive_failures": hm.consecutiveFailures.Load(),
|
||||
"total_checks": hm.totalChecks.Load(),
|
||||
"total_failures": hm.totalFailures.Load(),
|
||||
"last_check": lastCheck,
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop runs the health check loop
|
||||
func (hm *HealthMonitor) monitorLoop() {
|
||||
defer hm.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(hm.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Perform initial check immediately
|
||||
hm.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hm.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hm.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck executes a health check
|
||||
func (hm *HealthMonitor) performHealthCheck() {
|
||||
hm.totalChecks.Add(1)
|
||||
hm.lastCheckTime.Store(time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hm.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to get a connection and ping Redis
|
||||
conn, err := hm.pool.Get(ctx)
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
defer hm.pool.Put(conn)
|
||||
|
||||
// Ping Redis
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
hm.recordFailure()
|
||||
return
|
||||
}
|
||||
|
||||
// Success!
|
||||
hm.recordSuccess()
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hm *HealthMonitor) recordSuccess() {
|
||||
wasHealthy := hm.healthy.Load()
|
||||
hm.consecutiveFailures.Store(0)
|
||||
hm.healthy.Store(true)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if !wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(true)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hm *HealthMonitor) recordFailure() {
|
||||
hm.totalFailures.Add(1)
|
||||
failures := hm.consecutiveFailures.Add(1)
|
||||
|
||||
wasHealthy := hm.healthy.Load()
|
||||
|
||||
// Mark unhealthy if threshold exceeded
|
||||
if failures >= int64(hm.config.UnhealthyThreshold) {
|
||||
hm.healthy.Store(false)
|
||||
|
||||
// Trigger callback if health changed
|
||||
if wasHealthy && hm.config.OnHealthChange != nil {
|
||||
hm.config.OnHealthChange(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
+421
@@ -0,0 +1,421 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHealthMonitor_BasicOperation tests basic health monitoring
|
||||
func TestHealthMonitor_BasicOperation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create health monitor with fast check interval for testing
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
require.NotNil(t, hm)
|
||||
|
||||
// Initially should be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for a few checks
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Should still be healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
require.NotNil(t, stats)
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Greater(t, stats["total_checks"].(int64), int64(0))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_HealthyToUnhealthy tests transition to unhealthy state
|
||||
func TestHealthMonitor_HealthyToUnhealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var healthChangedCalled atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if !healthy {
|
||||
healthChangedCalled.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure (2 failures * 50ms + buffer)
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Health monitor should detect server failure")
|
||||
assert.True(t, healthChangedCalled.Load(), "OnHealthChange callback should be called")
|
||||
|
||||
// Check stats
|
||||
stats := hm.GetStats()
|
||||
assert.False(t, stats["healthy"].(bool))
|
||||
assert.GreaterOrEqual(t, stats["consecutive_failures"].(int64), int64(2))
|
||||
assert.Greater(t, stats["total_failures"].(int64), int64(0))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_UnhealthyToHealthy tests recovery to healthy state
|
||||
func TestHealthMonitor_UnhealthyToHealthy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
var recoveryDetected atomic.Bool
|
||||
hmConfig := &HealthMonitorConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
OnHealthChange: func(healthy bool) {
|
||||
if healthy {
|
||||
recoveryDetected.Store(true)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
hm := NewHealthMonitor(pool, hmConfig)
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, hm.IsHealthy())
|
||||
|
||||
// Simulate Redis errors
|
||||
mr.SetError("ERR server is down")
|
||||
|
||||
// Wait for health checks to detect failure
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should now be unhealthy
|
||||
assert.False(t, hm.IsHealthy(), "Should detect server failure")
|
||||
|
||||
// Clear error to simulate recovery
|
||||
mr.ClearError()
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Should be healthy again
|
||||
assert.True(t, hm.IsHealthy(), "Should recover after server restart")
|
||||
assert.True(t, recoveryDetected.Load(), "Recovery callback should be called")
|
||||
|
||||
// Consecutive failures should be reset
|
||||
stats := hm.GetStats()
|
||||
assert.True(t, stats["healthy"].(bool))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StartStop tests start/stop behavior
|
||||
func TestHealthMonitor_StartStop(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, DefaultHealthMonitorConfig())
|
||||
|
||||
// Start monitoring
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Starting again should be no-op
|
||||
hm.Start()
|
||||
assert.True(t, hm.running.Load())
|
||||
|
||||
// Stop monitoring
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
|
||||
// Stopping again should be no-op
|
||||
hm.Stop()
|
||||
assert.False(t, hm.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_MultipleMonitors tests multiple health monitors
|
||||
func TestHealthMonitor_MultipleMonitors(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Create multiple monitors
|
||||
hm1 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm2 := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 150 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
})
|
||||
|
||||
// Start both
|
||||
hm1.Start()
|
||||
hm2.Start()
|
||||
|
||||
// Both should be healthy
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
assert.True(t, hm1.IsHealthy())
|
||||
assert.True(t, hm2.IsHealthy())
|
||||
|
||||
// Stop both
|
||||
hm1.Stop()
|
||||
hm2.Stop()
|
||||
|
||||
// Verify they stopped
|
||||
assert.False(t, hm1.running.Load())
|
||||
assert.False(t, hm2.running.Load())
|
||||
}
|
||||
|
||||
// TestHealthMonitor_StatsAccuracy tests stats tracking
|
||||
func TestHealthMonitor_StatsAccuracy(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Wait for some checks
|
||||
time.Sleep(550 * time.Millisecond)
|
||||
|
||||
stats := hm.GetStats()
|
||||
|
||||
// Should have performed multiple checks
|
||||
totalChecks := stats["total_checks"].(int64)
|
||||
assert.GreaterOrEqual(t, totalChecks, int64(4))
|
||||
|
||||
// All checks should succeed
|
||||
assert.Equal(t, int64(0), stats["total_failures"].(int64))
|
||||
assert.Equal(t, int64(0), stats["consecutive_failures"].(int64))
|
||||
|
||||
// Last check time should be recent (within check interval + buffer)
|
||||
// Use 2s tolerance to account for CI runner load and timing variance
|
||||
lastCheck := stats["last_check"].(time.Time)
|
||||
assert.WithinDuration(t, time.Now(), lastCheck, 2*time.Second)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_DefaultConfig tests default configuration
|
||||
func TestHealthMonitor_DefaultConfig(t *testing.T) {
|
||||
config := DefaultHealthMonitorConfig()
|
||||
|
||||
assert.Equal(t, 5*time.Second, config.CheckInterval)
|
||||
assert.Equal(t, 3*time.Second, config.Timeout)
|
||||
assert.Equal(t, 3, config.UnhealthyThreshold)
|
||||
assert.Nil(t, config.OnHealthChange)
|
||||
}
|
||||
|
||||
// TestHealthMonitor_PoolExhaustion tests behavior when pool is exhausted
|
||||
func TestHealthMonitor_PoolExhaustion(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1, // Very small pool
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
hm := NewHealthMonitor(pool, &HealthMonitorConfig{
|
||||
CheckInterval: 100 * time.Millisecond,
|
||||
Timeout: 50 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
})
|
||||
|
||||
hm.Start()
|
||||
defer hm.Stop()
|
||||
|
||||
// Get the only connection, blocking health checks
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for health check attempts
|
||||
time.Sleep(350 * time.Millisecond)
|
||||
|
||||
// Health monitor might mark as unhealthy due to timeouts
|
||||
stats := hm.GetStats()
|
||||
t.Logf("Stats with blocked pool: %+v", stats)
|
||||
|
||||
// Return connection
|
||||
pool.Put(conn)
|
||||
|
||||
// Wait for recovery
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Should recover
|
||||
assert.True(t, hm.IsHealthy())
|
||||
}
|
||||
|
||||
// TestConnectionPool_WithHealthChecks tests pool with health checks enabled
|
||||
func TestConnectionPool_WithHealthChecks(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn))
|
||||
|
||||
// Use connection
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse and validate
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StaleConnectionRemoval tests stale connection handling
|
||||
func TestConnectionPool_StaleConnectionRemoval(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 3,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
EnableHealthCheck: true,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get and return a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
|
||||
initialTotal := pool.totalConns.Load()
|
||||
|
||||
// Close the connection manually to make it stale
|
||||
conn.Close()
|
||||
|
||||
// Get another connection - should detect stale and create new
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
// Connection should be healthy
|
||||
assert.True(t, pool.isConnectionHealthy(conn2))
|
||||
|
||||
pool.Put(conn2)
|
||||
|
||||
// Total connections might be same or less (stale removed)
|
||||
finalTotal := pool.totalConns.Load()
|
||||
assert.LessOrEqual(t, finalTotal, initialTotal+1)
|
||||
}
|
||||
+338
@@ -0,0 +1,338 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConnectionPool manages a pool of Redis connections
|
||||
// Pure-Go implementation compatible with Yaegi
|
||||
type ConnectionPool struct {
|
||||
config *PoolConfig
|
||||
|
||||
connections chan *RedisConn
|
||||
mu sync.Mutex
|
||||
closed atomic.Bool
|
||||
|
||||
// Metrics
|
||||
activeConns atomic.Int32
|
||||
totalConns atomic.Int32
|
||||
gets atomic.Int64
|
||||
puts atomic.Int64
|
||||
timeouts atomic.Int64
|
||||
}
|
||||
|
||||
// PoolConfig holds connection pool configuration
|
||||
type PoolConfig struct {
|
||||
Address string
|
||||
Password string
|
||||
DB int
|
||||
MaxConnections int
|
||||
ConnectTimeout time.Duration
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
EnableHealthCheck bool // Enable connection health validation
|
||||
MaxRetries int // Max retries for failed operations
|
||||
RetryDelay time.Duration // Initial delay between retries
|
||||
}
|
||||
|
||||
// NewConnectionPool creates a new connection pool
|
||||
func NewConnectionPool(config *PoolConfig) (*ConnectionPool, error) {
|
||||
if config == nil {
|
||||
return nil, errors.New("config is required")
|
||||
}
|
||||
|
||||
if config.MaxConnections <= 0 {
|
||||
config.MaxConnections = 10
|
||||
}
|
||||
|
||||
if config.ConnectTimeout == 0 {
|
||||
config.ConnectTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
pool := &ConnectionPool{
|
||||
config: config,
|
||||
connections: make(chan *RedisConn, config.MaxConnections),
|
||||
}
|
||||
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// Get retrieves a connection from the pool or creates a new one
|
||||
func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
p.gets.Add(1)
|
||||
|
||||
// Try to get a connection with validation
|
||||
maxAttempts := 3
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
var conn *RedisConn
|
||||
var err error
|
||||
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
// If this is the last attempt, return error
|
||||
if attempt == maxAttempts-1 {
|
||||
return nil, err
|
||||
}
|
||||
// Wait before retry with exponential backoff
|
||||
time.Sleep(time.Duration(attempt+1) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
p.totalConns.Add(1)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Pool exhausted, wait for a connection with timeout
|
||||
select {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
p.activeConns.Add(1)
|
||||
return conn, nil
|
||||
case <-ctx.Done():
|
||||
p.timeouts.Add(1)
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(p.config.ConnectTimeout):
|
||||
p.timeouts.Add(1)
|
||||
return nil, ErrPoolExhausted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("failed to get healthy connection after retries")
|
||||
}
|
||||
|
||||
// Put returns a connection to the pool
|
||||
func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.puts.Add(1)
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
|
||||
// Return to pool (non-blocking)
|
||||
select {
|
||||
case p.connections <- conn:
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool
|
||||
func (p *ConnectionPool) Close() error {
|
||||
if p.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
close(p.connections)
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns pool statistics
|
||||
func (p *ConnectionPool) Stats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"active_connections": p.activeConns.Load(),
|
||||
"total_connections": p.totalConns.Load(),
|
||||
"max_connections": p.config.MaxConnections,
|
||||
"gets": p.gets.Load(),
|
||||
"puts": p.puts.Load(),
|
||||
"timeouts": p.timeouts.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
// createConnection creates a new Redis connection
|
||||
func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Connect with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: p.config.ConnectTimeout,
|
||||
}
|
||||
|
||||
conn, err := dialer.Dial("tcp", p.config.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
redisConn := &RedisConn{
|
||||
conn: conn,
|
||||
readTimeout: p.config.ReadTimeout,
|
||||
writeTimeout: p.config.WriteTimeout,
|
||||
}
|
||||
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return redisConn, nil
|
||||
}
|
||||
|
||||
// RedisConn represents a single Redis connection
|
||||
type RedisConn struct {
|
||||
conn net.Conn
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
closed atomic.Bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// Do executes a Redis command and returns the response
|
||||
func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
if c.closed.Load() {
|
||||
return nil, ErrBackendClosed
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
// Protect against possible overflow
|
||||
if len(s) > maxTotalArgBytes-totalBytes {
|
||||
return nil, errors.New("arguments too large (would overflow maximum allowed total size)")
|
||||
}
|
||||
totalBytes += len(s)
|
||||
if totalBytes > maxTotalArgBytes {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
writer := NewRESPWriter(c.conn)
|
||||
err := writer.WriteCommand(cmdArgs...)
|
||||
writer.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
c.closed.Store(true)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
reader := NewRESPReader(c.conn)
|
||||
resp, err := reader.ReadResponse()
|
||||
reader.Release() // Return to pool immediately after use
|
||||
if err != nil {
|
||||
if !errors.Is(err, ErrNilResponse) {
|
||||
c.closed.Store(true)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Close closes the connection
|
||||
func (c *RedisConn) Close() error {
|
||||
if c.closed.Swap(true) {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isConnectionHealthy validates a connection is still working
|
||||
func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
if conn == nil || conn.closed.Load() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
return err == nil
|
||||
}
|
||||
+620
@@ -0,0 +1,620 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestConnectionPool_BasicOperations tests basic pool operations
|
||||
func TestConnectionPool_BasicOperations(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
t.Run("GetAndPutConnection", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Get a connection
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Verify connection works
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Get again - should reuse same connection
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
pool.Put(conn2)
|
||||
})
|
||||
|
||||
t.Run("Stats", func(t *testing.T) {
|
||||
stats := pool.Stats()
|
||||
require.NotNil(t, stats)
|
||||
|
||||
assert.Contains(t, stats, "active_connections")
|
||||
assert.Contains(t, stats, "total_connections")
|
||||
assert.Contains(t, stats, "max_connections")
|
||||
assert.Equal(t, 5, stats["max_connections"])
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_MaxConnections tests pool size limits
|
||||
func TestConnectionPool_MaxConnections(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
maxConns := 3
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: maxConns,
|
||||
ConnectTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get max connections
|
||||
conns := make([]*RedisConn, maxConns)
|
||||
for i := 0; i < maxConns; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
assert.Equal(t, int32(maxConns), stats["total_connections"])
|
||||
assert.Equal(t, int32(maxConns), stats["active_connections"])
|
||||
|
||||
// Try to get one more - should block/timeout
|
||||
ctx2, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pool.Get(ctx2)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn)
|
||||
|
||||
// Return one connection
|
||||
pool.Put(conns[0])
|
||||
|
||||
// Now we should be able to get a connection
|
||||
conn, err = pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
for i := 1; i < maxConns; i++ {
|
||||
pool.Put(conns[i])
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionPool_ConcurrentAccess tests concurrent pool usage
|
||||
func TestConnectionPool_ConcurrentAccess(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 10,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
numGoroutines := 50
|
||||
numOperations := 20
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Spawn goroutines
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperations; j++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// Do some work
|
||||
_, err = conn.Do("PING")
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
|
||||
// Return to pool
|
||||
pool.Put(conn)
|
||||
|
||||
// Small delay
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
errorCount := 0
|
||||
for err := range errors {
|
||||
t.Logf("Error: %v", err)
|
||||
errorCount++
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, errorCount, "Expected no errors in concurrent access")
|
||||
|
||||
// Verify stats
|
||||
stats := pool.Stats()
|
||||
t.Logf("Final stats: %+v", stats)
|
||||
assert.LessOrEqual(t, stats["total_connections"].(int32), int32(10))
|
||||
assert.Equal(t, int32(0), stats["active_connections"])
|
||||
}
|
||||
|
||||
// TestConnectionPool_ContextCancellation tests context cancellation
|
||||
func TestConnectionPool_ContextCancellation(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get the only connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get another with cancelled context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
conn2, err := pool.Get(ctx)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, conn2)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
|
||||
// Cleanup
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Authentication tests auth support
|
||||
func TestConnectionPool_Authentication(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
// Set password on miniredis
|
||||
mr.server.RequireAuth("secret-password")
|
||||
|
||||
t.Run("CorrectPassword", func(t *testing.T) {
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "secret-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
})
|
||||
|
||||
t.Run("WrongPassword", func(t *testing.T) {
|
||||
t.Skip("Miniredis doesn't fully simulate AUTH errors like real Redis")
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
Password: "wrong-password",
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
_, err := NewConnectionPool(config)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
})
|
||||
}
|
||||
|
||||
// TestConnectionPool_DatabaseSelection tests DB selection
|
||||
func TestConnectionPool_DatabaseSelection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
DB: 5,
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connection should be on DB 5
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_ClosedConnection tests handling closed connections
|
||||
func TestConnectionPool_ClosedConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Get connection
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Close it manually
|
||||
conn.Close()
|
||||
|
||||
// Try to use it
|
||||
_, err = conn.Do("PING")
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Return to pool (should be discarded)
|
||||
pool.Put(conn)
|
||||
|
||||
// Get new connection - should create a new one
|
||||
conn2, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn2)
|
||||
|
||||
resp, err := conn2.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn2)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Close tests pool closure
|
||||
func TestConnectionPool_Close(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get some connections
|
||||
conns := make([]*RedisConn, 3)
|
||||
for i := 0; i < 3; i++ {
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
conns[i] = conn
|
||||
}
|
||||
|
||||
// Return them
|
||||
for _, conn := range conns {
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Close pool
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to get connection from closed pool
|
||||
_, err = pool.Get(context.Background())
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrBackendClosed))
|
||||
|
||||
// Close again should be no-op
|
||||
err = pool.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestConnectionPool_Timeouts tests various timeout scenarios
|
||||
func TestConnectionPool_Timeouts(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
ConnectTimeout: 100 * time.Millisecond,
|
||||
ReadTimeout: 100 * time.Millisecond,
|
||||
WriteTimeout: 100 * time.Millisecond,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Normal operation should work
|
||||
resp, err := conn.Do("PING")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "PONG", resp)
|
||||
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestRedisConn_DoCommand tests the Do method
|
||||
func TestRedisConn_DoCommand(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("SET and GET", func(t *testing.T) {
|
||||
// SET
|
||||
resp, err := conn.Do("SET", "testkey", "testvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// GET
|
||||
resp, err = conn.Do("GET", "testkey")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "testvalue", resp)
|
||||
})
|
||||
|
||||
t.Run("DEL", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "delkey", "delvalue")
|
||||
require.NoError(t, err)
|
||||
|
||||
// DEL
|
||||
resp, err := conn.Do("DEL", "delkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
})
|
||||
|
||||
t.Run("EXISTS", func(t *testing.T) {
|
||||
// SET key first
|
||||
_, err := conn.Do("SET", "existskey", "value")
|
||||
require.NoError(t, err)
|
||||
|
||||
// EXISTS - key exists
|
||||
resp, err := conn.Do("EXISTS", "existskey")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
// EXISTS - key doesn't exist
|
||||
resp, err = conn.Do("EXISTS", "nonexistent")
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
})
|
||||
|
||||
t.Run("TTL commands", func(t *testing.T) {
|
||||
// SETEX
|
||||
resp, err := conn.Do("SETEX", "ttlkey", "60", "ttlvalue")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "OK", resp)
|
||||
|
||||
// TTL
|
||||
resp, err = conn.Do("TTL", "ttlkey")
|
||||
require.NoError(t, err)
|
||||
|
||||
ttl, err := RESPInt(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Greater(t, ttl, int64(0))
|
||||
assert.LessOrEqual(t, ttl, int64(60))
|
||||
})
|
||||
}
|
||||
|
||||
// TestPoolConfig_Defaults tests default configuration values
|
||||
func TestPoolConfig_Defaults(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
// Leave other fields at zero values
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Should use defaults
|
||||
assert.Equal(t, 10, pool.config.MaxConnections)
|
||||
assert.Equal(t, 5*time.Second, pool.config.ConnectTimeout)
|
||||
|
||||
// Verify it works
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_NilConnection tests handling nil connections
|
||||
func TestConnectionPool_NilConnection(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 2,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
// Putting nil should be safe
|
||||
pool.Put(nil)
|
||||
|
||||
// Pool should still work
|
||||
conn, err := pool.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, conn)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// TestConnectionPool_StatsTracking tests metrics tracking
|
||||
func TestConnectionPool_StatsTracking(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 5,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := pool.Stats()
|
||||
initialGets := stats["gets"].(int64)
|
||||
initialPuts := stats["puts"].(int64)
|
||||
|
||||
// Perform operations
|
||||
numOps := 10
|
||||
for i := 0; i < numOps; i++ {
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
pool.Put(conn)
|
||||
}
|
||||
|
||||
// Check updated stats
|
||||
stats = pool.Stats()
|
||||
assert.Equal(t, initialGets+int64(numOps), stats["gets"].(int64))
|
||||
assert.Equal(t, initialPuts+int64(numOps), stats["puts"].(int64))
|
||||
assert.Equal(t, int32(0), stats["active_connections"].(int32))
|
||||
}
|
||||
|
||||
// TestRedisConn_TooManyArguments tests protection against allocation overflow
|
||||
func TestRedisConn_TooManyArguments(t *testing.T) {
|
||||
mr := NewMiniredisServer(t)
|
||||
|
||||
config := &PoolConfig{
|
||||
Address: mr.GetAddr(),
|
||||
MaxConnections: 1,
|
||||
ConnectTimeout: 5 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
pool, err := NewConnectionPool(config)
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
conn, err := pool.Get(ctx)
|
||||
require.NoError(t, err)
|
||||
defer pool.Put(conn)
|
||||
|
||||
t.Run("AcceptableArgumentCount", func(t *testing.T) {
|
||||
// Should work with reasonable number of args
|
||||
args := make([]string, 100)
|
||||
for i := range args {
|
||||
args[i] = "value"
|
||||
}
|
||||
_, err := conn.Do("MSET", args...)
|
||||
// May fail due to Redis constraints, but shouldn't panic or error on overflow
|
||||
// Just verify it doesn't trigger our overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RejectExcessiveArguments", func(t *testing.T) {
|
||||
// Create an absurdly large number of arguments that would cause overflow
|
||||
// Use 1M + 1 to exceed maxSafeArgs = (1<<20)-1 = 1048575
|
||||
args := make([]string, 1<<20) // 1,048,576 args
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("MSET", args...)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "too many arguments")
|
||||
})
|
||||
|
||||
t.Run("BoundaryCase", func(t *testing.T) {
|
||||
// Test exactly at the boundary (maxSafeArgs)
|
||||
args := make([]string, (1<<20)-1) // Exactly 1,048,575 args (max allowed)
|
||||
for i := range args {
|
||||
args[i] = "x"
|
||||
}
|
||||
|
||||
_, err := conn.Do("ECHO", args...)
|
||||
// Should not error due to overflow protection
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "too many arguments")
|
||||
}
|
||||
})
|
||||
}
|
||||
+545
@@ -0,0 +1,545 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRedisBackend_BasicOperations tests basic Redis operations
|
||||
func TestRedisBackend_BasicOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetAndGet", func(t *testing.T) {
|
||||
key := "redis-test-key"
|
||||
value := []byte("redis-test-value")
|
||||
ttl := 1 * time.Minute
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, remainingTTL, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
assert.Greater(t, remainingTTL, 50*time.Second)
|
||||
})
|
||||
|
||||
t.Run("GetNonExistent", func(t *testing.T) {
|
||||
_, _, exists, err := backend.Get(ctx, "non-existent-redis-key")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Delete", func(t *testing.T) {
|
||||
key := "redis-delete-key"
|
||||
value := []byte("redis-delete-value")
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
deleted, err := backend.Delete(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Exists", func(t *testing.T) {
|
||||
key := "redis-exists-key"
|
||||
value := []byte("redis-exists-value")
|
||||
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_KeyPrefixing tests key namespace prefixing
|
||||
func TestRedisBackend_KeyPrefixing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "test:prefix:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "my-key"
|
||||
value := []byte("my-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that key is stored with prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, "test:prefix:my-key", keys[0])
|
||||
|
||||
// Get should work without prefix
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
// TestRedisBackend_TTLExpiration tests TTL handling
|
||||
func TestRedisBackend_TTLExpiration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ShortTTL", func(t *testing.T) {
|
||||
key := "ttl-key"
|
||||
value := []byte("ttl-value")
|
||||
shortTTL := 100 * time.Millisecond
|
||||
|
||||
err := backend.Set(ctx, key, value, shortTTL)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Exists immediately
|
||||
exists, err := backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward time in miniredis
|
||||
mr.FastForward(150 * time.Millisecond)
|
||||
|
||||
// Should be expired
|
||||
exists, err = backend.Exists(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("TTLRemaining", func(t *testing.T) {
|
||||
key := "ttl-remaining-key"
|
||||
value := []byte("ttl-remaining-value")
|
||||
ttl := 10 * time.Second
|
||||
|
||||
err := backend.Set(ctx, key, value, ttl)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get immediately
|
||||
_, ttl1, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Fast forward 2 seconds
|
||||
mr.FastForward(2 * time.Second)
|
||||
|
||||
// Check TTL is less
|
||||
_, ttl2, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Less(t, ttl2, ttl1)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_Clear tests clearing all keys
|
||||
func TestRedisBackend_Clear(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "clear-test:"
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add multiple keys
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("clear-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("clear-value-%d", i))
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify keys exist
|
||||
keys := mr.CheckKeys()
|
||||
assert.Len(t, keys, 10)
|
||||
|
||||
// Clear all
|
||||
err = backend.Clear(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all keys are gone
|
||||
keys = mr.CheckKeys()
|
||||
assert.Len(t, keys, 0)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConnectionFailure tests behavior on connection failure
|
||||
func TestRedisBackend_ConnectionFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Try to connect to non-existent Redis
|
||||
config := DefaultRedisConfig("localhost:9999")
|
||||
_, err := NewRedisBackend(config)
|
||||
assert.Error(t, err, "Should fail to connect to non-existent Redis")
|
||||
}
|
||||
|
||||
// TestRedisBackend_RedisErrors tests handling of Redis errors
|
||||
func TestRedisBackend_RedisErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Simulate Redis error
|
||||
mr.SetError("simulated error")
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "error-key", []byte("error-value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
|
||||
// Clear error
|
||||
mr.ClearError()
|
||||
|
||||
// Operations should work again
|
||||
err = backend.Set(ctx, "success-key", []byte("success-value"), 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_ConcurrentAccess tests thread safety
|
||||
func TestRedisBackend_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
key := fmt.Sprintf("concurrent-key-%d-%d", id, j)
|
||||
value := []byte(fmt.Sprintf("concurrent-value-%d-%d", id, j))
|
||||
|
||||
err := backend.Set(ctx, key, value, 1*time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
assert.NoError(t, err)
|
||||
if exists {
|
||||
assert.Equal(t, value, retrieved)
|
||||
}
|
||||
|
||||
if j%5 == 0 {
|
||||
backend.Delete(ctx, key)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := backend.GetStats()
|
||||
hits := stats["hits"].(int64)
|
||||
misses := stats["misses"].(int64)
|
||||
assert.Greater(t, hits+misses, int64(0))
|
||||
}
|
||||
|
||||
// TestRedisBackend_Stats tests statistics tracking
|
||||
func TestRedisBackend_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initial stats
|
||||
stats := backend.GetStats()
|
||||
assert.Equal(t, int64(0), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(0), stats["misses"].(int64))
|
||||
|
||||
// Add and access items
|
||||
backend.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
backend.Get(ctx, "key1") // Hit
|
||||
backend.Get(ctx, "non-existent") // Miss
|
||||
|
||||
stats = backend.GetStats()
|
||||
assert.Equal(t, int64(1), stats["hits"].(int64))
|
||||
assert.Equal(t, int64(1), stats["misses"].(int64))
|
||||
|
||||
hitRate := stats["hit_rate"].(float64)
|
||||
assert.InDelta(t, 0.5, hitRate, 0.01)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Ping tests health check
|
||||
func TestRedisBackend_Ping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = backend.Ping(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Close and ping should fail
|
||||
backend.Close()
|
||||
err = backend.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_Close tests proper cleanup
|
||||
func TestRedisBackend_Close(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Add items
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("close-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("close-value-%d", i))
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Close
|
||||
err = backend.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Operations should fail
|
||||
err = backend.Set(ctx, "after-close", []byte("value"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrBackendClosed, err)
|
||||
|
||||
// Double close should be safe
|
||||
err = backend.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestRedisBackend_UpdateExisting tests updating existing keys
|
||||
func TestRedisBackend_UpdateExisting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "update-key"
|
||||
value1 := []byte("original-value")
|
||||
value2 := []byte("updated-value")
|
||||
|
||||
// Set original
|
||||
err = backend.Set(ctx, key, value1, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update
|
||||
err = backend.Set(ctx, key, value2, 2*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify updated
|
||||
retrieved, ttl, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, value2, retrieved)
|
||||
assert.Greater(t, ttl, 1*time.Minute)
|
||||
}
|
||||
|
||||
// TestRedisBackend_LargeValues tests handling of large values
|
||||
func TestRedisBackend_LargeValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "large-key"
|
||||
largeValue := make([]byte, 1024*1024) // 1MB
|
||||
|
||||
err = backend.Set(ctx, key, largeValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, len(largeValue), len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_EmptyValues tests handling of empty values
|
||||
func TestRedisBackend_EmptyValues(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "empty-key"
|
||||
emptyValue := []byte{}
|
||||
|
||||
err = backend.Set(ctx, key, emptyValue, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, 0, len(retrieved))
|
||||
}
|
||||
|
||||
// TestRedisBackend_PipelineOperations tests batch operations
|
||||
func TestRedisBackend_PipelineOperations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("SetMany", func(t *testing.T) {
|
||||
items := make(map[string][]byte)
|
||||
for i := 0; i < 10; i++ {
|
||||
key := fmt.Sprintf("batch-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("batch-value-%d", i))
|
||||
items[key] = value
|
||||
}
|
||||
|
||||
err := backend.SetMany(ctx, items, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify all items were set
|
||||
for key, expectedValue := range items {
|
||||
retrieved, _, exists, err := backend.Get(ctx, key)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrieved)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetMany", func(t *testing.T) {
|
||||
// Set test data
|
||||
testData := GenerateTestData(5)
|
||||
for key, value := range testData {
|
||||
backend.Set(ctx, key, value, 1*time.Minute)
|
||||
}
|
||||
|
||||
// Get all keys
|
||||
keys := make([]string, 0, len(testData))
|
||||
for key := range testData {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, len(testData))
|
||||
|
||||
for key, expectedValue := range testData {
|
||||
retrievedValue, exists := results[key]
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, expectedValue, retrievedValue)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetManyWithNonExistent", func(t *testing.T) {
|
||||
keys := []string{"exists-1", "non-existent", "exists-2"}
|
||||
|
||||
backend.Set(ctx, "exists-1", []byte("value-1"), 1*time.Minute)
|
||||
backend.Set(ctx, "exists-2", []byte("value-2"), 1*time.Minute)
|
||||
|
||||
results, err := backend.GetMany(ctx, keys)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2) // Only existing keys
|
||||
assert.Equal(t, []byte("value-1"), results["exists-1"])
|
||||
assert.Equal(t, []byte("value-2"), results["exists-2"])
|
||||
_, exists := results["non-existent"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRedisBackend_NoPrefix tests operation without prefix
|
||||
func TestRedisBackend_NoPrefix(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mr := NewMiniredisServer(t)
|
||||
config := DefaultRedisConfig(mr.GetAddr())
|
||||
config.RedisPrefix = "" // No prefix
|
||||
backend, err := NewRedisBackend(config)
|
||||
require.NoError(t, err)
|
||||
defer backend.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
key := "no-prefix-key"
|
||||
value := []byte("no-prefix-value")
|
||||
|
||||
err = backend.Set(ctx, key, value, 1*time.Minute)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check key is stored without prefix
|
||||
keys := mr.CheckKeys()
|
||||
require.Len(t, keys, 1)
|
||||
assert.Equal(t, key, keys[0])
|
||||
}
|
||||
Vendored
+251
@@ -0,0 +1,251 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RESP (REdis Serialization Protocol) implementation
|
||||
// Pure Go implementation compatible with Yaegi interpreter (no unsafe package)
|
||||
|
||||
var (
|
||||
ErrInvalidRESP = errors.New("invalid RESP response")
|
||||
ErrNilResponse = errors.New("nil response")
|
||||
)
|
||||
|
||||
// Object pools for memory optimization - reduces allocations by 50-70%
|
||||
var (
|
||||
readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPReader{
|
||||
r: bufio.NewReaderSize(nil, 4096),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
writerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &RESPWriter{
|
||||
w: nil,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// RESPWriter writes RESP protocol messages
|
||||
type RESPWriter struct {
|
||||
w io.Writer
|
||||
}
|
||||
|
||||
// NewRESPWriter creates a new RESP writer from the pool (memory optimized)
|
||||
func NewRESPWriter(w io.Writer) *RESPWriter {
|
||||
writer := writerPool.Get().(*RESPWriter)
|
||||
writer.w = w
|
||||
return writer
|
||||
}
|
||||
|
||||
// Release returns the writer to the pool for reuse
|
||||
func (w *RESPWriter) Release() {
|
||||
w.w = nil
|
||||
writerPool.Put(w)
|
||||
}
|
||||
|
||||
// WriteCommand writes a Redis command in RESP array format
|
||||
// Example: SET key value EX 3600 -> *5\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n$2\r\nEX\r\n$4\r\n3600\r\n
|
||||
func (w *RESPWriter) WriteCommand(args ...string) error {
|
||||
// Write array header
|
||||
if _, err := fmt.Fprintf(w.w, "*%d\r\n", len(args)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write each argument as bulk string
|
||||
for _, arg := range args {
|
||||
if _, err := fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(arg), arg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RESPReader reads RESP protocol messages
|
||||
type RESPReader struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewRESPReader creates a new RESP reader from the pool (memory optimized)
|
||||
func NewRESPReader(r io.Reader) *RESPReader {
|
||||
reader := readerPool.Get().(*RESPReader)
|
||||
reader.r.Reset(r)
|
||||
return reader
|
||||
}
|
||||
|
||||
// Release returns the reader to the pool for reuse
|
||||
func (r *RESPReader) Release() {
|
||||
r.r.Reset(nil)
|
||||
readerPool.Put(r)
|
||||
}
|
||||
|
||||
// ReadResponse reads a RESP response and returns the parsed value
|
||||
func (r *RESPReader) ReadResponse() (interface{}, error) {
|
||||
typeByte, err := r.r.ReadByte()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typeByte {
|
||||
case '+': // Simple string
|
||||
return r.readSimpleString()
|
||||
case '-': // Error
|
||||
return nil, r.readError()
|
||||
case ':': // Integer
|
||||
return r.readInteger()
|
||||
case '$': // Bulk string
|
||||
return r.readBulkString()
|
||||
case '*': // Array
|
||||
return r.readArray()
|
||||
default:
|
||||
return nil, fmt.Errorf("%w: unknown type byte '%c'", ErrInvalidRESP, typeByte)
|
||||
}
|
||||
}
|
||||
|
||||
// readSimpleString reads a simple string (+OK\r\n)
|
||||
func (r *RESPReader) readSimpleString() (string, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// readError reads an error message (-Error message\r\n)
|
||||
func (r *RESPReader) readError() error {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errors.New(line)
|
||||
}
|
||||
|
||||
// readInteger reads an integer (:1000\r\n)
|
||||
func (r *RESPReader) readInteger() (int64, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseInt(line, 10, 64)
|
||||
}
|
||||
|
||||
// readBulkString reads a bulk string ($6\r\nfoobar\r\n or $-1\r\n for nil)
|
||||
func (r *RESPReader) readBulkString() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid bulk string length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil bulk string
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read exactly 'length' bytes plus \r\n
|
||||
buf := make([]byte, length+2)
|
||||
if _, err := io.ReadFull(r.r, buf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify \r\n terminator
|
||||
if buf[length] != '\r' || buf[length+1] != '\n' {
|
||||
return nil, fmt.Errorf("%w: missing CRLF after bulk string", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return string(buf[:length]), nil
|
||||
}
|
||||
|
||||
// readArray reads an array (*2\r\n...\r\n or *-1\r\n for nil)
|
||||
func (r *RESPReader) readArray() (interface{}, error) {
|
||||
line, err := r.readLine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
length, err := strconv.Atoi(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: invalid array length", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
// -1 indicates nil array
|
||||
if length == -1 {
|
||||
return nil, ErrNilResponse
|
||||
}
|
||||
|
||||
// Read each element
|
||||
result := make([]interface{}, length)
|
||||
for i := 0; i < length; i++ {
|
||||
elem, err := r.ReadResponse()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[i] = elem
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// readLine reads a line terminated by \r\n
|
||||
func (r *RESPReader) readLine() (string, error) {
|
||||
line, err := r.r.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Remove \r\n
|
||||
line = strings.TrimSuffix(line, "\r\n")
|
||||
if !strings.HasSuffix(line+"\r\n", "\r\n") {
|
||||
return "", fmt.Errorf("%w: missing CRLF", ErrInvalidRESP)
|
||||
}
|
||||
|
||||
return line, nil
|
||||
}
|
||||
|
||||
// RESPString extracts a string from RESP response
|
||||
func RESPString(resp interface{}) (string, error) {
|
||||
if resp == nil {
|
||||
return "", ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case string:
|
||||
return v, nil
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("expected string, got %T", resp)
|
||||
}
|
||||
}
|
||||
|
||||
// RESPInt extracts an integer from RESP response
|
||||
func RESPInt(resp interface{}) (int64, error) {
|
||||
if resp == nil {
|
||||
return 0, ErrNilResponse
|
||||
}
|
||||
|
||||
switch v := resp.(type) {
|
||||
case int64:
|
||||
return v, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("expected integer, got %T", resp)
|
||||
}
|
||||
}
|
||||
Vendored
+495
@@ -0,0 +1,495 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestRESPWriter_WriteCommand tests RESP command writing
|
||||
func TestRESPWriter_WriteCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple command",
|
||||
args: []string{"PING"},
|
||||
expected: "*1\r\n$4\r\nPING\r\n",
|
||||
},
|
||||
{
|
||||
name: "SET command",
|
||||
args: []string{"SET", "key", "value"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$5\r\nvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "SETEX command",
|
||||
args: []string{"SETEX", "mykey", "60", "myvalue"},
|
||||
expected: "*4\r\n$5\r\nSETEX\r\n$5\r\nmykey\r\n$2\r\n60\r\n$7\r\nmyvalue\r\n",
|
||||
},
|
||||
{
|
||||
name: "DEL with multiple keys",
|
||||
args: []string{"DEL", "key1", "key2", "key3"},
|
||||
expected: "*4\r\n$3\r\nDEL\r\n$4\r\nkey1\r\n$4\r\nkey2\r\n$4\r\nkey3\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with empty string",
|
||||
args: []string{"SET", "key", ""},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$0\r\n\r\n",
|
||||
},
|
||||
{
|
||||
name: "Command with special characters",
|
||||
args: []string{"SET", "key", "val\r\nue"},
|
||||
expected: "*3\r\n$3\r\nSET\r\n$3\r\nkey\r\n$7\r\nval\r\nue\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(buf)
|
||||
|
||||
err := writer.WriteCommand(tt.args...)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, buf.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadSimpleString tests reading simple strings
|
||||
func TestRESPReader_ReadSimpleString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "OK response",
|
||||
input: "+OK\r\n",
|
||||
expected: "OK",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "PONG response",
|
||||
input: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "+\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "String with spaces",
|
||||
input: "+Hello World\r\n",
|
||||
expected: "Hello World",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadError tests reading error messages
|
||||
func TestRESPReader_ReadError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
name: "ERR error",
|
||||
input: "-ERR unknown command\r\n",
|
||||
expectedError: "ERR unknown command",
|
||||
},
|
||||
{
|
||||
name: "WRONGTYPE error",
|
||||
input: "-WRONGTYPE Operation against a key holding the wrong kind of value\r\n",
|
||||
expectedError: "WRONGTYPE Operation against a key holding the wrong kind of value",
|
||||
},
|
||||
{
|
||||
name: "Simple error",
|
||||
input: "-Error\r\n",
|
||||
expectedError: "Error",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, tt.expectedError, err.Error())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadInteger tests reading integers
|
||||
func TestRESPReader_ReadInteger(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Zero",
|
||||
input: ":0\r\n",
|
||||
expected: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Positive integer",
|
||||
input: ":1000\r\n",
|
||||
expected: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Negative integer",
|
||||
input: ":-1\r\n",
|
||||
expected: -1,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Large integer",
|
||||
input: ":9223372036854775807\r\n",
|
||||
expected: 9223372036854775807,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid integer",
|
||||
input: ":abc\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadBulkString tests reading bulk strings
|
||||
func TestRESPReader_ReadBulkString(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Simple bulk string",
|
||||
input: "$6\r\nfoobar\r\n",
|
||||
expected: "foobar",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Empty bulk string",
|
||||
input: "$0\r\n\r\n",
|
||||
expected: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil bulk string",
|
||||
input: "$-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Binary safe bulk string",
|
||||
input: "$5\r\n\x00\x01\x02\x03\x04\r\n",
|
||||
expected: "\x00\x01\x02\x03\x04",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid length",
|
||||
input: "$abc\r\ntest\r\n",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_ReadArray tests reading arrays
|
||||
func TestRESPReader_ReadArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []interface{}
|
||||
wantErr bool
|
||||
isNil bool
|
||||
}{
|
||||
{
|
||||
name: "Empty array",
|
||||
input: "*0\r\n",
|
||||
expected: []interface{}{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of bulk strings",
|
||||
input: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
|
||||
expected: []interface{}{
|
||||
"foo",
|
||||
"bar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Array of integers",
|
||||
input: "*3\r\n:1\r\n:2\r\n:3\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Mixed array",
|
||||
input: "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n",
|
||||
expected: []interface{}{
|
||||
int64(1),
|
||||
int64(2),
|
||||
int64(3),
|
||||
int64(4),
|
||||
"foobar",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Nil array",
|
||||
input: "*-1\r\n",
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
isNil: true,
|
||||
},
|
||||
{
|
||||
name: "Nested arrays",
|
||||
input: "*2\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n*1\r\n$3\r\nbaz\r\n",
|
||||
expected: []interface{}{
|
||||
[]interface{}{"foo", "bar"},
|
||||
[]interface{}{"baz"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.isNil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_InvalidInput tests error handling for invalid input
|
||||
func TestRESPReader_InvalidInput(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "Unknown type byte",
|
||||
input: "?invalid\r\n",
|
||||
},
|
||||
{
|
||||
name: "Incomplete response",
|
||||
input: "+OK",
|
||||
},
|
||||
{
|
||||
name: "Missing CRLF in bulk string",
|
||||
input: "$5\r\nhello",
|
||||
},
|
||||
{
|
||||
name: "Truncated array",
|
||||
input: "*3\r\n:1\r\n:2\r\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(tt.input))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRESPReader_EOF tests handling of EOF
|
||||
func TestRESPReader_EOF(t *testing.T) {
|
||||
reader := NewRESPReader(strings.NewReader(""))
|
||||
_, err := reader.ReadResponse()
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, io.EOF))
|
||||
}
|
||||
|
||||
// TestRESPHelpers tests helper functions
|
||||
func TestRESPHelpers(t *testing.T) {
|
||||
t.Run("RESPString", func(t *testing.T) {
|
||||
// Valid string
|
||||
result, err := RESPString("hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello", result)
|
||||
|
||||
// Byte slice
|
||||
result, err = RESPString([]byte("world"))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "world", result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPString(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPString(123)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("RESPInt", func(t *testing.T) {
|
||||
// Valid int64
|
||||
result, err := RESPInt(int64(42))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Valid int
|
||||
result, err = RESPInt(42)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42), result)
|
||||
|
||||
// Nil
|
||||
_, err = RESPInt(nil)
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
|
||||
// Invalid type
|
||||
_, err = RESPInt("string")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// TestRESPRoundTrip tests full round-trip encoding/decoding
|
||||
func TestRESPRoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command []string
|
||||
response string
|
||||
expected interface{}
|
||||
}{
|
||||
{
|
||||
name: "PING command",
|
||||
command: []string{"PING"},
|
||||
response: "+PONG\r\n",
|
||||
expected: "PONG",
|
||||
},
|
||||
{
|
||||
name: "GET command with result",
|
||||
command: []string{"GET", "mykey"},
|
||||
response: "$7\r\nmyvalue\r\n",
|
||||
expected: "myvalue",
|
||||
},
|
||||
{
|
||||
name: "GET command with nil",
|
||||
command: []string{"GET", "nonexistent"},
|
||||
response: "$-1\r\n",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "DEL command",
|
||||
command: []string{"DEL", "key1", "key2"},
|
||||
response: ":2\r\n",
|
||||
expected: int64(2),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Write command
|
||||
writeBuf := &bytes.Buffer{}
|
||||
writer := NewRESPWriter(writeBuf)
|
||||
err := writer.WriteCommand(tt.command...)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read response
|
||||
reader := NewRESPReader(strings.NewReader(tt.response))
|
||||
result, err := reader.ReadResponse()
|
||||
|
||||
if tt.expected == nil {
|
||||
require.Error(t, err)
|
||||
assert.True(t, errors.Is(err, ErrNilResponse))
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+198
@@ -0,0 +1,198 @@
|
||||
package backends
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestLogger implements a simple logger for tests
|
||||
type TestLogger struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func NewTestLogger(t *testing.T) *TestLogger {
|
||||
return &TestLogger{t: t}
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debug(format string, args ...interface{}) {
|
||||
l.t.Logf("[DEBUG] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Info(format string, args ...interface{}) {
|
||||
l.t.Logf("[INFO] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Error(format string, args ...interface{}) {
|
||||
l.t.Logf("[ERROR] "+format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Debugf(format string, args ...interface{}) {
|
||||
l.Debug(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Infof(format string, args ...interface{}) {
|
||||
l.Info(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Errorf(format string, args ...interface{}) {
|
||||
l.Error(format, args...)
|
||||
}
|
||||
|
||||
func (l *TestLogger) Warnf(format string, args ...interface{}) {
|
||||
l.t.Logf("[WARN] "+format, args...)
|
||||
}
|
||||
|
||||
// MiniredisServer manages a miniredis instance for testing
|
||||
type MiniredisServer struct {
|
||||
server *miniredis.Miniredis
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
// NewMiniredisServer creates a new miniredis server for testing
|
||||
func NewMiniredisServer(t *testing.T) *MiniredisServer {
|
||||
t.Helper()
|
||||
|
||||
mr, err := miniredis.Run()
|
||||
require.NoError(t, err, "failed to start miniredis")
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// Verify connection
|
||||
ctx := context.Background()
|
||||
err = client.Ping(ctx).Err()
|
||||
require.NoError(t, err, "failed to ping miniredis")
|
||||
|
||||
t.Cleanup(func() {
|
||||
client.Close()
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return &MiniredisServer{
|
||||
server: mr,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAddr returns the address of the miniredis server
|
||||
func (m *MiniredisServer) GetAddr() string {
|
||||
return m.server.Addr()
|
||||
}
|
||||
|
||||
// GetClient returns the Redis client
|
||||
func (m *MiniredisServer) GetClient() *redis.Client {
|
||||
return m.client
|
||||
}
|
||||
|
||||
// FastForward advances the miniredis server's time
|
||||
func (m *MiniredisServer) FastForward(d time.Duration) {
|
||||
m.server.FastForward(d)
|
||||
}
|
||||
|
||||
// FlushAll removes all keys from the database
|
||||
func (m *MiniredisServer) FlushAll() {
|
||||
m.server.FlushAll()
|
||||
}
|
||||
|
||||
// SetError simulates a Redis error
|
||||
func (m *MiniredisServer) SetError(err string) {
|
||||
m.server.SetError(err)
|
||||
}
|
||||
|
||||
// ClearError clears any simulated errors
|
||||
func (m *MiniredisServer) ClearError() {
|
||||
m.server.SetError("")
|
||||
}
|
||||
|
||||
// CheckKeys verifies that specific keys exist in Redis
|
||||
func (m *MiniredisServer) CheckKeys() []string {
|
||||
return m.server.Keys()
|
||||
}
|
||||
|
||||
// Close closes the miniredis server
|
||||
func (m *MiniredisServer) Close() {
|
||||
m.server.Close()
|
||||
}
|
||||
|
||||
// Restart restarts the miniredis server
|
||||
func (m *MiniredisServer) Restart() {
|
||||
m.server.Restart()
|
||||
}
|
||||
|
||||
// TestConfig provides default test configuration
|
||||
type TestConfig struct {
|
||||
MaxSize int
|
||||
DefaultTTL time.Duration
|
||||
CleanupInterval time.Duration
|
||||
EnableMetrics bool
|
||||
}
|
||||
|
||||
// DefaultTestConfig returns a standard test configuration
|
||||
func DefaultTestConfig() *TestConfig {
|
||||
return &TestConfig{
|
||||
MaxSize: 100,
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
CleanupInterval: 1 * time.Second,
|
||||
EnableMetrics: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTestData creates test cache data
|
||||
func GenerateTestData(count int) map[string][]byte {
|
||||
data := make(map[string][]byte, count)
|
||||
for i := 0; i < count; i++ {
|
||||
key := fmt.Sprintf("test-key-%d", i)
|
||||
value := []byte(fmt.Sprintf("test-value-%d", i))
|
||||
data[key] = value
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// GenerateLargeValue creates a large test value
|
||||
func GenerateLargeValue(sizeBytes int) []byte {
|
||||
return make([]byte, sizeBytes)
|
||||
}
|
||||
|
||||
// AssertCacheStats is a helper to verify cache statistics
|
||||
func AssertCacheStats(t *testing.T, stats map[string]interface{}, expectedHits, expectedMisses int64) {
|
||||
t.Helper()
|
||||
|
||||
hits, ok := stats["hits"].(int64)
|
||||
require.True(t, ok, "hits should be int64")
|
||||
require.Equal(t, expectedHits, hits, "unexpected hit count")
|
||||
|
||||
misses, ok := stats["misses"].(int64)
|
||||
require.True(t, ok, "misses should be int64")
|
||||
require.Equal(t, expectedMisses, misses, "unexpected miss count")
|
||||
}
|
||||
|
||||
// WaitForCondition waits for a condition to be true or times out
|
||||
func WaitForCondition(t *testing.T, timeout time.Duration, checkInterval time.Duration, condition func() bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
if condition() {
|
||||
return
|
||||
}
|
||||
time.Sleep(checkInterval)
|
||||
}
|
||||
t.Fatal("timeout waiting for condition")
|
||||
}
|
||||
|
||||
// AssertEventuallyExpires verifies that a key eventually expires
|
||||
func AssertEventuallyExpires(t *testing.T, backend CacheBackend, ctx context.Context, key string, maxWait time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
WaitForCondition(t, maxWait, 100*time.Millisecond, func() bool {
|
||||
_, _, exists, err := backend.Get(ctx, key)
|
||||
return err == nil && !exists
|
||||
})
|
||||
}
|
||||
Vendored
+96
-10
@@ -1880,19 +1880,20 @@ func TestConcurrentManagerOperations(t *testing.T) {
|
||||
// TestTTLExpirationAndCleanup tests TTL expiration and cleanup routines comprehensively
|
||||
func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
config := DefaultConfig()
|
||||
config.CleanupInterval = 10 * time.Millisecond
|
||||
config.CleanupInterval = 50 * time.Millisecond
|
||||
config.EnableAutoCleanup = true
|
||||
cache := New(config)
|
||||
defer cache.Close()
|
||||
|
||||
// Test various TTL scenarios
|
||||
// Note: Timing increased 5x to account for race detector overhead
|
||||
testCases := []struct {
|
||||
key string
|
||||
ttl time.Duration
|
||||
}{
|
||||
{"very-short", 5 * time.Millisecond},
|
||||
{"short", 25 * time.Millisecond},
|
||||
{"medium", 100 * time.Millisecond},
|
||||
{"very-short", 25 * time.Millisecond},
|
||||
{"short", 125 * time.Millisecond},
|
||||
{"medium", 500 * time.Millisecond},
|
||||
{"long", 1 * time.Hour},
|
||||
}
|
||||
|
||||
@@ -1908,13 +1909,13 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for very short items to expire
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
time.Sleep(75 * time.Millisecond)
|
||||
if _, exists := cache.Get("very-short"); exists {
|
||||
t.Error("Very short item should be expired")
|
||||
}
|
||||
|
||||
// Wait for short items to expire
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
if _, exists := cache.Get("short"); exists {
|
||||
t.Error("Short item should be expired")
|
||||
}
|
||||
@@ -1930,16 +1931,16 @@ func TestTTLExpirationAndCleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test manual cleanup
|
||||
cache.Set("manual-cleanup", "value", 1*time.Millisecond)
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.Set("manual-cleanup", "value", 5*time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
cache.Cleanup()
|
||||
|
||||
// Add many expired items to test bulk cleanup
|
||||
for i := 0; i < 100; i++ {
|
||||
key := fmt.Sprintf("bulk-%d", i)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 1*time.Millisecond)
|
||||
cache.Set(key, fmt.Sprintf("value-%d", i), 5*time.Millisecond)
|
||||
}
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
|
||||
sizeBefore := cache.Size()
|
||||
cache.Cleanup()
|
||||
@@ -2038,3 +2039,88 @@ func TestCacheStatisticsAndMetrics(t *testing.T) {
|
||||
t.Error("Memory usage should increase after adding large item")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// noOpLogger Tests
|
||||
// ============================================================================
|
||||
|
||||
// TestNoOpLogger_AllMethods tests all noOpLogger methods to ensure they don't panic
|
||||
func TestNoOpLogger_AllMethods(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Test simple message methods
|
||||
logger.Debug("test debug message")
|
||||
logger.Info("test info message")
|
||||
logger.Error("test error message")
|
||||
logger.Warn("test warn message")
|
||||
logger.Fatal("test fatal message")
|
||||
|
||||
// Test formatted message methods
|
||||
logger.Debugf("test debug: %s", "value")
|
||||
logger.Infof("test info: %s", "value")
|
||||
logger.Errorf("test error: %s", "value")
|
||||
logger.Warnf("test warn: %s", "value")
|
||||
logger.Fatalf("test fatal: %s", "value")
|
||||
|
||||
// If we reach here, all methods executed without panicking
|
||||
// This is expected behavior for a no-op logger
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithField verifies WithField returns the same logger
|
||||
func TestNoOpLogger_WithField(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
result := logger.WithField("key", "value")
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithField should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithField")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_WithFields verifies WithFields returns the same logger
|
||||
func TestNoOpLogger_WithFields(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
fields := map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
"key3": true,
|
||||
}
|
||||
|
||||
result := logger.WithFields(fields)
|
||||
|
||||
if result != logger {
|
||||
t.Error("WithFields should return the same logger instance")
|
||||
}
|
||||
|
||||
// Verify the returned logger works
|
||||
result.Info("test message after WithFields")
|
||||
}
|
||||
|
||||
// TestNoOpLogger_Chaining verifies method chaining works
|
||||
func TestNoOpLogger_Chaining(t *testing.T) {
|
||||
logger := &noOpLogger{}
|
||||
|
||||
// Use WithField and verify it returns a usable logger
|
||||
result := logger.WithField("key1", "value1")
|
||||
|
||||
// Verify the result can be used for logging (Logger interface methods)
|
||||
result.Info("info after WithField")
|
||||
result.Infof("infof after WithField: %s", "test")
|
||||
result.Debug("debug after WithField")
|
||||
result.Debugf("debugf after WithField: %d", 123)
|
||||
result.Error("error after WithField")
|
||||
result.Errorf("errorf after WithField: %v", true)
|
||||
|
||||
// Use WithFields and verify it returns a usable logger
|
||||
result2 := logger.WithFields(map[string]interface{}{
|
||||
"key2": "value2",
|
||||
"key3": 123,
|
||||
})
|
||||
|
||||
// Verify the result can be used for logging
|
||||
result2.Infof("message after WithFields: %s", "test")
|
||||
}
|
||||
|
||||
+332
@@ -0,0 +1,332 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
// ErrCircuitOpen is returned when the circuit breaker is open
|
||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||
|
||||
// ErrTooManyRequests is returned when too many requests are made in half-open state
|
||||
ErrTooManyRequests = errors.New("too many requests in half-open state")
|
||||
)
|
||||
|
||||
// State represents the state of the circuit breaker
|
||||
type State int32
|
||||
|
||||
const (
|
||||
// StateClosed allows all operations to pass through
|
||||
StateClosed State = iota
|
||||
|
||||
// StateOpen blocks all operations
|
||||
StateOpen
|
||||
|
||||
// StateHalfOpen allows a limited number of operations to test recovery
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
// String returns the string representation of the state
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
case StateOpen:
|
||||
return "open"
|
||||
case StateHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig holds configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// MaxFailures is the number of consecutive failures before opening the circuit
|
||||
MaxFailures int
|
||||
|
||||
// FailureThreshold is the failure rate threshold (0.0 to 1.0)
|
||||
FailureThreshold float64
|
||||
|
||||
// Timeout is how long the circuit stays open before trying half-open
|
||||
Timeout time.Duration
|
||||
|
||||
// HalfOpenMaxRequests is the number of requests allowed in half-open state
|
||||
HalfOpenMaxRequests int
|
||||
|
||||
// ResetTimeout is how long to wait before resetting counters in closed state
|
||||
ResetTimeout time.Duration
|
||||
|
||||
// OnStateChange is called when the circuit breaker changes state
|
||||
OnStateChange func(from, to State)
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns default configuration
|
||||
func DefaultCircuitBreakerConfig() *CircuitBreakerConfig {
|
||||
return &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
FailureThreshold: 0.6,
|
||||
Timeout: 30 * time.Second,
|
||||
HalfOpenMaxRequests: 3,
|
||||
ResetTimeout: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern
|
||||
type CircuitBreaker struct {
|
||||
config *CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state atomic.Int32
|
||||
lastStateChange time.Time
|
||||
stateMu sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures atomic.Int32
|
||||
totalRequests atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
halfOpenRequests atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
nextRetryTime time.Time
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
stateTransitions atomic.Int64
|
||||
rejectedRequests atomic.Int64
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker
|
||||
func NewCircuitBreaker(config *CircuitBreakerConfig) *CircuitBreaker {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
config: config,
|
||||
lastStateChange: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs a function through the circuit breaker
|
||||
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
|
||||
if !cb.AllowRequest() {
|
||||
cb.rejectedRequests.Add(1)
|
||||
return ErrCircuitOpen
|
||||
}
|
||||
|
||||
cb.totalRequests.Add(1)
|
||||
|
||||
err := fn()
|
||||
if err != nil {
|
||||
cb.RecordFailure()
|
||||
} else {
|
||||
cb.RecordSuccess()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AllowRequest checks if a request is allowed to proceed
|
||||
func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// Check if timeout has passed and we should try half-open
|
||||
cb.timeMu.RLock()
|
||||
shouldRetry := time.Now().After(cb.nextRetryTime)
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
if shouldRetry {
|
||||
cb.setState(StateHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
func (cb *CircuitBreaker) RecordSuccess() {
|
||||
cb.timeMu.Lock()
|
||||
cb.lastSuccessTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Reset consecutive failures
|
||||
cb.consecutiveFailures.Store(0)
|
||||
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
func (cb *CircuitBreaker) RecordFailure() {
|
||||
cb.totalFailures.Add(1)
|
||||
failures := cb.consecutiveFailures.Add(1)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = time.Now()
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
state := cb.GetState()
|
||||
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
// #nosec G115 -- MaxFailures is a small config value that fits in int32
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
// Check failure rate
|
||||
total := cb.totalRequests.Load()
|
||||
failureCount := cb.totalFailures.Load()
|
||||
if total > 10 && float64(failureCount)/float64(total) > cb.config.FailureThreshold {
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
case StateHalfOpen:
|
||||
// Any failure in half-open state reopens the circuit
|
||||
cb.openCircuit()
|
||||
}
|
||||
}
|
||||
|
||||
// openCircuit transitions to open state
|
||||
func (cb *CircuitBreaker) openCircuit() {
|
||||
cb.setState(StateOpen)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
|
||||
cb.timeMu.Lock()
|
||||
cb.nextRetryTime = time.Now().Add(cb.config.Timeout)
|
||||
cb.timeMu.Unlock()
|
||||
}
|
||||
|
||||
// GetState returns the current state
|
||||
func (cb *CircuitBreaker) GetState() State {
|
||||
return State(cb.state.Load())
|
||||
}
|
||||
|
||||
// setState changes the circuit breaker state
|
||||
func (cb *CircuitBreaker) setState(newState State) {
|
||||
oldState := State(cb.state.Swap(int32(newState)))
|
||||
|
||||
if oldState != newState {
|
||||
cb.stateTransitions.Add(1)
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMu.Unlock()
|
||||
|
||||
if cb.config.OnStateChange != nil {
|
||||
cb.config.OnStateChange(oldState, newState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
cb.totalRequests.Store(0)
|
||||
cb.totalFailures.Store(0)
|
||||
cb.halfOpenRequests.Store(0)
|
||||
cb.rejectedRequests.Store(0)
|
||||
cb.stateTransitions.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
cb.timeMu.Lock()
|
||||
cb.lastFailureTime = now
|
||||
cb.lastSuccessTime = now
|
||||
cb.nextRetryTime = now
|
||||
cb.timeMu.Unlock()
|
||||
|
||||
cb.stateMu.Lock()
|
||||
cb.lastStateChange = now
|
||||
cb.stateMu.Unlock()
|
||||
}
|
||||
|
||||
// Stats returns circuit breaker statistics
|
||||
func (cb *CircuitBreaker) Stats() CircuitBreakerStats {
|
||||
cb.timeMu.RLock()
|
||||
lastFailure := cb.lastFailureTime
|
||||
lastSuccess := cb.lastSuccessTime
|
||||
nextRetry := cb.nextRetryTime
|
||||
cb.timeMu.RUnlock()
|
||||
|
||||
cb.stateMu.RLock()
|
||||
lastChange := cb.lastStateChange
|
||||
cb.stateMu.RUnlock()
|
||||
|
||||
totalReq := cb.totalRequests.Load()
|
||||
totalFail := cb.totalFailures.Load()
|
||||
successRate := float64(0)
|
||||
if totalReq > 0 {
|
||||
successRate = float64(totalReq-totalFail) / float64(totalReq)
|
||||
}
|
||||
|
||||
return CircuitBreakerStats{
|
||||
State: cb.GetState(),
|
||||
ConsecutiveFailures: cb.consecutiveFailures.Load(),
|
||||
TotalRequests: totalReq,
|
||||
TotalFailures: totalFail,
|
||||
SuccessRate: successRate,
|
||||
RejectedRequests: cb.rejectedRequests.Load(),
|
||||
StateTransitions: cb.stateTransitions.Load(),
|
||||
LastFailureTime: lastFailure,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastStateChange: lastChange,
|
||||
NextRetryTime: nextRetry,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats holds statistics for the circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
State State
|
||||
ConsecutiveFailures int32
|
||||
TotalRequests int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
RejectedRequests int64
|
||||
StateTransitions int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastStateChange time.Time
|
||||
NextRetryTime time.Time
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the circuit breaker is in a healthy state
|
||||
func (cb *CircuitBreaker) IsHealthy() bool {
|
||||
return cb.GetState() != StateOpen
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// CircuitBreakerBackend wraps a cache backend with circuit breaker protection
|
||||
type CircuitBreakerBackend struct {
|
||||
backend backends.CacheBackend
|
||||
cb *CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreakerBackend creates a new circuit breaker wrapped backend
|
||||
func NewCircuitBreakerBackend(b backends.CacheBackend, config *CircuitBreakerConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultCircuitBreakerConfig()
|
||||
}
|
||||
|
||||
return &CircuitBreakerBackend{
|
||||
backend: b,
|
||||
cb: NewCircuitBreaker(config),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Set(ctx, key, value, ttl)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return nil, 0, false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
value, ttl, exists, err := c.backend.Get(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
deleted, err := c.backend.Delete(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
if !c.cb.AllowRequest() {
|
||||
return false, backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
exists, err := c.backend.Exists(ctx, key)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Clear(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Clear(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including circuit breaker state
|
||||
func (c *CircuitBreakerBackend) GetStats() map[string]interface{} {
|
||||
stats := c.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
cbStats := c.cb.Stats()
|
||||
stats["circuit_breaker"] = map[string]interface{}{
|
||||
"state": cbStats.State.String(),
|
||||
"consecutive_failures": cbStats.ConsecutiveFailures,
|
||||
"total_requests": cbStats.TotalRequests,
|
||||
"total_failures": cbStats.TotalFailures,
|
||||
"success_rate": cbStats.SuccessRate,
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health with circuit breaker protection
|
||||
func (c *CircuitBreakerBackend) Ping(ctx context.Context) error {
|
||||
if !c.cb.AllowRequest() {
|
||||
return backends.ErrCircuitOpen
|
||||
}
|
||||
|
||||
err := c.backend.Ping(ctx)
|
||||
if err == nil {
|
||||
c.cb.RecordSuccess()
|
||||
} else {
|
||||
c.cb.RecordFailure()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the backend
|
||||
func (c *CircuitBreakerBackend) Close() error {
|
||||
return c.backend.Close()
|
||||
}
|
||||
@@ -0,0 +1,561 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockBackend is a simple mock implementation for testing
|
||||
type mockBackend struct {
|
||||
data map[string]mockEntry
|
||||
mu sync.RWMutex
|
||||
failSet bool
|
||||
failGet bool
|
||||
failDelete bool
|
||||
failExists bool
|
||||
failClear bool
|
||||
failPing bool
|
||||
callCount int
|
||||
}
|
||||
|
||||
type mockEntry struct {
|
||||
value []byte
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func newMockBackend() *mockBackend {
|
||||
return &mockBackend{
|
||||
data: make(map[string]mockEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failSet {
|
||||
return errors.New("mock set error")
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
if ttl == 0 {
|
||||
expiresAt = time.Now().Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
m.data[key] = mockEntry{
|
||||
value: value,
|
||||
expiresAt: expiresAt,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failGet {
|
||||
return nil, 0, false, errors.New("mock get error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
|
||||
ttl := time.Until(entry.expiresAt)
|
||||
return entry.value, ttl, true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failDelete {
|
||||
return false, errors.New("mock delete error")
|
||||
}
|
||||
|
||||
_, existed := m.data[key]
|
||||
delete(m.data, key)
|
||||
return existed, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failExists {
|
||||
return false, errors.New("mock exists error")
|
||||
}
|
||||
|
||||
entry, exists := m.data[key]
|
||||
if !exists {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if time.Now().After(entry.expiresAt) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Clear(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failClear {
|
||||
return errors.New("mock clear error")
|
||||
}
|
||||
|
||||
m.data = make(map[string]mockEntry)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) GetStats() map[string]interface{} {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"hits": int64(0),
|
||||
"misses": int64(0),
|
||||
"call_count": m.callCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Ping(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.callCount++
|
||||
|
||||
if m.failPing {
|
||||
return errors.New("mock ping error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Constructor Tests
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithDefaultConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
require.NotNil(t, cb)
|
||||
|
||||
// Verify it implements the interface (compile-time check)
|
||||
var _ backends.CacheBackend = cb
|
||||
}
|
||||
|
||||
func TestNewCircuitBreakerBackend_WithCustomConfig(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
FailureThreshold: 0.5,
|
||||
Timeout: 5 * time.Second,
|
||||
HalfOpenMaxRequests: 2,
|
||||
ResetTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
require.NotNil(t, cb)
|
||||
}
|
||||
|
||||
// Set Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, mockBE.callCount)
|
||||
|
||||
// Verify value was stored
|
||||
value, _, exists, _ := mockBE.Get(ctx, "key1")
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Set_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Circuit should be open now
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Get Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First set a value
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Now get it through circuit breaker
|
||||
value, _, exists, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
assert.Equal(t, []byte("value1"), value)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_Failure(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _, _, err := cb.Get(ctx, "key1")
|
||||
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Get_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failGet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Get(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, _, _, err := cb.Get(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Delete Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Delete through circuit breaker
|
||||
deleted, err := cb.Delete(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, deleted)
|
||||
|
||||
// Verify it's deleted
|
||||
exists, _ := mockBE.Exists(ctx, "key1")
|
||||
assert.False(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Delete_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failDelete = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Delete(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Delete(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Exists Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set a value first
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
|
||||
// Check existence through circuit breaker
|
||||
exists, err := cb.Exists(ctx, "key1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Exists_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failExists = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Exists(ctx, "key")
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
_, err := cb.Exists(ctx, "key2")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Clear Operation Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set some values
|
||||
mockBE.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
mockBE.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
|
||||
// Clear through circuit breaker
|
||||
err := cb.Clear(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify cleared
|
||||
exists1, _ := mockBE.Exists(ctx, "key1")
|
||||
exists2, _ := mockBE.Exists(ctx, "key2")
|
||||
assert.False(t, exists1)
|
||||
assert.False(t, exists2)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Clear_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failClear = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Clear(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Clear(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// GetStats Tests
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Perform some operations
|
||||
cb.Set(ctx, "key1", []byte("value1"), 1*time.Minute)
|
||||
cb.Get(ctx, "key1")
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
|
||||
// Should have circuit breaker stats
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
|
||||
cbStats, ok := stats["circuit_breaker"].(map[string]interface{})
|
||||
require.True(t, ok)
|
||||
|
||||
// Verify circuit breaker stats fields
|
||||
assert.Contains(t, cbStats, "state")
|
||||
assert.Contains(t, cbStats, "consecutive_failures")
|
||||
assert.Contains(t, cbStats, "total_requests")
|
||||
assert.Contains(t, cbStats, "total_failures")
|
||||
assert.Contains(t, cbStats, "success_rate")
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_GetStats_NilBackendStats(t *testing.T) {
|
||||
// Create a mock backend that returns nil stats
|
||||
mockBE := &mockBackendNilStats{}
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
stats := cb.GetStats()
|
||||
|
||||
require.NotNil(t, stats)
|
||||
assert.Contains(t, stats, "circuit_breaker")
|
||||
}
|
||||
|
||||
// mockBackendNilStats returns nil from GetStats
|
||||
type mockBackendNilStats struct {
|
||||
mockBackend
|
||||
}
|
||||
|
||||
func (m *mockBackendNilStats) GetStats() map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_Success(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
err := cb.Ping(ctx)
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestCircuitBreakerBackend_Ping_CircuitOpen(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failPing = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Ping(ctx)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
err := cb.Ping(ctx)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// Close Tests
|
||||
|
||||
func TestCircuitBreakerBackend_Close(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
cb := NewCircuitBreakerBackend(mockBE, nil)
|
||||
|
||||
err := cb.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Circuit Recovery Test
|
||||
|
||||
func TestCircuitBreakerBackend_CircuitRecovery(t *testing.T) {
|
||||
mockBE := newMockBackend()
|
||||
mockBE.failSet = true
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreakerBackend(mockBE, config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 5; i++ {
|
||||
cb.Set(ctx, "key", []byte("value"), 1*time.Minute)
|
||||
}
|
||||
|
||||
// Verify circuit is open
|
||||
err := cb.Set(ctx, "key2", []byte("value2"), 1*time.Minute)
|
||||
assert.Equal(t, backends.ErrCircuitOpen, err)
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
|
||||
// Fix the backend
|
||||
mockBE.mu.Lock()
|
||||
mockBE.failSet = false
|
||||
mockBE.mu.Unlock()
|
||||
|
||||
// Circuit should be in half-open state, allow a test request
|
||||
err = cb.Set(ctx, "key3", []byte("value3"), 1*time.Minute)
|
||||
|
||||
// After success threshold is met, circuit should close
|
||||
if err == nil {
|
||||
// Circuit recovered
|
||||
err2 := cb.Set(ctx, "key4", []byte("value4"), 1*time.Minute)
|
||||
assert.NoError(t, err2, "Circuit should be closed after recovery")
|
||||
}
|
||||
}
|
||||
+553
@@ -0,0 +1,553 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestCircuitBreaker_StateTransitions tests state machine transitions
|
||||
func TestCircuitBreaker_StateTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("Initial state is closed", func(t *testing.T) {
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Closed to Open after max failures", func(t *testing.T) {
|
||||
cb.Reset()
|
||||
|
||||
// Simulate failures
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("Open to HalfOpen after timeout", func(t *testing.T) {
|
||||
// Open the circuit
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should allow request and transition to half-open
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, StateHalfOpen, cb.GetState())
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Closed after successful requests", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First request transitions to half-open and succeeds
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
// Should be in half-open after first request
|
||||
state := cb.GetState()
|
||||
assert.True(t, state == StateHalfOpen || state == StateClosed,
|
||||
"After first successful request, should be half-open or potentially closed")
|
||||
|
||||
if state == StateHalfOpen {
|
||||
// Need more successful requests to close
|
||||
// The exact number depends on implementation but should be within HalfOpenMaxRequests
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
// After multiple successful requests, should eventually close
|
||||
finalState := cb.GetState()
|
||||
assert.True(t, finalState == StateClosed || finalState == StateHalfOpen,
|
||||
"After successful requests, circuit should transition towards closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("HalfOpen to Open on failure", func(t *testing.T) {
|
||||
// Open circuit then wait for half-open
|
||||
cb.Reset()
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// First call transitions to half-open, second failure reopens
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_OpenCircuitBlocks tests that open circuit blocks requests
|
||||
func TestCircuitBreaker_OpenCircuitBlocks(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger failures to open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Requests should be blocked
|
||||
err := cb.Execute(ctx, func() error {
|
||||
t.Fatal("Should not execute function when circuit is open")
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, ErrCircuitOpen, err)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_HalfOpenMaxRequests tests max requests in half-open state
|
||||
func TestCircuitBreaker_HalfOpenMaxRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit then wait for half-open
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("test error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// After timeout, circuit should allow transition to half-open
|
||||
// Execute HalfOpenMaxRequests successful requests
|
||||
successCount := 0
|
||||
for i := 0; i < config.HalfOpenMaxRequests; i++ {
|
||||
err := cb.Execute(ctx, func() error {
|
||||
successCount++
|
||||
return nil
|
||||
})
|
||||
// Should allow up to HalfOpenMaxRequests
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify we executed the expected number
|
||||
assert.Equal(t, config.HalfOpenMaxRequests, successCount)
|
||||
|
||||
// After successful requests, circuit behavior depends on implementation
|
||||
// It could close (allowing more requests) or stay half-open (blocking)
|
||||
// The important thing is that we allowed exactly HalfOpenMaxRequests
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_SuccessResetsFailures tests failure counter reset
|
||||
func TestCircuitBreaker_SuccessResetsFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Have some failures (but less than max)
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
|
||||
// One success should reset failures
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats = cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_ConcurrentAccess tests thread safety
|
||||
func TestCircuitBreaker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 10,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 5,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
iterations := 50
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterations; j++ {
|
||||
// Mix of successes and failures
|
||||
cb.Execute(ctx, func() error {
|
||||
if (id+j)%3 == 0 {
|
||||
return errors.New("test error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Random state checks
|
||||
_ = cb.GetState()
|
||||
_ = cb.Stats()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should complete without panics
|
||||
stats := cb.Stats()
|
||||
assert.NotNil(t, stats)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Stats tests statistics tracking
|
||||
func TestCircuitBreaker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 2,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Execute some requests
|
||||
cb.Execute(ctx, func() error { return nil }) // Success
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
cb.Execute(ctx, func() error { return errors.New("error") }) // Failure
|
||||
|
||||
stats := cb.Stats()
|
||||
|
||||
assert.Equal(t, StateClosed, stats.State)
|
||||
assert.Equal(t, int64(3), stats.TotalRequests)
|
||||
assert.Equal(t, int64(2), stats.TotalFailures)
|
||||
assert.Equal(t, int32(2), stats.ConsecutiveFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_Reset tests circuit reset
|
||||
func TestCircuitBreaker_Reset(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open the circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Reset
|
||||
cb.Reset()
|
||||
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
stats := cb.Stats()
|
||||
assert.Equal(t, int32(0), stats.ConsecutiveFailures)
|
||||
assert.Equal(t, int64(0), stats.TotalRequests)
|
||||
assert.Equal(t, int64(0), stats.TotalFailures)
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateChangeCallback tests state change notifications
|
||||
func TestCircuitBreaker_StateChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 50 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
OnStateChange: func(from, to State) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Trigger state transitions
|
||||
// Closed -> Open
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
// Should be open now
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait for timeout to allow half-open transition
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Open -> HalfOpen on first request after timeout
|
||||
err := cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Execute more successful requests to trigger HalfOpen -> Closed
|
||||
for i := 0; i < config.HalfOpenMaxRequests-1; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
assert.Contains(t, transitions, "closed->open")
|
||||
assert.Contains(t, transitions, "open->half-open")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_IsHealthy tests health check
|
||||
func TestCircuitBreaker_IsHealthy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 2,
|
||||
Timeout: 100 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Initially healthy
|
||||
assert.True(t, cb.IsHealthy())
|
||||
|
||||
// Open circuit
|
||||
for i := 0; i < 2; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
assert.False(t, cb.IsHealthy(), "Should not be healthy when open")
|
||||
|
||||
// Wait for timeout and allow successful request
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Should be healthy after recovery
|
||||
assert.True(t, cb.IsHealthy(), "Should be healthy after recovery")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_RapidFailures tests rapid consecutive failures
|
||||
func TestCircuitBreaker_RapidFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 5,
|
||||
Timeout: 200 * time.Millisecond,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Rapid failures
|
||||
for i := 0; i < 10; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("rapid error")
|
||||
})
|
||||
}
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
stats := cb.Stats()
|
||||
assert.GreaterOrEqual(t, stats.TotalFailures, int64(5))
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_TimeoutAccuracy tests timeout precision
|
||||
func TestCircuitBreaker_TimeoutAccuracy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
timeout := 100 * time.Millisecond
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1,
|
||||
Timeout: timeout,
|
||||
HalfOpenMaxRequests: 1,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Open circuit
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
|
||||
assert.Equal(t, StateOpen, cb.GetState())
|
||||
|
||||
// Wait just before timeout
|
||||
time.Sleep(timeout - 20*time.Millisecond)
|
||||
assert.False(t, cb.IsHealthy())
|
||||
|
||||
// Wait until after timeout
|
||||
time.Sleep(40 * time.Millisecond)
|
||||
// After timeout, AllowRequest should return true for transition to half-open
|
||||
assert.True(t, cb.AllowRequest())
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_DefaultConfig tests default configuration
|
||||
func TestCircuitBreaker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cb := NewCircuitBreaker(nil) // Should use defaults
|
||||
|
||||
assert.NotNil(t, cb)
|
||||
assert.Equal(t, StateClosed, cb.GetState())
|
||||
|
||||
// Verify defaults by triggering circuit breaker behavior
|
||||
ctx := context.Background()
|
||||
|
||||
// Test that it takes 5 failures to open (default MaxFailures)
|
||||
for i := 0; i < 4; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
}
|
||||
assert.Equal(t, StateClosed, cb.GetState(), "Should still be closed after 4 failures")
|
||||
|
||||
// 5th failure should open it
|
||||
cb.Execute(ctx, func() error {
|
||||
return errors.New("error")
|
||||
})
|
||||
assert.Equal(t, StateOpen, cb.GetState(), "Should be open after 5 failures (default threshold)")
|
||||
}
|
||||
|
||||
// TestCircuitBreaker_StateString tests state string representation
|
||||
func TestCircuitBreaker_StateString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "closed", StateClosed.String())
|
||||
assert.Equal(t, "open", StateOpen.String())
|
||||
assert.Equal(t, "half-open", StateHalfOpen.String())
|
||||
assert.Equal(t, "unknown", State(999).String())
|
||||
}
|
||||
|
||||
// Benchmark circuit breaker performance
|
||||
func BenchmarkCircuitBreaker_Execute(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 100,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkCircuitBreaker_ExecuteWithFailures(b *testing.B) {
|
||||
config := &CircuitBreakerConfig{
|
||||
MaxFailures: 1000,
|
||||
Timeout: 1 * time.Second,
|
||||
HalfOpenMaxRequests: 10,
|
||||
}
|
||||
cb := NewCircuitBreaker(config)
|
||||
ctx := context.Background()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cb.Execute(ctx, func() error {
|
||||
if i%10 == 0 {
|
||||
return errors.New("error")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
+377
@@ -0,0 +1,377 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// HealthStatus represents the health status of a backend
|
||||
type HealthStatus int32
|
||||
|
||||
const (
|
||||
// HealthUnknown indicates unknown health status
|
||||
HealthUnknown HealthStatus = iota
|
||||
|
||||
// HealthHealthy indicates the backend is healthy
|
||||
HealthHealthy
|
||||
|
||||
// HealthDegraded indicates the backend is degraded but operational
|
||||
HealthDegraded
|
||||
|
||||
// HealthUnhealthy indicates the backend is unhealthy
|
||||
HealthUnhealthy
|
||||
)
|
||||
|
||||
// String returns the string representation of the health status
|
||||
func (h HealthStatus) String() string {
|
||||
switch h {
|
||||
case HealthHealthy:
|
||||
return "healthy"
|
||||
case HealthDegraded:
|
||||
return "degraded"
|
||||
case HealthUnhealthy:
|
||||
return "unhealthy"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckConfig holds configuration for the health checker
|
||||
type HealthCheckConfig struct {
|
||||
// CheckInterval is how often to check health
|
||||
CheckInterval time.Duration
|
||||
|
||||
// Timeout is the timeout for each health check
|
||||
Timeout time.Duration
|
||||
|
||||
// HealthyThreshold is the number of consecutive successes to become healthy
|
||||
HealthyThreshold int
|
||||
|
||||
// UnhealthyThreshold is the number of consecutive failures to become unhealthy
|
||||
UnhealthyThreshold int
|
||||
|
||||
// DegradedThreshold is the latency threshold in ms to mark as degraded
|
||||
DegradedThreshold time.Duration
|
||||
|
||||
// OnStatusChange is called when health status changes
|
||||
OnStatusChange func(from, to HealthStatus)
|
||||
|
||||
// CheckFunc is the function to check health
|
||||
CheckFunc func(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DefaultHealthCheckConfig returns default configuration
|
||||
func DefaultHealthCheckConfig() *HealthCheckConfig {
|
||||
return &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
HealthyThreshold: 3,
|
||||
UnhealthyThreshold: 3,
|
||||
DegradedThreshold: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// HealthChecker monitors the health of a backend
|
||||
type HealthChecker struct {
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Status tracking
|
||||
status atomic.Int32
|
||||
consecutiveSuccesses atomic.Int32
|
||||
consecutiveFailures atomic.Int32
|
||||
|
||||
// Timing
|
||||
lastCheckTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureTime time.Time
|
||||
averageLatency atomic.Int64
|
||||
timeMu sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalChecks atomic.Int64
|
||||
totalSuccesses atomic.Int64
|
||||
totalFailures atomic.Int64
|
||||
statusChanges atomic.Int64
|
||||
|
||||
// Lifecycle
|
||||
ticker *time.Ticker
|
||||
stopChan chan struct{}
|
||||
stopped atomic.Bool
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthChecker creates a new health checker
|
||||
func NewHealthChecker(config *HealthCheckConfig) *HealthChecker {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
hc := &HealthChecker{
|
||||
config: config,
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Start begins health checking
|
||||
func (hc *HealthChecker) Start() {
|
||||
if hc.stopped.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
hc.ticker = time.NewTicker(hc.config.CheckInterval)
|
||||
hc.wg.Add(1)
|
||||
go hc.checkLoop()
|
||||
}
|
||||
|
||||
// Stop stops health checking
|
||||
func (hc *HealthChecker) Stop() {
|
||||
if hc.stopped.Swap(true) {
|
||||
return // Already stopped
|
||||
}
|
||||
|
||||
close(hc.stopChan)
|
||||
if hc.ticker != nil {
|
||||
hc.ticker.Stop()
|
||||
}
|
||||
hc.wg.Wait()
|
||||
}
|
||||
|
||||
// checkLoop runs periodic health checks
|
||||
func (hc *HealthChecker) checkLoop() {
|
||||
defer hc.wg.Done()
|
||||
|
||||
// Initial check - log error but continue
|
||||
if err := hc.Check(context.Background()); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hc.stopChan:
|
||||
return
|
||||
case <-hc.ticker.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hc.config.Timeout)
|
||||
if err := hc.Check(ctx); err != nil {
|
||||
// Error is already tracked in Check() method, no need to log again
|
||||
_ = err
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check performs a health check
|
||||
func (hc *HealthChecker) Check(ctx context.Context) error {
|
||||
if hc.config.CheckFunc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hc.totalChecks.Add(1)
|
||||
start := time.Now()
|
||||
|
||||
// Create timeout context if not already set
|
||||
if _, hasDeadline := ctx.Deadline(); !hasDeadline {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, hc.config.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// Perform health check
|
||||
err := hc.config.CheckFunc(ctx)
|
||||
latency := time.Since(start)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Update average latency
|
||||
hc.updateAverageLatency(latency)
|
||||
|
||||
if err != nil {
|
||||
hc.recordFailure()
|
||||
} else {
|
||||
hc.recordSuccess(latency)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// recordSuccess records a successful health check
|
||||
func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
hc.totalSuccesses.Add(1)
|
||||
successes := hc.consecutiveSuccesses.Add(1)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastSuccessTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
currentStatus := hc.GetStatus()
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
} else {
|
||||
newStatus = HealthHealthy
|
||||
}
|
||||
}
|
||||
|
||||
if newStatus != currentStatus {
|
||||
hc.setStatus(newStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// recordFailure records a failed health check
|
||||
func (hc *HealthChecker) recordFailure() {
|
||||
hc.totalFailures.Add(1)
|
||||
failures := hc.consecutiveFailures.Add(1)
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
|
||||
hc.timeMu.Lock()
|
||||
hc.lastFailureTime = time.Now()
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
}
|
||||
|
||||
// updateAverageLatency updates the rolling average latency
|
||||
func (hc *HealthChecker) updateAverageLatency(latency time.Duration) {
|
||||
// Simple exponential moving average
|
||||
currentAvg := time.Duration(hc.averageLatency.Load())
|
||||
if currentAvg == 0 {
|
||||
hc.averageLatency.Store(int64(latency))
|
||||
} else {
|
||||
// Weight: 0.2 for new value, 0.8 for old average
|
||||
newAvg := (currentAvg*4 + latency) / 5
|
||||
hc.averageLatency.Store(int64(newAvg))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the current health status
|
||||
func (hc *HealthChecker) GetStatus() HealthStatus {
|
||||
return HealthStatus(hc.status.Load())
|
||||
}
|
||||
|
||||
// setStatus changes the health status
|
||||
func (hc *HealthChecker) setStatus(newStatus HealthStatus) {
|
||||
oldStatus := HealthStatus(hc.status.Swap(int32(newStatus)))
|
||||
|
||||
if oldStatus != newStatus {
|
||||
hc.statusChanges.Add(1)
|
||||
if hc.config.OnStatusChange != nil {
|
||||
hc.config.OnStatusChange(oldStatus, newStatus)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy or degraded
|
||||
func (hc *HealthChecker) IsHealthy() bool {
|
||||
status := hc.GetStatus()
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// LastCheckTime returns the time of the last health check
|
||||
func (hc *HealthChecker) LastCheckTime() time.Time {
|
||||
hc.timeMu.RLock()
|
||||
defer hc.timeMu.RUnlock()
|
||||
return hc.lastCheckTime
|
||||
}
|
||||
|
||||
// HealthScore returns a health score between 0.0 (unhealthy) and 1.0 (healthy)
|
||||
func (hc *HealthChecker) HealthScore() float64 {
|
||||
status := hc.GetStatus()
|
||||
switch status {
|
||||
case HealthHealthy:
|
||||
return 1.0
|
||||
case HealthDegraded:
|
||||
return 0.7
|
||||
case HealthUnhealthy:
|
||||
return 0.0
|
||||
default:
|
||||
return 0.5
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns health checker statistics
|
||||
func (hc *HealthChecker) Stats() HealthCheckerStats {
|
||||
hc.timeMu.RLock()
|
||||
lastCheck := hc.lastCheckTime
|
||||
lastSuccess := hc.lastSuccessTime
|
||||
lastFailure := hc.lastFailureTime
|
||||
hc.timeMu.RUnlock()
|
||||
|
||||
totalChecks := hc.totalChecks.Load()
|
||||
totalSuccesses := hc.totalSuccesses.Load()
|
||||
totalFailures := hc.totalFailures.Load()
|
||||
|
||||
successRate := float64(0)
|
||||
if totalChecks > 0 {
|
||||
successRate = float64(totalSuccesses) / float64(totalChecks)
|
||||
}
|
||||
|
||||
return HealthCheckerStats{
|
||||
Status: hc.GetStatus(),
|
||||
ConsecutiveSuccesses: hc.consecutiveSuccesses.Load(),
|
||||
ConsecutiveFailures: hc.consecutiveFailures.Load(),
|
||||
TotalChecks: totalChecks,
|
||||
TotalSuccesses: totalSuccesses,
|
||||
TotalFailures: totalFailures,
|
||||
SuccessRate: successRate,
|
||||
AverageLatency: time.Duration(hc.averageLatency.Load()),
|
||||
StatusChanges: hc.statusChanges.Load(),
|
||||
LastCheckTime: lastCheck,
|
||||
LastSuccessTime: lastSuccess,
|
||||
LastFailureTime: lastFailure,
|
||||
HealthScore: hc.HealthScore(),
|
||||
}
|
||||
}
|
||||
|
||||
// HealthCheckerStats holds statistics for the health checker
|
||||
type HealthCheckerStats struct {
|
||||
Status HealthStatus
|
||||
ConsecutiveSuccesses int32
|
||||
ConsecutiveFailures int32
|
||||
TotalChecks int64
|
||||
TotalSuccesses int64
|
||||
TotalFailures int64
|
||||
SuccessRate float64
|
||||
AverageLatency time.Duration
|
||||
StatusChanges int64
|
||||
LastCheckTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
LastFailureTime time.Time
|
||||
HealthScore float64
|
||||
}
|
||||
|
||||
// Reset resets the health checker statistics
|
||||
func (hc *HealthChecker) Reset() {
|
||||
hc.status.Store(int32(HealthUnknown))
|
||||
hc.consecutiveSuccesses.Store(0)
|
||||
hc.consecutiveFailures.Store(0)
|
||||
hc.totalChecks.Store(0)
|
||||
hc.totalSuccesses.Store(0)
|
||||
hc.totalFailures.Store(0)
|
||||
hc.statusChanges.Store(0)
|
||||
hc.averageLatency.Store(0)
|
||||
|
||||
now := time.Now()
|
||||
hc.timeMu.Lock()
|
||||
hc.lastCheckTime = now
|
||||
hc.lastSuccessTime = now
|
||||
hc.lastFailureTime = now
|
||||
hc.timeMu.Unlock()
|
||||
}
|
||||
+216
@@ -0,0 +1,216 @@
|
||||
// Package resilience provides resilience patterns for cache backends.
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/cache/backends"
|
||||
)
|
||||
|
||||
// HealthCheckBackend wraps a cache backend with health checking
|
||||
type HealthCheckBackend struct {
|
||||
backend backends.CacheBackend
|
||||
config *HealthCheckConfig
|
||||
|
||||
// Health tracking
|
||||
status atomic.Int32
|
||||
consecutiveFails atomic.Int32
|
||||
consecutiveOK atomic.Int32
|
||||
lastCheck time.Time
|
||||
checkMutex sync.RWMutex
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewHealthCheckBackend creates a new health check wrapped backend
|
||||
func NewHealthCheckBackend(b backends.CacheBackend, config *HealthCheckConfig) backends.CacheBackend {
|
||||
if config == nil {
|
||||
config = DefaultHealthCheckConfig()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
hc := &HealthCheckBackend{
|
||||
backend: b,
|
||||
config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Set initial status to healthy (optimistic)
|
||||
hc.status.Store(int32(HealthHealthy))
|
||||
|
||||
// Start health check routine
|
||||
hc.wg.Add(1)
|
||||
go hc.healthCheckLoop()
|
||||
|
||||
return hc
|
||||
}
|
||||
|
||||
// Set stores a value and tracks health
|
||||
func (h *HealthCheckBackend) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
// Allow operations even if unhealthy (may recover)
|
||||
err := h.backend.Set(ctx, key, value, ttl)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Get retrieves a value and tracks health
|
||||
func (h *HealthCheckBackend) Get(ctx context.Context, key string) ([]byte, time.Duration, bool, error) {
|
||||
value, ttl, exists, err := h.backend.Get(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return value, ttl, exists, err
|
||||
}
|
||||
|
||||
// Delete removes a key and tracks health
|
||||
func (h *HealthCheckBackend) Delete(ctx context.Context, key string) (bool, error) {
|
||||
deleted, err := h.backend.Delete(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return deleted, err
|
||||
}
|
||||
|
||||
// Exists checks if a key exists and tracks health
|
||||
func (h *HealthCheckBackend) Exists(ctx context.Context, key string) (bool, error) {
|
||||
exists, err := h.backend.Exists(ctx, key)
|
||||
h.recordResult(err == nil)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// Clear removes all keys and tracks health
|
||||
func (h *HealthCheckBackend) Clear(ctx context.Context) error {
|
||||
err := h.backend.Clear(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetStats returns statistics including health status
|
||||
func (h *HealthCheckBackend) GetStats() map[string]interface{} {
|
||||
stats := h.backend.GetStats()
|
||||
if stats == nil {
|
||||
stats = make(map[string]interface{})
|
||||
}
|
||||
|
||||
h.checkMutex.RLock()
|
||||
lastCheck := h.lastCheck
|
||||
h.checkMutex.RUnlock()
|
||||
|
||||
status := HealthStatus(h.status.Load())
|
||||
stats["health"] = map[string]interface{}{
|
||||
"status": status.String(),
|
||||
"consecutive_fails": h.consecutiveFails.Load(),
|
||||
"consecutive_ok": h.consecutiveOK.Load(),
|
||||
"last_check": lastCheck.Format(time.RFC3339),
|
||||
"time_since_check": time.Since(lastCheck).Seconds(),
|
||||
"check_interval_sec": h.config.CheckInterval.Seconds(),
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// Ping checks backend health
|
||||
func (h *HealthCheckBackend) Ping(ctx context.Context) error {
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close shuts down the health checker and backend
|
||||
func (h *HealthCheckBackend) Close() error {
|
||||
// Stop health check routine
|
||||
h.cancel()
|
||||
|
||||
// Wait for routine to finish
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
h.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Finished normally
|
||||
case <-time.After(2 * time.Second):
|
||||
// Timeout
|
||||
}
|
||||
|
||||
return h.backend.Close()
|
||||
}
|
||||
|
||||
// IsHealthy returns true if the backend is healthy
|
||||
func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
status := HealthStatus(h.status.Load())
|
||||
return status == HealthHealthy || status == HealthDegraded
|
||||
}
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
// #nosec G115 -- threshold config values are small integers that fit in int32
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
// Check if we should transition to healthy
|
||||
if fails > 0 && oks >= int32(h.config.HealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthHealthy)))
|
||||
if oldStatus != HealthHealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthHealthy)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
oks := h.consecutiveOK.Swap(0)
|
||||
fails := h.consecutiveFails.Add(1)
|
||||
|
||||
// Check if we should transition to unhealthy
|
||||
if oks > 0 && fails >= int32(h.config.UnhealthyThreshold) {
|
||||
oldStatus := HealthStatus(h.status.Swap(int32(HealthUnhealthy)))
|
||||
if oldStatus != HealthUnhealthy && h.config.OnStatusChange != nil {
|
||||
h.config.OnStatusChange(oldStatus, HealthUnhealthy)
|
||||
}
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold)*2 {
|
||||
// Severely degraded
|
||||
h.status.Store(int32(HealthUnhealthy))
|
||||
} else if fails >= int32(h.config.UnhealthyThreshold) {
|
||||
// Degraded but still trying
|
||||
h.status.Store(int32(HealthDegraded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// healthCheckLoop runs periodic health checks
|
||||
func (h *HealthCheckBackend) healthCheckLoop() {
|
||||
defer h.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(h.config.CheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Do initial check
|
||||
h.performHealthCheck()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-h.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performHealthCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performHealthCheck performs a single health check
|
||||
func (h *HealthCheckBackend) performHealthCheck() {
|
||||
h.checkMutex.Lock()
|
||||
h.lastCheck = time.Now()
|
||||
h.checkMutex.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), h.config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
err := h.backend.Ping(ctx)
|
||||
h.recordResult(err == nil)
|
||||
}
|
||||
+447
@@ -0,0 +1,447 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestHealthChecker_StatusTransitions tests health status transitions
|
||||
func TestHealthChecker_StatusTransitions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Initially unknown
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy after threshold failures
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status == HealthUnhealthy || status == HealthDegraded)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should recover towards healthy
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded || finalStatus == HealthUnknown)
|
||||
}
|
||||
|
||||
// TestHealthChecker_InitialState tests initial health status
|
||||
func TestHealthChecker_InitialState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
assert.False(t, hc.IsHealthy())
|
||||
}
|
||||
|
||||
// TestHealthChecker_ForceCheck tests manual health check trigger
|
||||
func TestHealthChecker_ForceCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Second, // Long interval
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
initialCount := callCount.Load()
|
||||
|
||||
// Force check
|
||||
hc.Check(context.Background())
|
||||
|
||||
// Should have been called
|
||||
assert.Greater(t, callCount.Load(), initialCount)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusChangeCallback tests status change notifications
|
||||
func TestHealthChecker_StatusChangeCallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var transitions []string
|
||||
var mu sync.Mutex
|
||||
var shouldFail atomic.Bool
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
if shouldFail.Load() {
|
||||
return errors.New("health check failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
transitions = append(transitions, from.String()+"->"+to.String())
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
// Trigger failures
|
||||
shouldFail.Store(true)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Recover
|
||||
shouldFail.Store(false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should have status transitions
|
||||
assert.NotEmpty(t, transitions)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Stats tests statistics tracking
|
||||
func TestHealthChecker_Stats(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
if callCount.Load()%2 == 0 {
|
||||
return errors.New("failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 5,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
stats := hc.Stats()
|
||||
|
||||
assert.Greater(t, stats.TotalChecks, int64(0))
|
||||
assert.Greater(t, stats.TotalFailures, int64(0))
|
||||
assert.Greater(t, stats.SuccessRate, 0.0)
|
||||
assert.Less(t, stats.SuccessRate, 1.0)
|
||||
}
|
||||
|
||||
// TestHealthChecker_Timeout tests check timeout handling
|
||||
func TestHealthChecker_Timeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
// Simulate slow check
|
||||
select {
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 50 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond, // Short timeout
|
||||
UnhealthyThreshold: 2,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Should be unhealthy due to timeouts
|
||||
status := hc.GetStatus()
|
||||
assert.NotEqual(t, HealthHealthy, status)
|
||||
}
|
||||
|
||||
// TestHealthChecker_ConcurrentAccess tests thread safety
|
||||
func TestHealthChecker_ConcurrentAccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Millisecond,
|
||||
Timeout: 5 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
goroutines := 20
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 50; j++ {
|
||||
_ = hc.GetStatus()
|
||||
_ = hc.IsHealthy()
|
||||
_ = hc.Stats()
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
// Should complete without panics
|
||||
}
|
||||
|
||||
// TestHealthChecker_StopAndStart tests lifecycle management
|
||||
func TestHealthChecker_StopAndStart(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
callCount.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 20 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
// Start
|
||||
hc.Start()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count1 := callCount.Load()
|
||||
assert.Greater(t, count1, int32(0))
|
||||
|
||||
// Stop
|
||||
hc.Stop()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
count2 := callCount.Load()
|
||||
|
||||
// Should not have increased significantly after stop
|
||||
assert.Less(t, count2-count1, int32(3))
|
||||
}
|
||||
|
||||
// TestHealthChecker_DegradedState tests degraded status
|
||||
func TestHealthChecker_DegradedState(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
count := callCount.Add(1)
|
||||
// Fail once, then succeed
|
||||
if count == 1 {
|
||||
return errors.New("single failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3, // Need 3 failures for unhealthy
|
||||
HealthyThreshold: 2, // Need 2 successes for healthy
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// After initial checks, status should be set (might be healthy or degraded based on execution)
|
||||
status := hc.GetStatus()
|
||||
assert.True(t, status != HealthUnknown, "Status should not be unknown after checks")
|
||||
}
|
||||
|
||||
// TestHealthChecker_DefaultConfig tests default configuration
|
||||
func TestHealthChecker_DefaultConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
assert.NotNil(t, hc)
|
||||
assert.Equal(t, HealthUnknown, hc.GetStatus())
|
||||
|
||||
// Verify default config was applied (we can't access private fields, so just check it works)
|
||||
assert.NotNil(t, hc)
|
||||
}
|
||||
|
||||
// TestHealthChecker_StatusString tests status string representation
|
||||
func TestHealthChecker_StatusString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, "healthy", HealthHealthy.String())
|
||||
assert.Equal(t, "unhealthy", HealthUnhealthy.String())
|
||||
assert.Equal(t, "degraded", HealthDegraded.String())
|
||||
assert.Equal(t, "unknown", HealthStatus(999).String())
|
||||
}
|
||||
|
||||
// TestHealthChecker_RecoveryPattern tests typical failure and recovery
|
||||
func TestHealthChecker_RecoveryPattern(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var checkNumber atomic.Int32
|
||||
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
n := checkNumber.Add(1)
|
||||
// Fail checks 3-5, succeed others
|
||||
if n >= 3 && n <= 5 {
|
||||
return errors.New("temporary failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var statusLog []HealthStatus
|
||||
var mu sync.Mutex
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 30 * time.Millisecond,
|
||||
Timeout: 10 * time.Millisecond,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
OnStatusChange: func(from, to HealthStatus) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
statusLog = append(statusLog, to)
|
||||
},
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
hc.Start()
|
||||
defer hc.Stop()
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Should see transitions through unhealthy and back to healthy
|
||||
assert.NotEmpty(t, statusLog)
|
||||
|
||||
// Final status should be healthy or degraded (recovered)
|
||||
finalStatus := hc.GetStatus()
|
||||
assert.True(t, finalStatus == HealthHealthy || finalStatus == HealthDegraded, "Should have recovered")
|
||||
}
|
||||
|
||||
// Benchmark health checker performance
|
||||
func BenchmarkHealthChecker_ForceCheck(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckInterval: 10 * time.Minute,
|
||||
Timeout: 1 * time.Second,
|
||||
UnhealthyThreshold: 3,
|
||||
HealthyThreshold: 2,
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
hc.Check(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHealthChecker_Status(b *testing.B) {
|
||||
checkFunc := func(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
config := &HealthCheckConfig{
|
||||
CheckFunc: checkFunc,
|
||||
}
|
||||
hc := NewHealthChecker(config)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = hc.GetStatus()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,931 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Mock logger for testing
|
||||
type mockLogger struct {
|
||||
mu sync.Mutex
|
||||
logs []string
|
||||
errLogs []string
|
||||
debugLog []string
|
||||
}
|
||||
|
||||
func (m *mockLogger) Logf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.logs = append(m.logs, format)
|
||||
}
|
||||
|
||||
func (m *mockLogger) ErrorLogf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.errLogs = append(m.errLogs, format)
|
||||
}
|
||||
|
||||
func (m *mockLogger) DebugLogf(format string, args ...interface{}) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.debugLog = append(m.debugLog, format)
|
||||
}
|
||||
|
||||
func (m *mockLogger) getLogCount() int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return len(m.logs)
|
||||
}
|
||||
|
||||
// BackgroundTask tests
|
||||
func TestNewBackgroundTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
var wg sync.WaitGroup
|
||||
runCount := 0
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {
|
||||
runCount++
|
||||
}, logger, &wg)
|
||||
|
||||
if task == nil {
|
||||
t.Fatal("Expected NewBackgroundTask to return non-nil")
|
||||
}
|
||||
|
||||
if task.name != "test-task" {
|
||||
t.Errorf("Expected name 'test-task', got '%s'", task.name)
|
||||
}
|
||||
|
||||
if task.interval != 100*time.Millisecond {
|
||||
t.Errorf("Expected interval 100ms, got %v", task.interval)
|
||||
}
|
||||
|
||||
if task.IsRunning() {
|
||||
t.Error("Expected task to not be running initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgroundTask_Start(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
runCount := int32(0)
|
||||
|
||||
task := NewBackgroundTask("test-task", 50*time.Millisecond, func() {
|
||||
atomic.AddInt32(&runCount, 1)
|
||||
}, logger)
|
||||
|
||||
task.Start()
|
||||
|
||||
if !task.IsRunning() {
|
||||
t.Error("Expected task to be running after Start()")
|
||||
}
|
||||
|
||||
// Wait for at least 2 executions
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
|
||||
task.Stop()
|
||||
|
||||
count := atomic.LoadInt32(&runCount)
|
||||
if count < 2 {
|
||||
t.Errorf("Expected at least 2 executions, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgroundTask_Stop(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
task.Start()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
task.Stop()
|
||||
|
||||
if task.IsRunning() {
|
||||
t.Error("Expected task to not be running after Stop()")
|
||||
}
|
||||
|
||||
// Calling Stop again should not panic
|
||||
task.Stop()
|
||||
}
|
||||
|
||||
func TestBackgroundTask_DoubleStart(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
task.Start()
|
||||
logCountBefore := logger.getLogCount()
|
||||
|
||||
// Second start should be ignored
|
||||
task.Start()
|
||||
|
||||
logCountAfter := logger.getLogCount()
|
||||
if logCountAfter <= logCountBefore {
|
||||
t.Error("Expected log message about task already running")
|
||||
}
|
||||
|
||||
task.Stop()
|
||||
}
|
||||
|
||||
func TestBackgroundTask_ExecuteWithPanic(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
panicCount := int32(0)
|
||||
|
||||
task := NewBackgroundTask("panic-task", 50*time.Millisecond, func() {
|
||||
count := atomic.AddInt32(&panicCount, 1)
|
||||
if count == 1 {
|
||||
panic("test panic")
|
||||
}
|
||||
}, logger)
|
||||
|
||||
task.Start()
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
task.Stop()
|
||||
|
||||
// Task should recover from panic and continue
|
||||
finalCount := atomic.LoadInt32(&panicCount)
|
||||
if finalCount < 2 {
|
||||
t.Errorf("Expected task to continue after panic, got %d executions", finalCount)
|
||||
}
|
||||
|
||||
stats := task.GetStats()
|
||||
if stats["errorCount"].(int64) < 1 {
|
||||
t.Error("Expected error count to be at least 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgroundTask_GetStats(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
runCount := int32(0)
|
||||
|
||||
task := NewBackgroundTask("test-task", 50*time.Millisecond, func() {
|
||||
atomic.AddInt32(&runCount, 1)
|
||||
}, logger)
|
||||
|
||||
task.Start()
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
task.Stop()
|
||||
|
||||
stats := task.GetStats()
|
||||
|
||||
if stats["name"] != "test-task" {
|
||||
t.Errorf("Expected name 'test-task', got %v", stats["name"])
|
||||
}
|
||||
|
||||
if !stats["isRunning"].(bool) == true {
|
||||
// Task should be stopped
|
||||
}
|
||||
|
||||
if stats["runCount"].(int64) < 2 {
|
||||
t.Errorf("Expected runCount >= 2, got %v", stats["runCount"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackgroundTask_WithWaitGroup(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
var wg sync.WaitGroup
|
||||
runCount := int32(0)
|
||||
|
||||
task := NewBackgroundTask("test-task", 50*time.Millisecond, func() {
|
||||
atomic.AddInt32(&runCount, 1)
|
||||
}, logger, &wg)
|
||||
|
||||
task.Start()
|
||||
|
||||
// Wait for task to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Stop and wait
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
task.Stop()
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Timeout waiting for task to stop")
|
||||
}
|
||||
}
|
||||
|
||||
// TaskRegistry tests
|
||||
func TestNewTaskRegistry(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
if registry == nil {
|
||||
t.Fatal("Expected NewTaskRegistry to return non-nil")
|
||||
}
|
||||
|
||||
if registry.maxTasks != 10 {
|
||||
t.Errorf("Expected maxTasks 10, got %d", registry.maxTasks)
|
||||
}
|
||||
|
||||
if registry.GetTaskCount() != 0 {
|
||||
t.Error("Expected initial task count to be 0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_RegisterTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
err := registry.RegisterTask("test-task", task)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if registry.GetTaskCount() != 1 {
|
||||
t.Error("Expected task count to be 1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_RegisterTask_Duplicate(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task1 := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
task2 := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
err1 := registry.RegisterTask("test-task", task1)
|
||||
if err1 != nil {
|
||||
t.Errorf("Expected no error on first registration, got %v", err1)
|
||||
}
|
||||
|
||||
err2 := registry.RegisterTask("test-task", task2)
|
||||
if err2 == nil {
|
||||
t.Error("Expected error when registering duplicate task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_RegisterTask_Nil(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
err := registry.RegisterTask("test-task", nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when registering nil task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_RegisterTask_MaxLimit(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 2)
|
||||
|
||||
task1 := NewBackgroundTask("task1", 100*time.Millisecond, func() {}, logger)
|
||||
task2 := NewBackgroundTask("task2", 100*time.Millisecond, func() {}, logger)
|
||||
task3 := NewBackgroundTask("task3", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
registry.RegisterTask("task1", task1)
|
||||
registry.RegisterTask("task2", task2)
|
||||
err := registry.RegisterTask("task3", task3)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when exceeding max tasks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_UnregisterTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
registry.RegisterTask("test-task", task)
|
||||
|
||||
if registry.GetTaskCount() != 1 {
|
||||
t.Error("Expected task count to be 1")
|
||||
}
|
||||
|
||||
registry.UnregisterTask("test-task")
|
||||
|
||||
if registry.GetTaskCount() != 0 {
|
||||
t.Error("Expected task count to be 0 after unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_UnregisterTask_Running(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
registry.RegisterTask("test-task", task)
|
||||
task.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
registry.UnregisterTask("test-task")
|
||||
|
||||
if task.IsRunning() {
|
||||
t.Error("Expected task to be stopped after unregister")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_GetTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
registry.RegisterTask("test-task", task)
|
||||
|
||||
retrieved, exists := registry.GetTask("test-task")
|
||||
if !exists {
|
||||
t.Error("Expected task to exist")
|
||||
}
|
||||
|
||||
if retrieved != task {
|
||||
t.Error("Expected to retrieve the same task")
|
||||
}
|
||||
|
||||
_, exists = registry.GetTask("non-existent")
|
||||
if exists {
|
||||
t.Error("Expected non-existent task to not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_StopAllTasks(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task1 := NewBackgroundTask("task1", 100*time.Millisecond, func() {}, logger)
|
||||
task2 := NewBackgroundTask("task2", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
registry.RegisterTask("task1", task1)
|
||||
registry.RegisterTask("task2", task2)
|
||||
|
||||
task1.Start()
|
||||
task2.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
registry.StopAllTasks()
|
||||
|
||||
if task1.IsRunning() || task2.IsRunning() {
|
||||
t.Error("Expected all tasks to be stopped")
|
||||
}
|
||||
|
||||
if registry.GetTaskCount() != 0 {
|
||||
t.Error("Expected task count to be 0 after StopAllTasks")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_CreateSingletonTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
runCount := int32(0)
|
||||
task1, err1 := registry.CreateSingletonTask("singleton", 50*time.Millisecond, func() {
|
||||
atomic.AddInt32(&runCount, 1)
|
||||
}, logger)
|
||||
|
||||
if err1 != nil {
|
||||
t.Errorf("Expected no error, got %v", err1)
|
||||
}
|
||||
|
||||
if task1 == nil {
|
||||
t.Fatal("Expected task to be created")
|
||||
}
|
||||
|
||||
if !task1.IsRunning() {
|
||||
t.Error("Expected task to be running")
|
||||
}
|
||||
|
||||
// Try to create same task again
|
||||
task2, err2 := registry.CreateSingletonTask("singleton", 50*time.Millisecond, func() {
|
||||
atomic.AddInt32(&runCount, 1)
|
||||
}, logger)
|
||||
|
||||
if err2 != nil {
|
||||
t.Errorf("Expected no error on second call, got %v", err2)
|
||||
}
|
||||
|
||||
if task2 != task1 {
|
||||
t.Error("Expected to get the same task instance")
|
||||
}
|
||||
|
||||
time.Sleep(120 * time.Millisecond)
|
||||
task1.Stop()
|
||||
|
||||
if atomic.LoadInt32(&runCount) < 2 {
|
||||
t.Error("Expected task to have run multiple times")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_GetAllTasks(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task1 := NewBackgroundTask("task1", 100*time.Millisecond, func() {}, logger)
|
||||
task2 := NewBackgroundTask("task2", 100*time.Millisecond, func() {}, logger)
|
||||
|
||||
registry.RegisterTask("task1", task1)
|
||||
registry.RegisterTask("task2", task2)
|
||||
|
||||
allTasks := registry.GetAllTasks()
|
||||
|
||||
if len(allTasks) != 2 {
|
||||
t.Errorf("Expected 2 tasks, got %d", len(allTasks))
|
||||
}
|
||||
|
||||
if _, ok := allTasks["task1"]; !ok {
|
||||
t.Error("Expected task1 in results")
|
||||
}
|
||||
|
||||
if _, ok := allTasks["task2"]; !ok {
|
||||
t.Error("Expected task2 in results")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskRegistry_GetStats(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
registry.RegisterTask("test-task", task)
|
||||
task.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
stats := registry.GetStats()
|
||||
|
||||
if stats["totalTasks"].(int) != 1 {
|
||||
t.Errorf("Expected totalTasks 1, got %v", stats["totalTasks"])
|
||||
}
|
||||
|
||||
if stats["runningTasks"].(int) != 1 {
|
||||
t.Errorf("Expected runningTasks 1, got %v", stats["runningTasks"])
|
||||
}
|
||||
|
||||
if _, ok := stats["memory"]; !ok {
|
||||
t.Error("Expected memory stats")
|
||||
}
|
||||
|
||||
task.Stop()
|
||||
}
|
||||
|
||||
func TestGlobalTaskRegistry(t *testing.T) {
|
||||
// Reset before test
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
registry1 := GetGlobalTaskRegistry()
|
||||
registry2 := GetGlobalTaskRegistry()
|
||||
|
||||
if registry1 != registry2 {
|
||||
t.Error("Expected singleton to return same instance")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
ResetGlobalTaskRegistry()
|
||||
}
|
||||
|
||||
func TestResetGlobalTaskRegistry(t *testing.T) {
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
registry := GetGlobalTaskRegistry()
|
||||
logger := &mockLogger{}
|
||||
task := NewBackgroundTask("test-task", 100*time.Millisecond, func() {}, logger)
|
||||
registry.RegisterTask("test-task", task)
|
||||
task.Start()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
ResetGlobalTaskRegistry()
|
||||
|
||||
// Should get a new instance
|
||||
newRegistry := GetGlobalTaskRegistry()
|
||||
if newRegistry.GetTaskCount() != 0 {
|
||||
t.Error("Expected new registry to be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TaskCircuitBreaker tests
|
||||
func TestNewTaskCircuitBreaker(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(5, 30*time.Second, logger)
|
||||
|
||||
if cb == nil {
|
||||
t.Fatal("Expected NewTaskCircuitBreaker to return non-nil")
|
||||
}
|
||||
|
||||
if cb.failureThreshold != 5 {
|
||||
t.Errorf("Expected failureThreshold 5, got %d", cb.failureThreshold)
|
||||
}
|
||||
|
||||
if cb.timeout != 30*time.Second {
|
||||
t.Errorf("Expected timeout 30s, got %v", cb.timeout)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Expected initial state to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCircuitBreaker_CanCreateTask(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(3, 100*time.Millisecond, logger)
|
||||
|
||||
err := cb.CanCreateTask("test-task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error initially, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCircuitBreaker_OnTaskFailure(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(3, 100*time.Millisecond, logger)
|
||||
|
||||
// Record failures
|
||||
for i := 0; i < 3; i++ {
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
}
|
||||
|
||||
// Circuit should be open
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Expected circuit breaker to be open after threshold failures")
|
||||
}
|
||||
|
||||
// Should not be able to create task
|
||||
err := cb.CanCreateTask("test-task")
|
||||
if err == nil {
|
||||
t.Error("Expected error when circuit breaker is open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCircuitBreaker_OnTaskSuccess(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(5, 100*time.Millisecond, logger)
|
||||
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
|
||||
cb.OnTaskSuccess("test-task")
|
||||
|
||||
// Task-specific failures should be reset
|
||||
err := cb.CanCreateTask("test-task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after success, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCircuitBreaker_Reset(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(2, 100*time.Millisecond, logger)
|
||||
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Expected circuit breaker to be open")
|
||||
}
|
||||
|
||||
cb.Reset()
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Expected circuit breaker to be closed after reset")
|
||||
}
|
||||
|
||||
err := cb.CanCreateTask("test-task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error after reset, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskCircuitBreaker_TimeoutRecovery(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
cb := NewTaskCircuitBreaker(2, 100*time.Millisecond, logger)
|
||||
|
||||
// Open circuit breaker
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
cb.OnTaskFailure("test-task", nil)
|
||||
|
||||
if cb.GetState() != CircuitBreakerOpen {
|
||||
t.Error("Expected circuit breaker to be open")
|
||||
}
|
||||
|
||||
// Wait for timeout
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Circuit breaker should reset, but task-specific failures remain
|
||||
// Need to check with a different task name
|
||||
err := cb.CanCreateTask("different-task")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error for different task after timeout, got %v", err)
|
||||
}
|
||||
|
||||
if cb.GetState() != CircuitBreakerClosed {
|
||||
t.Error("Expected circuit breaker to be closed after timeout")
|
||||
}
|
||||
|
||||
// Original task still has too many failures
|
||||
err = cb.CanCreateTask("test-task")
|
||||
if err == nil {
|
||||
t.Error("Expected error for original task with too many failures")
|
||||
}
|
||||
}
|
||||
|
||||
// TaskMemoryMonitor tests
|
||||
func TestNewTaskMemoryMonitor(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
if monitor == nil {
|
||||
t.Fatal("Expected NewTaskMemoryMonitor to return non-nil")
|
||||
}
|
||||
|
||||
if monitor.registry != registry {
|
||||
t.Error("Expected registry to be set")
|
||||
}
|
||||
|
||||
if monitor.memoryThreshold != 1024*1024*1024 {
|
||||
t.Errorf("Expected default threshold 1GB, got %d", monitor.memoryThreshold)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskMemoryMonitor_SetMemoryThreshold(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
monitor.SetMemoryThreshold(512 * 1024 * 1024)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if stats["memoryThreshold"].(uint64) != 512*1024*1024 {
|
||||
t.Error("Expected threshold to be updated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTaskMemoryMonitor_StartStop(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
monitor.StartMonitoring()
|
||||
|
||||
stats := monitor.GetStats()
|
||||
if !stats["isMonitoring"].(bool) {
|
||||
t.Error("Expected monitor to be running")
|
||||
}
|
||||
|
||||
// Double start should be ignored
|
||||
monitor.StartMonitoring()
|
||||
|
||||
monitor.StopMonitoring()
|
||||
|
||||
stats = monitor.GetStats()
|
||||
if stats["isMonitoring"].(bool) {
|
||||
t.Error("Expected monitor to be stopped")
|
||||
}
|
||||
|
||||
// Double stop should be safe
|
||||
monitor.StopMonitoring()
|
||||
}
|
||||
|
||||
func TestTaskMemoryMonitor_GetStats(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
registry := NewTaskRegistry(logger, 10)
|
||||
monitor := NewTaskMemoryMonitor(logger, registry)
|
||||
|
||||
stats := monitor.GetStats()
|
||||
|
||||
if _, ok := stats["isMonitoring"]; !ok {
|
||||
t.Error("Expected isMonitoring in stats")
|
||||
}
|
||||
|
||||
if _, ok := stats["currentMemory"]; !ok {
|
||||
t.Error("Expected currentMemory in stats")
|
||||
}
|
||||
|
||||
if _, ok := stats["memoryThreshold"]; !ok {
|
||||
t.Error("Expected memoryThreshold in stats")
|
||||
}
|
||||
}
|
||||
|
||||
// WorkerPool tests
|
||||
func TestNewWorkerPool(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(4, 10, logger)
|
||||
|
||||
if pool == nil {
|
||||
t.Fatal("Expected NewWorkerPool to return non-nil")
|
||||
}
|
||||
|
||||
if pool.workers != 4 {
|
||||
t.Errorf("Expected 4 workers, got %d", pool.workers)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_DefaultWorkers(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(0, 0, logger)
|
||||
|
||||
// Should default to NumCPU
|
||||
if pool.workers <= 0 {
|
||||
t.Error("Expected positive number of workers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_StartStop(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(2, 5, logger)
|
||||
|
||||
pool.Start()
|
||||
|
||||
metrics := pool.GetMetrics()
|
||||
if !metrics["isRunning"].(bool) {
|
||||
t.Error("Expected worker pool to be running")
|
||||
}
|
||||
|
||||
// Double start should be ignored
|
||||
pool.Start()
|
||||
|
||||
pool.Stop()
|
||||
|
||||
metrics = pool.GetMetrics()
|
||||
if metrics["isRunning"].(bool) {
|
||||
t.Error("Expected worker pool to be stopped")
|
||||
}
|
||||
|
||||
// Double stop should be safe
|
||||
pool.Stop()
|
||||
}
|
||||
|
||||
func TestWorkerPool_Submit(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(2, 5, logger)
|
||||
|
||||
pool.Start()
|
||||
defer pool.Stop()
|
||||
|
||||
executed := int32(0)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
err := pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&executed, 1)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error submitting task, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for tasks to complete
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Timeout waiting for tasks to complete")
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&executed) != 3 {
|
||||
t.Errorf("Expected 3 tasks executed, got %d", atomic.LoadInt32(&executed))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_SubmitWhenStopped(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(2, 5, logger)
|
||||
|
||||
err := pool.Submit(func() {})
|
||||
if err == nil {
|
||||
t.Error("Expected error when submitting to stopped pool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_TaskPanic(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(2, 5, logger)
|
||||
|
||||
pool.Start()
|
||||
defer pool.Stop()
|
||||
|
||||
executed := int32(0)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(2)
|
||||
// Submit task that panics
|
||||
pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
panic("test panic")
|
||||
})
|
||||
|
||||
// Submit normal task
|
||||
pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&executed, 1)
|
||||
})
|
||||
|
||||
// Wait for tasks
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Timeout waiting for tasks")
|
||||
}
|
||||
|
||||
// Pool should still be functional
|
||||
metrics := pool.GetMetrics()
|
||||
if metrics["tasksFailed"].(int64) < 1 {
|
||||
t.Error("Expected at least one failed task")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_GetMetrics(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(2, 5, logger)
|
||||
|
||||
pool.Start()
|
||||
defer pool.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
})
|
||||
|
||||
pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
metrics := pool.GetMetrics()
|
||||
|
||||
if metrics["workers"].(int) != 2 {
|
||||
t.Errorf("Expected 2 workers, got %v", metrics["workers"])
|
||||
}
|
||||
|
||||
if metrics["tasksProcessed"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 processed tasks, got %v", metrics["tasksProcessed"])
|
||||
}
|
||||
|
||||
if metrics["tasksQueued"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 queued tasks, got %v", metrics["tasksQueued"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerPool_Concurrent(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
pool := NewWorkerPool(4, 20, logger)
|
||||
|
||||
pool.Start()
|
||||
defer pool.Stop()
|
||||
|
||||
executed := int32(0)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
taskCount := 10
|
||||
for i := 0; i < taskCount; i++ {
|
||||
wg.Add(1)
|
||||
err := pool.Submit(func() {
|
||||
defer wg.Done()
|
||||
atomic.AddInt32(&executed, 1)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
wg.Done()
|
||||
t.Errorf("Failed to submit task: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Timeout waiting for concurrent tasks")
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&executed) != int32(taskCount) {
|
||||
t.Errorf("Expected %d tasks executed, got %d", taskCount, atomic.LoadInt32(&executed))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
// Package cleanup provides background task management and cleanup functionality.
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Logger defines the logging interface
|
||||
type Logger interface {
|
||||
Logf(format string, args ...interface{})
|
||||
ErrorLogf(format string, args ...interface{})
|
||||
DebugLogf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// BackgroundTask represents a recurring background task
|
||||
type BackgroundTask struct {
|
||||
name string
|
||||
interval time.Duration
|
||||
taskFunc func()
|
||||
ticker *time.Ticker
|
||||
stopChan chan bool
|
||||
isRunning int32
|
||||
logger Logger
|
||||
waitGroup *sync.WaitGroup
|
||||
lastRun time.Time
|
||||
runCount int64
|
||||
errorCount int64
|
||||
mu sync.RWMutex
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
}
|
||||
|
||||
// NewBackgroundTask creates a new background task
|
||||
func NewBackgroundTask(name string, interval time.Duration, taskFunc func(), logger Logger, wg ...*sync.WaitGroup) *BackgroundTask {
|
||||
var waitGroup *sync.WaitGroup
|
||||
if len(wg) > 0 && wg[0] != nil {
|
||||
waitGroup = wg[0]
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &BackgroundTask{
|
||||
name: name,
|
||||
interval: interval,
|
||||
taskFunc: taskFunc,
|
||||
stopChan: make(chan bool, 1),
|
||||
isRunning: 0,
|
||||
logger: logger,
|
||||
waitGroup: waitGroup,
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins executing the background task
|
||||
func (bt *BackgroundTask) Start() {
|
||||
if !atomic.CompareAndSwapInt32(&bt.isRunning, 0, 1) {
|
||||
if bt.logger != nil {
|
||||
bt.logger.Logf("Background task %s is already running", bt.name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
bt.ticker = time.NewTicker(bt.interval)
|
||||
|
||||
if bt.waitGroup != nil {
|
||||
bt.waitGroup.Add(1)
|
||||
}
|
||||
|
||||
go bt.run()
|
||||
|
||||
if bt.logger != nil {
|
||||
bt.logger.Logf("Started background task: %s (interval: %v)", bt.name, bt.interval)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the background task
|
||||
func (bt *BackgroundTask) Stop() {
|
||||
if !atomic.CompareAndSwapInt32(&bt.isRunning, 1, 0) {
|
||||
if bt.logger != nil {
|
||||
bt.logger.Logf("Background task %s is not running", bt.name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Cancel context
|
||||
if bt.cancelFunc != nil {
|
||||
bt.cancelFunc()
|
||||
}
|
||||
|
||||
// Stop ticker
|
||||
if bt.ticker != nil {
|
||||
bt.ticker.Stop()
|
||||
}
|
||||
|
||||
// Send stop signal
|
||||
select {
|
||||
case bt.stopChan <- true:
|
||||
case <-time.After(5 * time.Second):
|
||||
if bt.logger != nil {
|
||||
bt.logger.ErrorLogf("Timeout stopping background task: %s", bt.name)
|
||||
}
|
||||
}
|
||||
|
||||
if bt.logger != nil {
|
||||
bt.logger.Logf("Stopped background task: %s", bt.name)
|
||||
}
|
||||
}
|
||||
|
||||
// run is the main loop for the background task
|
||||
func (bt *BackgroundTask) run() {
|
||||
defer func() {
|
||||
if bt.waitGroup != nil {
|
||||
bt.waitGroup.Done()
|
||||
}
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&bt.errorCount, 1)
|
||||
if bt.logger != nil {
|
||||
bt.logger.ErrorLogf("Background task %s panicked: %v", bt.name, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Run task immediately on start
|
||||
bt.executeTask()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-bt.ticker.C:
|
||||
bt.executeTask()
|
||||
case <-bt.stopChan:
|
||||
return
|
||||
case <-bt.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeTask runs the task function with error handling
|
||||
func (bt *BackgroundTask) executeTask() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&bt.errorCount, 1)
|
||||
if bt.logger != nil {
|
||||
bt.logger.ErrorLogf("Task %s panicked: %v", bt.name, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
bt.mu.Lock()
|
||||
bt.lastRun = time.Now()
|
||||
bt.mu.Unlock()
|
||||
|
||||
atomic.AddInt64(&bt.runCount, 1)
|
||||
bt.taskFunc()
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the task
|
||||
func (bt *BackgroundTask) GetStats() map[string]interface{} {
|
||||
bt.mu.RLock()
|
||||
lastRun := bt.lastRun
|
||||
bt.mu.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"name": bt.name,
|
||||
"interval": bt.interval.String(),
|
||||
"isRunning": atomic.LoadInt32(&bt.isRunning) == 1,
|
||||
"lastRun": lastRun.Format(time.RFC3339),
|
||||
"runCount": atomic.LoadInt64(&bt.runCount),
|
||||
"errorCount": atomic.LoadInt64(&bt.errorCount),
|
||||
}
|
||||
}
|
||||
|
||||
// IsRunning returns whether the task is currently running
|
||||
func (bt *BackgroundTask) IsRunning() bool {
|
||||
return atomic.LoadInt32(&bt.isRunning) == 1
|
||||
}
|
||||
|
||||
// TaskRegistry manages all background tasks
|
||||
type TaskRegistry struct {
|
||||
tasks map[string]*BackgroundTask
|
||||
mu sync.RWMutex
|
||||
logger Logger
|
||||
maxTasks int
|
||||
circuitBreaker *TaskCircuitBreaker
|
||||
}
|
||||
|
||||
// globalTaskRegistry is the singleton task registry
|
||||
var (
|
||||
globalTaskRegistry *TaskRegistry
|
||||
registryOnce sync.Once
|
||||
registryMutex sync.Mutex
|
||||
)
|
||||
|
||||
// GetGlobalTaskRegistry returns the global task registry singleton
|
||||
func GetGlobalTaskRegistry() *TaskRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalTaskRegistry = &TaskRegistry{
|
||||
tasks: make(map[string]*BackgroundTask),
|
||||
maxTasks: 100, // Default maximum tasks
|
||||
}
|
||||
})
|
||||
return globalTaskRegistry
|
||||
}
|
||||
|
||||
// ResetGlobalTaskRegistry resets the global task registry (mainly for testing)
|
||||
func ResetGlobalTaskRegistry() {
|
||||
registryMutex.Lock()
|
||||
defer registryMutex.Unlock()
|
||||
|
||||
if globalTaskRegistry != nil {
|
||||
globalTaskRegistry.StopAllTasks()
|
||||
globalTaskRegistry = nil
|
||||
}
|
||||
registryOnce = sync.Once{}
|
||||
}
|
||||
|
||||
// NewTaskRegistry creates a new task registry
|
||||
func NewTaskRegistry(logger Logger, maxTasks int) *TaskRegistry {
|
||||
return &TaskRegistry{
|
||||
tasks: make(map[string]*BackgroundTask),
|
||||
logger: logger,
|
||||
maxTasks: maxTasks,
|
||||
circuitBreaker: NewTaskCircuitBreaker(5, 30*time.Second, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterTask registers a new background task
|
||||
func (tr *TaskRegistry) RegisterTask(name string, task *BackgroundTask) error {
|
||||
if task == nil {
|
||||
return fmt.Errorf("task cannot be nil")
|
||||
}
|
||||
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
// Check if task already exists
|
||||
if _, exists := tr.tasks[name]; exists {
|
||||
return fmt.Errorf("task with name %s already exists", name)
|
||||
}
|
||||
|
||||
// Check task limit
|
||||
if len(tr.tasks) >= tr.maxTasks {
|
||||
return fmt.Errorf("maximum number of tasks (%d) reached", tr.maxTasks)
|
||||
}
|
||||
|
||||
// Check circuit breaker
|
||||
if tr.circuitBreaker != nil {
|
||||
if err := tr.circuitBreaker.CanCreateTask(name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
tr.tasks[name] = task
|
||||
|
||||
if tr.logger != nil {
|
||||
tr.logger.Logf("Registered task: %s", name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterTask removes a task from the registry
|
||||
func (tr *TaskRegistry) UnregisterTask(name string) {
|
||||
tr.mu.Lock()
|
||||
defer tr.mu.Unlock()
|
||||
|
||||
if task, exists := tr.tasks[name]; exists {
|
||||
if task.IsRunning() {
|
||||
task.Stop()
|
||||
}
|
||||
delete(tr.tasks, name)
|
||||
|
||||
if tr.logger != nil {
|
||||
tr.logger.Logf("Unregistered task: %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetTask returns a task by name
|
||||
func (tr *TaskRegistry) GetTask(name string) (*BackgroundTask, bool) {
|
||||
tr.mu.RLock()
|
||||
defer tr.mu.RUnlock()
|
||||
|
||||
task, exists := tr.tasks[name]
|
||||
return task, exists
|
||||
}
|
||||
|
||||
// StopAllTasks stops all registered tasks
|
||||
func (tr *TaskRegistry) StopAllTasks() {
|
||||
tr.mu.RLock()
|
||||
tasks := make([]*BackgroundTask, 0, len(tr.tasks))
|
||||
for _, task := range tr.tasks {
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
tr.mu.RUnlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, task := range tasks {
|
||||
if task.IsRunning() {
|
||||
wg.Add(1)
|
||||
go func(t *BackgroundTask) {
|
||||
defer wg.Done()
|
||||
t.Stop()
|
||||
}(task)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Clear all tasks from the registry after stopping them
|
||||
tr.mu.Lock()
|
||||
tr.tasks = make(map[string]*BackgroundTask)
|
||||
tr.mu.Unlock()
|
||||
|
||||
if tr.logger != nil {
|
||||
tr.logger.Logf("Stopped all tasks")
|
||||
}
|
||||
}
|
||||
|
||||
// GetTaskCount returns the number of registered tasks
|
||||
func (tr *TaskRegistry) GetTaskCount() int {
|
||||
tr.mu.RLock()
|
||||
defer tr.mu.RUnlock()
|
||||
return len(tr.tasks)
|
||||
}
|
||||
|
||||
// CreateSingletonTask creates or retrieves an existing task
|
||||
func (tr *TaskRegistry) CreateSingletonTask(name string, interval time.Duration,
|
||||
taskFunc func(), logger Logger, wg ...*sync.WaitGroup) (*BackgroundTask, error) {
|
||||
|
||||
// Check if task already exists
|
||||
if existingTask, exists := tr.GetTask(name); exists {
|
||||
if existingTask.IsRunning() {
|
||||
if logger != nil {
|
||||
logger.Logf("Task %s already exists and is running", name)
|
||||
}
|
||||
return existingTask, nil
|
||||
}
|
||||
// Task exists but not running, start it
|
||||
existingTask.Start()
|
||||
return existingTask, nil
|
||||
}
|
||||
|
||||
// Create new task
|
||||
task := NewBackgroundTask(name, interval, taskFunc, logger, wg...)
|
||||
|
||||
// Register task
|
||||
if err := tr.RegisterTask(name, task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Start task
|
||||
task.Start()
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// GetAllTasks returns all registered tasks
|
||||
func (tr *TaskRegistry) GetAllTasks() map[string]*BackgroundTask {
|
||||
tr.mu.RLock()
|
||||
defer tr.mu.RUnlock()
|
||||
|
||||
tasks := make(map[string]*BackgroundTask)
|
||||
for name, task := range tr.tasks {
|
||||
tasks[name] = task
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
// GetStats returns statistics for all tasks
|
||||
func (tr *TaskRegistry) GetStats() map[string]interface{} {
|
||||
tr.mu.RLock()
|
||||
defer tr.mu.RUnlock()
|
||||
|
||||
stats := make(map[string]interface{})
|
||||
stats["totalTasks"] = len(tr.tasks)
|
||||
|
||||
runningCount := 0
|
||||
taskStats := make(map[string]interface{})
|
||||
for name, task := range tr.tasks {
|
||||
if task.IsRunning() {
|
||||
runningCount++
|
||||
}
|
||||
taskStats[name] = task.GetStats()
|
||||
}
|
||||
|
||||
stats["runningTasks"] = runningCount
|
||||
stats["tasks"] = taskStats
|
||||
|
||||
// Add memory stats
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
stats["memory"] = map[string]interface{}{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlloc": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"numGC": m.NumGC,
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
@@ -0,0 +1,449 @@
|
||||
// Package cleanup provides background task management and cleanup functionality.
|
||||
package cleanup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TaskCircuitBreaker prevents task creation failures from cascading
|
||||
type TaskCircuitBreaker struct {
|
||||
failureThreshold int32
|
||||
failureCount int32
|
||||
lastFailureTime time.Time
|
||||
timeout time.Duration
|
||||
state int32 // 0: closed, 1: open
|
||||
logger Logger
|
||||
mu sync.RWMutex
|
||||
taskFailures map[string]int32
|
||||
}
|
||||
|
||||
// CircuitBreakerState represents the state of the circuit breaker
|
||||
type CircuitBreakerState int32
|
||||
|
||||
const (
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
CircuitBreakerOpen
|
||||
)
|
||||
|
||||
// NewTaskCircuitBreaker creates a new circuit breaker for task management
|
||||
func NewTaskCircuitBreaker(failureThreshold int32, timeout time.Duration, logger Logger) *TaskCircuitBreaker {
|
||||
return &TaskCircuitBreaker{
|
||||
failureThreshold: failureThreshold,
|
||||
timeout: timeout,
|
||||
logger: logger,
|
||||
taskFailures: make(map[string]int32),
|
||||
}
|
||||
}
|
||||
|
||||
// CanCreateTask checks if a new task can be created
|
||||
func (cb *TaskCircuitBreaker) CanCreateTask(taskName string) error {
|
||||
cb.mu.RLock()
|
||||
defer cb.mu.RUnlock()
|
||||
|
||||
// Check circuit breaker state
|
||||
if atomic.LoadInt32(&cb.state) == int32(CircuitBreakerOpen) {
|
||||
// Check if timeout has elapsed
|
||||
if time.Since(cb.lastFailureTime) < cb.timeout {
|
||||
return fmt.Errorf("circuit breaker open: too many task failures")
|
||||
}
|
||||
// Reset circuit breaker
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
|
||||
atomic.StoreInt32(&cb.failureCount, 0)
|
||||
if cb.logger != nil {
|
||||
cb.logger.Logf("Circuit breaker reset after timeout")
|
||||
}
|
||||
}
|
||||
|
||||
// Check task-specific failures
|
||||
if failures, exists := cb.taskFailures[taskName]; exists {
|
||||
if failures >= cb.failureThreshold {
|
||||
return fmt.Errorf("task %s has too many failures (%d)", taskName, failures)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnTaskStart records that a task has started
|
||||
func (cb *TaskCircuitBreaker) OnTaskStart(taskName string) {
|
||||
// Currently just for tracking, could add rate limiting here
|
||||
if cb.logger != nil {
|
||||
cb.logger.DebugLogf("Task %s started", taskName)
|
||||
}
|
||||
}
|
||||
|
||||
// OnTaskComplete records that a task completed (success or failure)
|
||||
func (cb *TaskCircuitBreaker) OnTaskComplete(taskName string) {
|
||||
// Currently just for tracking
|
||||
if cb.logger != nil {
|
||||
cb.logger.DebugLogf("Task %s completed", taskName)
|
||||
}
|
||||
}
|
||||
|
||||
// OnTaskSuccess records a successful task execution
|
||||
func (cb *TaskCircuitBreaker) OnTaskSuccess(taskName string) {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
// Reset task-specific failure count on success
|
||||
delete(cb.taskFailures, taskName)
|
||||
}
|
||||
|
||||
// OnTaskFailure records a task failure
|
||||
func (cb *TaskCircuitBreaker) OnTaskFailure(taskName string, err error) {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
// Increment task-specific failure count
|
||||
cb.taskFailures[taskName]++
|
||||
|
||||
// Increment overall failure count
|
||||
failures := atomic.AddInt32(&cb.failureCount, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
|
||||
if cb.logger != nil {
|
||||
cb.logger.ErrorLogf("Task %s failed: %v (failure count: %d)", taskName, err, cb.taskFailures[taskName])
|
||||
}
|
||||
|
||||
// Open circuit breaker if threshold reached
|
||||
if failures >= cb.failureThreshold {
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerOpen))
|
||||
if cb.logger != nil {
|
||||
cb.logger.ErrorLogf("Circuit breaker opened due to %d failures", failures)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker
|
||||
func (cb *TaskCircuitBreaker) Reset() {
|
||||
cb.mu.Lock()
|
||||
defer cb.mu.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
|
||||
atomic.StoreInt32(&cb.failureCount, 0)
|
||||
cb.taskFailures = make(map[string]int32)
|
||||
cb.lastFailureTime = time.Time{}
|
||||
|
||||
if cb.logger != nil {
|
||||
cb.logger.Logf("Circuit breaker reset")
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *TaskCircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
}
|
||||
|
||||
// TaskMemoryMonitor monitors memory usage and can trigger cleanup
|
||||
type TaskMemoryMonitor struct {
|
||||
logger Logger
|
||||
registry *TaskRegistry
|
||||
memoryThreshold uint64
|
||||
checkInterval time.Duration
|
||||
isMonitoring int32
|
||||
stopChan chan bool
|
||||
lastCheck time.Time
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
globalMemoryMonitor *TaskMemoryMonitor
|
||||
monitorOnce sync.Once
|
||||
)
|
||||
|
||||
// GetGlobalTaskMemoryMonitor returns the global memory monitor singleton
|
||||
func GetGlobalTaskMemoryMonitor(logger Logger) *TaskMemoryMonitor {
|
||||
monitorOnce.Do(func() {
|
||||
globalMemoryMonitor = NewTaskMemoryMonitor(logger, GetGlobalTaskRegistry())
|
||||
})
|
||||
return globalMemoryMonitor
|
||||
}
|
||||
|
||||
// NewTaskMemoryMonitor creates a new memory monitor
|
||||
func NewTaskMemoryMonitor(logger Logger, registry *TaskRegistry) *TaskMemoryMonitor {
|
||||
return &TaskMemoryMonitor{
|
||||
logger: logger,
|
||||
registry: registry,
|
||||
memoryThreshold: 1024 * 1024 * 1024, // 1GB default
|
||||
checkInterval: 1 * time.Minute,
|
||||
stopChan: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// SetMemoryThreshold sets the memory threshold for triggering cleanup
|
||||
func (tmm *TaskMemoryMonitor) SetMemoryThreshold(bytes uint64) {
|
||||
tmm.mu.Lock()
|
||||
defer tmm.mu.Unlock()
|
||||
tmm.memoryThreshold = bytes
|
||||
}
|
||||
|
||||
// StartMonitoring starts the memory monitoring routine
|
||||
func (tmm *TaskMemoryMonitor) StartMonitoring() {
|
||||
if !atomic.CompareAndSwapInt32(&tmm.isMonitoring, 0, 1) {
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.Logf("Memory monitor is already running")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
go tmm.monitorLoop()
|
||||
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.Logf("Started memory monitoring (threshold: %d bytes, interval: %v)",
|
||||
tmm.memoryThreshold, tmm.checkInterval)
|
||||
}
|
||||
}
|
||||
|
||||
// StopMonitoring stops the memory monitoring routine
|
||||
func (tmm *TaskMemoryMonitor) StopMonitoring() {
|
||||
if !atomic.CompareAndSwapInt32(&tmm.isMonitoring, 1, 0) {
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.Logf("Memory monitor is not running")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case tmm.stopChan <- true:
|
||||
case <-time.After(5 * time.Second):
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.ErrorLogf("Timeout stopping memory monitor")
|
||||
}
|
||||
}
|
||||
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.Logf("Stopped memory monitoring")
|
||||
}
|
||||
}
|
||||
|
||||
// monitorLoop is the main monitoring loop
|
||||
func (tmm *TaskMemoryMonitor) monitorLoop() {
|
||||
ticker := time.NewTicker(tmm.checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
tmm.checkMemory()
|
||||
case <-tmm.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkMemory checks current memory usage and triggers cleanup if needed
|
||||
func (tmm *TaskMemoryMonitor) checkMemory() {
|
||||
tmm.mu.Lock()
|
||||
tmm.lastCheck = time.Now()
|
||||
tmm.mu.Unlock()
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.DebugLogf("Memory check - Alloc: %d MB, Sys: %d MB, NumGC: %d",
|
||||
m.Alloc/1024/1024, m.Sys/1024/1024, m.NumGC)
|
||||
}
|
||||
|
||||
// Check if memory usage exceeds threshold
|
||||
if m.Alloc > tmm.memoryThreshold {
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.Logf("Memory usage (%d MB) exceeds threshold (%d MB), triggering cleanup",
|
||||
m.Alloc/1024/1024, tmm.memoryThreshold/1024/1024)
|
||||
}
|
||||
|
||||
// Trigger garbage collection
|
||||
runtime.GC()
|
||||
|
||||
// Could also trigger task-specific cleanup here
|
||||
tmm.triggerTaskCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// triggerTaskCleanup triggers cleanup operations on tasks
|
||||
func (tmm *TaskMemoryMonitor) triggerTaskCleanup() {
|
||||
if tmm.registry == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get all tasks and potentially pause non-critical ones
|
||||
tasks := tmm.registry.GetAllTasks()
|
||||
for name, task := range tasks {
|
||||
// Could implement task priority here
|
||||
if tmm.logger != nil {
|
||||
tmm.logger.DebugLogf("Checking task %s for cleanup opportunities", name)
|
||||
}
|
||||
// Tasks could implement a Cleanup() method
|
||||
_ = task // Placeholder for future cleanup logic
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats returns memory monitor statistics
|
||||
func (tmm *TaskMemoryMonitor) GetStats() map[string]interface{} {
|
||||
tmm.mu.RLock()
|
||||
lastCheck := tmm.lastCheck
|
||||
tmm.mu.RUnlock()
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
|
||||
return map[string]interface{}{
|
||||
"isMonitoring": atomic.LoadInt32(&tmm.isMonitoring) == 1,
|
||||
"lastCheck": lastCheck.Format(time.RFC3339),
|
||||
"checkInterval": tmm.checkInterval.String(),
|
||||
"memoryThreshold": tmm.memoryThreshold,
|
||||
"currentMemory": map[string]interface{}{
|
||||
"alloc": m.Alloc,
|
||||
"totalAlloc": m.TotalAlloc,
|
||||
"sys": m.Sys,
|
||||
"mallocs": m.Mallocs,
|
||||
"frees": m.Frees,
|
||||
"numGC": m.NumGC,
|
||||
"goroutines": runtime.NumGoroutine(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WorkerPool manages a pool of worker goroutines for task execution
|
||||
type WorkerPool struct {
|
||||
workers int
|
||||
taskQueue chan func()
|
||||
workerWg sync.WaitGroup
|
||||
isRunning int32
|
||||
logger Logger
|
||||
stopChan chan bool
|
||||
metrics WorkerPoolMetrics
|
||||
}
|
||||
|
||||
// WorkerPoolMetrics tracks worker pool performance
|
||||
type WorkerPoolMetrics struct {
|
||||
tasksProcessed int64
|
||||
tasksQueued int64
|
||||
tasksFailed int64
|
||||
avgProcessTime int64 // nanoseconds
|
||||
}
|
||||
|
||||
// NewWorkerPool creates a new worker pool
|
||||
func NewWorkerPool(workers int, queueSize int, logger Logger) *WorkerPool {
|
||||
if workers <= 0 {
|
||||
workers = runtime.NumCPU()
|
||||
}
|
||||
if queueSize <= 0 {
|
||||
queueSize = workers * 10
|
||||
}
|
||||
|
||||
return &WorkerPool{
|
||||
workers: workers,
|
||||
taskQueue: make(chan func(), queueSize),
|
||||
stopChan: make(chan bool),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the worker pool
|
||||
func (wp *WorkerPool) Start() {
|
||||
if !atomic.CompareAndSwapInt32(&wp.isRunning, 0, 1) {
|
||||
if wp.logger != nil {
|
||||
wp.logger.Logf("Worker pool is already running")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < wp.workers; i++ {
|
||||
wp.workerWg.Add(1)
|
||||
go wp.worker(i)
|
||||
}
|
||||
|
||||
if wp.logger != nil {
|
||||
wp.logger.Logf("Started worker pool with %d workers", wp.workers)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the worker pool
|
||||
func (wp *WorkerPool) Stop() {
|
||||
if !atomic.CompareAndSwapInt32(&wp.isRunning, 1, 0) {
|
||||
if wp.logger != nil {
|
||||
wp.logger.Logf("Worker pool is not running")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
close(wp.stopChan)
|
||||
close(wp.taskQueue)
|
||||
wp.workerWg.Wait()
|
||||
|
||||
if wp.logger != nil {
|
||||
wp.logger.Logf("Stopped worker pool")
|
||||
}
|
||||
}
|
||||
|
||||
// Submit submits a task to the worker pool
|
||||
func (wp *WorkerPool) Submit(task func()) error {
|
||||
if atomic.LoadInt32(&wp.isRunning) != 1 {
|
||||
return fmt.Errorf("worker pool is not running")
|
||||
}
|
||||
|
||||
select {
|
||||
case wp.taskQueue <- task:
|
||||
atomic.AddInt64(&wp.metrics.tasksQueued, 1)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("worker pool queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// worker is the main worker routine
|
||||
func (wp *WorkerPool) worker(id int) {
|
||||
defer wp.workerWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case task, ok := <-wp.taskQueue:
|
||||
if !ok {
|
||||
return // Channel closed
|
||||
}
|
||||
wp.executeTask(task)
|
||||
case <-wp.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// executeTask executes a task with error handling
|
||||
func (wp *WorkerPool) executeTask(task func()) {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
atomic.AddInt64(&wp.metrics.tasksFailed, 1)
|
||||
if wp.logger != nil {
|
||||
wp.logger.ErrorLogf("Worker pool task panicked: %v", r)
|
||||
}
|
||||
}
|
||||
// Update average process time
|
||||
duration := time.Since(startTime).Nanoseconds()
|
||||
processed := atomic.AddInt64(&wp.metrics.tasksProcessed, 1)
|
||||
currentAvg := atomic.LoadInt64(&wp.metrics.avgProcessTime)
|
||||
newAvg := (currentAvg*(processed-1) + duration) / processed
|
||||
atomic.StoreInt64(&wp.metrics.avgProcessTime, newAvg)
|
||||
}()
|
||||
|
||||
task()
|
||||
}
|
||||
|
||||
// GetMetrics returns worker pool metrics
|
||||
func (wp *WorkerPool) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"workers": wp.workers,
|
||||
"isRunning": atomic.LoadInt32(&wp.isRunning) == 1,
|
||||
"queueSize": len(wp.taskQueue),
|
||||
"queueCapacity": cap(wp.taskQueue),
|
||||
"tasksProcessed": atomic.LoadInt64(&wp.metrics.tasksProcessed),
|
||||
"tasksQueued": atomic.LoadInt64(&wp.metrics.tasksQueued),
|
||||
"tasksFailed": atomic.LoadInt64(&wp.metrics.tasksFailed),
|
||||
"avgProcessTime": time.Duration(atomic.LoadInt64(&wp.metrics.avgProcessTime)),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
// Package compat provides backward compatibility layer during refactoring
|
||||
package compat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CompatibilityLayer provides backward compatibility during the migration
|
||||
type CompatibilityLayer struct {
|
||||
mappings map[string]string // old path -> new path
|
||||
converters map[string]Converter
|
||||
deprecations map[string]string // deprecated field -> warning message
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Converter is a function that converts old value format to new format
|
||||
type Converter func(oldValue interface{}) (newValue interface{}, err error)
|
||||
|
||||
// Global compatibility layer instance
|
||||
var (
|
||||
layer *CompatibilityLayer
|
||||
layerOnce sync.Once
|
||||
)
|
||||
|
||||
// GetLayer returns the global compatibility layer instance
|
||||
func GetLayer() *CompatibilityLayer {
|
||||
layerOnce.Do(func() {
|
||||
layer = &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
layer.initialize()
|
||||
})
|
||||
return layer
|
||||
}
|
||||
|
||||
// initialize sets up default compatibility mappings
|
||||
func (c *CompatibilityLayer) initialize() {
|
||||
// Configuration path mappings (old -> new)
|
||||
c.RegisterMapping("ProviderURL", "Provider.IssuerURL")
|
||||
c.RegisterMapping("ClientID", "Provider.ClientID")
|
||||
c.RegisterMapping("ClientSecret", "Provider.ClientSecret")
|
||||
c.RegisterMapping("CallbackURL", "Provider.RedirectURL")
|
||||
c.RegisterMapping("LogoutURL", "Provider.LogoutURL")
|
||||
c.RegisterMapping("SessionEncryptionKey", "Session.EncryptionKey")
|
||||
c.RegisterMapping("Scopes", "Provider.Scopes")
|
||||
c.RegisterMapping("RateLimit", "Middleware.RateLimit")
|
||||
c.RegisterMapping("RefreshGracePeriodSeconds", "Token.RefreshGracePeriod")
|
||||
|
||||
// Redis configuration mappings
|
||||
c.RegisterMapping("RedisAddr", "Redis.Addresses[0]")
|
||||
c.RegisterMapping("RedisPassword", "Redis.Password")
|
||||
c.RegisterMapping("RedisDB", "Redis.DB")
|
||||
|
||||
// Session configuration mappings
|
||||
c.RegisterMapping("SessionName", "Session.Name")
|
||||
c.RegisterMapping("SessionMaxAge", "Session.MaxAge")
|
||||
c.RegisterMapping("SessionSecret", "Session.Secret")
|
||||
c.RegisterMapping("SessionChunkSize", "Session.ChunkSize")
|
||||
|
||||
// Security configuration mappings
|
||||
c.RegisterMapping("ForceHTTPS", "Security.ForceHTTPS")
|
||||
c.RegisterMapping("EnablePKCE", "Security.EnablePKCE")
|
||||
c.RegisterMapping("AllowedUsers", "Security.AllowedUsers")
|
||||
c.RegisterMapping("AllowedUserDomains", "Security.AllowedUserDomains")
|
||||
c.RegisterMapping("AllowedRolesAndGroups", "Security.AllowedRolesAndGroups")
|
||||
c.RegisterMapping("ExcludedURLs", "Security.ExcludedURLs")
|
||||
|
||||
// Register converters for complex transformations
|
||||
c.RegisterConverter("RefreshGracePeriodSeconds", func(oldValue interface{}) (interface{}, error) {
|
||||
// Convert seconds (int) to duration string
|
||||
if seconds, ok := oldValue.(int); ok {
|
||||
return fmt.Sprintf("%ds", seconds), nil
|
||||
}
|
||||
return oldValue, nil
|
||||
})
|
||||
|
||||
// Register deprecations
|
||||
c.RegisterDeprecation("LogLevel", "LogLevel is deprecated, use Logging.Level instead")
|
||||
c.RegisterDeprecation("HTTPClient", "HTTPClient is deprecated, configure via Transport settings")
|
||||
}
|
||||
|
||||
// RegisterMapping registers a field mapping from old to new path
|
||||
func (c *CompatibilityLayer) RegisterMapping(oldPath, newPath string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.mappings[oldPath] = newPath
|
||||
}
|
||||
|
||||
// RegisterConverter registers a value converter for a field
|
||||
func (c *CompatibilityLayer) RegisterConverter(field string, converter Converter) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.converters[field] = converter
|
||||
}
|
||||
|
||||
// RegisterDeprecation registers a deprecation warning for a field
|
||||
func (c *CompatibilityLayer) RegisterDeprecation(field, message string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.deprecations[field] = message
|
||||
}
|
||||
|
||||
// GetMapping returns the new path for an old configuration path
|
||||
func (c *CompatibilityLayer) GetMapping(oldPath string) (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
newPath, exists := c.mappings[oldPath]
|
||||
return newPath, exists
|
||||
}
|
||||
|
||||
// Convert applies conversion logic to a value
|
||||
func (c *CompatibilityLayer) Convert(field string, value interface{}) (interface{}, error) {
|
||||
c.mu.RLock()
|
||||
converter, exists := c.converters[field]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return converter(value)
|
||||
}
|
||||
|
||||
// CheckDeprecation checks if a field is deprecated and returns warning message
|
||||
func (c *CompatibilityLayer) CheckDeprecation(field string) (string, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
message, deprecated := c.deprecations[field]
|
||||
return message, deprecated
|
||||
}
|
||||
|
||||
// MigrateMap migrates an old configuration map to new structure
|
||||
func (c *CompatibilityLayer) MigrateMap(oldConfig map[string]interface{}) (map[string]interface{}, []string) {
|
||||
newConfig := make(map[string]interface{})
|
||||
warnings := []string{}
|
||||
|
||||
for key, value := range oldConfig {
|
||||
// Check for deprecation
|
||||
if warning, deprecated := c.CheckDeprecation(key); deprecated {
|
||||
warnings = append(warnings, warning)
|
||||
}
|
||||
|
||||
// Get new path
|
||||
newPath, hasMappming := c.GetMapping(key)
|
||||
if !hasMappming {
|
||||
// No mapping, use as-is
|
||||
newConfig[key] = value
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply converter if exists
|
||||
convertedValue, err := c.Convert(key, value)
|
||||
if err != nil {
|
||||
warnings = append(warnings, fmt.Sprintf("Failed to convert %s: %v", key, err))
|
||||
convertedValue = value
|
||||
}
|
||||
|
||||
// Set value at new path
|
||||
setNestedValue(newConfig, newPath, convertedValue)
|
||||
}
|
||||
|
||||
return newConfig, warnings
|
||||
}
|
||||
|
||||
// setNestedValue sets a value in a nested map structure using dot notation
|
||||
func setNestedValue(m map[string]interface{}, path string, value interface{}) {
|
||||
keys := splitPath(path)
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
current := m
|
||||
for i := 0; i < len(keys)-1; i++ {
|
||||
key := keys[i]
|
||||
|
||||
// Check if this key has array notation
|
||||
if isArrayPath(key) {
|
||||
// Handle array notation (e.g., "Addresses[0]")
|
||||
continue // Skip array handling for now, will be handled in actual migration
|
||||
}
|
||||
|
||||
if _, exists := current[key]; !exists {
|
||||
current[key] = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// Ensure it's a map
|
||||
if next, ok := current[key].(map[string]interface{}); ok {
|
||||
current = next
|
||||
} else {
|
||||
// Can't traverse further, create new map
|
||||
newMap := make(map[string]interface{})
|
||||
current[key] = newMap
|
||||
current = newMap
|
||||
}
|
||||
}
|
||||
|
||||
// Set the final value
|
||||
finalKey := keys[len(keys)-1]
|
||||
current[finalKey] = value
|
||||
}
|
||||
|
||||
// splitPath splits a configuration path into segments
|
||||
func splitPath(path string) []string {
|
||||
segments := []string{}
|
||||
current := ""
|
||||
|
||||
for i := 0; i < len(path); i++ {
|
||||
if path[i] == '.' {
|
||||
if current != "" {
|
||||
segments = append(segments, current)
|
||||
current = ""
|
||||
}
|
||||
} else {
|
||||
current += string(path[i])
|
||||
}
|
||||
}
|
||||
|
||||
if current != "" {
|
||||
segments = append(segments, current)
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
|
||||
// isArrayPath checks if a path segment contains array notation
|
||||
func isArrayPath(segment string) bool {
|
||||
for _, char := range segment {
|
||||
if char == '[' {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConfigAdapter provides an adapter interface for old code to work with new config
|
||||
type ConfigAdapter struct {
|
||||
newConfig interface{}
|
||||
oldPaths map[string]func() interface{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewConfigAdapter creates a new configuration adapter
|
||||
func NewConfigAdapter(newConfig interface{}) *ConfigAdapter {
|
||||
adapter := &ConfigAdapter{
|
||||
newConfig: newConfig,
|
||||
oldPaths: make(map[string]func() interface{}),
|
||||
}
|
||||
return adapter
|
||||
}
|
||||
|
||||
// RegisterGetter registers a getter function for an old path
|
||||
func (a *ConfigAdapter) RegisterGetter(oldPath string, getter func() interface{}) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.oldPaths[oldPath] = getter
|
||||
}
|
||||
|
||||
// Get retrieves a value using old path notation
|
||||
func (a *ConfigAdapter) Get(oldPath string) (interface{}, bool) {
|
||||
a.mu.RLock()
|
||||
getter, exists := a.oldPaths[oldPath]
|
||||
a.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Try to get from new config using reflection
|
||||
return a.getFromNewConfig(oldPath)
|
||||
}
|
||||
|
||||
return getter(), true
|
||||
}
|
||||
|
||||
// getFromNewConfig attempts to retrieve value from new config using reflection
|
||||
func (a *ConfigAdapter) getFromNewConfig(path string) (interface{}, bool) {
|
||||
// Check if there's a mapping for this path
|
||||
compat := GetLayer()
|
||||
if newPath, hasMappming := compat.GetMapping(path); hasMappming {
|
||||
return a.getNestedField(newPath)
|
||||
}
|
||||
|
||||
// Try direct access
|
||||
return a.getNestedField(path)
|
||||
}
|
||||
|
||||
// getNestedField retrieves a nested field value using reflection
|
||||
func (a *ConfigAdapter) getNestedField(path string) (interface{}, bool) {
|
||||
segments := splitPath(path)
|
||||
if len(segments) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
v := reflect.ValueOf(a.newConfig)
|
||||
|
||||
// Dereference pointer if needed
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
for _, segment := range segments {
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
field := v.FieldByName(segment)
|
||||
if !field.IsValid() {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
v = field
|
||||
}
|
||||
|
||||
if v.IsValid() && v.CanInterface() {
|
||||
return v.Interface(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
@@ -0,0 +1,495 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package compat
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetLayer_Singleton(t *testing.T) {
|
||||
// Reset global state
|
||||
layerOnce = sync.Once{}
|
||||
layer = nil
|
||||
|
||||
layer1 := GetLayer()
|
||||
layer2 := GetLayer()
|
||||
|
||||
if layer1 != layer2 {
|
||||
t.Error("Expected GetLayer to return same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLayer_Initialize(t *testing.T) {
|
||||
// Reset global state
|
||||
layerOnce = sync.Once{}
|
||||
layer = nil
|
||||
|
||||
l := GetLayer()
|
||||
|
||||
// Check default mappings exist
|
||||
if _, exists := l.GetMapping("ProviderURL"); !exists {
|
||||
t.Error("Expected ProviderURL mapping to exist")
|
||||
}
|
||||
|
||||
if _, exists := l.GetMapping("ClientID"); !exists {
|
||||
t.Error("Expected ClientID mapping to exist")
|
||||
}
|
||||
|
||||
// Check deprecations exist
|
||||
if _, deprecated := l.CheckDeprecation("LogLevel"); !deprecated {
|
||||
t.Error("Expected LogLevel to be marked deprecated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMapping(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterMapping("OldField", "New.Field")
|
||||
|
||||
newPath, exists := l.GetMapping("OldField")
|
||||
if !exists {
|
||||
t.Error("Expected mapping to exist")
|
||||
}
|
||||
|
||||
if newPath != "New.Field" {
|
||||
t.Errorf("Expected 'New.Field', got '%s'", newPath)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterConverter(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
converter := func(oldValue interface{}) (interface{}, error) {
|
||||
if str, ok := oldValue.(string); ok {
|
||||
return str + "_converted", nil
|
||||
}
|
||||
return oldValue, nil
|
||||
}
|
||||
|
||||
l.RegisterConverter("TestField", converter)
|
||||
|
||||
result, err := l.Convert("TestField", "test")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if result != "test_converted" {
|
||||
t.Errorf("Expected 'test_converted', got '%v'", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvert_NoConverter(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
// No converter registered
|
||||
result, err := l.Convert("UnknownField", "value")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if result != "value" {
|
||||
t.Error("Expected original value when no converter exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterDeprecation(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterDeprecation("OldField", "This field is deprecated")
|
||||
|
||||
message, deprecated := l.CheckDeprecation("OldField")
|
||||
if !deprecated {
|
||||
t.Error("Expected field to be deprecated")
|
||||
}
|
||||
|
||||
if message != "This field is deprecated" {
|
||||
t.Errorf("Expected deprecation message, got '%s'", message)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDeprecation_NotDeprecated(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
_, deprecated := l.CheckDeprecation("NewField")
|
||||
if deprecated {
|
||||
t.Error("Expected field not to be deprecated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateMap_BasicMapping(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterMapping("OldField", "New.Field")
|
||||
|
||||
oldConfig := map[string]interface{}{
|
||||
"OldField": "value123",
|
||||
}
|
||||
|
||||
newConfig, warnings := l.MigrateMap(oldConfig)
|
||||
|
||||
if len(warnings) != 0 {
|
||||
t.Errorf("Expected no warnings, got %d", len(warnings))
|
||||
}
|
||||
|
||||
// Check nested structure
|
||||
if newMap, ok := newConfig["New"].(map[string]interface{}); ok {
|
||||
if val, exists := newMap["Field"]; !exists || val != "value123" {
|
||||
t.Errorf("Expected nested field value 'value123', got %v", val)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected nested map structure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateMap_WithDeprecation(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterMapping("DeprecatedField", "New.Field")
|
||||
l.RegisterDeprecation("DeprecatedField", "Field is deprecated")
|
||||
|
||||
oldConfig := map[string]interface{}{
|
||||
"DeprecatedField": "value",
|
||||
}
|
||||
|
||||
_, warnings := l.MigrateMap(oldConfig)
|
||||
|
||||
if len(warnings) != 1 {
|
||||
t.Errorf("Expected 1 warning, got %d", len(warnings))
|
||||
}
|
||||
|
||||
if warnings[0] != "Field is deprecated" {
|
||||
t.Errorf("Expected deprecation warning, got '%s'", warnings[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateMap_WithConverter(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterMapping("Seconds", "Duration")
|
||||
l.RegisterConverter("Seconds", func(oldValue interface{}) (interface{}, error) {
|
||||
if seconds, ok := oldValue.(int); ok {
|
||||
return seconds * 1000, nil // Convert to milliseconds
|
||||
}
|
||||
return oldValue, nil
|
||||
})
|
||||
|
||||
oldConfig := map[string]interface{}{
|
||||
"Seconds": 60,
|
||||
}
|
||||
|
||||
newConfig, _ := l.MigrateMap(oldConfig)
|
||||
|
||||
if val, ok := newConfig["Duration"]; !ok || val != 60000 {
|
||||
t.Errorf("Expected Duration to be 60000, got %v", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateMap_NoMapping(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
oldConfig := map[string]interface{}{
|
||||
"UnmappedField": "value",
|
||||
}
|
||||
|
||||
newConfig, _ := l.MigrateMap(oldConfig)
|
||||
|
||||
if val, ok := newConfig["UnmappedField"]; !ok || val != "value" {
|
||||
t.Error("Expected unmapped field to be copied as-is")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected []string
|
||||
}{
|
||||
{"Simple", []string{"Simple"}},
|
||||
{"Nested.Path", []string{"Nested", "Path"}},
|
||||
{"Deep.Nested.Path", []string{"Deep", "Nested", "Path"}},
|
||||
{"", []string{}},
|
||||
{"Single", []string{"Single"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := splitPath(tt.path)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Path '%s': expected %d segments, got %d", tt.path, len(tt.expected), len(result))
|
||||
continue
|
||||
}
|
||||
|
||||
for i, segment := range result {
|
||||
if segment != tt.expected[i] {
|
||||
t.Errorf("Path '%s': segment %d expected '%s', got '%s'", tt.path, i, tt.expected[i], segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsArrayPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
segment string
|
||||
expected bool
|
||||
}{
|
||||
{"Addresses[0]", true},
|
||||
{"Items[5]", true},
|
||||
{"Simple", false},
|
||||
{"NoArray", false},
|
||||
{"[start", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := isArrayPath(tt.segment)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Segment '%s': expected %v, got %v", tt.segment, tt.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetNestedValue_SingleLevel(t *testing.T) {
|
||||
m := make(map[string]interface{})
|
||||
setNestedValue(m, "Field", "value")
|
||||
|
||||
if val, ok := m["Field"]; !ok || val != "value" {
|
||||
t.Error("Expected single level field to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetNestedValue_MultiLevel(t *testing.T) {
|
||||
m := make(map[string]interface{})
|
||||
setNestedValue(m, "Parent.Child", "value")
|
||||
|
||||
parent, ok := m["Parent"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected Parent to be a map")
|
||||
}
|
||||
|
||||
if val, ok := parent["Child"]; !ok || val != "value" {
|
||||
t.Error("Expected nested field to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetNestedValue_DeepNesting(t *testing.T) {
|
||||
m := make(map[string]interface{})
|
||||
setNestedValue(m, "Level1.Level2.Level3", "deep_value")
|
||||
|
||||
level1, ok := m["Level1"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected Level1 to be a map")
|
||||
}
|
||||
|
||||
level2, ok := level1["Level2"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected Level2 to be a map")
|
||||
}
|
||||
|
||||
if val, ok := level2["Level3"]; !ok || val != "deep_value" {
|
||||
t.Error("Expected deeply nested field to be set")
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigAdapter tests
|
||||
|
||||
func TestNewConfigAdapter(t *testing.T) {
|
||||
config := map[string]interface{}{"key": "value"}
|
||||
adapter := NewConfigAdapter(config)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("Expected adapter to be created")
|
||||
}
|
||||
|
||||
if adapter.newConfig == nil {
|
||||
t.Error("Expected config to be stored")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigAdapter_RegisterGetter(t *testing.T) {
|
||||
adapter := NewConfigAdapter(nil)
|
||||
|
||||
called := false
|
||||
adapter.RegisterGetter("TestPath", func() interface{} {
|
||||
called = true
|
||||
return "test_value"
|
||||
})
|
||||
|
||||
val, exists := adapter.Get("TestPath")
|
||||
if !exists {
|
||||
t.Error("Expected getter to exist")
|
||||
}
|
||||
|
||||
if val != "test_value" {
|
||||
t.Errorf("Expected 'test_value', got %v", val)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Expected getter function to be called")
|
||||
}
|
||||
}
|
||||
|
||||
type TestConfig struct {
|
||||
Provider struct {
|
||||
IssuerURL string
|
||||
ClientID string
|
||||
}
|
||||
Session struct {
|
||||
EncryptionKey string
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigAdapter_GetNestedField(t *testing.T) {
|
||||
config := &TestConfig{}
|
||||
config.Provider.IssuerURL = "https://test.com"
|
||||
config.Provider.ClientID = "test-client"
|
||||
config.Session.EncryptionKey = "secret123"
|
||||
|
||||
adapter := NewConfigAdapter(config)
|
||||
|
||||
// Test nested field access
|
||||
val, exists := adapter.getNestedField("Provider.IssuerURL")
|
||||
if !exists {
|
||||
t.Error("Expected field to exist")
|
||||
}
|
||||
|
||||
if val != "https://test.com" {
|
||||
t.Errorf("Expected 'https://test.com', got %v", val)
|
||||
}
|
||||
|
||||
// Test another nested field
|
||||
val2, exists2 := adapter.getNestedField("Provider.ClientID")
|
||||
if !exists2 || val2 != "test-client" {
|
||||
t.Error("Expected ClientID to be accessible")
|
||||
}
|
||||
|
||||
// Test non-existent field
|
||||
_, exists3 := adapter.getNestedField("NonExistent.Field")
|
||||
if exists3 {
|
||||
t.Error("Expected non-existent field to return false")
|
||||
}
|
||||
}
|
||||
|
||||
// Race condition tests
|
||||
|
||||
func TestCompatibilityLayer_ConcurrentAccess(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent registrations
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
l.RegisterMapping(string(rune('A'+idx%26)), "New.Field")
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, _ = l.GetMapping(string(rune('A' + idx%26)))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestCompatibilityLayer_ConcurrentMigrate(t *testing.T) {
|
||||
l := &CompatibilityLayer{
|
||||
mappings: make(map[string]string),
|
||||
converters: make(map[string]Converter),
|
||||
deprecations: make(map[string]string),
|
||||
}
|
||||
|
||||
l.RegisterMapping("OldField", "New.Field")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent migrations
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
oldConfig := map[string]interface{}{
|
||||
"OldField": "value",
|
||||
}
|
||||
_, _ = l.MigrateMap(oldConfig)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConfigAdapter_ConcurrentAccess(t *testing.T) {
|
||||
config := &TestConfig{}
|
||||
config.Provider.IssuerURL = "https://test.com"
|
||||
|
||||
adapter := NewConfigAdapter(config)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent getter registrations
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
path := string(rune('A' + idx%26))
|
||||
adapter.RegisterGetter(path, func() interface{} {
|
||||
return "value"
|
||||
})
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Concurrent gets
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
path := string(rune('A' + idx%26))
|
||||
_, _ = adapter.Get(path)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
// Package features provides feature flag management for safe rollback during refactoring
|
||||
package features
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// FeatureFlag represents a feature flag for controlling new functionality
|
||||
type FeatureFlag struct {
|
||||
name string
|
||||
description string
|
||||
enabled atomic.Bool
|
||||
mu sync.RWMutex
|
||||
callbacks []func(bool)
|
||||
}
|
||||
|
||||
// FeatureManager manages all feature flags in the application
|
||||
type FeatureManager struct {
|
||||
flags map[string]*FeatureFlag
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
var (
|
||||
// Global feature manager instance
|
||||
manager *FeatureManager
|
||||
managerOnce sync.Once
|
||||
)
|
||||
|
||||
// Feature flag names
|
||||
const (
|
||||
// UseUnifiedConfig enables the new unified configuration system
|
||||
UseUnifiedConfig = "USE_UNIFIED_CONFIG"
|
||||
|
||||
// UseNewFileStructure enables the new modularized file structure
|
||||
UseNewFileStructure = "USE_NEW_FILE_STRUCTURE"
|
||||
|
||||
// UseStandardErrors enables the standardized error package
|
||||
UseStandardErrors = "USE_STANDARD_ERRORS"
|
||||
|
||||
// UseEnhancedLogging enables the enhanced logging system
|
||||
UseEnhancedLogging = "USE_ENHANCED_LOGGING"
|
||||
|
||||
// UseOptimizedTests enables the consolidated test suite
|
||||
UseOptimizedTests = "USE_OPTIMIZED_TESTS"
|
||||
|
||||
// UseRedisRESP enables the custom Redis RESP implementation
|
||||
UseRedisRESP = "USE_REDIS_RESP"
|
||||
)
|
||||
|
||||
// GetManager returns the global feature manager instance
|
||||
func GetManager() *FeatureManager {
|
||||
managerOnce.Do(func() {
|
||||
manager = &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
manager.initialize()
|
||||
})
|
||||
return manager
|
||||
}
|
||||
|
||||
// initialize sets up default feature flags
|
||||
func (m *FeatureManager) initialize() {
|
||||
// Phase 0: Feature flags setup
|
||||
m.Register(UseUnifiedConfig, "Enable unified configuration package", false)
|
||||
m.Register(UseNewFileStructure, "Enable modularized file structure", false)
|
||||
m.Register(UseStandardErrors, "Enable standardized error handling", false)
|
||||
m.Register(UseEnhancedLogging, "Enable enhanced logging system", false)
|
||||
m.Register(UseOptimizedTests, "Enable optimized test suite", false)
|
||||
m.Register(UseRedisRESP, "Enable custom Redis RESP implementation", false)
|
||||
|
||||
// Load from environment variables
|
||||
m.LoadFromEnv()
|
||||
}
|
||||
|
||||
// Register creates a new feature flag
|
||||
func (m *FeatureManager) Register(name, description string, defaultValue bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
flag := &FeatureFlag{
|
||||
name: name,
|
||||
description: description,
|
||||
callbacks: make([]func(bool), 0),
|
||||
}
|
||||
flag.enabled.Store(defaultValue)
|
||||
m.flags[name] = flag
|
||||
}
|
||||
|
||||
// IsEnabled checks if a feature flag is enabled
|
||||
func (m *FeatureManager) IsEnabled(name string) bool {
|
||||
m.mu.RLock()
|
||||
flag, exists := m.flags[name]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return flag.enabled.Load()
|
||||
}
|
||||
|
||||
// Enable turns on a feature flag
|
||||
func (m *FeatureManager) Enable(name string) {
|
||||
m.setFlag(name, true)
|
||||
}
|
||||
|
||||
// Disable turns off a feature flag
|
||||
func (m *FeatureManager) Disable(name string) {
|
||||
m.setFlag(name, false)
|
||||
}
|
||||
|
||||
// Toggle switches a feature flag state
|
||||
func (m *FeatureManager) Toggle(name string) {
|
||||
m.mu.RLock()
|
||||
flag, exists := m.flags[name]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
newValue := !flag.enabled.Load()
|
||||
m.setFlag(name, newValue)
|
||||
}
|
||||
}
|
||||
|
||||
// setFlag updates a feature flag value and triggers callbacks
|
||||
func (m *FeatureManager) setFlag(name string, value bool) {
|
||||
m.mu.RLock()
|
||||
flag, exists := m.flags[name]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
oldValue := flag.enabled.Swap(value)
|
||||
|
||||
// Only trigger callbacks if value actually changed
|
||||
if oldValue != value {
|
||||
flag.mu.RLock()
|
||||
callbacks := flag.callbacks
|
||||
flag.mu.RUnlock()
|
||||
|
||||
for _, callback := range callbacks {
|
||||
callback(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnChange registers a callback to be called when a feature flag changes
|
||||
func (m *FeatureManager) OnChange(name string, callback func(bool)) {
|
||||
m.mu.RLock()
|
||||
flag, exists := m.flags[name]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
flag.mu.Lock()
|
||||
flag.callbacks = append(flag.callbacks, callback)
|
||||
flag.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// LoadFromEnv loads feature flag values from environment variables
|
||||
func (m *FeatureManager) LoadFromEnv() {
|
||||
m.mu.RLock()
|
||||
flags := make(map[string]*FeatureFlag)
|
||||
for name, flag := range m.flags {
|
||||
flags[name] = flag
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for name, flag := range flags {
|
||||
envVar := "FEATURE_" + name
|
||||
if value := os.Getenv(envVar); value != "" {
|
||||
enabled := strings.ToLower(value) == "true" || value == "1"
|
||||
flag.enabled.Store(enabled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAll returns all feature flags and their states
|
||||
func (m *FeatureManager) GetAll() map[string]bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]bool)
|
||||
for name, flag := range m.flags {
|
||||
result[name] = flag.enabled.Load()
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Reset resets all feature flags to their default values
|
||||
func (m *FeatureManager) Reset() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
for _, flag := range m.flags {
|
||||
flag.enabled.Store(false)
|
||||
flag.callbacks = make([]func(bool), 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for common checks
|
||||
|
||||
// IsUnifiedConfigEnabled checks if unified config is enabled
|
||||
func IsUnifiedConfigEnabled() bool {
|
||||
return GetManager().IsEnabled(UseUnifiedConfig)
|
||||
}
|
||||
|
||||
// IsNewFileStructureEnabled checks if new file structure is enabled
|
||||
func IsNewFileStructureEnabled() bool {
|
||||
return GetManager().IsEnabled(UseNewFileStructure)
|
||||
}
|
||||
|
||||
// IsStandardErrorsEnabled checks if standard errors are enabled
|
||||
func IsStandardErrorsEnabled() bool {
|
||||
return GetManager().IsEnabled(UseStandardErrors)
|
||||
}
|
||||
|
||||
// IsEnhancedLoggingEnabled checks if enhanced logging is enabled
|
||||
func IsEnhancedLoggingEnabled() bool {
|
||||
return GetManager().IsEnabled(UseEnhancedLogging)
|
||||
}
|
||||
|
||||
// IsOptimizedTestsEnabled checks if optimized tests are enabled
|
||||
func IsOptimizedTestsEnabled() bool {
|
||||
return GetManager().IsEnabled(UseOptimizedTests)
|
||||
}
|
||||
|
||||
// IsRedisRESPEnabled checks if custom Redis RESP is enabled
|
||||
func IsRedisRESPEnabled() bool {
|
||||
return GetManager().IsEnabled(UseRedisRESP)
|
||||
}
|
||||
@@ -0,0 +1,483 @@
|
||||
//go:build !yaegi
|
||||
|
||||
package features
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFeatureManager_Register(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
if !m.flags["TEST_FEATURE"].enabled.Load() == false {
|
||||
t.Error("Expected feature to be disabled by default")
|
||||
}
|
||||
|
||||
m.Register("TEST_ENABLED", "Test enabled feature", true)
|
||||
if m.flags["TEST_ENABLED"].enabled.Load() != true {
|
||||
t.Error("Expected feature to be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_IsEnabled(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", true)
|
||||
|
||||
if !m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected feature to be enabled")
|
||||
}
|
||||
|
||||
if m.IsEnabled("NON_EXISTENT") {
|
||||
t.Error("Expected non-existent feature to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_EnableDisable(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
// Enable the feature
|
||||
m.Enable("TEST_FEATURE")
|
||||
if !m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected feature to be enabled")
|
||||
}
|
||||
|
||||
// Disable the feature
|
||||
m.Disable("TEST_FEATURE")
|
||||
if m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected feature to be disabled")
|
||||
}
|
||||
|
||||
// Enable/Disable non-existent feature should not panic
|
||||
m.Enable("NON_EXISTENT")
|
||||
m.Disable("NON_EXISTENT")
|
||||
}
|
||||
|
||||
func TestFeatureManager_Toggle(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
// Toggle from false to true
|
||||
m.Toggle("TEST_FEATURE")
|
||||
if !m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected feature to be enabled after toggle")
|
||||
}
|
||||
|
||||
// Toggle from true to false
|
||||
m.Toggle("TEST_FEATURE")
|
||||
if m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected feature to be disabled after toggle")
|
||||
}
|
||||
|
||||
// Toggle non-existent feature should not panic
|
||||
m.Toggle("NON_EXISTENT")
|
||||
}
|
||||
|
||||
func TestFeatureManager_OnChange(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
var callbackCalled atomic.Bool
|
||||
var callbackValue atomic.Bool
|
||||
|
||||
m.OnChange("TEST_FEATURE", func(enabled bool) {
|
||||
callbackCalled.Store(true)
|
||||
callbackValue.Store(enabled)
|
||||
})
|
||||
|
||||
// Enable should trigger callback
|
||||
m.Enable("TEST_FEATURE")
|
||||
|
||||
// Wait briefly for callback
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if !callbackCalled.Load() {
|
||||
t.Error("Expected callback to be called")
|
||||
}
|
||||
|
||||
if !callbackValue.Load() {
|
||||
t.Error("Expected callback value to be true")
|
||||
}
|
||||
|
||||
// Setting to same value should NOT trigger callback again
|
||||
callbackCalled.Store(false)
|
||||
m.Enable("TEST_FEATURE")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if callbackCalled.Load() {
|
||||
t.Error("Expected callback NOT to be called when value doesn't change")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_LoadFromEnv(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
m.Register("TEST_FEATURE_2", "Test feature 2", false)
|
||||
|
||||
// Set environment variables
|
||||
os.Setenv("FEATURE_TEST_FEATURE", "true")
|
||||
os.Setenv("FEATURE_TEST_FEATURE_2", "1")
|
||||
defer func() {
|
||||
os.Unsetenv("FEATURE_TEST_FEATURE")
|
||||
os.Unsetenv("FEATURE_TEST_FEATURE_2")
|
||||
}()
|
||||
|
||||
m.LoadFromEnv()
|
||||
|
||||
if !m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected TEST_FEATURE to be enabled from env")
|
||||
}
|
||||
|
||||
if !m.IsEnabled("TEST_FEATURE_2") {
|
||||
t.Error("Expected TEST_FEATURE_2 to be enabled from env (value=1)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_LoadFromEnv_FalseValues(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", true) // Default true
|
||||
|
||||
// Set to false
|
||||
os.Setenv("FEATURE_TEST_FEATURE", "false")
|
||||
defer os.Unsetenv("FEATURE_TEST_FEATURE")
|
||||
|
||||
m.LoadFromEnv()
|
||||
|
||||
if m.IsEnabled("TEST_FEATURE") {
|
||||
t.Error("Expected TEST_FEATURE to be disabled from env")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_GetAll(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("FEATURE_1", "Feature 1", true)
|
||||
m.Register("FEATURE_2", "Feature 2", false)
|
||||
|
||||
all := m.GetAll()
|
||||
|
||||
if len(all) != 2 {
|
||||
t.Errorf("Expected 2 features, got %d", len(all))
|
||||
}
|
||||
|
||||
if !all["FEATURE_1"] {
|
||||
t.Error("Expected FEATURE_1 to be enabled")
|
||||
}
|
||||
|
||||
if all["FEATURE_2"] {
|
||||
t.Error("Expected FEATURE_2 to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_Reset(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("FEATURE_1", "Feature 1", true)
|
||||
m.Register("FEATURE_2", "Feature 2", true)
|
||||
|
||||
var callbackCalled atomic.Int32
|
||||
m.OnChange("FEATURE_1", func(enabled bool) {
|
||||
callbackCalled.Add(1)
|
||||
})
|
||||
|
||||
m.Reset()
|
||||
|
||||
// All features should be disabled
|
||||
if m.IsEnabled("FEATURE_1") {
|
||||
t.Error("Expected FEATURE_1 to be disabled after reset")
|
||||
}
|
||||
|
||||
if m.IsEnabled("FEATURE_2") {
|
||||
t.Error("Expected FEATURE_2 to be disabled after reset")
|
||||
}
|
||||
|
||||
// Callbacks should be cleared
|
||||
m.Enable("FEATURE_1")
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
if callbackCalled.Load() != 0 {
|
||||
t.Error("Expected callbacks to be cleared after reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetManager_Singleton(t *testing.T) {
|
||||
// Reset global state for clean test
|
||||
managerOnce = sync.Once{}
|
||||
manager = nil
|
||||
|
||||
m1 := GetManager()
|
||||
m2 := GetManager()
|
||||
|
||||
if m1 != m2 {
|
||||
t.Error("Expected GetManager to return same instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetManager_Initialize(t *testing.T) {
|
||||
// Reset global state for clean test
|
||||
managerOnce = sync.Once{}
|
||||
manager = nil
|
||||
|
||||
m := GetManager()
|
||||
|
||||
// Should have default feature flags
|
||||
all := m.GetAll()
|
||||
if len(all) < 6 {
|
||||
t.Errorf("Expected at least 6 default feature flags, got %d", len(all))
|
||||
}
|
||||
|
||||
// Check specific flags exist
|
||||
flags := []string{
|
||||
UseUnifiedConfig,
|
||||
UseNewFileStructure,
|
||||
UseStandardErrors,
|
||||
UseEnhancedLogging,
|
||||
UseOptimizedTests,
|
||||
UseRedisRESP,
|
||||
}
|
||||
|
||||
for _, flag := range flags {
|
||||
if _, exists := m.flags[flag]; !exists {
|
||||
t.Errorf("Expected default flag %s to exist", flag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHelperFunctions(t *testing.T) {
|
||||
// Reset global state
|
||||
managerOnce = sync.Once{}
|
||||
manager = nil
|
||||
|
||||
// Test IsUnifiedConfigEnabled
|
||||
if IsUnifiedConfigEnabled() {
|
||||
t.Error("Expected unified config to be disabled by default")
|
||||
}
|
||||
|
||||
GetManager().Enable(UseUnifiedConfig)
|
||||
if !IsUnifiedConfigEnabled() {
|
||||
t.Error("Expected unified config to be enabled")
|
||||
}
|
||||
|
||||
// Reset for next test
|
||||
GetManager().Reset()
|
||||
|
||||
// Test IsNewFileStructureEnabled
|
||||
if IsNewFileStructureEnabled() {
|
||||
t.Error("Expected new file structure to be disabled by default")
|
||||
}
|
||||
|
||||
GetManager().Enable(UseNewFileStructure)
|
||||
if !IsNewFileStructureEnabled() {
|
||||
t.Error("Expected new file structure to be enabled")
|
||||
}
|
||||
|
||||
// Test IsStandardErrorsEnabled
|
||||
GetManager().Reset()
|
||||
GetManager().Enable(UseStandardErrors)
|
||||
if !IsStandardErrorsEnabled() {
|
||||
t.Error("Expected standard errors to be enabled")
|
||||
}
|
||||
|
||||
// Test IsEnhancedLoggingEnabled
|
||||
GetManager().Reset()
|
||||
GetManager().Enable(UseEnhancedLogging)
|
||||
if !IsEnhancedLoggingEnabled() {
|
||||
t.Error("Expected enhanced logging to be enabled")
|
||||
}
|
||||
|
||||
// Test IsOptimizedTestsEnabled
|
||||
GetManager().Reset()
|
||||
GetManager().Enable(UseOptimizedTests)
|
||||
if !IsOptimizedTestsEnabled() {
|
||||
t.Error("Expected optimized tests to be enabled")
|
||||
}
|
||||
|
||||
// Test IsRedisRESPEnabled
|
||||
GetManager().Reset()
|
||||
GetManager().Enable(UseRedisRESP)
|
||||
if !IsRedisRESPEnabled() {
|
||||
t.Error("Expected Redis RESP to be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// Race condition tests
|
||||
func TestFeatureManager_ConcurrentAccess(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
iterations := 100
|
||||
|
||||
// Concurrent enables
|
||||
for i := 0; i < iterations; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Enable("TEST_FEATURE")
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent disables
|
||||
for i := 0; i < iterations; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Disable("TEST_FEATURE")
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent reads
|
||||
for i := 0; i < iterations; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = m.IsEnabled("TEST_FEATURE")
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should not panic - final state is not deterministic but that's ok
|
||||
}
|
||||
|
||||
func TestFeatureManager_ConcurrentCallbacks(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("TEST_FEATURE", "Test feature", false)
|
||||
|
||||
var callbackCount atomic.Int32
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Register multiple callbacks concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.OnChange("TEST_FEATURE", func(enabled bool) {
|
||||
callbackCount.Add(1)
|
||||
})
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Toggle the feature
|
||||
m.Toggle("TEST_FEATURE")
|
||||
|
||||
// Wait for callbacks
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// All 10 callbacks should have been called
|
||||
if callbackCount.Load() != 10 {
|
||||
t.Errorf("Expected 10 callbacks, got %d", callbackCount.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFeatureManager_ConcurrentGetAll(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
m.Register(string(rune('A'+i)), "Feature", false)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Concurrent GetAll calls
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
all := m.GetAll()
|
||||
if len(all) != 5 {
|
||||
t.Errorf("Expected 5 flags, got %d", len(all))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Concurrent modifications
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
flag := string(rune('A' + (idx % 5)))
|
||||
if idx%2 == 0 {
|
||||
m.Enable(flag)
|
||||
} else {
|
||||
m.Disable(flag)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestFeatureManager_LoadFromEnv_Concurrent(t *testing.T) {
|
||||
m := &FeatureManager{
|
||||
flags: make(map[string]*FeatureFlag),
|
||||
}
|
||||
|
||||
m.Register("FEATURE_1", "Feature 1", false)
|
||||
m.Register("FEATURE_2", "Feature 2", false)
|
||||
|
||||
os.Setenv("FEATURE_FEATURE_1", "true")
|
||||
os.Setenv("FEATURE_FEATURE_2", "true")
|
||||
defer func() {
|
||||
os.Unsetenv("FEATURE_FEATURE_1")
|
||||
os.Unsetenv("FEATURE_FEATURE_2")
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Load from env concurrently
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.LoadFromEnv()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Both should be enabled
|
||||
if !m.IsEnabled("FEATURE_1") || !m.IsEnabled("FEATURE_2") {
|
||||
t.Error("Expected features to be enabled from env")
|
||||
}
|
||||
}
|
||||
@@ -304,7 +304,8 @@ func (f *Factory) createSecureTLSConfig() *tls.Config {
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it to false is safe
|
||||
PreferServerCipherSuites: false, // Let client choose best cipher
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,14 @@ func getOrCreateLogFile(filename string) io.Writer {
|
||||
}
|
||||
|
||||
// Ensure log directory exists
|
||||
// #nosec G301 -- log directory needs to be readable by monitoring tools
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
// Fall back to stderr if we can't create the directory
|
||||
return os.Stderr
|
||||
}
|
||||
|
||||
filepath := logDir + "/" + filename
|
||||
// #nosec G302 G304 -- log files need to be readable; path is from trusted env var
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
// Fall back to stderr if we can't open the file
|
||||
|
||||
@@ -107,6 +107,7 @@ const (
|
||||
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
|
||||
|
||||
// Bearer token pattern (Authorization header)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
|
||||
|
||||
// Client ID pattern (alphanumeric with common separators)
|
||||
@@ -119,6 +120,7 @@ const (
|
||||
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
|
||||
|
||||
// CSRF token pattern (base64url)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
|
||||
|
||||
// Nonce pattern (base64url)
|
||||
|
||||
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
|
||||
@@ -39,25 +39,25 @@ func (p *Auth0Provider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -40,7 +40,7 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Remove offline_access scope as Cognito doesn't use it (case-insensitive)
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if strings.ToLower(scope) != "offline_access" {
|
||||
if strings.ToLower(scope) != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
@@ -48,18 +48,18 @@ func (p *AWSCognitoProvider) BuildAuthParams(baseParams url.Values, scopes []str
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range filteredScopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
filteredScopes = append(filteredScopes, "openid")
|
||||
filteredScopes = append(filteredScopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
// Default Cognito scopes if none specified
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
|
||||
filteredScopes = append(filteredScopes, "email", "profile")
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
|
||||
filteredScopes = append(filteredScopes, ScopeEmail, ScopeProfile)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -38,13 +38,13 @@ func (p *AzureProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -102,17 +102,17 @@ func (p *BaseProvider) ValidateTokenExpiry(session Session, token string, tokenC
|
||||
}
|
||||
|
||||
// BuildAuthParams constructs authorization parameters for the provider.
|
||||
// It includes the "offline_access" scope by default for refresh token support.
|
||||
// It includes the offline_access scope by default for refresh token support.
|
||||
func (p *BaseProvider) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -38,7 +38,7 @@ func (p *GitHubProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// GitHub doesn't use offline_access scope, so remove it if present
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Remove offline_access scope as GitLab doesn't use it
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
@@ -47,18 +47,18 @@ func (p *GitLabProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
// Ensure openid scope is present for OIDC
|
||||
hasOpenID := false
|
||||
for _, scope := range filteredScopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
filteredScopes = append(filteredScopes, "openid")
|
||||
filteredScopes = append(filteredScopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
// Default GitLab scopes if none specified
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == "openid" {
|
||||
filteredScopes = append(filteredScopes, "profile", "email")
|
||||
if len(filteredScopes) == 1 && filteredScopes[0] == ScopeOpenID {
|
||||
filteredScopes = append(filteredScopes, ScopeProfile, ScopeEmail)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -36,10 +36,10 @@ func (p *GoogleProvider) BuildAuthParams(baseParams url.Values, scopes []string)
|
||||
baseParams.Set("access_type", "offline")
|
||||
baseParams.Set("prompt", "consent")
|
||||
|
||||
// Google does not use the "offline_access" scope, so we remove it if present.
|
||||
// Google does not use the ScopeOfflineAccess scope, so we remove it if present.
|
||||
var filteredScopes []string
|
||||
for _, scope := range scopes {
|
||||
if scope != "offline_access" {
|
||||
if scope != ScopeOfflineAccess {
|
||||
filteredScopes = append(filteredScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,14 @@ const (
|
||||
ProviderTypeGitLab
|
||||
)
|
||||
|
||||
// Standard OAuth2/OIDC scope constants
|
||||
const (
|
||||
ScopeOfflineAccess = "offline_access"
|
||||
ScopeOpenID = "openid"
|
||||
ScopeProfile = "profile"
|
||||
ScopeEmail = "email"
|
||||
)
|
||||
|
||||
// ProviderCapabilities defines the specific features and behaviors of an OIDC provider.
|
||||
type ProviderCapabilities struct {
|
||||
PreferredTokenValidation string
|
||||
|
||||
@@ -39,25 +39,25 @@ func (p *KeycloakProvider) BuildAuthParams(baseParams url.Values, scopes []strin
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -39,25 +39,25 @@ func (p *OktaProvider) BuildAuthParams(baseParams url.Values, scopes []string) (
|
||||
// Ensure offline_access scope is present for refresh tokens
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "offline_access" {
|
||||
if scope == ScopeOfflineAccess {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
scopes = append(scopes, "offline_access")
|
||||
scopes = append(scopes, ScopeOfflineAccess)
|
||||
}
|
||||
|
||||
// Ensure openid scope is present
|
||||
hasOpenID := false
|
||||
for _, scope := range scopes {
|
||||
if scope == "openid" {
|
||||
if scope == ScopeOpenID {
|
||||
hasOpenID = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOpenID {
|
||||
scopes = append(scopes, "openid")
|
||||
scopes = append(scopes, ScopeOpenID)
|
||||
}
|
||||
|
||||
return &AuthParams{
|
||||
|
||||
@@ -61,7 +61,7 @@ func (v *ConfigValidator) ValidateScopes(scopes []string) error {
|
||||
|
||||
hasOpenIDScope := false
|
||||
for _, scope := range scopes {
|
||||
if strings.TrimSpace(scope) == "openid" {
|
||||
if strings.TrimSpace(scope) == ScopeOpenID {
|
||||
hasOpenIDScope = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -0,0 +1,307 @@
|
||||
// Package recovery provides error recovery and resilience mechanisms for OIDC authentication.
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ErrorRecoveryMechanism defines the interface for error recovery strategies.
|
||||
// It provides a common contract for implementing various resilience patterns
|
||||
// such as circuit breakers, retry mechanisms, and fallback strategies.
|
||||
type ErrorRecoveryMechanism interface {
|
||||
// ExecuteWithContext runs a function with error recovery using the provided context
|
||||
ExecuteWithContext(ctx context.Context, fn func() error) error
|
||||
// Reset resets the recovery mechanism state
|
||||
Reset()
|
||||
// IsAvailable checks if the mechanism is currently available for use
|
||||
IsAvailable() bool
|
||||
// GetMetrics returns metrics about the recovery mechanism's performance
|
||||
GetMetrics() map[string]interface{}
|
||||
}
|
||||
|
||||
// Logger defines the logging interface
|
||||
type Logger interface {
|
||||
Logf(format string, args ...interface{})
|
||||
ErrorLogf(format string, args ...interface{})
|
||||
DebugLogf(format string, args ...interface{})
|
||||
}
|
||||
|
||||
// BaseRecoveryMechanism provides common functionality and metrics tracking
|
||||
// for all recovery mechanism implementations. It handles request counting,
|
||||
// success/failure tracking, and timestamp management in a thread-safe manner.
|
||||
type BaseRecoveryMechanism struct {
|
||||
// name identifies the recovery mechanism instance
|
||||
name string
|
||||
// logger provides structured logging capabilities
|
||||
logger Logger
|
||||
|
||||
// Metrics tracked with atomic operations for thread safety
|
||||
totalRequests int64
|
||||
successCount int64
|
||||
failureCount int64
|
||||
lastSuccessStr string
|
||||
lastFailureStr string
|
||||
|
||||
// mutexes for thread-safe timestamp updates
|
||||
successMutex sync.RWMutex
|
||||
failureMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewBaseRecoveryMechanism creates a new base recovery mechanism with the given name and logger.
|
||||
// This serves as the foundation for specific recovery mechanism implementations.
|
||||
// Parameters:
|
||||
// - name: Identifier for this recovery mechanism instance
|
||||
// - logger: Logger instance for outputting diagnostic information
|
||||
//
|
||||
// Returns:
|
||||
// - A new BaseRecoveryMechanism instance with initialized metrics
|
||||
func NewBaseRecoveryMechanism(name string, logger Logger) *BaseRecoveryMechanism {
|
||||
return &BaseRecoveryMechanism{
|
||||
name: name,
|
||||
logger: logger,
|
||||
totalRequests: 0,
|
||||
successCount: 0,
|
||||
failureCount: 0,
|
||||
lastSuccessStr: "never",
|
||||
lastFailureStr: "never",
|
||||
}
|
||||
}
|
||||
|
||||
// RecordRequest increments the total request counter.
|
||||
// This method is thread-safe using atomic operations.
|
||||
func (b *BaseRecoveryMechanism) RecordRequest() {
|
||||
atomic.AddInt64(&b.totalRequests, 1)
|
||||
}
|
||||
|
||||
// RecordSuccess increments the success counter and updates the last success timestamp.
|
||||
// This method is thread-safe using atomic operations for counters
|
||||
// and mutex protection for timestamp updates.
|
||||
func (b *BaseRecoveryMechanism) RecordSuccess() {
|
||||
atomic.AddInt64(&b.successCount, 1)
|
||||
b.successMutex.Lock()
|
||||
b.lastSuccessStr = time.Now().Format(time.RFC3339)
|
||||
b.successMutex.Unlock()
|
||||
}
|
||||
|
||||
// RecordFailure increments the failure counter and updates the last failure timestamp.
|
||||
// This method is thread-safe using atomic operations for counters
|
||||
// and mutex protection for timestamp updates.
|
||||
func (b *BaseRecoveryMechanism) RecordFailure() {
|
||||
atomic.AddInt64(&b.failureCount, 1)
|
||||
b.failureMutex.Lock()
|
||||
b.lastFailureStr = time.Now().Format(time.RFC3339)
|
||||
b.failureMutex.Unlock()
|
||||
}
|
||||
|
||||
// GetBaseMetrics returns comprehensive metrics about the recovery mechanism.
|
||||
// Includes request counts, success/failure rates, timing information,
|
||||
// and calculated percentages. All access is thread-safe.
|
||||
func (b *BaseRecoveryMechanism) GetBaseMetrics() map[string]interface{} {
|
||||
total := atomic.LoadInt64(&b.totalRequests)
|
||||
success := atomic.LoadInt64(&b.successCount)
|
||||
failure := atomic.LoadInt64(&b.failureCount)
|
||||
|
||||
b.successMutex.RLock()
|
||||
lastSuccess := b.lastSuccessStr
|
||||
b.successMutex.RUnlock()
|
||||
|
||||
b.failureMutex.RLock()
|
||||
lastFailure := b.lastFailureStr
|
||||
b.failureMutex.RUnlock()
|
||||
|
||||
metrics := map[string]interface{}{
|
||||
"name": b.name,
|
||||
"totalRequests": total,
|
||||
"successCount": success,
|
||||
"failureCount": failure,
|
||||
"lastSuccess": lastSuccess,
|
||||
"lastFailure": lastFailure,
|
||||
}
|
||||
|
||||
// Calculate success and failure rates
|
||||
if total > 0 {
|
||||
successRate := float64(success) / float64(total) * 100
|
||||
failureRate := float64(failure) / float64(total) * 100
|
||||
metrics["successRate"] = fmt.Sprintf("%.2f%%", successRate)
|
||||
metrics["failureRate"] = fmt.Sprintf("%.2f%%", failureRate)
|
||||
} else {
|
||||
metrics["successRate"] = "0.00%"
|
||||
metrics["failureRate"] = "0.00%"
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// LogInfo logs an informational message with the mechanism name as prefix.
|
||||
// Provides consistent logging format across all recovery mechanisms.
|
||||
func (b *BaseRecoveryMechanism) LogInfo(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.Logf("[%s] %s", b.name, fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// LogError logs an error message with the mechanism name as prefix.
|
||||
// Used for reporting failures and error conditions in recovery mechanisms.
|
||||
func (b *BaseRecoveryMechanism) LogError(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.ErrorLogf("[%s] %s", b.name, fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// LogDebug logs a debug message with the mechanism name as prefix.
|
||||
// Useful for detailed troubleshooting of recovery mechanism behavior.
|
||||
func (b *BaseRecoveryMechanism) LogDebug(format string, args ...interface{}) {
|
||||
if b.logger != nil {
|
||||
b.logger.DebugLogf("[%s] %s", b.name, fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorType represents different categories of errors
|
||||
type ErrorType int
|
||||
|
||||
const (
|
||||
// ErrorTypeUnknown represents an unknown error type
|
||||
ErrorTypeUnknown ErrorType = iota
|
||||
// ErrorTypeNetwork represents network-related errors
|
||||
ErrorTypeNetwork
|
||||
// ErrorTypeTimeout represents timeout errors
|
||||
ErrorTypeTimeout
|
||||
// ErrorTypeAuthentication represents authentication errors
|
||||
ErrorTypeAuthentication
|
||||
// ErrorTypeRateLimit represents rate limiting errors
|
||||
ErrorTypeRateLimit
|
||||
// ErrorTypeServerError represents server errors (5xx)
|
||||
ErrorTypeServerError
|
||||
// ErrorTypeClientError represents client errors (4xx)
|
||||
ErrorTypeClientError
|
||||
)
|
||||
|
||||
// HTTPError represents an HTTP error with status code and message
|
||||
type HTTPError struct {
|
||||
StatusCode int
|
||||
Message string
|
||||
Body []byte
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *HTTPError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Message)
|
||||
}
|
||||
|
||||
// IsRetryable checks if the HTTP error is retryable
|
||||
func (e *HTTPError) IsRetryable() bool {
|
||||
// Retry on 5xx errors and specific 4xx errors
|
||||
return e.StatusCode >= 500 || e.StatusCode == 429 || e.StatusCode == 408
|
||||
}
|
||||
|
||||
// OIDCError represents an OIDC-specific error
|
||||
type OIDCError struct {
|
||||
Code string
|
||||
Description string
|
||||
URI string
|
||||
State string
|
||||
}
|
||||
|
||||
// Error implements the error interface
|
||||
func (e *OIDCError) Error() string {
|
||||
if e.Description != "" {
|
||||
return fmt.Sprintf("OIDC error %s: %s", e.Code, e.Description)
|
||||
}
|
||||
return fmt.Sprintf("OIDC error: %s", e.Code)
|
||||
}
|
||||
|
||||
// IsRetryable checks if the OIDC error is retryable
|
||||
func (e *OIDCError) IsRetryable() bool {
|
||||
// Some OIDC errors are retryable
|
||||
switch e.Code {
|
||||
case "temporarily_unavailable", "server_error":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackMechanism provides a simple fallback recovery strategy
|
||||
type FallbackMechanism struct {
|
||||
*BaseRecoveryMechanism
|
||||
fallbackFunc func() error
|
||||
}
|
||||
|
||||
// NewFallbackMechanism creates a new fallback mechanism
|
||||
func NewFallbackMechanism(name string, logger Logger, fallbackFunc func() error) *FallbackMechanism {
|
||||
return &FallbackMechanism{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism(name, logger),
|
||||
fallbackFunc: fallbackFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes the primary function and falls back on error
|
||||
func (f *FallbackMechanism) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
f.RecordRequest()
|
||||
|
||||
// Check context first
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
f.RecordFailure()
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Try primary function
|
||||
if err := fn(); err != nil {
|
||||
f.LogInfo("Primary function failed: %v, trying fallback", err)
|
||||
|
||||
// Try fallback
|
||||
if f.fallbackFunc != nil {
|
||||
if fallbackErr := f.fallbackFunc(); fallbackErr == nil {
|
||||
f.RecordSuccess()
|
||||
return nil
|
||||
} else {
|
||||
f.LogError("Fallback also failed: %v", fallbackErr)
|
||||
f.RecordFailure()
|
||||
return fmt.Errorf("both primary and fallback failed: primary=%v, fallback=%v", err, fallbackErr)
|
||||
}
|
||||
}
|
||||
|
||||
f.RecordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
f.RecordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset resets the fallback mechanism state
|
||||
func (f *FallbackMechanism) Reset() {
|
||||
// Reset metrics
|
||||
atomic.StoreInt64(&f.totalRequests, 0)
|
||||
atomic.StoreInt64(&f.successCount, 0)
|
||||
atomic.StoreInt64(&f.failureCount, 0)
|
||||
|
||||
f.successMutex.Lock()
|
||||
f.lastSuccessStr = "never"
|
||||
f.successMutex.Unlock()
|
||||
|
||||
f.failureMutex.Lock()
|
||||
f.lastFailureStr = "never"
|
||||
f.failureMutex.Unlock()
|
||||
}
|
||||
|
||||
// IsAvailable checks if the fallback mechanism is available
|
||||
func (f *FallbackMechanism) IsAvailable() bool {
|
||||
// Fallback is always available
|
||||
return true
|
||||
}
|
||||
|
||||
// GetMetrics returns metrics about the fallback mechanism
|
||||
func (f *FallbackMechanism) GetMetrics() map[string]interface{} {
|
||||
metrics := f.GetBaseMetrics()
|
||||
metrics["type"] = "fallback"
|
||||
metrics["hasFallback"] = f.fallbackFunc != nil
|
||||
return metrics
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
// Package recovery provides error recovery and resilience mechanisms for OIDC authentication.
|
||||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the current state of the circuit breaker
|
||||
type CircuitBreakerState int
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed allows all requests to pass through
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen blocks all requests
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen allows limited requests for testing
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
// String returns the string representation of the circuit breaker state
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig defines configuration for the circuit breaker
|
||||
type CircuitBreakerConfig struct {
|
||||
// FailureThreshold is the number of failures before opening the circuit
|
||||
FailureThreshold int
|
||||
// SuccessThreshold is the number of successes in half-open state before closing
|
||||
SuccessThreshold int
|
||||
// Timeout is the duration to wait before transitioning from open to half-open
|
||||
Timeout time.Duration
|
||||
// MaxRequests is the maximum number of requests allowed in half-open state
|
||||
MaxRequests int
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig returns sensible default configuration
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
FailureThreshold: 5,
|
||||
SuccessThreshold: 2,
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRequests: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for fault tolerance.
|
||||
// It prevents cascading failures by temporarily blocking requests to a failing service.
|
||||
type CircuitBreaker struct {
|
||||
*BaseRecoveryMechanism
|
||||
config CircuitBreakerConfig
|
||||
|
||||
// State management
|
||||
state int32 // atomic: CircuitBreakerState
|
||||
lastStateChange time.Time
|
||||
stateMutex sync.RWMutex
|
||||
|
||||
// Failure tracking
|
||||
consecutiveFailures int32 // atomic
|
||||
consecutiveSuccesses int32 // atomic
|
||||
|
||||
// Half-open state management
|
||||
halfOpenRequests int32 // atomic
|
||||
}
|
||||
|
||||
// NewCircuitBreaker creates a new circuit breaker with the given configuration
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger Logger) *CircuitBreaker {
|
||||
return &CircuitBreaker{
|
||||
BaseRecoveryMechanism: NewBaseRecoveryMechanism("CircuitBreaker", logger),
|
||||
config: config,
|
||||
state: int32(CircuitBreakerClosed),
|
||||
lastStateChange: time.Now(),
|
||||
consecutiveFailures: 0,
|
||||
consecutiveSuccesses: 0,
|
||||
halfOpenRequests: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithContext executes a function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) ExecuteWithContext(ctx context.Context, fn func() error) error {
|
||||
cb.RecordRequest()
|
||||
|
||||
// Check if request is allowed
|
||||
if !cb.allowRequest() {
|
||||
cb.RecordFailure()
|
||||
return fmt.Errorf("circuit breaker is open")
|
||||
}
|
||||
|
||||
// Execute the function
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute executes a function with circuit breaker protection (legacy method)
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
return cb.ExecuteWithContext(context.Background(), fn)
|
||||
}
|
||||
|
||||
// allowRequest determines if a request should be allowed based on the circuit state
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
return true
|
||||
|
||||
case CircuitBreakerOpen:
|
||||
// Check if timeout has elapsed
|
||||
cb.stateMutex.RLock()
|
||||
lastChange := cb.lastStateChange
|
||||
cb.stateMutex.RUnlock()
|
||||
|
||||
if time.Since(lastChange) > cb.config.Timeout {
|
||||
// Transition to half-open
|
||||
cb.transitionToHalfOpen()
|
||||
return cb.allowHalfOpenRequest()
|
||||
}
|
||||
return false
|
||||
|
||||
case CircuitBreakerHalfOpen:
|
||||
return cb.allowHalfOpenRequest()
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// allowHalfOpenRequest checks if a request is allowed in half-open state
|
||||
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
|
||||
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
|
||||
// #nosec G115 -- MaxRequests is a small config value that fits in int32
|
||||
if current <= int32(cb.config.MaxRequests) {
|
||||
return true
|
||||
}
|
||||
atomic.AddInt32(&cb.halfOpenRequests, -1)
|
||||
return false
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.RecordFailure()
|
||||
|
||||
failures := atomic.AddInt32(&cb.consecutiveFailures, 1)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
|
||||
cb.transitionToOpen()
|
||||
} else if state == CircuitBreakerHalfOpen {
|
||||
cb.transitionToOpen()
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.RecordSuccess()
|
||||
|
||||
successes := atomic.AddInt32(&cb.consecutiveSuccesses, 1)
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
|
||||
cb.transitionToClosed()
|
||||
}
|
||||
}
|
||||
|
||||
// transitionToClosed transitions the circuit to closed state
|
||||
func (cb *CircuitBreaker) transitionToClosed() {
|
||||
if atomic.CompareAndSwapInt32(&cb.state, int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
atomic.StoreInt32(&cb.halfOpenRequests, 0)
|
||||
|
||||
cb.LogInfo("Circuit breaker closed")
|
||||
}
|
||||
}
|
||||
|
||||
// transitionToOpen transitions the circuit to open state
|
||||
func (cb *CircuitBreaker) transitionToOpen() {
|
||||
oldState := atomic.SwapInt32(&cb.state, int32(CircuitBreakerOpen))
|
||||
if oldState != int32(CircuitBreakerOpen) {
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
atomic.StoreInt32(&cb.halfOpenRequests, 0)
|
||||
|
||||
cb.LogError("Circuit breaker opened due to failures")
|
||||
}
|
||||
}
|
||||
|
||||
// transitionToHalfOpen transitions the circuit to half-open state
|
||||
func (cb *CircuitBreaker) transitionToHalfOpen() {
|
||||
if atomic.CompareAndSwapInt32(&cb.state, int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
atomic.StoreInt32(&cb.halfOpenRequests, 0)
|
||||
|
||||
cb.LogInfo("Circuit breaker half-open, testing recovery")
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
}
|
||||
|
||||
// Reset resets the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
|
||||
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
atomic.StoreInt32(&cb.halfOpenRequests, 0)
|
||||
|
||||
// Reset base metrics
|
||||
atomic.StoreInt64(&cb.totalRequests, 0)
|
||||
atomic.StoreInt64(&cb.successCount, 0)
|
||||
atomic.StoreInt64(&cb.failureCount, 0)
|
||||
|
||||
cb.LogInfo("Circuit breaker reset to closed state")
|
||||
}
|
||||
|
||||
// IsAvailable returns true if the circuit breaker is not fully open
|
||||
func (cb *CircuitBreaker) IsAvailable() bool {
|
||||
state := cb.GetState()
|
||||
return state != CircuitBreakerOpen || time.Since(cb.getLastStateChange()) > cb.config.Timeout
|
||||
}
|
||||
|
||||
// getLastStateChange returns the last state change time safely
|
||||
func (cb *CircuitBreaker) getLastStateChange() time.Time {
|
||||
cb.stateMutex.RLock()
|
||||
defer cb.stateMutex.RUnlock()
|
||||
return cb.lastStateChange
|
||||
}
|
||||
|
||||
// GetMetrics returns comprehensive metrics about the circuit breaker
|
||||
func (cb *CircuitBreaker) GetMetrics() map[string]interface{} {
|
||||
metrics := cb.GetBaseMetrics()
|
||||
|
||||
state := cb.GetState()
|
||||
metrics["state"] = state.String()
|
||||
metrics["consecutiveFailures"] = atomic.LoadInt32(&cb.consecutiveFailures)
|
||||
metrics["consecutiveSuccesses"] = atomic.LoadInt32(&cb.consecutiveSuccesses)
|
||||
metrics["halfOpenRequests"] = atomic.LoadInt32(&cb.halfOpenRequests)
|
||||
|
||||
cb.stateMutex.RLock()
|
||||
metrics["lastStateChange"] = cb.lastStateChange.Format(time.RFC3339)
|
||||
metrics["timeSinceLastChange"] = time.Since(cb.lastStateChange).String()
|
||||
cb.stateMutex.RUnlock()
|
||||
|
||||
// Configuration
|
||||
metrics["config"] = map[string]interface{}{
|
||||
"failureThreshold": cb.config.FailureThreshold,
|
||||
"successThreshold": cb.config.SuccessThreshold,
|
||||
"timeout": cb.config.Timeout.String(),
|
||||
"maxRequests": cb.config.MaxRequests,
|
||||
}
|
||||
|
||||
// Health indicator
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
metrics["health"] = "healthy"
|
||||
case CircuitBreakerHalfOpen:
|
||||
metrics["health"] = "recovering"
|
||||
case CircuitBreakerOpen:
|
||||
if time.Since(cb.getLastStateChange()) > cb.config.Timeout {
|
||||
metrics["health"] = "ready-to-recover"
|
||||
} else {
|
||||
metrics["health"] = "unhealthy"
|
||||
}
|
||||
}
|
||||
|
||||
return metrics
|
||||
}
|
||||
|
||||
// ForceOpen forces the circuit breaker to open state
|
||||
func (cb *CircuitBreaker) ForceOpen() {
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerOpen))
|
||||
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
cb.LogInfo("Circuit breaker forced open")
|
||||
}
|
||||
|
||||
// ForceClosed forces the circuit breaker to closed state
|
||||
func (cb *CircuitBreaker) ForceClosed() {
|
||||
atomic.StoreInt32(&cb.state, int32(CircuitBreakerClosed))
|
||||
|
||||
cb.stateMutex.Lock()
|
||||
cb.lastStateChange = time.Now()
|
||||
cb.stateMutex.Unlock()
|
||||
|
||||
atomic.StoreInt32(&cb.consecutiveFailures, 0)
|
||||
atomic.StoreInt32(&cb.consecutiveSuccesses, 0)
|
||||
atomic.StoreInt32(&cb.halfOpenRequests, 0)
|
||||
|
||||
cb.LogInfo("Circuit breaker forced closed")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user