mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9126c74723 | |||
| a750c4f5b9 | |||
| 56051779ee | |||
| 3f126d50f3 | |||
| 91f0fc9ab8 |
@@ -1,622 +0,0 @@
|
||||
name: PR Validation
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
checks: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
# Fast feedback - format and basic checks
|
||||
quick-checks:
|
||||
name: Quick Checks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Format check
|
||||
run: |
|
||||
# Exclude vendor directory from format checks
|
||||
UNFORMATTED=$(gofmt -s -l . | grep -v "^vendor/" || true)
|
||||
if [ -n "$UNFORMATTED" ]; then
|
||||
echo "Code is not formatted. Run: gofmt -s -w ."
|
||||
echo "Unformatted files:"
|
||||
echo "$UNFORMATTED"
|
||||
gofmt -s -d $(echo "$UNFORMATTED")
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Go vet
|
||||
run: go vet ./...
|
||||
|
||||
- name: Go mod verify
|
||||
run: go mod verify
|
||||
|
||||
- name: Go mod tidy check
|
||||
run: |
|
||||
go mod tidy
|
||||
git diff --exit-code go.mod go.sum
|
||||
|
||||
# Static analysis with golangci-lint (advisory - will not fail the build)
|
||||
golangci-lint:
|
||||
name: golangci-lint
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v8
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=10m
|
||||
continue-on-error: true # Allow pipeline to continue even with linting warnings
|
||||
|
||||
# Staticcheck analysis
|
||||
staticcheck:
|
||||
name: Staticcheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install staticcheck
|
||||
run: go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||
|
||||
- name: Run staticcheck
|
||||
run: staticcheck ./...
|
||||
|
||||
# Security scanning with gosec
|
||||
gosec:
|
||||
name: Gosec Security Scanner
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run Gosec Security Scanner
|
||||
run: |
|
||||
go install github.com/securego/gosec/v2/cmd/gosec@latest
|
||||
gosec -no-fail -fmt sarif -out results.sarif ./... || echo "Gosec completed with warnings"
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: always() && hashFiles('results.sarif') != ''
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
continue-on-error: true
|
||||
|
||||
# Vulnerability scanning
|
||||
govulncheck:
|
||||
name: Vulnerability Scan
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Install govulncheck
|
||||
run: go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Run govulncheck
|
||||
run: govulncheck ./...
|
||||
|
||||
# CodeQL analysis
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Initialize CodeQL
|
||||
uses: github/codeql-action/init@v3
|
||||
with:
|
||||
languages: go
|
||||
continue-on-error: true
|
||||
|
||||
- name: Autobuild
|
||||
uses: github/codeql-action/autobuild@v3
|
||||
continue-on-error: true
|
||||
|
||||
- name: Perform CodeQL Analysis
|
||||
uses: github/codeql-action/analyze@v3
|
||||
continue-on-error: true
|
||||
|
||||
# Unit tests with race detection
|
||||
test-race:
|
||||
name: Unit Tests (Race Detector)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with race detector
|
||||
run: go test -race -timeout=15m -count=1 -v ./...
|
||||
env:
|
||||
GOMAXPROCS: 4
|
||||
|
||||
# Coverage analysis with threshold check
|
||||
test-coverage:
|
||||
name: Test Coverage
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
go test -coverprofile=coverage.out -covermode=atomic -timeout=15m ./...
|
||||
go tool cover -func=coverage.out -o=coverage.txt
|
||||
|
||||
- name: Calculate coverage
|
||||
id: coverage
|
||||
run: |
|
||||
COVERAGE=$(go tool cover -func=coverage.out | grep total | awk '{print $3}' | sed 's/%//')
|
||||
echo "coverage=$COVERAGE" >> $GITHUB_OUTPUT
|
||||
echo "Total Coverage: $COVERAGE%"
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
file: ./coverage.out
|
||||
flags: unittests
|
||||
name: codecov-umbrella
|
||||
fail_ci_if_error: false
|
||||
continue-on-error: true
|
||||
|
||||
- name: Comment coverage on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
uses: actions/github-script@v8
|
||||
with:
|
||||
script: |
|
||||
const coverage = '${{ steps.coverage.outputs.coverage }}';
|
||||
const threshold = 70;
|
||||
const coverageNum = parseFloat(coverage);
|
||||
const emoji = coverageNum >= threshold ? '✅' : '⚠️';
|
||||
const status = coverageNum >= threshold ? 'meets' : 'below';
|
||||
|
||||
const body = `## ${emoji} Test Coverage Report
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| **Total Coverage** | ${coverage}% |
|
||||
| **Threshold** | ${threshold}% |
|
||||
| **Status** | ${emoji} Coverage ${status} threshold |`;
|
||||
|
||||
// Find existing comment
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
});
|
||||
|
||||
const botComment = comments.find(comment =>
|
||||
comment.user.type === 'Bot' &&
|
||||
comment.body.includes('Test Coverage Report')
|
||||
);
|
||||
|
||||
if (botComment) {
|
||||
await github.rest.issues.updateComment({
|
||||
comment_id: botComment.id,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
} else {
|
||||
await github.rest.issues.createComment({
|
||||
issue_number: context.issue.number,
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: body
|
||||
});
|
||||
}
|
||||
|
||||
- name: Check coverage threshold
|
||||
run: |
|
||||
COVERAGE=${{ steps.coverage.outputs.coverage }}
|
||||
THRESHOLD=70
|
||||
echo "Coverage: $COVERAGE%"
|
||||
echo "Threshold: $THRESHOLD%"
|
||||
if (( $(echo "$COVERAGE < $THRESHOLD" | bc -l) )); then
|
||||
echo "⚠️ Coverage $COVERAGE% is below threshold $THRESHOLD%"
|
||||
exit 1
|
||||
fi
|
||||
echo "✅ Coverage $COVERAGE% meets threshold $THRESHOLD%"
|
||||
|
||||
# Memory leak detection
|
||||
test-memory-leaks:
|
||||
name: Memory Leak Detection
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run goroutine leak tests
|
||||
run: |
|
||||
echo "Running goroutine leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Gg]oroutine.*[Ll]eak.*' ./... || echo "No goroutine leak tests found"
|
||||
|
||||
- name: Run memory leak tests
|
||||
run: |
|
||||
echo "Running memory leak detection tests..."
|
||||
go test -v -timeout=20m -run='.*[Mm]emory.*[Ll]eak.*' ./... || echo "No memory leak tests found"
|
||||
|
||||
- name: Run cleanup tests
|
||||
run: |
|
||||
echo "Running cleanup and resource management tests..."
|
||||
go test -v -timeout=20m -run='.*[Cc]leanup.*' ./... || echo "No cleanup tests found"
|
||||
|
||||
# Integration tests
|
||||
test-integration:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run integration tests
|
||||
run: |
|
||||
if [ -d "./integration" ]; then
|
||||
go test -v -timeout=20m ./integration/...
|
||||
else
|
||||
echo "Running integration tests from all packages..."
|
||||
go test -v -timeout=20m -run='.*[Ii]ntegration.*' ./...
|
||||
fi
|
||||
|
||||
# Regression tests
|
||||
test-regression:
|
||||
name: Regression Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
echo "Running regression tests..."
|
||||
go test -v -timeout=20m -run='.*[Rr]egression.*' ./...
|
||||
|
||||
# Provider-specific tests (parallel matrix)
|
||||
test-providers:
|
||||
name: Provider Tests (${{ matrix.provider }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
provider:
|
||||
- google
|
||||
- azure
|
||||
- auth0
|
||||
- okta
|
||||
- keycloak
|
||||
- cognito
|
||||
- gitlab
|
||||
- github
|
||||
- generic
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run ${{ matrix.provider }} provider tests
|
||||
run: |
|
||||
PROVIDER_CAP=$(echo "${{ matrix.provider }}" | sed 's/.*/\u&/')
|
||||
echo "Testing $PROVIDER_CAP provider..."
|
||||
go test -v -timeout=10m -run=".*$PROVIDER_CAP.*" ./internal/providers/... || true
|
||||
go test -v -timeout=10m -run=".*${{ matrix.provider }}.*" ./... || true
|
||||
|
||||
# Benchmark tests with performance tracking
|
||||
benchmark:
|
||||
name: Benchmark Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run benchmarks
|
||||
run: |
|
||||
echo "Running benchmark tests..."
|
||||
go test -bench=. -benchmem -benchtime=1s -run=^$ ./... | tee benchmark.txt
|
||||
|
||||
- name: Upload benchmark results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark-results
|
||||
path: benchmark.txt
|
||||
retention-days: 30
|
||||
|
||||
- name: Compare benchmarks
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
run: |
|
||||
echo "Benchmark results available in artifacts"
|
||||
echo "To compare with main branch, download previous benchmark results"
|
||||
|
||||
# Build validation across platforms
|
||||
build:
|
||||
name: Build (${{ matrix.os }}/${{ matrix.arch }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [linux, darwin]
|
||||
arch: [amd64, arm64]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Build for ${{ matrix.os }}/${{ matrix.arch }}
|
||||
env:
|
||||
GOOS: ${{ matrix.os }}
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
run: |
|
||||
echo "Building for $GOOS/$GOARCH..."
|
||||
go build -v -ldflags="-s -w" ./...
|
||||
|
||||
# Security-specific edge case tests
|
||||
test-security-edge-cases:
|
||||
name: Security Edge Cases
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run security edge case tests
|
||||
run: |
|
||||
echo "Running security edge case tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ecurity.*' ./...
|
||||
|
||||
# Session management tests
|
||||
test-session:
|
||||
name: Session Management Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run session tests
|
||||
run: |
|
||||
echo "Running session management tests..."
|
||||
go test -v -timeout=15m -run='.*[Ss]ession.*' ./...
|
||||
|
||||
# Token validation tests
|
||||
test-token:
|
||||
name: Token Validation Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run token validation tests
|
||||
run: |
|
||||
echo "Running token validation tests..."
|
||||
go test -v -timeout=15m -run='.*[Tt]oken.*' ./...
|
||||
|
||||
# CSRF and security tests
|
||||
test-csrf:
|
||||
name: CSRF and Security Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: '1.24'
|
||||
cache: true
|
||||
|
||||
- name: Run CSRF tests
|
||||
run: |
|
||||
echo "Running CSRF and security tests..."
|
||||
go test -v -timeout=15m -run='.*[Cc][Ss][Rr][Ff].*' ./...
|
||||
|
||||
# Multi-Go version compatibility
|
||||
test-go-versions:
|
||||
name: Go ${{ matrix.go-version }} Compatibility
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
go-version: ['1.24']
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Run tests on Go ${{ matrix.go-version }}
|
||||
run: go test -short -timeout=10m ./...
|
||||
|
||||
# Final validation - all checks must pass (golangci-lint is advisory)
|
||||
all-checks-passed:
|
||||
name: ✅ All Checks Passed
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- quick-checks
|
||||
- golangci-lint
|
||||
- staticcheck
|
||||
- gosec
|
||||
- govulncheck
|
||||
- codeql
|
||||
- test-race
|
||||
- test-coverage
|
||||
- test-memory-leaks
|
||||
- test-integration
|
||||
- test-regression
|
||||
- test-providers
|
||||
- benchmark
|
||||
- build
|
||||
- test-security-edge-cases
|
||||
- test-session
|
||||
- test-token
|
||||
- test-csrf
|
||||
- test-go-versions
|
||||
if: always()
|
||||
steps:
|
||||
- name: Check all jobs status
|
||||
run: |
|
||||
echo "Checking status of all jobs..."
|
||||
|
||||
# Check critical jobs (excluding golangci-lint which is advisory)
|
||||
CRITICAL_FAILURES=false
|
||||
|
||||
if [ "${{ needs.quick-checks.result }}" == "failure" ] || \
|
||||
[ "${{ needs.staticcheck.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-race.result }}" == "failure" ] || \
|
||||
[ "${{ needs.test-coverage.result }}" == "failure" ] || \
|
||||
[ "${{ needs.build.result }}" == "failure" ]; then
|
||||
CRITICAL_FAILURES=true
|
||||
fi
|
||||
|
||||
if [ "$CRITICAL_FAILURES" == "true" ]; then
|
||||
echo "❌ Critical checks failed"
|
||||
exit 1
|
||||
elif [ "${{ contains(needs.*.result, 'cancelled') }}" == "true" ]; then
|
||||
echo "⚠️ Some checks were cancelled"
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All critical checks passed successfully!"
|
||||
if [ "${{ needs.golangci-lint.result }}" != "success" ]; then
|
||||
echo "ℹ️ Note: golangci-lint reported issues (advisory only)"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Post summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "# PR Validation Summary" >> $GITHUB_STEP_SUMMARY
|
||||
echo "" >> $GITHUB_STEP_SUMMARY
|
||||
echo "## Job Status" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Quick Checks: ${{ needs.quick-checks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Linting (advisory): ${{ needs.golangci-lint.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Static Analysis: ${{ needs.staticcheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Scan (gosec): ${{ needs.gosec.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Vulnerability Scan: ${{ needs.govulncheck.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CodeQL: ${{ needs.codeql.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Race Detection: ${{ needs.test-race.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Coverage: ${{ needs.test-coverage.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Memory Leaks: ${{ needs.test-memory-leaks.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Integration Tests: ${{ needs.test-integration.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Regression Tests: ${{ needs.test-regression.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Provider Tests: ${{ needs.test-providers.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Benchmarks: ${{ needs.benchmark.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Build: ${{ needs.build.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Security Edge Cases: ${{ needs.test-security-edge-cases.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Session Tests: ${{ needs.test-session.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Token Tests: ${{ needs.test-token.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- CSRF Tests: ${{ needs.test-csrf.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
echo "- Go Version Compatibility: ${{ needs.test-go-versions.result }}" >> $GITHUB_STEP_SUMMARY
|
||||
@@ -0,0 +1,23 @@
|
||||
name: Pull Request
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
push:
|
||||
branches:
|
||||
- "**"
|
||||
- "!main"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
security-events: write
|
||||
|
||||
jobs:
|
||||
pr-checks:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-pr.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
coverage-threshold: 70
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,21 @@
|
||||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**.go"
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
release:
|
||||
uses: lukaszraczylo/shared-actions/.github/workflows/go-release.yaml@main
|
||||
with:
|
||||
go-version: "1.24.11"
|
||||
secrets: inherit
|
||||
@@ -0,0 +1,49 @@
|
||||
version: 2
|
||||
|
||||
# Traefik plugins are source-only - no binary builds
|
||||
# Traefik loads plugins via Yaegi interpreter at runtime
|
||||
builds:
|
||||
- skip: true
|
||||
|
||||
# Create source archive for GitHub releases
|
||||
archives:
|
||||
- formats: [tar.gz]
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_source"
|
||||
files:
|
||||
- "*.go"
|
||||
- "**/*.go"
|
||||
- go.mod
|
||||
- go.sum
|
||||
- .traefik.yml
|
||||
- LICENSE*
|
||||
- README*
|
||||
# Exclude test files and vendor from release archive
|
||||
- "!**/*_test.go"
|
||||
- "!vendor/**"
|
||||
- "!docker/**"
|
||||
- "!integration/**"
|
||||
- "!regression/**"
|
||||
- "!examples/**"
|
||||
- "!docs/**"
|
||||
|
||||
checksum:
|
||||
name_template: "{{ .ProjectName }}_v{{ .Version }}_checksums.txt"
|
||||
algorithm: sha256
|
||||
|
||||
changelog:
|
||||
sort: asc
|
||||
filters:
|
||||
exclude:
|
||||
- "^docs:"
|
||||
- "^test:"
|
||||
- "^Merge"
|
||||
- "^WIP"
|
||||
- "^chore:"
|
||||
|
||||
release:
|
||||
github:
|
||||
owner: lukaszraczylo
|
||||
name: traefikoidc
|
||||
name_template: "v{{ .Version }}"
|
||||
draft: false
|
||||
prerelease: auto
|
||||
+118
@@ -77,6 +77,7 @@ testData:
|
||||
# Custom claim names for Auth0 and other providers with namespaced claims
|
||||
roleClaimName: roles # JWT claim name for extracting user roles (default: "roles")
|
||||
groupClaimName: groups # JWT claim name for extracting user groups (default: "groups")
|
||||
userIdentifierClaim: email # JWT claim for user identification (default: "email", alternatives: "sub", "oid", "upn", "preferred_username")
|
||||
|
||||
# ⚠️ CRITICAL for TLS termination scenarios (AWS ALB, Cloud Load Balancers, etc.)
|
||||
# When NOT specified in config: defaults to FALSE (Go zero value)
|
||||
@@ -120,6 +121,8 @@ testData:
|
||||
allowOpaqueTokens: false # Enable opaque (non-JWT) access token support via RFC 7662 introspection
|
||||
requireTokenIntrospection: false # Force introspection for opaque tokens (requires introspection endpoint)
|
||||
disableReplayDetection: false # Disable JTI replay detection for multi-replica deployments (default: false)
|
||||
allowPrivateIPAddresses: false # Allow private IP addresses in provider URLs for internal networks (default: false)
|
||||
minimalHeaders: false # Reduce forwarded headers to prevent 431 errors (default: false)
|
||||
|
||||
# Security Headers Configuration (enabled by default with 'default' profile)
|
||||
securityHeaders:
|
||||
@@ -266,6 +269,8 @@ testDataWithRedis:
|
||||
# allowedRolesAndGroups: # Corresponds to 'Token Claim Name' in Keycloak mappers
|
||||
# - admin
|
||||
# - editor
|
||||
# # For internal Keycloak deployments with private IPs (Docker/Kubernetes internal):
|
||||
# # allowPrivateIPAddresses: true # Enable for private IP addresses like 192.168.x.x, 10.x.x.x
|
||||
# # Ensure Keycloak client mappers add 'email', 'roles', 'groups' etc. to the ID Token.
|
||||
# # See README.md "Provider Configuration Recommendations" for Keycloak.
|
||||
|
||||
@@ -287,6 +292,26 @@ testDataWithRedis:
|
||||
# - "AppRoleName"
|
||||
# # See README.md "Provider Configuration Recommendations" for Azure AD.
|
||||
|
||||
# --- Azure AD Users Without Email Example (Issue #95) ---
|
||||
# testDataAzureADNoEmail:
|
||||
# providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
# clientID: your-azure-ad-client-id
|
||||
# clientSecret: your-azure-ad-client-secret
|
||||
# callbackURL: /oauth2/callback
|
||||
# sessionEncryptionKey: "a-very-secure-key-at-least-32-bytes-long-for-azure"
|
||||
# # Use 'sub' claim instead of 'email' for user identification
|
||||
# userIdentifierClaim: sub # or "oid", "upn", "preferred_username"
|
||||
# overrideScopes: true # Remove email scope if not needed
|
||||
# scopes:
|
||||
# - openid
|
||||
# - profile
|
||||
# - groups # For group-based access control
|
||||
# # When using non-email identifiers, allowedUsers matches against the claim value
|
||||
# allowedUsers:
|
||||
# - "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID (sub or oid claim)
|
||||
# # NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
# # See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
|
||||
# --- Google Workspace / Google Cloud Identity Example ---
|
||||
# testDataGoogle:
|
||||
# providerURL: https://accounts.google.com # Standard Google OIDC endpoint
|
||||
@@ -605,6 +630,38 @@ configuration:
|
||||
items:
|
||||
type: string
|
||||
|
||||
userIdentifierClaim:
|
||||
type: string
|
||||
description: |
|
||||
Specifies the JWT claim to use as the user identifier for authentication and authorization.
|
||||
|
||||
This allows authentication for users without email addresses, such as Azure AD service
|
||||
accounts or organizational accounts that don't have email attributes configured.
|
||||
|
||||
When set to a non-email claim (e.g., "sub", "oid", "upn"):
|
||||
- AllowedUsers will match against this claim value instead of email
|
||||
- AllowedUserDomains validation is skipped (domains only apply to email addresses)
|
||||
- The session stores this identifier as the user's identity
|
||||
- If the configured claim is missing, falls back to "sub" (required by OIDC spec)
|
||||
|
||||
Common values by provider:
|
||||
- Default: "email" (standard email-based identification)
|
||||
- Azure AD: "sub", "oid" (object ID), "upn" (User Principal Name), "preferred_username"
|
||||
- Generic OIDC: "sub" (always present per OIDC specification)
|
||||
- Keycloak: "sub", "preferred_username"
|
||||
|
||||
Example for Azure AD users without email:
|
||||
```yaml
|
||||
userIdentifierClaim: sub
|
||||
allowedUsers:
|
||||
- "abc123-user-object-id"
|
||||
- "xyz789-another-user-id"
|
||||
```
|
||||
|
||||
Default: "email"
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/95
|
||||
required: false
|
||||
|
||||
revocationURL:
|
||||
type: string
|
||||
description: |
|
||||
@@ -903,6 +960,67 @@ configuration:
|
||||
Default: false (replay detection enabled)
|
||||
required: false
|
||||
|
||||
allowPrivateIPAddresses:
|
||||
type: boolean
|
||||
description: |
|
||||
Allow private IP addresses in OIDC provider URLs for internal network deployments.
|
||||
|
||||
By default, the plugin blocks URLs containing private IP address ranges
|
||||
(10.x.x.x, 172.16-31.x.x, 192.168.x.x) to prevent SSRF attacks and ensure
|
||||
OIDC providers are publicly accessible.
|
||||
|
||||
Enable this option when:
|
||||
- Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
- You don't have DNS resolution available for internal services
|
||||
- Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
|
||||
When enabled, the plugin will accept provider URLs like:
|
||||
- https://192.168.1.100:8443/auth/realms/your-realm
|
||||
- https://10.0.0.50:8080/realms/master
|
||||
- https://172.16.0.10/auth
|
||||
|
||||
Security Warning:
|
||||
Enabling this option reduces SSRF protection. Only use in trusted network
|
||||
environments where the OIDC provider is known and controlled. Loopback
|
||||
addresses (127.0.0.1, localhost, ::1) remain blocked even with this option enabled.
|
||||
|
||||
Default: false (private IPs are blocked for security)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/97
|
||||
required: false
|
||||
|
||||
minimalHeaders:
|
||||
type: boolean
|
||||
description: |
|
||||
Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors.
|
||||
|
||||
When enabled, the middleware only forwards the X-Forwarded-User header and skips
|
||||
the larger authentication headers that can cause downstream services to reject
|
||||
requests due to header size limits (typically 8KB).
|
||||
|
||||
Headers when disabled (default):
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-Auth-Request-Redirect: Original request URI
|
||||
- X-Auth-Request-User: User's email address
|
||||
- X-Auth-Request-Token: Full ID token (can be very large with many claims)
|
||||
- X-User-Groups: Comma-separated user groups (if configured)
|
||||
- X-User-Roles: Comma-separated user roles (if configured)
|
||||
|
||||
Headers when enabled:
|
||||
- X-Forwarded-User: User's email address (always set)
|
||||
- X-User-Groups: Comma-separated user groups (if configured, still forwarded)
|
||||
- X-User-Roles: Comma-separated user roles (if configured, still forwarded)
|
||||
- Custom templated headers (still processed)
|
||||
|
||||
Use this option when:
|
||||
- Downstream services return "431 Request Header Fields Too Large" errors
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- You don't need the full ID token forwarded to backend services
|
||||
- You want to reduce request overhead
|
||||
|
||||
Default: false (all headers forwarded for backward compatibility)
|
||||
See: https://github.com/lukaszraczylo/traefikoidc/issues/64
|
||||
required: false
|
||||
|
||||
headers:
|
||||
type: array
|
||||
description: |
|
||||
|
||||
@@ -124,6 +124,7 @@ The middleware supports the following configuration options:
|
||||
| `allowedRolesAndGroups` | Restricts access to users with specific roles or groups | none | `["admin", "developer"]` |
|
||||
| `roleClaimName` | JWT claim name for extracting user roles (supports namespaced claims for Auth0) | `"roles"` | `"https://myapp.com/roles"`, `"user_roles"` |
|
||||
| `groupClaimName` | JWT claim name for extracting user groups (supports namespaced claims for Auth0) | `"groups"` | `"https://myapp.com/groups"`, `"user_groups"` |
|
||||
| `userIdentifierClaim` | JWT claim to use as user identifier (for users without email, e.g., Azure AD service accounts) | `"email"` | `"sub"`, `"oid"`, `"upn"`, `"preferred_username"` |
|
||||
| `revocationURL` | The endpoint for revoking tokens | auto-discovered | `https://accounts.google.com/revoke` |
|
||||
| `oidcEndSessionURL` | The provider's end session endpoint | auto-discovered | `https://accounts.google.com/logout` |
|
||||
| `enablePKCE` | Enables PKCE (Proof Key for Code Exchange) for authorization code flow | `false` | `true`, `false` |
|
||||
@@ -138,6 +139,8 @@ The middleware supports the following configuration options:
|
||||
| `headers` | Custom HTTP headers with templates that can access OIDC claims and tokens | none | See "Templated Headers" section |
|
||||
| `securityHeaders` | Configure security headers including CSP, HSTS, CORS, and custom headers | enabled with default profile | See "Security Headers Configuration" section |
|
||||
| `disableReplayDetection` | Disable JTI-based replay attack detection for multi-replica deployments | `false` | `true` |
|
||||
| `allowPrivateIPAddresses` | Allow private IP addresses in provider URLs (for internal networks with Keycloak, etc.) | `false` | `true` |
|
||||
| `minimalHeaders` | Reduce forwarded headers to prevent "431 Request Header Fields Too Large" errors | `false` | `true` |
|
||||
| `redis` | Redis cache configuration for distributed deployments | disabled | See "Redis Cache" section |
|
||||
|
||||
> **⚠️ IMPORTANT - TLS Termination at Load Balancer:**
|
||||
@@ -1241,6 +1244,45 @@ spec:
|
||||
- "AppRoleName" # Application role names
|
||||
```
|
||||
|
||||
### Azure AD Configuration (Users Without Email)
|
||||
|
||||
For Azure AD users without email addresses (service accounts, organizational accounts without mail attributes):
|
||||
|
||||
```yaml
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-azure-no-email
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: https://login.microsoftonline.com/your-tenant-id/v2.0
|
||||
clientID: your-azure-ad-client-id
|
||||
clientSecret: your-azure-ad-client-secret
|
||||
sessionEncryptionKey: your-secure-encryption-key-min-32-chars
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
|
||||
# Use 'sub' instead of 'email' for user identification
|
||||
userIdentifierClaim: sub # Can also use: "oid", "upn", "preferred_username"
|
||||
|
||||
overrideScopes: true # Optional: Don't request email scope if not needed
|
||||
scopes:
|
||||
- openid
|
||||
- profile
|
||||
- groups
|
||||
|
||||
# When using non-email identifiers, allowedUsers matches against the claim value
|
||||
allowedUsers:
|
||||
- "abc12345-6789-0abc-def0-123456789abc" # Azure AD user object ID
|
||||
- "def67890-1234-5678-90ab-cdef12345678"
|
||||
|
||||
# NOTE: allowedUserDomains is ignored when userIdentifierClaim is not "email"
|
||||
```
|
||||
|
||||
> **Note**: When `userIdentifierClaim` is set to a non-email claim (like `sub`, `oid`, or `upn`), the `allowedUserDomains` configuration is ignored since domain-based validation only applies to email addresses. Use `allowedUsers` with the actual claim values instead.
|
||||
|
||||
### Auth0 Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1327,8 +1369,12 @@ spec:
|
||||
- admin
|
||||
- editor
|
||||
# Ensure Keycloak client mappers add necessary claims to ID Token
|
||||
# For internal Keycloak deployments with private IPs (e.g., Docker network):
|
||||
# allowPrivateIPAddresses: true
|
||||
```
|
||||
|
||||
> **Internal Network Deployment**: If your Keycloak runs on an internal network with private IP addresses (e.g., `192.168.x.x`, `10.x.x.x`, `172.16-31.x.x`) and you don't have DNS resolution available, set `allowPrivateIPAddresses: true` to allow the plugin to connect to your Keycloak instance. See [Issue #97](https://github.com/lukaszraczylo/traefikoidc/issues/97) for details.
|
||||
|
||||
### AWS Cognito Configuration
|
||||
|
||||
```yaml
|
||||
@@ -1629,12 +1675,39 @@ headers:
|
||||
|
||||
When a user is authenticated, the middleware sets the following headers for downstream services:
|
||||
|
||||
- `X-Forwarded-User`: The user's email address
|
||||
- `X-Forwarded-User`: The user's email address (always set)
|
||||
- `X-User-Groups`: Comma-separated list of user groups (if available)
|
||||
- `X-User-Roles`: Comma-separated list of user roles (if available)
|
||||
- `X-Auth-Request-Redirect`: The original request URI
|
||||
- `X-Auth-Request-User`: The user's email address
|
||||
- `X-Auth-Request-Token`: The user's access token
|
||||
- `X-Auth-Request-Token`: The user's ID token (can be large)
|
||||
|
||||
#### Minimal Headers Mode
|
||||
|
||||
If your downstream services return **"431 Request Header Fields Too Large"** errors, you can enable minimal headers mode to reduce header overhead:
|
||||
|
||||
```yaml
|
||||
http:
|
||||
middlewares:
|
||||
my-auth:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
minimalHeaders: true
|
||||
# ... other config
|
||||
```
|
||||
|
||||
When `minimalHeaders: true` is set:
|
||||
- **Only forwards**: `X-Forwarded-User`
|
||||
- **Skips**: `X-Auth-Request-Token` (the full ID token - often the largest header), `X-Auth-Request-User`, `X-Auth-Request-Redirect`
|
||||
- **Still forwards**: `X-User-Groups` and `X-User-Roles` (if configured)
|
||||
- **Still processes**: Custom templated headers
|
||||
|
||||
This is particularly useful when:
|
||||
- Your ID tokens are large (many claims, long group lists)
|
||||
- Downstream services have limited header buffer sizes (default 8KB in many servers)
|
||||
- You don't need the full token forwarded to backend services
|
||||
|
||||
See [GitHub Issue #64](https://github.com/lukaszraczylo/traefikoidc/issues/64) for details.
|
||||
|
||||
### Security Headers
|
||||
|
||||
@@ -1862,6 +1935,15 @@ logLevel: debug
|
||||
- No refresh tokens (re-authentication required on expiry)
|
||||
- Use only for GitHub API access, not user authentication
|
||||
|
||||
15. **Environment variable names containing "API" cause plugin failure** ([Issue #98](https://github.com/lukaszraczylo/traefikoidc/issues/98)):
|
||||
- When using environment variable syntax like `${OIDC_ENCRYPTION_SECRET_API}` in Traefik configuration, the plugin fails with "invalid handler type: \<nil\>" error
|
||||
- This is a **Traefik-side issue**, not a plugin bug. Traefik uses reserved environment variables starting with `TRAEFIK_API_*` for its internal API configuration, and the "API" substring in user-defined variable names may interfere with Traefik's environment variable processing
|
||||
- **Workaround**: Avoid using "API" as a substring in environment variable names. Use alternatives like:
|
||||
- `${OIDC_ENCRYPTION_SECRET_SVC}` instead of `${OIDC_ENCRYPTION_SECRET_API}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_SERVICE}`
|
||||
- `${OIDC_ENCRYPTION_SECRET_BACKEND}`
|
||||
- Any name that doesn't contain the literal substring "API"
|
||||
|
||||
### Provider Warnings and Recommendations
|
||||
|
||||
The middleware includes built-in warnings for provider-specific limitations. Check your logs for important notices about:
|
||||
|
||||
+21
-20
@@ -849,26 +849,27 @@ func TestAudienceEndToEndScenario(t *testing.T) {
|
||||
customAudience := "https://api.company.com"
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
next: nextHandler,
|
||||
name: "test",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/callback/logout",
|
||||
issuerURL: "https://auth.company.com",
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
audience: customAudience, // Set custom audience
|
||||
jwkCache: mockJWKCache,
|
||||
jwksURL: "https://auth.company.com/.well-known/jwks.json",
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"company.com": {}},
|
||||
userIdentifierClaim: "email", // Required for user identification
|
||||
excludedURLs: map[string]struct{}{},
|
||||
httpClient: &http.Client{},
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
extractClaimsFunc: extractClaims,
|
||||
}
|
||||
tOidc.jwtVerifier = tOidc
|
||||
tOidc.tokenVerifier = tOidc
|
||||
|
||||
+29
-25
@@ -18,17 +18,18 @@ type ScopeFilter interface {
|
||||
|
||||
// Handler provides core authentication functionality for OIDC flows
|
||||
type Handler struct {
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
logger Logger
|
||||
enablePKCE bool
|
||||
isGoogleProv func() bool
|
||||
isAzureProv func() bool
|
||||
clientID string
|
||||
authURL string
|
||||
issuerURL string
|
||||
scopes []string
|
||||
overrideScopes bool
|
||||
scopeFilter ScopeFilter // NEW
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -40,19 +41,20 @@ type Logger interface {
|
||||
// NewAuthHandler creates a new Handler instance
|
||||
func NewAuthHandler(logger Logger, enablePKCE bool, isGoogleProv, isAzureProv func() bool,
|
||||
clientID, authURL, issuerURL string, scopes []string, overrideScopes bool,
|
||||
scopeFilter ScopeFilter, scopesSupported []string) *Handler {
|
||||
scopeFilter ScopeFilter, scopesSupported []string, allowPrivateIPAddresses bool) *Handler {
|
||||
return &Handler{
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter, // NEW
|
||||
scopesSupported: scopesSupported, // NEW
|
||||
logger: logger,
|
||||
enablePKCE: enablePKCE,
|
||||
isGoogleProv: isGoogleProv,
|
||||
isAzureProv: isAzureProv,
|
||||
clientID: clientID,
|
||||
authURL: authURL,
|
||||
issuerURL: issuerURL,
|
||||
scopes: scopes,
|
||||
overrideScopes: overrideScopes,
|
||||
scopeFilter: scopeFilter,
|
||||
scopesSupported: scopesSupported,
|
||||
allowPrivateIPAddresses: allowPrivateIPAddresses,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,6 +349,7 @@ func (h *Handler) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname for security and reachability.
|
||||
// It prevents access to private networks and localhost addresses.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
func (h *Handler) validateHost(host string) error {
|
||||
if host == "" {
|
||||
return fmt.Errorf("empty host")
|
||||
@@ -361,7 +364,7 @@ func (h *Handler) validateHost(host string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for localhost variations
|
||||
// Check for localhost variations (always blocked, even with allowPrivateIPAddresses)
|
||||
localhostVariations := []string{
|
||||
"localhost", "127.0.0.1", "::1", "0.0.0.0",
|
||||
}
|
||||
@@ -376,7 +379,8 @@ func (h *Handler) validateHost(host string) error {
|
||||
if ip.IsLoopback() {
|
||||
return fmt.Errorf("loopback IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsPrivate() {
|
||||
// Skip private IP check if allowPrivateIPAddresses is enabled
|
||||
if !h.allowPrivateIPAddresses && ip.IsPrivate() {
|
||||
return fmt.Errorf("private IP not allowed: %s", host)
|
||||
}
|
||||
if ip.IsLinkLocalUnicast() {
|
||||
|
||||
+25
-25
@@ -86,7 +86,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
||||
"test-client-id", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -125,7 +125,7 @@ func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{redirectCount: 5} // At the limit
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -160,7 +160,7 @@ func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -191,7 +191,7 @@ func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -222,7 +222,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
@@ -253,7 +253,7 @@ func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T)
|
||||
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{saveError: &testError{"save failed"}}
|
||||
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
||||
@@ -297,7 +297,7 @@ func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
||||
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false, nil, nil, false)
|
||||
|
||||
session := &mockSessionData{}
|
||||
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
||||
@@ -400,7 +400,7 @@ func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -440,7 +440,7 @@ func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil)
|
||||
[]string{"openid", "profile", "email"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -468,7 +468,7 @@ func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -493,7 +493,7 @@ func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
||||
[]string{"openid"}, false, nil, nil)
|
||||
[]string{"openid"}, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -565,7 +565,7 @@ func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
tt.scopes, tt.overrideScopes, nil, nil)
|
||||
tt.scopes, tt.overrideScopes, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -634,7 +634,7 @@ func TestAuthHandler_BuildAuthURL_WithScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -676,7 +676,7 @@ func TestAuthHandler_BuildAuthURL_WithoutScopeFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, nil)
|
||||
scopes, false, nil, nil, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -714,7 +714,7 @@ func TestAuthHandler_BuildAuthURL_GitLabFiltersOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"gitlab-client", "https://gitlab.example.com/oauth/authorize",
|
||||
"https://gitlab.example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -756,7 +756,7 @@ func TestAuthHandler_BuildAuthURL_GoogleRemovesOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
||||
"google-client", "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
"https://accounts.google.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -797,7 +797,7 @@ func TestAuthHandler_BuildAuthURL_AzureAddsOfflineAccess(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
||||
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -831,7 +831,7 @@ func TestAuthHandler_BuildAuthURL_GenericWithFiltering(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"generic-client", "https://auth.provider.com/authorize",
|
||||
"https://auth.provider.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -870,7 +870,7 @@ func TestAuthHandler_BuildAuthURL_OverrideScopesWithFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, true, scopeFilter, scopesSupported)
|
||||
scopes, true, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -916,7 +916,7 @@ func TestAuthHandler_BuildAuthURL_DoubleFiltering(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -955,7 +955,7 @@ func TestAuthHandler_BuildAuthURL_NoScopeFilterProvided(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, nil, scopesSupported) // scopeFilter is nil
|
||||
scopes, false, nil, scopesSupported, false) // scopeFilter is nil
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -988,7 +988,7 @@ func TestAuthHandler_BuildAuthURL_EmptyScopesSupported(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
@@ -1021,7 +1021,7 @@ func TestAuthHandler_BuildAuthURL_FilteringWithPKCE(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
||||
|
||||
@@ -1064,7 +1064,7 @@ func TestAuthHandler_BuildAuthURL_ComplexScenario(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
||||
"complex-client", "https://auth.complex.com/authorize", "https://auth.complex.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
authURL := handler.BuildAuthURL("https://example.com/callback", "state-123", "nonce-456", "challenge-789")
|
||||
|
||||
@@ -1130,7 +1130,7 @@ func TestAuthHandler_BuildAuthURL_LoggingVerification(t *testing.T) {
|
||||
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com",
|
||||
scopes, false, scopeFilter, scopesSupported)
|
||||
scopes, false, scopeFilter, scopesSupported, false)
|
||||
|
||||
handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
||||
|
||||
|
||||
+103
-5
@@ -10,7 +10,7 @@ import (
|
||||
func TestAuthHandler_validateURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -185,7 +185,7 @@ func TestAuthHandler_validateURL(t *testing.T) {
|
||||
func TestAuthHandler_validateHost(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -334,7 +334,7 @@ func TestAuthHandler_validateHost(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -438,7 +438,7 @@ func TestAuthHandler_buildURLWithParams(t *testing.T) {
|
||||
func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
// Test special characters that need encoding
|
||||
params := url.Values{
|
||||
@@ -477,7 +477,7 @@ func TestAuthHandler_buildURLWithParams_ParameterEncoding(t *testing.T) {
|
||||
func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil)
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -560,3 +560,101 @@ func TestAuthHandler_validateParsedURL(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthHandler_validateHost_AllowPrivateIPAddresses tests the allowPrivateIPAddresses flag
|
||||
func TestAuthHandler_validateHost_AllowPrivateIPAddresses(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
|
||||
// Test with allowPrivateIPAddresses = false (default)
|
||||
t.Run("Private IPs blocked by default", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, false)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err == nil {
|
||||
t.Errorf("Expected private IP %s to be blocked, but it was allowed", ip)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), "private IP not allowed") {
|
||||
t.Errorf("Expected 'private IP not allowed' error for %s, got: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test with allowPrivateIPAddresses = true
|
||||
t.Run("Private IPs allowed when flag enabled", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
privateIPs := []string{
|
||||
"192.168.1.1",
|
||||
"10.0.0.1",
|
||||
"172.16.0.1",
|
||||
"172.31.255.255",
|
||||
}
|
||||
|
||||
for _, ip := range privateIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected private IP %s to be allowed with flag enabled, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that loopback is still blocked even with flag enabled
|
||||
t.Run("Loopback always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
loopbackAddresses := []string{
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
"::1",
|
||||
"0.0.0.0",
|
||||
}
|
||||
|
||||
for _, addr := range loopbackAddresses {
|
||||
err := handler.validateHost(addr)
|
||||
if err == nil {
|
||||
t.Errorf("Expected loopback address %s to be blocked even with allowPrivateIPAddresses=true", addr)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test that link-local is still blocked even with flag enabled
|
||||
t.Run("Link-local always blocked", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
err := handler.validateHost("169.254.1.1")
|
||||
if err == nil {
|
||||
t.Error("Expected link-local address to be blocked even with allowPrivateIPAddresses=true")
|
||||
}
|
||||
})
|
||||
|
||||
// Test that public IPs work with flag enabled
|
||||
t.Run("Public IPs allowed", func(t *testing.T) {
|
||||
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
||||
"test-client", "https://example.com/auth", "https://example.com", []string{}, false, nil, nil, true)
|
||||
|
||||
publicIPs := []string{
|
||||
"8.8.8.8",
|
||||
"1.1.1.1",
|
||||
"142.250.185.68",
|
||||
}
|
||||
|
||||
for _, ip := range publicIPs {
|
||||
err := handler.validateHost(ip)
|
||||
if err != nil {
|
||||
t.Errorf("Expected public IP %s to be allowed, but got error: %v", ip, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
+19
-9
@@ -223,15 +223,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[t.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if t.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
t.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", t.userIdentifierClaim)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", t.userIdentifierClaim)
|
||||
}
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -240,7 +250,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
t.sendErrorResponse(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -787,6 +787,7 @@ func (mm *TaskMemoryMonitor) ForceGC() (before, after TaskMemoryStats, err error
|
||||
}
|
||||
|
||||
if mm.logger != nil {
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freed := int64(before.HeapAlloc) - int64(after.HeapAlloc)
|
||||
mm.logger.Infof("Forced GC: freed %d bytes (%.2f MB)", freed, float64(freed)/(1024*1024))
|
||||
}
|
||||
|
||||
@@ -105,6 +105,7 @@ func (l *ConfigLoader) loadFile(path string) (*UnifiedConfig, error) {
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file %s: %w", absPath, err)
|
||||
|
||||
@@ -252,6 +252,7 @@ func (m *ConfigMigrator) MigrateFile(filePath string) (*UnifiedConfig, error) {
|
||||
}
|
||||
|
||||
// Read the file with validated path
|
||||
// #nosec G304 -- path is validated via filepath.Abs above
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
|
||||
@@ -437,6 +437,21 @@ http:
|
||||
4. Configure client scopes and mappers
|
||||
5. Generate client secret in Credentials tab
|
||||
|
||||
### Internal Network Deployment
|
||||
|
||||
If your Keycloak instance runs on an internal network with private IP addresses (e.g., Docker networks, Kubernetes internal services), set `allowPrivateIPAddresses: true`:
|
||||
|
||||
```yaml
|
||||
traefikoidc:
|
||||
providerUrl: "https://192.168.1.100:8443/auth/realms/your-realm" # Private IP
|
||||
allowPrivateIPAddresses: true # Required for private IP addresses
|
||||
clientId: "your-client-id"
|
||||
clientSecret: "your-client-secret"
|
||||
# ... other config
|
||||
```
|
||||
|
||||
> **Security Warning**: Only enable `allowPrivateIPAddresses` in trusted network environments where you control the OIDC provider. This setting reduces SSRF protection.
|
||||
|
||||
---
|
||||
|
||||
## Okta
|
||||
|
||||
@@ -348,6 +348,7 @@ func (r *DynamicClientRegistrar) saveCredentials(resp *ClientRegistrationRespons
|
||||
func (r *DynamicClientRegistrar) loadCredentials() (*ClientRegistrationResponse, error) {
|
||||
filePath := r.credentialsFilePath()
|
||||
|
||||
// #nosec G304 -- path is constructed from trusted config values via credentialsFilePath()
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
|
||||
@@ -2,10 +2,14 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -411,6 +415,31 @@ func DefaultRetryConfig() RetryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
// MetadataFetchRetryConfig returns retry configuration optimized for OIDC metadata
|
||||
// fetching during startup. Uses more aggressive retry settings to handle the race
|
||||
// condition where Traefik initializes the plugin before routes are fully established,
|
||||
// or before TLS certificates are properly loaded.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func MetadataFetchRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 10, // More attempts for startup scenarios
|
||||
InitialDelay: 1 * time.Second, // 1 second between attempts as suggested
|
||||
MaxDelay: 10 * time.Second, // Cap at 10 seconds
|
||||
BackoffFactor: 1.5, // Gentler backoff for startup
|
||||
EnableJitter: true, // Prevent thundering herd
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// RetryExecutor implements retry logic with exponential backoff and jitter.
|
||||
// It automatically retries failed operations based on configurable error patterns
|
||||
// and uses exponential backoff to avoid overwhelming failing services.
|
||||
@@ -487,11 +516,29 @@ func (re *RetryExecutor) Execute(ctx context.Context, fn func() error) error {
|
||||
// isRetryableError checks if an error should trigger a retry
|
||||
// isRetryableError determines if an error should trigger a retry attempt.
|
||||
// Checks error message against configured retryable error patterns.
|
||||
// Also handles startup-specific errors like Traefik default certificate errors
|
||||
// and EOF errors that occur during service initialization.
|
||||
func (re *RetryExecutor) isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for Traefik default certificate error (startup race condition)
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
if isTraefikDefaultCertError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for EOF errors (common during startup when services aren't ready)
|
||||
if isEOFError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for certificate errors (transient during startup)
|
||||
if isCertificateError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
for _, retryableErr := range re.config.RetryableErrors {
|
||||
@@ -538,6 +585,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
delay = float64(re.config.MaxDelay)
|
||||
}
|
||||
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.EnableJitter {
|
||||
jitter := delay * 0.1 * (2.0*rand.Float64() - 1.0)
|
||||
delay += jitter
|
||||
@@ -1087,3 +1135,86 @@ func containsSubstring(s, substr string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isTraefikDefaultCertError detects when Traefik is serving its default self-signed
|
||||
// certificate during cold-start, before the real certificates are loaded.
|
||||
// This manifests as an x509.HostnameError where one of the certificate's DNS names
|
||||
// ends with "traefik.default" (the default Traefik certificate pattern).
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func isTraefikDefaultCertError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var hostnameErr x509.HostnameError
|
||||
if errors.As(err, &hostnameErr) {
|
||||
if hostnameErr.Certificate != nil {
|
||||
for _, name := range hostnameErr.Certificate.DNSNames {
|
||||
if strings.HasSuffix(name, "traefik.default") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isEOFError checks if an error is an EOF error, which can occur during
|
||||
// connection establishment when the remote end closes unexpectedly.
|
||||
// This is common during service startup when endpoints aren't fully ready.
|
||||
func isEOFError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for direct EOF
|
||||
if errors.Is(err, io.EOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for unexpected EOF
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for EOF patterns (wrapped errors)
|
||||
errStr := err.Error()
|
||||
return strings.Contains(errStr, "EOF") || strings.Contains(errStr, "unexpected EOF")
|
||||
}
|
||||
|
||||
// isCertificateError checks if an error is related to TLS certificate validation.
|
||||
// These errors are often transient during startup when services are still initializing.
|
||||
func isCertificateError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for x509 certificate errors
|
||||
var certInvalidErr x509.CertificateInvalidError
|
||||
var hostnameErr x509.HostnameError
|
||||
var unknownAuthErr x509.UnknownAuthorityError
|
||||
|
||||
if errors.As(err, &certInvalidErr) ||
|
||||
errors.As(err, &hostnameErr) ||
|
||||
errors.As(err, &unknownAuthErr) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check error message for certificate patterns
|
||||
errStr := strings.ToLower(err.Error())
|
||||
certPatterns := []string{
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
"ssl",
|
||||
}
|
||||
|
||||
for _, pattern := range certPatterns {
|
||||
if strings.Contains(errStr, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -846,3 +846,296 @@ func (e *mockNetError) Temporary() bool { return e.temporary }
|
||||
|
||||
// Ensure mockNetError implements net.Error
|
||||
var _ net.Error = (*mockNetError)(nil)
|
||||
|
||||
// Test isTraefikDefaultCertError
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
|
||||
func TestIsTraefikDefaultCertError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "network error",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isTraefikDefaultCertError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isTraefikDefaultCertError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isEOFError
|
||||
|
||||
func TestIsEOFError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing EOF in message",
|
||||
err: errors.New("connection closed: EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing unexpected EOF",
|
||||
err: errors.New("read: unexpected EOF"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "network error without EOF",
|
||||
err: &mockNetError{msg: "connection refused"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isEOFError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isEOFError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test isCertificateError
|
||||
|
||||
func TestIsCertificateError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "regular error",
|
||||
err: errors.New("some error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing certificate in message",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing x509 in message",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing tls in message",
|
||||
err: errors.New("tls handshake failed"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error containing ssl in message",
|
||||
err: errors.New("ssl connection error"),
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCertificateError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCertificateError() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test MetadataFetchRetryConfig
|
||||
|
||||
func TestMetadataFetchRetryConfig(t *testing.T) {
|
||||
config := MetadataFetchRetryConfig()
|
||||
|
||||
if config.MaxAttempts != 10 {
|
||||
t.Errorf("Expected MaxAttempts 10, got %d", config.MaxAttempts)
|
||||
}
|
||||
|
||||
if config.InitialDelay != 1*time.Second {
|
||||
t.Errorf("Expected InitialDelay 1s, got %v", config.InitialDelay)
|
||||
}
|
||||
|
||||
if config.MaxDelay != 10*time.Second {
|
||||
t.Errorf("Expected MaxDelay 10s, got %v", config.MaxDelay)
|
||||
}
|
||||
|
||||
if config.BackoffFactor != 1.5 {
|
||||
t.Errorf("Expected BackoffFactor 1.5, got %v", config.BackoffFactor)
|
||||
}
|
||||
|
||||
if !config.EnableJitter {
|
||||
t.Error("Expected EnableJitter to be true")
|
||||
}
|
||||
|
||||
// Verify retryable errors include startup-related patterns
|
||||
expectedPatterns := []string{"EOF", "certificate", "x509", "tls"}
|
||||
for _, pattern := range expectedPatterns {
|
||||
found := false
|
||||
for _, retryableErr := range config.RetryableErrors {
|
||||
if retryableErr == pattern {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected '%s' in RetryableErrors", pattern)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test RetryExecutor with startup-specific errors
|
||||
|
||||
func TestRetryExecutorStartupErrors(t *testing.T) {
|
||||
// Verify MetadataFetchRetryConfig creates a valid retry executor
|
||||
_ = NewRetryExecutor(MetadataFetchRetryConfig(), nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF error",
|
||||
err: errors.New("read tcp: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "unexpected EOF",
|
||||
err: errors.New("http: unexpected EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate error",
|
||||
err: errors.New("x509: certificate signed by unknown authority"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS error",
|
||||
err: errors.New("tls: failed to verify certificate"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "connection refused",
|
||||
err: errors.New("dial tcp: connection refused"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "permanent error",
|
||||
err: errors.New("invalid response format"),
|
||||
shouldRetry: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Use very short delays for testing
|
||||
testConfig := RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 1 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 1.5,
|
||||
EnableJitter: false,
|
||||
RetryableErrors: []string{
|
||||
"connection refused",
|
||||
"timeout",
|
||||
"temporary failure",
|
||||
"network unreachable",
|
||||
"EOF",
|
||||
"certificate",
|
||||
"x509",
|
||||
"tls",
|
||||
},
|
||||
}
|
||||
testRe := NewRetryExecutor(testConfig, nil)
|
||||
|
||||
attempts := 0
|
||||
_ = testRe.ExecuteWithContext(context.Background(), func() error {
|
||||
attempts++
|
||||
return tt.err
|
||||
})
|
||||
|
||||
expectedAttempts := 1
|
||||
if tt.shouldRetry {
|
||||
expectedAttempts = 3
|
||||
}
|
||||
|
||||
if attempts != expectedAttempts {
|
||||
t.Errorf("Expected %d attempts for '%s', got %d", expectedAttempts, tt.name, attempts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that retry executor properly uses isRetryableError with new error types
|
||||
|
||||
func TestRetryExecutorIsRetryableErrorIntegration(t *testing.T) {
|
||||
re := NewRetryExecutor(DefaultRetryConfig(), nil)
|
||||
|
||||
// Test that the enhanced isRetryableError is being used
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
shouldRetry bool
|
||||
}{
|
||||
{
|
||||
name: "EOF in error message",
|
||||
err: errors.New("connection reset by peer: EOF"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "certificate in error message",
|
||||
err: errors.New("x509: certificate has expired"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
{
|
||||
name: "TLS in error message",
|
||||
err: errors.New("tls: handshake failure"),
|
||||
shouldRetry: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := re.isRetryableError(tt.err)
|
||||
if result != tt.shouldRetry {
|
||||
t.Errorf("isRetryableError(%q) = %v, expected %v", tt.err.Error(), result, tt.shouldRetry)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/redis/go-redis/v9 v9.14.0
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
|
||||
@@ -20,8 +20,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
|
||||
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
|
||||
|
||||
+29
-12
@@ -15,7 +15,8 @@ type OAuthHandler struct {
|
||||
tokenExchanger TokenExchanger
|
||||
tokenVerifier TokenVerifier
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
isAllowedDomainFunc func(email string) bool
|
||||
isAllowedUserFunc func(userIdentifier string) bool // validates user authorization
|
||||
userIdentifierClaim string // JWT claim to use for user identification
|
||||
redirURLPath string
|
||||
sendErrorResponseFunc func(rw http.ResponseWriter, req *http.Request, message string, code int)
|
||||
}
|
||||
@@ -77,16 +78,22 @@ type TokenResponse struct {
|
||||
// NewOAuthHandler creates a new OAuth handler
|
||||
func NewOAuthHandler(logger Logger, sessionManager SessionManager, tokenExchanger TokenExchanger,
|
||||
tokenVerifier TokenVerifier, extractClaimsFunc func(string) (map[string]interface{}, error),
|
||||
isAllowedDomainFunc func(string) bool, redirURLPath string,
|
||||
isAllowedUserFunc func(string) bool, userIdentifierClaim string, redirURLPath string,
|
||||
sendErrorResponseFunc func(http.ResponseWriter, *http.Request, string, int)) *OAuthHandler {
|
||||
|
||||
// Default to "email" for backward compatibility
|
||||
if userIdentifierClaim == "" {
|
||||
userIdentifierClaim = "email"
|
||||
}
|
||||
|
||||
return &OAuthHandler{
|
||||
logger: logger,
|
||||
sessionManager: sessionManager,
|
||||
tokenExchanger: tokenExchanger,
|
||||
tokenVerifier: tokenVerifier,
|
||||
extractClaimsFunc: extractClaimsFunc,
|
||||
isAllowedDomainFunc: isAllowedDomainFunc,
|
||||
isAllowedUserFunc: isAllowedUserFunc,
|
||||
userIdentifierClaim: userIdentifierClaim,
|
||||
redirURLPath: redirURLPath,
|
||||
sendErrorResponseFunc: sendErrorResponseFunc,
|
||||
}
|
||||
@@ -225,15 +232,25 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
h.logger.Errorf("Email claim missing or empty in token during callback")
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
// Extract user identifier from the configured claim (defaults to "email" for backward compatibility)
|
||||
userIdentifier, _ := claims[h.userIdentifierClaim].(string)
|
||||
if userIdentifier == "" {
|
||||
// Try "sub" as fallback since it's required by OIDC spec
|
||||
if h.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = claims["sub"].(string)
|
||||
}
|
||||
if userIdentifier == "" {
|
||||
h.logger.Errorf("User identifier claim '%s' missing or empty in token during callback", h.userIdentifierClaim)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User identifier missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.logger.Debugf("Configured claim '%s' not found, using 'sub' claim as fallback", h.userIdentifierClaim)
|
||||
}
|
||||
if !h.isAllowedDomainFunc(email) {
|
||||
h.logger.Errorf("Disallowed email domain during callback: %s", email)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: Email domain not allowed", http.StatusForbidden)
|
||||
|
||||
// Validate user authorization
|
||||
if !h.isAllowedUserFunc(userIdentifier) {
|
||||
h.logger.Errorf("User not authorized during callback: %s", userIdentifier)
|
||||
h.sendErrorResponseFunc(rw, req, "Authentication failed: User not authorized", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -242,7 +259,7 @@ func (h *OAuthHandler) HandleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
h.sendErrorResponseFunc(rw, req, "Failed to update session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.SetEmail(email)
|
||||
session.SetEmail(userIdentifier) // SetEmail stores the user identifier (email or other claim)
|
||||
session.SetIDToken(tokenResponse.IDToken)
|
||||
session.SetAccessToken(tokenResponse.AccessToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
@@ -108,11 +108,11 @@ func TestOAuthHandler_NewOAuthHandler(t *testing.T) {
|
||||
return map[string]interface{}{"email": "test@example.com", "nonce": "test-nonce"}, nil
|
||||
}
|
||||
|
||||
isAllowed := func(email string) bool { return true }
|
||||
isAllowedUser := func(userIdentifier string) bool { return true }
|
||||
sendError := func(rw http.ResponseWriter, req *http.Request, msg string, code int) {}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowedUser, "email", "/callback", sendError)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
@@ -151,7 +151,7 @@ func TestOAuthHandler_HandleCallback_SessionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -190,7 +190,7 @@ func TestOAuthHandler_HandleCallback_ProviderError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
// Test with error parameter
|
||||
req := httptest.NewRequest("GET", "/callback?error=access_denied&error_description=User%20denied%20access", nil)
|
||||
@@ -230,7 +230,7 @@ func TestOAuthHandler_HandleCallback_MissingState(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -265,7 +265,7 @@ func TestOAuthHandler_HandleCallback_MissingCSRF(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -300,7 +300,7 @@ func TestOAuthHandler_HandleCallback_CSRFMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -335,7 +335,7 @@ func TestOAuthHandler_HandleCallback_MissingCode(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -370,7 +370,7 @@ func TestOAuthHandler_HandleCallback_TokenExchangeError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -406,7 +406,7 @@ func TestOAuthHandler_HandleCallback_TokenVerificationError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -444,7 +444,7 @@ func TestOAuthHandler_HandleCallback_ClaimsExtractionError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -483,7 +483,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInToken(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -521,7 +521,7 @@ func TestOAuthHandler_HandleCallback_MissingNonceInSession(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -559,7 +559,7 @@ func TestOAuthHandler_HandleCallback_NonceMismatch(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -591,13 +591,13 @@ func TestOAuthHandler_HandleCallback_MissingEmail(t *testing.T) {
|
||||
if code != http.StatusInternalServerError {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email missing in token") {
|
||||
t.Errorf("Expected error message to contain 'Email missing in token', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User identifier missing in token") {
|
||||
t.Errorf("Expected error message to contain 'User identifier missing in token', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -629,13 +629,13 @@ func TestOAuthHandler_HandleCallback_DisallowedDomain(t *testing.T) {
|
||||
if code != http.StatusForbidden {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusForbidden, code)
|
||||
}
|
||||
if !strings.Contains(msg, "Email domain not allowed") {
|
||||
t.Errorf("Expected error message to contain 'Email domain not allowed', got '%s'", msg)
|
||||
if !strings.Contains(msg, "User not authorized") {
|
||||
t.Errorf("Expected error message to contain 'User not authorized', got '%s'", msg)
|
||||
}
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -677,7 +677,7 @@ func TestOAuthHandler_HandleCallback_SessionSaveError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -719,7 +719,7 @@ func TestOAuthHandler_HandleCallback_SetAuthenticatedError(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -760,7 +760,7 @@ func TestOAuthHandler_HandleCallback_Success(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -843,7 +843,7 @@ func TestOAuthHandler_HandleCallback_SuccessDefaultRedirect(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
@@ -884,7 +884,7 @@ func TestOAuthHandler_HandleCallback_RedirectURLPathExcluded(t *testing.T) {
|
||||
}
|
||||
|
||||
handler := NewOAuthHandler(logger, sessionManager, tokenExchanger, tokenVerifier,
|
||||
extractClaims, isAllowed, "/callback", sendError)
|
||||
extractClaims, isAllowed, "email", "/callback", sendError)
|
||||
|
||||
req := httptest.NewRequest("GET", "/callback?code=test-code&state=test-state", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
+52
-47
@@ -15,20 +15,21 @@ import (
|
||||
// XSS, path traversal, and other injection attacks. It validates and sanitizes
|
||||
// various input types used in OIDC authentication flows.
|
||||
type InputValidator struct {
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
usernameRegex *regexp.Regexp
|
||||
tokenRegex *regexp.Regexp
|
||||
logger *Logger
|
||||
urlRegex *regexp.Regexp
|
||||
emailRegex *regexp.Regexp
|
||||
sqlInjectionPatterns []string
|
||||
pathTraversalPatterns []string
|
||||
xssPatterns []string
|
||||
maxUsernameLength int
|
||||
maxURLLength int
|
||||
maxTokenLength int
|
||||
maxEmailLength int
|
||||
maxClaimLength int
|
||||
maxHeaderLength int
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// ValidationResult encapsulates the outcome of input validation.
|
||||
@@ -46,13 +47,14 @@ type ValidationResult struct {
|
||||
// It specifies maximum lengths for various input types and controls whether
|
||||
// strict validation mode is enabled.
|
||||
type InputValidationConfig struct {
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
MaxTokenLength int `json:"max_token_length"`
|
||||
MaxURLLength int `json:"max_url_length"`
|
||||
MaxHeaderLength int `json:"max_header_length"`
|
||||
MaxClaimLength int `json:"max_claim_length"`
|
||||
MaxEmailLength int `json:"max_email_length"`
|
||||
MaxUsernameLength int `json:"max_username_length"`
|
||||
StrictMode bool `json:"strict_mode"`
|
||||
AllowPrivateIPAddresses bool `json:"allow_private_ip_addresses"` // Allow private IP addresses in URL validation
|
||||
}
|
||||
|
||||
// DefaultInputValidationConfig returns a secure default configuration
|
||||
@@ -103,16 +105,17 @@ func NewInputValidator(config InputValidationConfig, logger *Logger) (*InputVali
|
||||
}
|
||||
|
||||
return &InputValidator{
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
maxTokenLength: config.MaxTokenLength,
|
||||
maxURLLength: config.MaxURLLength,
|
||||
maxHeaderLength: config.MaxHeaderLength,
|
||||
maxClaimLength: config.MaxClaimLength,
|
||||
maxEmailLength: config.MaxEmailLength,
|
||||
maxUsernameLength: config.MaxUsernameLength,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
emailRegex: emailRegex,
|
||||
urlRegex: urlRegex,
|
||||
tokenRegex: tokenRegex,
|
||||
usernameRegex: usernameRegex,
|
||||
sqlInjectionPatterns: []string{
|
||||
"'", "\"", ";", "--", "/*", "*/", "xp_", "sp_",
|
||||
"union", "select", "insert", "update", "delete", "drop",
|
||||
@@ -335,24 +338,26 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
// Check for private IP ranges (RFC 1918) - skip if allowPrivateIPAddresses is enabled
|
||||
if !iv.allowPrivateIPAddresses {
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Vendored
+2
-2
@@ -76,7 +76,7 @@ func NewRedisBackend(config *Config) (*RedisBackend, error) {
|
||||
|
||||
// Test connectivity
|
||||
if err := backend.Ping(context.Background()); err != nil {
|
||||
pool.Close()
|
||||
_ = pool.Close()
|
||||
return nil, fmt.Errorf("failed to ping Redis: %w", err)
|
||||
}
|
||||
|
||||
@@ -263,7 +263,7 @@ func (r *RedisBackend) Clear(ctx context.Context) error {
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
conn.Do("DEL", key) // Best effort, ignore errors
|
||||
_, _ = conn.Do("DEL", key) // Best effort, ignore errors
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
+23
-22
@@ -82,7 +82,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
// Reuse existing connection - validate if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
// Connection is stale, close it and try again
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -94,6 +94,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
|
||||
default:
|
||||
// No available connection, create new one if under limit
|
||||
// #nosec G115 -- MaxConnections is a small config value that fits in int32
|
||||
if p.totalConns.Load() < int32(p.config.MaxConnections) {
|
||||
conn, err = p.createConnection()
|
||||
if err != nil {
|
||||
@@ -115,7 +116,7 @@ func (p *ConnectionPool) Get(ctx context.Context) (*RedisConn, error) {
|
||||
case conn = <-p.connections:
|
||||
// Validate connection if health check enabled
|
||||
if p.config.EnableHealthCheck && !p.isConnectionHealthy(conn) {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
continue
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
p.activeConns.Add(-1)
|
||||
|
||||
if p.closed.Load() || conn.closed.Load() {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
return
|
||||
}
|
||||
@@ -155,7 +156,7 @@ func (p *ConnectionPool) Put(conn *RedisConn) {
|
||||
// Successfully returned to pool
|
||||
default:
|
||||
// Pool full, close connection
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
p.totalConns.Add(-1)
|
||||
}
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func (p *ConnectionPool) Close() error {
|
||||
|
||||
// Close all pooled connections
|
||||
for conn := range p.connections {
|
||||
conn.Close()
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -212,7 +213,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Authenticate if password is provided
|
||||
if p.config.Password != "" {
|
||||
if _, err := redisConn.Do("AUTH", p.config.Password); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +221,7 @@ func (p *ConnectionPool) createConnection() (*RedisConn, error) {
|
||||
// Select database
|
||||
if p.config.DB != 0 {
|
||||
if _, err := redisConn.Do("SELECT", fmt.Sprintf("%d", p.config.DB)); err != nil {
|
||||
redisConn.Close()
|
||||
_ = redisConn.Close()
|
||||
return nil, fmt.Errorf("failed to select database: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -246,15 +247,15 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Build command arguments
|
||||
// Check for overflow: ensure len(args)+1 doesn't cause allocation overflow
|
||||
// Limit to a safe value that prevents integer overflow in allocation size calculation
|
||||
// (capacity * sizeof(string) must fit in int/size_t)
|
||||
argsLen := len(args)
|
||||
const maxSafeArgs = (1 << 20) - 1 // 1M args is already absurdly large for Redis commands
|
||||
if argsLen < 0 || argsLen > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments")
|
||||
// Validate argument count to prevent integer overflow in slice operations
|
||||
// maxSafeArgs is set to (1<<20)-1 = 1,048,575 which is more than any reasonable Redis command
|
||||
const maxSafeArgs = (1 << 20) - 1
|
||||
if len(args) > maxSafeArgs {
|
||||
return nil, errors.New("too many arguments: exceeds maximum safe count")
|
||||
}
|
||||
|
||||
// Build command arguments
|
||||
// Validate total argument size to prevent memory exhaustion
|
||||
const maxTotalArgBytes = 64 << 20 // 64 MiB max total size
|
||||
totalBytes := len(command)
|
||||
for _, s := range args {
|
||||
@@ -267,13 +268,13 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
return nil, errors.New("total argument size exceeds maximum allowed")
|
||||
}
|
||||
}
|
||||
cmdArgs := make([]string, 0, argsLen+1)
|
||||
cmdArgs = append(cmdArgs, command)
|
||||
cmdArgs = append(cmdArgs, args...)
|
||||
// Build command slice: prepend command to args
|
||||
// Using append avoids arithmetic on potentially large len(args)
|
||||
cmdArgs := append([]string{command}, args...)
|
||||
|
||||
// Set write timeout
|
||||
if c.writeTimeout > 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(c.writeTimeout))
|
||||
}
|
||||
|
||||
// Write command (using pooled writer for memory efficiency)
|
||||
@@ -287,7 +288,7 @@ func (c *RedisConn) Do(command string, args ...string) (interface{}, error) {
|
||||
|
||||
// Set read timeout
|
||||
if c.readTimeout > 0 {
|
||||
c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
_ = c.conn.SetReadDeadline(time.Now().Add(c.readTimeout))
|
||||
}
|
||||
|
||||
// Read response (using pooled reader for memory efficiency)
|
||||
@@ -328,8 +329,8 @@ func (p *ConnectionPool) isConnectionHealthy(conn *RedisConn) bool {
|
||||
|
||||
// Set a read deadline for the ping
|
||||
if conn.conn != nil {
|
||||
conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer conn.conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||
_ = conn.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
defer func() { _ = conn.conn.SetReadDeadline(time.Time{}) }() // Clear deadline
|
||||
}
|
||||
|
||||
_, err := conn.Do("PING")
|
||||
|
||||
@@ -158,6 +158,7 @@ func (cb *CircuitBreaker) AllowRequest() bool {
|
||||
case StateHalfOpen:
|
||||
// Allow limited requests in half-open state
|
||||
current := cb.halfOpenRequests.Add(1)
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
return current <= int32(cb.config.HalfOpenMaxRequests)
|
||||
|
||||
default:
|
||||
@@ -181,6 +182,7 @@ func (cb *CircuitBreaker) RecordSuccess() {
|
||||
case StateHalfOpen:
|
||||
// If we've had enough successful requests, close the circuit
|
||||
successfulRequests := cb.halfOpenRequests.Load()
|
||||
// #nosec G115 -- HalfOpenMaxRequests is a small config value that fits in int32
|
||||
if successfulRequests >= int32(cb.config.HalfOpenMaxRequests) {
|
||||
cb.setState(StateClosed)
|
||||
cb.consecutiveFailures.Store(0)
|
||||
@@ -203,6 +205,7 @@ func (cb *CircuitBreaker) RecordFailure() {
|
||||
switch state {
|
||||
case StateClosed:
|
||||
// Check if we should open the circuit
|
||||
// #nosec G115 -- MaxFailures is a small config value that fits in int32
|
||||
if failures >= int32(cb.config.MaxFailures) {
|
||||
cb.openCircuit()
|
||||
} else if cb.config.FailureThreshold > 0 {
|
||||
|
||||
+2
@@ -217,6 +217,7 @@ func (hc *HealthChecker) recordSuccess(latency time.Duration) {
|
||||
newStatus := currentStatus
|
||||
|
||||
// Check if we should become healthy
|
||||
// #nosec G115 -- HealthyThreshold is a small config value that fits in int32
|
||||
if successes >= int32(hc.config.HealthyThreshold) {
|
||||
if latency > hc.config.DegradedThreshold {
|
||||
newStatus = HealthDegraded
|
||||
@@ -241,6 +242,7 @@ func (hc *HealthChecker) recordFailure() {
|
||||
hc.timeMu.Unlock()
|
||||
|
||||
// Check if we should become unhealthy
|
||||
// #nosec G115 -- UnhealthyThreshold is a small config value that fits in int32
|
||||
if failures >= int32(hc.config.UnhealthyThreshold) {
|
||||
hc.setStatus(HealthUnhealthy)
|
||||
}
|
||||
|
||||
@@ -150,6 +150,7 @@ func (h *HealthCheckBackend) IsHealthy() bool {
|
||||
|
||||
// recordResult records the result of an operation for health tracking
|
||||
func (h *HealthCheckBackend) recordResult(success bool) {
|
||||
// #nosec G115 -- threshold config values are small integers that fit in int32
|
||||
if success {
|
||||
fails := h.consecutiveFails.Swap(0)
|
||||
oks := h.consecutiveOK.Add(1)
|
||||
|
||||
@@ -304,7 +304,8 @@ func (f *Factory) createSecureTLSConfig() *tls.Config {
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
InsecureSkipVerify: false, // SECURITY: Always verify certificates
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it to false is safe
|
||||
PreferServerCipherSuites: false, // Let client choose best cipher
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,14 @@ func getOrCreateLogFile(filename string) io.Writer {
|
||||
}
|
||||
|
||||
// Ensure log directory exists
|
||||
// #nosec G301 -- log directory needs to be readable by monitoring tools
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
// Fall back to stderr if we can't create the directory
|
||||
return os.Stderr
|
||||
}
|
||||
|
||||
filepath := logDir + "/" + filename
|
||||
// #nosec G302 G304 -- log files need to be readable; path is from trusted env var
|
||||
file, err := os.OpenFile(filepath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
// Fall back to stderr if we can't open the file
|
||||
|
||||
@@ -107,6 +107,7 @@ const (
|
||||
JWTPattern = `^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`
|
||||
|
||||
// Bearer token pattern (Authorization header)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
BearerTokenPattern = `^Bearer\s+([A-Za-z0-9._~+/-]+=*)$`
|
||||
|
||||
// Client ID pattern (alphanumeric with common separators)
|
||||
@@ -119,6 +120,7 @@ const (
|
||||
SessionIDPattern = `^[a-fA-F0-9]{32,128}$`
|
||||
|
||||
// CSRF token pattern (base64url)
|
||||
// #nosec G101 -- This is a regex pattern for validation, not a hardcoded credential
|
||||
CSRFTokenPattern = `^[A-Za-z0-9_-]+$`
|
||||
|
||||
// Nonce pattern (base64url)
|
||||
|
||||
@@ -202,8 +202,10 @@ func (p *TransportPool) createTransport(config TransportConfig) *http.Transport
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
// #nosec G402 -- PreferServerCipherSuites is deprecated in Go 1.17+ but setting it is harmless
|
||||
PreferServerCipherSuites: true,
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
// #nosec G402 -- InsecureSkipVerify is configurable for testing/dev environments
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
}
|
||||
|
||||
return &http.Transport{
|
||||
|
||||
@@ -148,6 +148,7 @@ func (cb *CircuitBreaker) allowRequest() bool {
|
||||
// allowHalfOpenRequest checks if a request is allowed in half-open state
|
||||
func (cb *CircuitBreaker) allowHalfOpenRequest() bool {
|
||||
current := atomic.AddInt32(&cb.halfOpenRequests, 1)
|
||||
// #nosec G115 -- MaxRequests is a small config value that fits in int32
|
||||
if current <= int32(cb.config.MaxRequests) {
|
||||
return true
|
||||
}
|
||||
@@ -164,6 +165,7 @@ func (cb *CircuitBreaker) recordFailure() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- FailureThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerClosed && failures >= int32(cb.config.FailureThreshold) {
|
||||
cb.transitionToOpen()
|
||||
} else if state == CircuitBreakerHalfOpen {
|
||||
@@ -180,6 +182,7 @@ func (cb *CircuitBreaker) recordSuccess() {
|
||||
|
||||
state := CircuitBreakerState(atomic.LoadInt32(&cb.state))
|
||||
|
||||
// #nosec G115 -- SuccessThreshold is a small config value that fits in int32
|
||||
if state == CircuitBreakerHalfOpen && successes >= int32(cb.config.SuccessThreshold) {
|
||||
cb.transitionToClosed()
|
||||
}
|
||||
|
||||
@@ -191,6 +191,7 @@ func (re *RetryExecutor) calculateDelay(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
// Add jitter
|
||||
// #nosec G404 -- math/rand is acceptable for jitter timing, not security-sensitive
|
||||
if re.config.RandomizationFactor > 0 {
|
||||
jitter := delay * re.config.RandomizationFactor
|
||||
minDelay := delay - jitter
|
||||
|
||||
@@ -169,7 +169,10 @@ func (jwk *JWK) ToRSAPublicKey() (*rsa.PublicKey, error) {
|
||||
// Pad to 8 bytes for uint64
|
||||
paddedE := make([]byte, 8)
|
||||
copy(paddedE[8-len(eBytes):], eBytes)
|
||||
e = int(binary.BigEndian.Uint64(paddedE))
|
||||
eUint64 := binary.BigEndian.Uint64(paddedE)
|
||||
// RSA exponents are typically small (65537 is common), so overflow is not a concern
|
||||
// #nosec G115 -- RSA public exponents are small values that fit in int
|
||||
e = int(eUint64)
|
||||
} else {
|
||||
return nil, fmt.Errorf("exponent too large")
|
||||
}
|
||||
|
||||
@@ -177,6 +177,12 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}
|
||||
return "groups" // Backward compatible default
|
||||
}(),
|
||||
userIdentifierClaim: func() string {
|
||||
if config.UserIdentifierClaim != "" {
|
||||
return config.UserIdentifierClaim
|
||||
}
|
||||
return "email" // Backward compatible default
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
@@ -218,6 +224,8 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
securityHeadersApplier: config.GetSecurityHeadersApplier(),
|
||||
scopeFilter: NewScopeFilter(logger), // NEW - for discovery-based scope filtering
|
||||
dcrConfig: config.DynamicClientRegistration,
|
||||
allowPrivateIPAddresses: config.AllowPrivateIPAddresses,
|
||||
minimalHeaders: config.MinimalHeaders,
|
||||
}
|
||||
|
||||
// Log audience configuration
|
||||
|
||||
@@ -545,3 +545,168 @@ func createTestSessionManager(t *testing.T) *SessionManager {
|
||||
}
|
||||
return sm
|
||||
}
|
||||
|
||||
// TestMinimalHeaders tests the minimalHeaders configuration option
|
||||
// This addresses GitHub issue #64 - Request Header Fields Too Large
|
||||
func TestMinimalHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minimalHeaders bool
|
||||
expectForwardedUser bool
|
||||
expectAuthRequestUser bool
|
||||
expectAuthRequestRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "minimalHeaders=false (default) forwards all headers",
|
||||
minimalHeaders: false,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: true,
|
||||
expectAuthRequestRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "minimalHeaders=true only forwards X-Forwarded-User",
|
||||
minimalHeaders: true,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: false,
|
||||
expectAuthRequestRedirect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Track which headers were set
|
||||
var capturedHeaders http.Header
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
// Create request and get session properly through session manager
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set up session data
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Call processAuthorizedRequest directly
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is always set
|
||||
if tt.expectForwardedUser {
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-User
|
||||
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
|
||||
if tt.expectAuthRequestUser && !hasAuthRequestUser {
|
||||
t.Error("expected X-Auth-Request-User to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestUser && hasAuthRequestUser {
|
||||
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Redirect
|
||||
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
|
||||
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
|
||||
t.Error("expected X-Auth-Request-Redirect to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
|
||||
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
|
||||
}
|
||||
|
||||
// Note: X-Auth-Request-Token is only set if session.GetIDToken() returns non-empty.
|
||||
// Token storage has validation that may reject test tokens, so we verify the flag
|
||||
// logic through the other headers. The important behavior is that when
|
||||
// minimalHeaders=true, the token header would NOT be set even if a token existed.
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMinimalHeaders_TokenHeaderNotSet verifies that the X-Auth-Request-Token header
|
||||
// is NOT set when minimalHeaders is enabled, even if a token exists.
|
||||
func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
var capturedHeaders http.Header
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
sessionManager := createTestSessionManager(t)
|
||||
oidc := &TraefikOidc{
|
||||
next: next,
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
return map[string]interface{}{
|
||||
"email": "user@example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
oidc.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is set (always should be)
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
|
||||
// The key verification: X-Auth-Request-Token should NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-Token") != "" {
|
||||
t.Error("expected X-Auth-Request-Token to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
|
||||
// X-Auth-Request-User should also NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-User") != "" {
|
||||
t.Error("expected X-Auth-Request-User to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
|
||||
// X-Auth-Request-Redirect should also NOT be set with minimalHeaders=true
|
||||
if capturedHeaders.Get("X-Auth-Request-Redirect") != "" {
|
||||
t.Error("expected X-Auth-Request-Redirect to NOT be set with minimalHeaders=true")
|
||||
}
|
||||
}
|
||||
|
||||
+243
-19
@@ -122,22 +122,23 @@ func (ts *TestSuite) Setup() {
|
||||
|
||||
// Common TraefikOidc instance
|
||||
ts.tOidc = &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles", // Set default for backward compatibility
|
||||
groupClaimName: "groups", // Set default for backward compatibility
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
revocationURL: "https://revocation-endpoint.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
issuerURL: "https://test-issuer.com",
|
||||
clientID: "test-client-id",
|
||||
audience: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
roleClaimName: "roles", // Set default for backward compatibility
|
||||
groupClaimName: "groups", // Set default for backward compatibility
|
||||
userIdentifierClaim: "email", // Set default for backward compatibility
|
||||
jwkCache: ts.mockJWKCache,
|
||||
jwksURL: "https://test-jwks-url.com",
|
||||
revocationURL: "https://revocation-endpoint.com",
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: tokenBlacklist,
|
||||
tokenCache: tokenCache,
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
|
||||
httpClient: &http.Client{Timeout: 10 * time.Second},
|
||||
// Explicitly set paths as New() is bypassed
|
||||
redirURLPath: "/callback", // Assume default callback path for tests
|
||||
logoutURLPath: "/callback/logout", // Assume default logout path for tests
|
||||
@@ -784,7 +785,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
"Accept": "application/json",
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: Your email domain is not allowed. To log out, visit: /callback/logout","status_code":403}`,
|
||||
expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`,
|
||||
},
|
||||
{
|
||||
name: "Disallowed Domain (Accept: HTML)",
|
||||
@@ -1282,8 +1283,9 @@ func TestHandleCallback(t *testing.T) {
|
||||
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
|
||||
}
|
||||
tOidc := &TraefikOidc{
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: logger,
|
||||
userIdentifierClaim: "email", // Required for claim extraction
|
||||
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
|
||||
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
|
||||
tokenVerifier: nil, // Will be set to self below
|
||||
@@ -1438,6 +1440,228 @@ func TestIsAllowedDomain(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAllowedUser(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
allowedDomains map[string]struct{}
|
||||
allowedUsers map[string]struct{}
|
||||
userIdentifierClaim string
|
||||
name string
|
||||
userIdentifier string
|
||||
allowed bool
|
||||
}{
|
||||
// Email-based identification (default behavior)
|
||||
{
|
||||
name: "Email identifier - allowed domain",
|
||||
userIdentifier: "user@example.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Email identifier - disallowed domain",
|
||||
userIdentifier: "user@notallowed.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "Email identifier - specific user allowed",
|
||||
userIdentifier: "specific.user@otherdomain.com",
|
||||
userIdentifierClaim: "email",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}},
|
||||
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// Non-email identifier (sub claim - for Azure AD users without email)
|
||||
{
|
||||
name: "Sub identifier - allowed in allowedUsers",
|
||||
userIdentifier: "abc12345-6789-0abc-def0-123456789abc",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - not in allowedUsers",
|
||||
userIdentifier: "xyz-not-allowed-user",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
||||
allowed: false,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - allowedDomains ignored for non-email",
|
||||
userIdentifier: "user-id-12345",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored
|
||||
allowedUsers: map[string]struct{}{"user-id-12345": {}},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - no restrictions allows all",
|
||||
userIdentifier: "any-user-id",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{},
|
||||
allowed: true,
|
||||
},
|
||||
{
|
||||
name: "Sub identifier - case insensitive matching",
|
||||
userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// OID claim (Azure AD object ID)
|
||||
{
|
||||
name: "OID identifier - allowed user",
|
||||
userIdentifier: "oid-12345-67890",
|
||||
userIdentifierClaim: "oid",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"oid-12345-67890": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// UPN claim (Azure AD User Principal Name)
|
||||
{
|
||||
name: "UPN identifier - allowed user (looks like email but use sub logic)",
|
||||
userIdentifier: "user@tenant.onmicrosoft.com",
|
||||
userIdentifierClaim: "upn",
|
||||
allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored
|
||||
allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}},
|
||||
allowed: true,
|
||||
},
|
||||
|
||||
// Edge cases
|
||||
{
|
||||
name: "Empty identifier - not allowed",
|
||||
userIdentifier: "",
|
||||
userIdentifierClaim: "sub",
|
||||
allowedDomains: map[string]struct{}{},
|
||||
allowedUsers: map[string]struct{}{"some-user": {}},
|
||||
allowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Configure TraefikOidc instance for this test case
|
||||
tOidc := ts.tOidc
|
||||
tOidc.allowedUserDomains = tc.allowedDomains
|
||||
tOidc.allowedUsers = tc.allowedUsers
|
||||
tOidc.userIdentifierClaim = tc.userIdentifierClaim
|
||||
|
||||
allowed := tOidc.isAllowedUser(tc.userIdentifier)
|
||||
if allowed != tc.allowed {
|
||||
t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q",
|
||||
tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIdentifierClaimExtraction(t *testing.T) {
|
||||
// Test that the correct claim is extracted based on userIdentifierClaim config
|
||||
tests := []struct {
|
||||
name string
|
||||
userIdentifierClaim string
|
||||
claims map[string]interface{}
|
||||
expectedIdentifier string
|
||||
shouldFallbackToSub bool
|
||||
}{
|
||||
{
|
||||
name: "Email claim extraction (default)",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedIdentifier: "user@example.com",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "Sub claim extraction",
|
||||
userIdentifierClaim: "sub",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
},
|
||||
expectedIdentifier: "user-sub-id",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "OID claim extraction (Azure AD)",
|
||||
userIdentifierClaim: "oid",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"email": "user@example.com",
|
||||
"oid": "azure-object-id",
|
||||
},
|
||||
expectedIdentifier: "azure-object-id",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "UPN claim extraction (Azure AD)",
|
||||
userIdentifierClaim: "upn",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"upn": "user@tenant.onmicrosoft.com",
|
||||
},
|
||||
expectedIdentifier: "user@tenant.onmicrosoft.com",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
{
|
||||
name: "Fallback to sub when configured claim is missing",
|
||||
userIdentifierClaim: "email",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "fallback-sub-id",
|
||||
// email is missing
|
||||
},
|
||||
expectedIdentifier: "fallback-sub-id",
|
||||
shouldFallbackToSub: true,
|
||||
},
|
||||
{
|
||||
name: "preferred_username claim extraction",
|
||||
userIdentifierClaim: "preferred_username",
|
||||
claims: map[string]interface{}{
|
||||
"sub": "user-sub-id",
|
||||
"preferred_username": "jdoe",
|
||||
},
|
||||
expectedIdentifier: "jdoe",
|
||||
shouldFallbackToSub: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Extract user identifier using the same logic as auth_flow.go
|
||||
userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string)
|
||||
usedFallback := false
|
||||
|
||||
if userIdentifier == "" && tc.userIdentifierClaim != "sub" {
|
||||
userIdentifier, _ = tc.claims["sub"].(string)
|
||||
usedFallback = true
|
||||
}
|
||||
|
||||
if userIdentifier != tc.expectedIdentifier {
|
||||
t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier)
|
||||
}
|
||||
|
||||
if usedFallback != tc.shouldFallbackToSub {
|
||||
t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOIDCHandler(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
+8
-5
@@ -120,8 +120,9 @@ func NewMemoryMonitor(logger *Logger, thresholds MemoryAlertThresholds) *MemoryM
|
||||
alertThresholds: thresholds,
|
||||
baselineHeap: memStats.HeapAlloc,
|
||||
baselineGoroutines: runtime.NumGoroutine(),
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
lastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
lastGCCount: memStats.NumGC,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,9 +159,10 @@ func (mm *MemoryMonitor) GetCurrentStats() *MemoryStats {
|
||||
StackSysBytes: memStats.StackSys,
|
||||
GCSysBytes: memStats.GCSys,
|
||||
NumGoroutines: runtime.NumGoroutine(),
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: gcFrequency,
|
||||
Timestamp: now,
|
||||
// #nosec G115 -- LastGC nanoseconds fits in int64 for centuries
|
||||
LastGCTime: time.Unix(0, int64(memStats.LastGC)),
|
||||
GCFrequency: gcFrequency,
|
||||
Timestamp: now,
|
||||
}
|
||||
|
||||
// Get application-specific stats
|
||||
@@ -386,6 +388,7 @@ func (mm *MemoryMonitor) TriggerGC() {
|
||||
|
||||
after := mm.GetCurrentStats()
|
||||
|
||||
// #nosec G115 -- heap allocation bytes fit in int64 for practical purposes
|
||||
freedBytes := int64(before.HeapAllocBytes) - int64(after.HeapAllocBytes)
|
||||
freedMB := float64(freedBytes) / (1024 * 1024)
|
||||
|
||||
|
||||
+49
-4
@@ -189,11 +189,56 @@ func (mc *MetadataCache) GetMetadata(providerURL string, httpClient *http.Client
|
||||
return mc.GetProviderMetadata(ctx, providerURL, httpClient)
|
||||
}
|
||||
|
||||
// GetMetadataWithRecovery fetches metadata with recovery support
|
||||
// GetMetadataWithRecovery fetches metadata with retry support for startup scenarios.
|
||||
// This handles the race condition where Traefik initializes the plugin before the
|
||||
// OIDC provider routes are fully established, or before TLS certificates are loaded.
|
||||
// Uses aggressive retry settings (10 attempts, 1s intervals) to give the infrastructure
|
||||
// time to stabilize during cold starts.
|
||||
// See: https://github.com/lukaszraczylo/traefikoidc/issues/90
|
||||
func (mc *MetadataCache) GetMetadataWithRecovery(providerURL string, httpClient *http.Client, logger *Logger, errorRecoveryManager *ErrorRecoveryManager) (*ProviderMetadata, error) {
|
||||
// For now, just use regular GetMetadata
|
||||
// Recovery would be handled by ErrorRecoveryManager if needed
|
||||
return mc.GetMetadata(providerURL, httpClient, logger)
|
||||
// Check cache first - if we have valid cached metadata, use it
|
||||
if metadata, exists := mc.Get(providerURL); exists {
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// Create a retry executor with metadata-fetch-specific configuration
|
||||
retryConfig := MetadataFetchRetryConfig()
|
||||
retryExecutor := NewRetryExecutor(retryConfig, logger)
|
||||
|
||||
var metadata *ProviderMetadata
|
||||
var lastErr error
|
||||
|
||||
// Use context with overall timeout for the entire retry sequence
|
||||
// 10 attempts * ~10s max delay = ~100s worst case, so use 2 minute timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
err := retryExecutor.ExecuteWithContext(ctx, func() error {
|
||||
// Create per-attempt context with shorter timeout
|
||||
attemptCtx, attemptCancel := context.WithTimeout(ctx, 15*time.Second)
|
||||
defer attemptCancel()
|
||||
|
||||
var fetchErr error
|
||||
metadata, fetchErr = mc.GetProviderMetadata(attemptCtx, providerURL, httpClient)
|
||||
if fetchErr != nil {
|
||||
lastErr = fetchErr
|
||||
if logger != nil {
|
||||
logger.Debugf("Metadata fetch attempt failed: %v", fetchErr)
|
||||
}
|
||||
return fetchErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
// Return the last actual error, not the retry wrapper error
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return metadata, nil
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics for testing
|
||||
|
||||
+17
-14
@@ -125,12 +125,12 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
email := session.GetEmail()
|
||||
// Domain restriction check removed debug output
|
||||
if authenticated && email != "" {
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
||||
userIdentifier := session.GetEmail() // GetEmail returns the stored user identifier (email or other claim)
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -193,10 +193,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if refreshed {
|
||||
email = session.GetEmail()
|
||||
if email != "" && !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with refreshed token email %s is not from an allowed domain", email)
|
||||
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
||||
userIdentifier = session.GetEmail() // GetEmail returns the stored user identifier
|
||||
if userIdentifier != "" && !t.isAllowedUser(userIdentifier) {
|
||||
t.logger.Infof("User with refreshed token %s is not authorized", userIdentifier)
|
||||
errorMsg := fmt.Sprintf("Access denied: You are not authorized to access this resource. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
@@ -308,10 +308,13 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !t.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
if len(t.headerTemplates) > 0 {
|
||||
|
||||
@@ -41,6 +41,7 @@ type AuthMiddleware struct {
|
||||
goroutineWG *sync.WaitGroup
|
||||
startTokenCleanupFunc func()
|
||||
startMetadataRefreshFunc func(string)
|
||||
minimalHeaders bool
|
||||
}
|
||||
|
||||
// Logger interface for dependency injection
|
||||
@@ -120,6 +121,7 @@ func NewAuthMiddleware(
|
||||
goroutineWG *sync.WaitGroup,
|
||||
startTokenCleanupFunc func(),
|
||||
startMetadataRefreshFunc func(string),
|
||||
minimalHeaders bool,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
logger: logger,
|
||||
@@ -149,6 +151,7 @@ func NewAuthMiddleware(
|
||||
goroutineWG: goroutineWG,
|
||||
startTokenCleanupFunc: startTokenCleanupFunc,
|
||||
startMetadataRefreshFunc: startMetadataRefreshFunc,
|
||||
minimalHeaders: minimalHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -414,10 +417,14 @@ func (m *AuthMiddleware) processAuthorizedRequest(rw http.ResponseWriter, req *h
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
|
||||
// When minimalHeaders is enabled, skip extra headers to prevent 431 errors
|
||||
if !m.minimalHeaders {
|
||||
req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI())
|
||||
req.Header.Set("X-Auth-Request-User", email)
|
||||
if idToken := session.GetIDToken(); idToken != "" {
|
||||
req.Header.Set("X-Auth-Request-Token", idToken)
|
||||
}
|
||||
}
|
||||
|
||||
m.next.ServeHTTP(rw, req)
|
||||
|
||||
@@ -66,6 +66,7 @@ func TestNewAuthMiddleware(t *testing.T) {
|
||||
wg,
|
||||
startTokenCleanup,
|
||||
startMetadataRefresh,
|
||||
false, // minimalHeaders
|
||||
)
|
||||
|
||||
if m == nil {
|
||||
|
||||
@@ -802,3 +802,99 @@ func TestServeHTTP_AdditionalCoverage(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestProcessAuthorizedRequest_MinimalHeaders tests the minimalHeaders configuration
|
||||
// This addresses GitHub issue #64 - Request Header Fields Too Large
|
||||
func TestProcessAuthorizedRequest_MinimalHeaders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
minimalHeaders bool
|
||||
expectForwardedUser bool
|
||||
expectAuthRequestUser bool
|
||||
expectAuthRequestToken bool
|
||||
expectAuthRequestRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "minimalHeaders=false forwards all headers",
|
||||
minimalHeaders: false,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: true,
|
||||
expectAuthRequestToken: true,
|
||||
expectAuthRequestRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "minimalHeaders=true only forwards X-Forwarded-User",
|
||||
minimalHeaders: true,
|
||||
expectForwardedUser: true,
|
||||
expectAuthRequestUser: false,
|
||||
expectAuthRequestToken: false,
|
||||
expectAuthRequestRedirect: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logger := &mockLogger{}
|
||||
var capturedHeaders http.Header
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedHeaders = r.Header.Clone()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
session := &mockSessionData{
|
||||
email: "user@example.com",
|
||||
idToken: "test-id-token-that-could-be-very-large",
|
||||
accessToken: "test-access-token",
|
||||
}
|
||||
|
||||
m := &AuthMiddleware{
|
||||
logger: logger,
|
||||
next: nextHandler,
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) {
|
||||
return nil, nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
m.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
||||
|
||||
// Verify X-Forwarded-User is always set
|
||||
if tt.expectForwardedUser {
|
||||
if capturedHeaders.Get("X-Forwarded-User") != "user@example.com" {
|
||||
t.Errorf("expected X-Forwarded-User to be set, got %q", capturedHeaders.Get("X-Forwarded-User"))
|
||||
}
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-User
|
||||
hasAuthRequestUser := capturedHeaders.Get("X-Auth-Request-User") != ""
|
||||
if tt.expectAuthRequestUser && !hasAuthRequestUser {
|
||||
t.Error("expected X-Auth-Request-User to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestUser && hasAuthRequestUser {
|
||||
t.Errorf("expected X-Auth-Request-User to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-User"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Token (the big one that causes 431 errors)
|
||||
hasAuthRequestToken := capturedHeaders.Get("X-Auth-Request-Token") != ""
|
||||
if tt.expectAuthRequestToken && !hasAuthRequestToken {
|
||||
t.Error("expected X-Auth-Request-Token to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestToken && hasAuthRequestToken {
|
||||
t.Errorf("expected X-Auth-Request-Token to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Token"))
|
||||
}
|
||||
|
||||
// Verify X-Auth-Request-Redirect
|
||||
hasAuthRequestRedirect := capturedHeaders.Get("X-Auth-Request-Redirect") != ""
|
||||
if tt.expectAuthRequestRedirect && !hasAuthRequestRedirect {
|
||||
t.Error("expected X-Auth-Request-Redirect to be set")
|
||||
}
|
||||
if !tt.expectAuthRequestRedirect && hasAuthRequestRedirect {
|
||||
t.Errorf("expected X-Auth-Request-Redirect to NOT be set when minimalHeaders=true, got %q", capturedHeaders.Get("X-Auth-Request-Redirect"))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func generateSecureRandomString(length int) (string, error) {
|
||||
}
|
||||
|
||||
// Cookie names and configuration constants used for session management
|
||||
// #nosec G101 -- These are cookie names, not hardcoded credentials
|
||||
const (
|
||||
mainCookieName = "_oidc_raczylo_m"
|
||||
accessTokenCookie = "_oidc_raczylo_a"
|
||||
|
||||
+53
@@ -127,10 +127,63 @@ type Config struct {
|
||||
// Default: "groups"
|
||||
GroupClaimName string `json:"groupClaimName,omitempty"`
|
||||
|
||||
// UserIdentifierClaim specifies the JWT claim to use as the user identifier.
|
||||
// This allows authentication for users without email addresses (e.g., Azure AD service accounts).
|
||||
//
|
||||
// Examples:
|
||||
// - Default (backward compatible): "email"
|
||||
// - Azure AD without email: "sub", "oid", "upn", or "preferred_username"
|
||||
// - Generic OIDC: "sub" (always present per OIDC spec)
|
||||
//
|
||||
// When set to a non-email claim:
|
||||
// - AllowedUsers will match against this claim value instead of email
|
||||
// - AllowedUserDomains validation is skipped (domains only apply to email)
|
||||
// - The session will store this identifier as the user's identity
|
||||
//
|
||||
// Default: "email"
|
||||
UserIdentifierClaim string `json:"userIdentifierClaim,omitempty"`
|
||||
|
||||
// DynamicClientRegistration enables OIDC Dynamic Client Registration (RFC 7591)
|
||||
// When enabled, the middleware will automatically register as a client with
|
||||
// the OIDC provider if ClientID/ClientSecret are not provided.
|
||||
DynamicClientRegistration *DynamicClientRegistrationConfig `json:"dynamicClientRegistration,omitempty"`
|
||||
|
||||
// AllowPrivateIPAddresses disables the security check that blocks private/internal IP addresses.
|
||||
// By default, the plugin rejects URLs containing private IP ranges (10.x.x.x, 172.16-31.x.x, 192.168.x.x)
|
||||
// to prevent SSRF attacks and ensure OIDC providers are publicly accessible.
|
||||
//
|
||||
// Enable this option ONLY when:
|
||||
// - Your OIDC provider (e.g., Keycloak) runs on an internal network with private IPs
|
||||
// - You have no DNS resolution available for internal services
|
||||
// - Your entire stack runs in a Docker network or Kubernetes cluster with private addressing
|
||||
//
|
||||
// Security Warning: Enabling this option reduces SSRF protection. Only use in trusted
|
||||
// network environments where the OIDC provider is known and controlled.
|
||||
//
|
||||
// Default: false (private IPs are blocked for security)
|
||||
AllowPrivateIPAddresses bool `json:"allowPrivateIPAddresses,omitempty"`
|
||||
|
||||
// MinimalHeaders reduces the number of headers forwarded to downstream services.
|
||||
// This helps prevent "431 Request Header Fields Too Large" errors when downstream
|
||||
// services have limited header buffer sizes.
|
||||
//
|
||||
// When enabled (true):
|
||||
// - Only forwards: X-Forwarded-User
|
||||
// - Skips: X-Auth-Request-Token (full ID token), X-Auth-Request-Redirect
|
||||
// - Groups/roles headers (X-User-Groups, X-User-Roles) are still forwarded if configured
|
||||
// - Custom templated headers are still processed
|
||||
//
|
||||
// When disabled (false, default):
|
||||
// - Forwards all headers: X-Forwarded-User, X-Auth-Request-User, X-Auth-Request-Redirect,
|
||||
// X-Auth-Request-Token (full ID token)
|
||||
//
|
||||
// Use this option when:
|
||||
// - Downstream services return "431 Request Header Fields Too Large" errors
|
||||
// - You don't need the full ID token forwarded to backend services
|
||||
// - You want to reduce request overhead
|
||||
//
|
||||
// Default: false (all headers forwarded for backward compatibility)
|
||||
MinimalHeaders bool `json:"minimalHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// RedisConfig configures Redis cache backend settings for distributed caching.
|
||||
|
||||
+3
-2
@@ -51,7 +51,8 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
|
||||
}
|
||||
|
||||
return &ShardedCache{
|
||||
shards: shards,
|
||||
shards: shards,
|
||||
// #nosec G115 -- numShards is validated to be positive and small (typically 32-256)
|
||||
numShards: uint32(numShards),
|
||||
maxPerShard: maxPerShard,
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func NewShardedCache(numShards int, maxSize int) *ShardedCache {
|
||||
// FNV-1a is fast and provides good distribution.
|
||||
func (c *ShardedCache) getShard(key string) *cacheShard {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(key))
|
||||
_, _ = h.Write([]byte(key)) // hash.Hash.Write never returns an error
|
||||
return c.shards[h.Sum32()%c.numShards]
|
||||
}
|
||||
|
||||
|
||||
@@ -734,6 +734,7 @@ func (r *TestSuiteRunner) RunMemoryLeakTests(t *testing.T, tests []MemoryLeakTes
|
||||
}
|
||||
|
||||
// Check memory growth
|
||||
// #nosec G115 -- memory stats are within int64 range for practical purposes
|
||||
memoryGrowthBytes := int64(finalMem.Alloc) - int64(initialMem.Alloc)
|
||||
memoryGrowthMB := float64(memoryGrowthBytes) / (1024 * 1024)
|
||||
|
||||
|
||||
@@ -99,6 +99,7 @@ type TraefikOidc struct {
|
||||
audience string // Expected JWT audience, defaults to clientID
|
||||
roleClaimName string // JWT claim name for extracting roles, defaults to "roles"
|
||||
groupClaimName string // JWT claim name for extracting groups, defaults to "groups"
|
||||
userIdentifierClaim string // JWT claim for user identification, defaults to "email"
|
||||
name string
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
@@ -128,6 +129,8 @@ type TraefikOidc struct {
|
||||
suppressDiagnosticLogs bool
|
||||
firstRequestReceived bool
|
||||
metadataRefreshStarted bool
|
||||
allowPrivateIPAddresses bool // Allow private IP addresses in URLs (for internal networks)
|
||||
minimalHeaders bool // Reduce headers to prevent 431 errors
|
||||
securityHeadersApplier func(http.ResponseWriter, *http.Request)
|
||||
scopeFilter *ScopeFilter // NEW - for discovery-based scope filtering
|
||||
scopesSupported []string // NEW - from provider metadata
|
||||
|
||||
+8
-1
@@ -340,6 +340,7 @@ func (t *TraefikOidc) validateParsedURL(u *url.URL) error {
|
||||
|
||||
// validateHost validates a hostname or IP address for security.
|
||||
// It prevents access to localhost, private networks, and known metadata endpoints.
|
||||
// When allowPrivateIPAddresses is enabled, private IP checks are skipped.
|
||||
// Parameters:
|
||||
// - host: The host string to validate (may include port).
|
||||
//
|
||||
@@ -357,7 +358,13 @@ func (t *TraefikOidc) validateHost(host string) error {
|
||||
|
||||
ip := net.ParseIP(hostname)
|
||||
if ip != nil {
|
||||
if ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
// Always block loopback, link-local, and multicast addresses
|
||||
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||||
return fmt.Errorf("access to loopback/link-local IP addresses is not allowed: %s", ip.String())
|
||||
}
|
||||
|
||||
// Skip private IP check if allowPrivateIPAddresses is enabled
|
||||
if !t.allowPrivateIPAddresses && ip.IsPrivate() {
|
||||
return fmt.Errorf("access to private/internal IP addresses is not allowed: %s", ip.String())
|
||||
}
|
||||
|
||||
|
||||
@@ -55,6 +55,51 @@ func (t *TraefikOidc) safeLogInfo(msg string) {
|
||||
// DOMAIN VALIDATION
|
||||
// =============================================================================
|
||||
|
||||
// isAllowedUser checks if a user identifier is authorized based on the configured user identifier claim.
|
||||
// When using email as the identifier (default), it validates against allowedUsers and allowedUserDomains.
|
||||
// When using non-email identifiers (sub, oid, upn, etc.), it only validates against allowedUsers
|
||||
// since domain-based validation doesn't apply to non-email identifiers.
|
||||
//
|
||||
// Parameters:
|
||||
// - userIdentifier: The user identifier to validate (email, sub, oid, upn, etc.).
|
||||
//
|
||||
// Returns:
|
||||
// - true if the user is authorized, false otherwise.
|
||||
func (t *TraefikOidc) isAllowedUser(userIdentifier string) bool {
|
||||
// If no restrictions are configured, allow all authenticated users
|
||||
if len(t.allowedUserDomains) == 0 && len(t.allowedUsers) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if user is explicitly allowed
|
||||
if len(t.allowedUsers) > 0 {
|
||||
_, userAllowed := t.allowedUsers[strings.ToLower(userIdentifier)]
|
||||
if userAllowed {
|
||||
t.logger.Debugf("User identifier %s is explicitly allowed in allowedUsers", userIdentifier)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// For email-based identifiers, also check domain restrictions
|
||||
// Only apply domain validation if using email as identifier AND identifier looks like an email
|
||||
if t.userIdentifierClaim == "email" && strings.Contains(userIdentifier, "@") {
|
||||
return t.isAllowedDomain(userIdentifier)
|
||||
}
|
||||
|
||||
// For non-email identifiers with allowedUserDomains configured, log a warning
|
||||
if len(t.allowedUserDomains) > 0 && t.userIdentifierClaim != "email" {
|
||||
t.logger.Debugf("AllowedUserDomains is configured but userIdentifierClaim is '%s', not 'email'. Domain validation skipped for: %s",
|
||||
t.userIdentifierClaim, userIdentifier)
|
||||
}
|
||||
|
||||
// User not found in allowedUsers list
|
||||
if len(t.allowedUsers) > 0 {
|
||||
t.logger.Debugf("User identifier %s is not in the allowed users list", userIdentifier)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isAllowedDomain checks if an email address is authorized based on domain or user whitelist.
|
||||
// It validates against both allowed user domains and specific allowed users.
|
||||
// Parameters:
|
||||
|
||||
+4
@@ -9,3 +9,7 @@ coverage.txt
|
||||
**/coverage.txt
|
||||
.vscode
|
||||
tmp/*
|
||||
*.test
|
||||
|
||||
# maintenanceNotifications upgrade documentation (temporary)
|
||||
maintenanceNotifications/docs/
|
||||
|
||||
+2
-2
@@ -1,8 +1,8 @@
|
||||
GO_MOD_DIRS := $(shell find . -type f -name 'go.mod' -exec dirname {} \; | sort)
|
||||
REDIS_VERSION ?= 8.2
|
||||
REDIS_VERSION ?= 8.4
|
||||
RE_CLUSTER ?= false
|
||||
RCE_DOCKER ?= true
|
||||
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.2.1-pre
|
||||
CLIENT_LIBS_TEST_IMAGE ?= redislabs/client-libs-test:8.4.0
|
||||
|
||||
docker.start:
|
||||
export RE_CLUSTER=$(RE_CLUSTER) && \
|
||||
|
||||
+149
-14
@@ -2,7 +2,7 @@
|
||||
|
||||
[](https://github.com/redis/go-redis/actions)
|
||||
[](https://pkg.go.dev/github.com/redis/go-redis/v9?tab=doc)
|
||||
[](https://redis.uptrace.dev/)
|
||||
[](https://redis.io/docs/latest/develop/clients/go/)
|
||||
[](https://goreportcard.com/report/github.com/redis/go-redis/v9)
|
||||
[](https://codecov.io/github/redis/go-redis)
|
||||
|
||||
@@ -17,15 +17,15 @@
|
||||
## Supported versions
|
||||
|
||||
In `go-redis` we are aiming to support the last three releases of Redis. Currently, this means we do support:
|
||||
- [Redis 7.2](https://raw.githubusercontent.com/redis/redis/7.2/00-RELEASENOTES) - using Redis Stack 7.2 for modules support
|
||||
- [Redis 7.4](https://raw.githubusercontent.com/redis/redis/7.4/00-RELEASENOTES) - using Redis Stack 7.4 for modules support
|
||||
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0 where modules are included
|
||||
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2 where modules are included
|
||||
- [Redis 8.0](https://raw.githubusercontent.com/redis/redis/8.0/00-RELEASENOTES) - using Redis CE 8.0
|
||||
- [Redis 8.2](https://raw.githubusercontent.com/redis/redis/8.2/00-RELEASENOTES) - using Redis CE 8.2
|
||||
- [Redis 8.4](https://raw.githubusercontent.com/redis/redis/8.4/00-RELEASENOTES) - using Redis CE 8.4
|
||||
|
||||
Although the `go.mod` states it requires at minimum `go 1.18`, our CI is configured to run the tests against all three
|
||||
versions of Redis and latest two versions of Go ([1.23](https://go.dev/doc/devel/release#go1.23.0),
|
||||
[1.24](https://go.dev/doc/devel/release#go1.24.0)). We observe that some modules related test may not pass with
|
||||
Redis Stack 7.2 and some commands are changed with Redis CE 8.0.
|
||||
Although it is not officially supported, `go-redis/v9` should be able to work with any Redis 7.0+.
|
||||
Please do refer to the documentation and the tests if you experience any issues. We do plan to update the go version
|
||||
in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
|
||||
@@ -43,10 +43,6 @@ in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
|
||||
[Work at Redis](https://redis.com/company/careers/jobs/)
|
||||
|
||||
## Documentation
|
||||
|
||||
- [English](https://redis.uptrace.dev)
|
||||
- [简体中文](https://redis.uptrace.dev/zh/)
|
||||
|
||||
## Resources
|
||||
|
||||
@@ -55,16 +51,18 @@ in the `go.mod` to `go 1.24` in one of the next releases.
|
||||
- [Reference](https://pkg.go.dev/github.com/redis/go-redis/v9)
|
||||
- [Examples](https://pkg.go.dev/github.com/redis/go-redis/v9#pkg-examples)
|
||||
|
||||
## old documentation
|
||||
|
||||
- [English](https://redis.uptrace.dev)
|
||||
- [简体中文](https://redis.uptrace.dev/zh/)
|
||||
|
||||
## Ecosystem
|
||||
|
||||
- [Redis Mock](https://github.com/go-redis/redismock)
|
||||
- [Entra ID (Azure AD)](https://github.com/redis/go-redis-entraid)
|
||||
- [Distributed Locks](https://github.com/bsm/redislock)
|
||||
- [Redis Cache](https://github.com/go-redis/cache)
|
||||
- [Rate limiting](https://github.com/go-redis/redis_rate)
|
||||
|
||||
This client also works with [Kvrocks](https://github.com/apache/incubator-kvrocks), a distributed
|
||||
key value NoSQL database that uses RocksDB as storage engine and is compatible with Redis protocol.
|
||||
|
||||
## Features
|
||||
|
||||
- Redis commands except QUIT and SYNC.
|
||||
@@ -75,7 +73,6 @@ key value NoSQL database that uses RocksDB as storage engine and is compatible w
|
||||
- [Scripting](https://redis.uptrace.dev/guide/lua-scripting.html).
|
||||
- [Redis Sentinel](https://redis.uptrace.dev/guide/go-redis-sentinel.html).
|
||||
- [Redis Cluster](https://redis.uptrace.dev/guide/go-redis-cluster.html).
|
||||
- [Redis Ring](https://redis.uptrace.dev/guide/ring.html).
|
||||
- [Redis Performance Monitoring](https://redis.uptrace.dev/guide/redis-performance-monitoring.html).
|
||||
- [Redis Probabilistic [RedisStack]](https://redis.io/docs/data-types/probabilistic/)
|
||||
- [Customizable read and write buffers size.](#custom-buffer-sizes)
|
||||
@@ -429,6 +426,144 @@ vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello")
|
||||
res, err := rdb.Do(ctx, "set", "key", "value").Result()
|
||||
```
|
||||
|
||||
## Typed Errors
|
||||
|
||||
go-redis provides typed error checking functions for common Redis errors:
|
||||
|
||||
```go
|
||||
// Cluster and replication errors
|
||||
redis.IsLoadingError(err) // Redis is loading the dataset
|
||||
redis.IsReadOnlyError(err) // Write to read-only replica
|
||||
redis.IsClusterDownError(err) // Cluster is down
|
||||
redis.IsTryAgainError(err) // Command should be retried
|
||||
redis.IsMasterDownError(err) // Master is down
|
||||
redis.IsMovedError(err) // Returns (address, true) if key moved
|
||||
redis.IsAskError(err) // Returns (address, true) if key being migrated
|
||||
|
||||
// Connection and resource errors
|
||||
redis.IsMaxClientsError(err) // Maximum clients reached
|
||||
redis.IsAuthError(err) // Authentication failed (NOAUTH, WRONGPASS, unauthenticated)
|
||||
redis.IsPermissionError(err) // Permission denied (NOPERM)
|
||||
redis.IsOOMError(err) // Out of memory (OOM)
|
||||
|
||||
// Transaction errors
|
||||
redis.IsExecAbortError(err) // Transaction aborted (EXECABORT)
|
||||
```
|
||||
|
||||
### Error Wrapping in Hooks
|
||||
|
||||
When wrapping errors in hooks, use custom error types with `Unwrap()` method (preferred) or `fmt.Errorf` with `%w`. Always call `cmd.SetErr()` to preserve error type information:
|
||||
|
||||
```go
|
||||
// Custom error type (preferred)
|
||||
type AppError struct {
|
||||
Code string
|
||||
RequestID string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *AppError) Error() string {
|
||||
return fmt.Sprintf("[%s] request_id=%s: %v", e.Code, e.RequestID, e.Err)
|
||||
}
|
||||
|
||||
func (e *AppError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// Hook implementation
|
||||
func (h MyHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||
return func(ctx context.Context, cmd redis.Cmder) error {
|
||||
err := next(ctx, cmd)
|
||||
if err != nil {
|
||||
// Wrap with custom error type
|
||||
wrappedErr := &AppError{
|
||||
Code: "REDIS_ERROR",
|
||||
RequestID: getRequestID(ctx),
|
||||
Err: err,
|
||||
}
|
||||
cmd.SetErr(wrappedErr)
|
||||
return wrappedErr // Return wrapped error to preserve it
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Typed error detection works through wrappers
|
||||
if redis.IsLoadingError(err) {
|
||||
// Retry logic
|
||||
}
|
||||
|
||||
// Extract custom error if needed
|
||||
var appErr *AppError
|
||||
if errors.As(err, &appErr) {
|
||||
log.Printf("Request: %s", appErr.RequestID)
|
||||
}
|
||||
```
|
||||
|
||||
Alternatively, use `fmt.Errorf` with `%w`:
|
||||
```go
|
||||
wrappedErr := fmt.Errorf("context: %w", err)
|
||||
cmd.SetErr(wrappedErr)
|
||||
```
|
||||
|
||||
### Pipeline Hook Example
|
||||
|
||||
For pipeline operations, use `ProcessPipelineHook`:
|
||||
|
||||
```go
|
||||
type PipelineLoggingHook struct{}
|
||||
|
||||
func (h PipelineLoggingHook) DialHook(next redis.DialHook) redis.DialHook {
|
||||
return next
|
||||
}
|
||||
|
||||
func (h PipelineLoggingHook) ProcessHook(next redis.ProcessHook) redis.ProcessHook {
|
||||
return next
|
||||
}
|
||||
|
||||
func (h PipelineLoggingHook) ProcessPipelineHook(next redis.ProcessPipelineHook) redis.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redis.Cmder) error {
|
||||
start := time.Now()
|
||||
|
||||
// Execute the pipeline
|
||||
err := next(ctx, cmds)
|
||||
|
||||
duration := time.Since(start)
|
||||
log.Printf("Pipeline executed %d commands in %v", len(cmds), duration)
|
||||
|
||||
// Process individual command errors
|
||||
// Note: Individual command errors are already set on each cmd by the pipeline execution
|
||||
for _, cmd := range cmds {
|
||||
if cmdErr := cmd.Err(); cmdErr != nil {
|
||||
// Check for specific error types using typed error functions
|
||||
if redis.IsAuthError(cmdErr) {
|
||||
log.Printf("Auth error in pipeline command %s: %v", cmd.Name(), cmdErr)
|
||||
} else if redis.IsPermissionError(cmdErr) {
|
||||
log.Printf("Permission error in pipeline command %s: %v", cmd.Name(), cmdErr)
|
||||
}
|
||||
|
||||
// Optionally wrap individual command errors to add context
|
||||
// The wrapped error preserves type information through errors.As()
|
||||
wrappedErr := fmt.Errorf("pipeline cmd %s failed: %w", cmd.Name(), cmdErr)
|
||||
cmd.SetErr(wrappedErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the pipeline-level error (connection errors, etc.)
|
||||
// You can wrap it if needed, or return it as-is
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Register the hook
|
||||
rdb.AddHook(PipelineLoggingHook{})
|
||||
|
||||
// Use pipeline - errors are still properly typed
|
||||
pipe := rdb.Pipeline()
|
||||
pipe.Set(ctx, "key1", "value1", 0)
|
||||
pipe.Get(ctx, "key2")
|
||||
_, err := pipe.Exec(ctx)
|
||||
```
|
||||
|
||||
## Run the test
|
||||
|
||||
|
||||
+224
@@ -1,5 +1,229 @@
|
||||
# Release Notes
|
||||
|
||||
# 9.17.2 (2025-12-01)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **Connection Pool**: Fixed critical race condition in turn management that could cause connection leaks when dial goroutines complete after request timeout ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
|
||||
- **Context Timeout**: Improved context timeout calculation to use minimum of remaining time and DialTimeout, preventing goroutines from waiting longer than necessary ([#3626](https://github.com/redis/go-redis/pull/3626)) by [@cyningsun](https://github.com/cyningsun)
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.54.0 to 0.55.0 ([#3627](https://github.com/redis/go-redis/pull/3627))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@cyningsun](https://github.com/cyningsun) and [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.1...v9.17.2
|
||||
|
||||
# 9.17.1 (2025-11-25)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- add wait to keyless commands list ([#3615](https://github.com/redis/go-redis/pull/3615)) by [@marcoferrer](https://github.com/marcoferrer)
|
||||
- fix(time): remove cached time optimization ([#3611](https://github.com/redis/go-redis/pull/3611)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump golangci/golangci-lint-action from 9.0.0 to 9.1.0 ([#3609](https://github.com/redis/go-redis/pull/3609))
|
||||
- chore(deps): bump actions/checkout from 5 to 6 ([#3610](https://github.com/redis/go-redis/pull/3610))
|
||||
- chore(script): fix help call in tag.sh ([#3606](https://github.com/redis/go-redis/pull/3606)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@marcoferrer](https://github.com/marcoferrer) and [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.17.0...v9.17.1
|
||||
|
||||
# 9.17.0 (2025-11-19)
|
||||
|
||||
## 🚀 Highlights
|
||||
|
||||
### Redis 8.4 Support
|
||||
Added support for Redis 8.4, including new commands and features ([#3572](https://github.com/redis/go-redis/pull/3572))
|
||||
|
||||
### Typed Errors
|
||||
Introduced typed errors for better error handling using `errors.As` instead of string checks. Errors can now be wrapped and set to commands in hooks without breaking library functionality ([#3602](https://github.com/redis/go-redis/pull/3602))
|
||||
|
||||
### New Commands
|
||||
- **CAS/CAD Commands**: Added support for Compare-And-Set/Compare-And-Delete operations with conditional matching (`IFEQ`, `IFNE`, `IFDEQ`, `IFDNE`) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595))
|
||||
- **MSETEX**: Atomically set multiple key-value pairs with expiration options and conditional modes ([#3580](https://github.com/redis/go-redis/pull/3580))
|
||||
- **XReadGroup CLAIM**: Consume both incoming and idle pending entries from streams in a single call ([#3578](https://github.com/redis/go-redis/pull/3578))
|
||||
- **ACL Commands**: Added `ACLGenPass`, `ACLUsers`, and `ACLWhoAmI` ([#3576](https://github.com/redis/go-redis/pull/3576))
|
||||
- **SLOWLOG Commands**: Added `SLOWLOG LEN` and `SLOWLOG RESET` ([#3585](https://github.com/redis/go-redis/pull/3585))
|
||||
- **LATENCY Commands**: Added `LATENCY LATEST` and `LATENCY RESET` ([#3584](https://github.com/redis/go-redis/pull/3584))
|
||||
|
||||
### Search & Vector Improvements
|
||||
- **Hybrid Search**: Added **EXPERIMENTAL** support for the new `FT.HYBRID` command ([#3573](https://github.com/redis/go-redis/pull/3573))
|
||||
- **Vector Range**: Added `VRANGE` command for vector sets ([#3543](https://github.com/redis/go-redis/pull/3543))
|
||||
- **FT.INFO Enhancements**: Added vector-specific attributes in FT.INFO response ([#3596](https://github.com/redis/go-redis/pull/3596))
|
||||
|
||||
### Connection Pool Improvements
|
||||
- **Improved Connection Success Rate**: Implemented FIFO queue-based fairness and context pattern for connection creation to prevent premature cancellation under high concurrency ([#3518](https://github.com/redis/go-redis/pull/3518))
|
||||
- **Connection State Machine**: Resolved race conditions and improved pool performance with proper state tracking ([#3559](https://github.com/redis/go-redis/pull/3559))
|
||||
- **Pool Performance**: Significant performance improvements with faster semaphores, lockless hook manager, and reduced allocations (47-67% faster Get/Put operations) ([#3565](https://github.com/redis/go-redis/pull/3565))
|
||||
|
||||
### Metrics & Observability
|
||||
- **Canceled Metric Attribute**: Added 'canceled' metrics attribute to distinguish context cancellation errors from other errors ([#3566](https://github.com/redis/go-redis/pull/3566))
|
||||
|
||||
## ✨ New Features
|
||||
|
||||
- Typed errors with wrapping support ([#3602](https://github.com/redis/go-redis/pull/3602)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- CAS/CAD commands (marked as experimental) ([#3583](https://github.com/redis/go-redis/pull/3583), [#3595](https://github.com/redis/go-redis/pull/3595)) by [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis)
|
||||
- MSETEX command support ([#3580](https://github.com/redis/go-redis/pull/3580)) by [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
- XReadGroup CLAIM argument ([#3578](https://github.com/redis/go-redis/pull/3578)) by [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
- ACL commands: GenPass, Users, WhoAmI ([#3576](https://github.com/redis/go-redis/pull/3576)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- SLOWLOG commands: LEN, RESET ([#3585](https://github.com/redis/go-redis/pull/3585)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- LATENCY commands: LATEST, RESET ([#3584](https://github.com/redis/go-redis/pull/3584)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- Hybrid search command (FT.HYBRID) ([#3573](https://github.com/redis/go-redis/pull/3573)) by [@htemelski-redis](https://github.com/htemelski-redis)
|
||||
- Vector range command (VRANGE) ([#3543](https://github.com/redis/go-redis/pull/3543)) by [@cxljs](https://github.com/cxljs)
|
||||
- Vector-specific attributes in FT.INFO ([#3596](https://github.com/redis/go-redis/pull/3596)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Improved connection pool success rate with FIFO queue ([#3518](https://github.com/redis/go-redis/pull/3518)) by [@cyningsun](https://github.com/cyningsun)
|
||||
- Canceled metrics attribute for context errors ([#3566](https://github.com/redis/go-redis/pull/3566)) by [@pvragov](https://github.com/pvragov)
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- Fixed Failover Client MaintNotificationsConfig ([#3600](https://github.com/redis/go-redis/pull/3600)) by [@ajax16384](https://github.com/ajax16384)
|
||||
- Fixed ACLGenPass function to use the bit parameter ([#3597](https://github.com/redis/go-redis/pull/3597)) by [@destinyoooo](https://github.com/destinyoooo)
|
||||
- Return error instead of panic from commands ([#3568](https://github.com/redis/go-redis/pull/3568)) by [@dragneelfps](https://github.com/dragneelfps)
|
||||
- Safety harness in `joinErrors` to prevent panic ([#3577](https://github.com/redis/go-redis/pull/3577)) by [@manisharma](https://github.com/manisharma)
|
||||
|
||||
## ⚡ Performance
|
||||
|
||||
- Connection state machine with race condition fixes ([#3559](https://github.com/redis/go-redis/pull/3559)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Pool performance improvements: 47-67% faster Get/Put, 33% less memory, 50% fewer allocations ([#3565](https://github.com/redis/go-redis/pull/3565)) by [@ndyakov](https://github.com/ndyakov)
|
||||
|
||||
## 🧪 Testing & Infrastructure
|
||||
|
||||
- Updated to Redis 8.4.0 image ([#3603](https://github.com/redis/go-redis/pull/3603)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Added Redis 8.4-RC1-pre to CI ([#3572](https://github.com/redis/go-redis/pull/3572)) by [@ndyakov](https://github.com/ndyakov)
|
||||
- Refactored tests for idiomatic Go ([#3561](https://github.com/redis/go-redis/pull/3561), [#3562](https://github.com/redis/go-redis/pull/3562), [#3563](https://github.com/redis/go-redis/pull/3563)) by [@12ya](https://github.com/12ya)
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@12ya](https://github.com/12ya), [@ajax16384](https://github.com/ajax16384), [@cxljs](https://github.com/cxljs), [@cyningsun](https://github.com/cyningsun), [@destinyoooo](https://github.com/destinyoooo), [@dragneelfps](https://github.com/dragneelfps), [@htemelski-redis](https://github.com/htemelski-redis), [@manisharma](https://github.com/manisharma), [@ndyakov](https://github.com/ndyakov), [@ofekshenawa](https://github.com/ofekshenawa), [@pvragov](https://github.com/pvragov)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.16.0...v9.17.0
|
||||
|
||||
# 9.16.0 (2025-10-23)
|
||||
|
||||
## 🚀 Highlights
|
||||
|
||||
### Maintenance Notifications Support
|
||||
|
||||
This release introduces comprehensive support for Redis maintenance notifications, enabling applications to handle server maintenance events gracefully. The new `maintnotifications` package provides:
|
||||
|
||||
- **RESP3 Push Notifications**: Full support for Redis RESP3 protocol push notifications
|
||||
- **Connection Handoff**: Automatic connection migration during server maintenance with configurable retry policies and circuit breakers
|
||||
- **Graceful Degradation**: Configurable timeout relaxation during maintenance windows to prevent false failures
|
||||
- **Event-Driven Architecture**: Background workers with on-demand scaling for efficient handoff processing
|
||||
- **Production-Ready**: Comprehensive E2E testing framework and monitoring capabilities
|
||||
|
||||
For detailed usage examples and configuration options, see the [maintenance notifications documentation](maintnotifications/README.md).
|
||||
|
||||
## ✨ New Features
|
||||
|
||||
- **Trace Filtering**: Add support for filtering traces for specific commands, including pipeline operations and dial operations ([#3519](https://github.com/redis/go-redis/pull/3519), [#3550](https://github.com/redis/go-redis/pull/3550))
|
||||
- New `TraceCmdFilter` option to selectively trace commands
|
||||
- Reduces overhead by excluding high-frequency or low-value commands from traces
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- **Pipeline Error Handling**: Fix issue where pipeline repeatedly sets the same error ([#3525](https://github.com/redis/go-redis/pull/3525))
|
||||
- **Connection Pool**: Ensure re-authentication does not interfere with connection handoff operations ([#3547](https://github.com/redis/go-redis/pull/3547))
|
||||
|
||||
## 🔧 Improvements
|
||||
|
||||
- **Hash Commands**: Update hash command implementations ([#3523](https://github.com/redis/go-redis/pull/3523))
|
||||
- **OpenTelemetry**: Use `metric.WithAttributeSet` to avoid unnecessary attribute copying in redisotel ([#3552](https://github.com/redis/go-redis/pull/3552))
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- **Cluster Client**: Add explanation for why `MaxRetries` is disabled for `ClusterClient` ([#3551](https://github.com/redis/go-redis/pull/3551))
|
||||
|
||||
## 🧪 Testing & Infrastructure
|
||||
|
||||
- **E2E Testing**: Upgrade E2E testing framework with improved reliability and coverage ([#3541](https://github.com/redis/go-redis/pull/3541))
|
||||
- **Release Process**: Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
|
||||
|
||||
## 📦 Dependencies
|
||||
|
||||
- Bump `rojopolis/spellcheck-github-actions` from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
|
||||
- Bump `github/codeql-action` from 3 to 4 ([#3544](https://github.com/redis/go-redis/pull/3544))
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@Sovietaced](https://github.com/Sovietaced), [@Udhayarajan](https://github.com/Udhayarajan), [@boekkooi-impossiblecloud](https://github.com/boekkooi-impossiblecloud), [@Pika-Gopher](https://github.com/Pika-Gopher), [@cxljs](https://github.com/cxljs), [@huiyifyj](https://github.com/huiyifyj), [@omid-h70](https://github.com/omid-h70)
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: https://github.com/redis/go-redis/compare/v9.14.0...v9.16.0
|
||||
|
||||
|
||||
# 9.15.0 was accidentally released. Please use version 9.16.0 instead.
|
||||
|
||||
# 9.15.0-beta.3 (2025-09-26)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
# Changes
|
||||
|
||||
- chore: Update hash_commands.go ([#3523](https://github.com/redis/go-redis/pull/3523))
|
||||
|
||||
## 🚀 New Features
|
||||
|
||||
- feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## 🐛 Bug Fixes
|
||||
|
||||
- fix: pipeline repeatedly sets the error ([#3525](https://github.com/redis/go-redis/pull/3525))
|
||||
|
||||
## 🧰 Maintenance
|
||||
|
||||
- chore(deps): bump rojopolis/spellcheck-github-actions from 0.51.0 to 0.52.0 ([#3520](https://github.com/redis/go-redis/pull/3520))
|
||||
- feat(e2e-testing): maintnotifications e2e and refactor ([#3526](https://github.com/redis/go-redis/pull/3526))
|
||||
- feat(tag.sh): Improved resiliency of the release process ([#3530](https://github.com/redis/go-redis/pull/3530))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@cxljs](https://github.com/cxljs), [@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), and [@omid-h70](https://github.com/omid-h70)
|
||||
|
||||
|
||||
# 9.15.0-beta.1 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
This beta release includes a pre-production version of processing push notifications and hitless upgrades.
|
||||
|
||||
### Hitless Upgrades
|
||||
Hitless upgrades is a major new feature that allows for zero-downtime upgrades in Redis clusters.
|
||||
You can find more information in the [Hitless Upgrades documentation](https://github.com/redis/go-redis/tree/master/hitless).
|
||||
|
||||
# Changes
|
||||
|
||||
## 🚀 New Features
|
||||
- [CAE-1088] & [CAE-1072] feat: RESP3 notifications support & Hitless notifications handling ([#3418](https://github.com/redis/go-redis/pull/3418))
|
||||
|
||||
## Contributors
|
||||
We'd like to thank all the contributors who worked on this release!
|
||||
|
||||
[@ndyakov](https://github.com/ndyakov), [@htemelski-redis](https://github.com/htemelski-redis), [@ofekshenawa](https://github.com/ofekshenawa)
|
||||
|
||||
|
||||
# 9.14.0 (2025-09-10)
|
||||
|
||||
## Highlights
|
||||
|
||||
+27
@@ -8,8 +8,12 @@ type ACLCmdable interface {
|
||||
ACLLog(ctx context.Context, count int64) *ACLLogCmd
|
||||
ACLLogReset(ctx context.Context) *StatusCmd
|
||||
|
||||
ACLGenPass(ctx context.Context, bit int) *StringCmd
|
||||
|
||||
ACLSetUser(ctx context.Context, username string, rules ...string) *StatusCmd
|
||||
ACLDelUser(ctx context.Context, username string) *IntCmd
|
||||
ACLUsers(ctx context.Context) *StringSliceCmd
|
||||
ACLWhoAmI(ctx context.Context) *StringCmd
|
||||
ACLList(ctx context.Context) *StringSliceCmd
|
||||
|
||||
ACLCat(ctx context.Context) *StringSliceCmd
|
||||
@@ -65,6 +69,29 @@ func (c cmdable) ACLSetUser(ctx context.Context, username string, rules ...strin
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLGenPass(ctx context.Context, bit int) *StringCmd {
|
||||
args := make([]interface{}, 0, 3)
|
||||
args = append(args, "acl", "genpass")
|
||||
if bit > 0 {
|
||||
args = append(args, bit)
|
||||
}
|
||||
cmd := NewStringCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLUsers(ctx context.Context) *StringSliceCmd {
|
||||
cmd := NewStringSliceCmd(ctx, "acl", "users")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLWhoAmI(ctx context.Context) *StringCmd {
|
||||
cmd := NewStringCmd(ctx, "acl", "whoami")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) ACLList(ctx context.Context) *StringSliceCmd {
|
||||
cmd := NewStringSliceCmd(ctx, "acl", "list")
|
||||
_ = c(ctx, cmd)
|
||||
|
||||
+111
@@ -0,0 +1,111 @@
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand.
|
||||
var ErrInvalidCommand = errors.New("invalid command type")
|
||||
|
||||
// ErrInvalidPool is returned when the pool type is not supported.
|
||||
var ErrInvalidPool = errors.New("invalid pool type")
|
||||
|
||||
// newClientAdapter creates a new client adapter for regular Redis clients.
|
||||
func newClientAdapter(client *baseClient) interfaces.ClientInterface {
|
||||
return &clientAdapter{client: client}
|
||||
}
|
||||
|
||||
// clientAdapter adapts a Redis client to implement interfaces.ClientInterface.
|
||||
type clientAdapter struct {
|
||||
client *baseClient
|
||||
}
|
||||
|
||||
// GetOptions returns the client options.
|
||||
func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface {
|
||||
return &optionsAdapter{options: ca.client.opt}
|
||||
}
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor {
|
||||
return &pushProcessorAdapter{processor: ca.client.pushProcessor}
|
||||
}
|
||||
|
||||
// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface.
|
||||
type optionsAdapter struct {
|
||||
options *Options
|
||||
}
|
||||
|
||||
// GetReadTimeout returns the read timeout.
|
||||
func (oa *optionsAdapter) GetReadTimeout() time.Duration {
|
||||
return oa.options.ReadTimeout
|
||||
}
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
func (oa *optionsAdapter) GetWriteTimeout() time.Duration {
|
||||
return oa.options.WriteTimeout
|
||||
}
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
func (oa *optionsAdapter) GetNetwork() string {
|
||||
return oa.options.Network
|
||||
}
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
func (oa *optionsAdapter) GetAddr() string {
|
||||
return oa.options.Addr
|
||||
}
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
func (oa *optionsAdapter) IsTLSEnabled() bool {
|
||||
return oa.options.TLSConfig != nil
|
||||
}
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
func (oa *optionsAdapter) GetProtocol() int {
|
||||
return oa.options.Protocol
|
||||
}
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
func (oa *optionsAdapter) GetPoolSize() int {
|
||||
return oa.options.PoolSize
|
||||
}
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) {
|
||||
baseDialer := oa.options.NewDialer()
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Extract network and address from the options
|
||||
network := oa.options.Network
|
||||
addr := oa.options.Addr
|
||||
return baseDialer(ctx, network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor.
|
||||
type pushProcessorAdapter struct {
|
||||
processor push.NotificationProcessor
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error {
|
||||
if pushHandler, ok := handler.(push.NotificationHandler); ok {
|
||||
return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected)
|
||||
}
|
||||
return errors.New("handler must implement push.NotificationHandler")
|
||||
}
|
||||
|
||||
// UnregisterHandler removes a handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error {
|
||||
return ppa.processor.UnregisterHandler(pushNotificationName)
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for a specific push notification name.
|
||||
func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} {
|
||||
return ppa.processor.GetHandler(pushNotificationName)
|
||||
}
|
||||
+6
-2
@@ -141,7 +141,9 @@ func (c cmdable) BitPos(ctx context.Context, key string, bit int64, pos ...int64
|
||||
args[3] = pos[0]
|
||||
args[4] = pos[1]
|
||||
default:
|
||||
panic("too many arguments")
|
||||
cmd := NewIntCmd(ctx)
|
||||
cmd.SetErr(errors.New("too many arguments"))
|
||||
return cmd
|
||||
}
|
||||
cmd := NewIntCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
@@ -182,7 +184,9 @@ func (c cmdable) BitFieldRO(ctx context.Context, key string, values ...interface
|
||||
args[0] = "BITFIELD_RO"
|
||||
args[1] = key
|
||||
if len(values)%2 != 0 {
|
||||
panic("BitFieldRO: invalid number of arguments, must be even")
|
||||
c := NewIntSliceCmd(ctx)
|
||||
c.SetErr(errors.New("BitFieldRO: invalid number of arguments, must be even"))
|
||||
return c
|
||||
}
|
||||
for i := 0; i < len(values); i += 2 {
|
||||
args = append(args, "GET", values[i], values[i+1])
|
||||
|
||||
+169
-3
@@ -64,6 +64,7 @@ var keylessCommands = map[string]struct{}{
|
||||
"sync": {},
|
||||
"unsubscribe": {},
|
||||
"unwatch": {},
|
||||
"wait": {},
|
||||
}
|
||||
|
||||
type Cmder interface {
|
||||
@@ -698,6 +699,68 @@ func (cmd *IntCmd) readReply(rd *proto.Reader) (err error) {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// DigestCmd is a command that returns a uint64 xxh3 hash digest.
|
||||
//
|
||||
// This command is specifically designed for the Redis DIGEST command,
|
||||
// which returns the xxh3 hash of a key's value as a hex string.
|
||||
// The hex string is automatically parsed to a uint64 value.
|
||||
//
|
||||
// The digest can be used for optimistic locking with SetIFDEQ, SetIFDNE,
|
||||
// and DelExArgs commands.
|
||||
//
|
||||
// For examples of client-side digest generation and usage patterns, see:
|
||||
// example/digest-optimistic-locking/
|
||||
//
|
||||
// Redis 8.4+. See https://redis.io/commands/digest/
|
||||
type DigestCmd struct {
|
||||
baseCmd
|
||||
|
||||
val uint64
|
||||
}
|
||||
|
||||
var _ Cmder = (*DigestCmd)(nil)
|
||||
|
||||
func NewDigestCmd(ctx context.Context, args ...interface{}) *DigestCmd {
|
||||
return &DigestCmd{
|
||||
baseCmd: baseCmd{
|
||||
ctx: ctx,
|
||||
args: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) SetVal(val uint64) {
|
||||
cmd.val = val
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) Val() uint64 {
|
||||
return cmd.val
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) Result() (uint64, error) {
|
||||
return cmd.val, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) String() string {
|
||||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *DigestCmd) readReply(rd *proto.Reader) (err error) {
|
||||
// Redis DIGEST command returns a hex string (e.g., "a1b2c3d4e5f67890")
|
||||
// We parse it as a uint64 xxh3 hash value
|
||||
var hexStr string
|
||||
hexStr, err = rd.ReadString()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Parse hex string to uint64
|
||||
cmd.val, err = strconv.ParseUint(hexStr, 16, 64)
|
||||
return err
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
type IntSliceCmd struct {
|
||||
baseCmd
|
||||
|
||||
@@ -1585,6 +1648,12 @@ func (cmd *StringStructMapCmd) readReply(rd *proto.Reader) error {
|
||||
type XMessage struct {
|
||||
ID string
|
||||
Values map[string]interface{}
|
||||
// MillisElapsedFromDelivery is the number of milliseconds since the entry was last delivered.
|
||||
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
|
||||
MillisElapsedFromDelivery int64
|
||||
// DeliveredCount is the number of times the entry was delivered.
|
||||
// Only populated when using XREADGROUP with CLAIM argument for claimed entries.
|
||||
DeliveredCount int64
|
||||
}
|
||||
|
||||
type XMessageSliceCmd struct {
|
||||
@@ -1641,10 +1710,16 @@ func readXMessageSlice(rd *proto.Reader) ([]XMessage, error) {
|
||||
}
|
||||
|
||||
func readXMessage(rd *proto.Reader) (XMessage, error) {
|
||||
if err := rd.ReadFixedArrayLen(2); err != nil {
|
||||
// Read array length can be 2 or 4 (with CLAIM metadata)
|
||||
n, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
|
||||
if n != 2 && n != 4 {
|
||||
return XMessage{}, fmt.Errorf("redis: got %d elements in the XMessage array, expected 2 or 4", n)
|
||||
}
|
||||
|
||||
id, err := rd.ReadString()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
@@ -1657,10 +1732,24 @@ func readXMessage(rd *proto.Reader) (XMessage, error) {
|
||||
}
|
||||
}
|
||||
|
||||
return XMessage{
|
||||
msg := XMessage{
|
||||
ID: id,
|
||||
Values: v,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if n == 4 {
|
||||
msg.MillisElapsedFromDelivery, err = rd.ReadInt()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
|
||||
msg.DeliveredCount, err = rd.ReadInt()
|
||||
if err != nil {
|
||||
return XMessage{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func stringInterfaceMapParser(rd *proto.Reader) (map[string]interface{}, error) {
|
||||
@@ -3768,6 +3857,83 @@ func (cmd *SlowLogCmd) readReply(rd *proto.Reader) error {
|
||||
|
||||
//-----------------------------------------------------------------------
|
||||
|
||||
type Latency struct {
|
||||
Name string
|
||||
Time time.Time
|
||||
Latest time.Duration
|
||||
Max time.Duration
|
||||
}
|
||||
|
||||
type LatencyCmd struct {
|
||||
baseCmd
|
||||
val []Latency
|
||||
}
|
||||
|
||||
var _ Cmder = (*LatencyCmd)(nil)
|
||||
|
||||
func NewLatencyCmd(ctx context.Context, args ...interface{}) *LatencyCmd {
|
||||
return &LatencyCmd{
|
||||
baseCmd: baseCmd{
|
||||
ctx: ctx,
|
||||
args: args,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) SetVal(val []Latency) {
|
||||
cmd.val = val
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) Val() []Latency {
|
||||
return cmd.val
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) Result() ([]Latency, error) {
|
||||
return cmd.val, cmd.err
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) String() string {
|
||||
return cmdString(cmd, cmd.val)
|
||||
}
|
||||
|
||||
func (cmd *LatencyCmd) readReply(rd *proto.Reader) error {
|
||||
n, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val = make([]Latency, n)
|
||||
for i := 0; i < len(cmd.val); i++ {
|
||||
nn, err := rd.ReadArrayLen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nn < 3 {
|
||||
return fmt.Errorf("redis: got %d elements in latency get, expected at least 3", nn)
|
||||
}
|
||||
if cmd.val[i].Name, err = rd.ReadString(); err != nil {
|
||||
return err
|
||||
}
|
||||
createdAt, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Time = time.Unix(createdAt, 0)
|
||||
latest, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Latest = time.Duration(latest) * time.Millisecond
|
||||
maximum, err := rd.ReadInt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.val[i].Max = time.Duration(maximum) * time.Millisecond
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------
|
||||
|
||||
type MapStringInterfaceCmd struct {
|
||||
baseCmd
|
||||
|
||||
|
||||
+53
-1
@@ -193,6 +193,7 @@ type Cmdable interface {
|
||||
ClientID(ctx context.Context) *IntCmd
|
||||
ClientUnblock(ctx context.Context, id int64) *IntCmd
|
||||
ClientUnblockWithError(ctx context.Context, id int64) *IntCmd
|
||||
ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd
|
||||
ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd
|
||||
ConfigResetStat(ctx context.Context) *StatusCmd
|
||||
ConfigSet(ctx context.Context, parameter, value string) *StatusCmd
|
||||
@@ -210,9 +211,13 @@ type Cmdable interface {
|
||||
ShutdownNoSave(ctx context.Context) *StatusCmd
|
||||
SlaveOf(ctx context.Context, host, port string) *StatusCmd
|
||||
SlowLogGet(ctx context.Context, num int64) *SlowLogCmd
|
||||
SlowLogLen(ctx context.Context) *IntCmd
|
||||
SlowLogReset(ctx context.Context) *StatusCmd
|
||||
Time(ctx context.Context) *TimeCmd
|
||||
DebugObject(ctx context.Context, key string) *StringCmd
|
||||
MemoryUsage(ctx context.Context, key string, samples ...int) *IntCmd
|
||||
Latency(ctx context.Context) *LatencyCmd
|
||||
LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd
|
||||
|
||||
ModuleLoadex(ctx context.Context, conf *ModuleLoadexConfig) *StringCmd
|
||||
|
||||
@@ -519,6 +524,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ClientMaintNotifications enables or disables maintenance notifications for maintenance upgrades.
|
||||
// When enabled, the client will receive push notifications about Redis maintenance events.
|
||||
func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd {
|
||||
args := []interface{}{"client", "maint_notifications"}
|
||||
if enabled {
|
||||
if endpointType == "" {
|
||||
endpointType = "none"
|
||||
}
|
||||
args = append(args, "on", "moving-endpoint-type", endpointType)
|
||||
} else {
|
||||
args = append(args, "off")
|
||||
}
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
|
||||
func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd {
|
||||
@@ -655,6 +677,34 @@ func (c cmdable) SlowLogGet(ctx context.Context, num int64) *SlowLogCmd {
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SlowLogLen(ctx context.Context) *IntCmd {
|
||||
cmd := NewIntCmd(ctx, "slowlog", "len")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) SlowLogReset(ctx context.Context) *StatusCmd {
|
||||
cmd := NewStatusCmd(ctx, "slowlog", "reset")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) Latency(ctx context.Context) *LatencyCmd {
|
||||
cmd := NewLatencyCmd(ctx, "latency", "latest")
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) LatencyReset(ctx context.Context, events ...interface{}) *StatusCmd {
|
||||
args := make([]interface{}, 2+len(events))
|
||||
args[0] = "latency"
|
||||
args[1] = "reset"
|
||||
copy(args[2:], events)
|
||||
cmd := NewStatusCmd(ctx, args...)
|
||||
_ = c(ctx, cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (c cmdable) Sync(_ context.Context) {
|
||||
panic("not implemented")
|
||||
}
|
||||
@@ -675,7 +725,9 @@ func (c cmdable) MemoryUsage(ctx context.Context, key string, samples ...int) *I
|
||||
args := []interface{}{"memory", "usage", key}
|
||||
if len(samples) > 0 {
|
||||
if len(samples) != 1 {
|
||||
panic("MemoryUsage expects single sample count")
|
||||
cmd := NewIntCmd(ctx)
|
||||
cmd.SetErr(errors.New("MemoryUsage expects single sample count"))
|
||||
return cmd
|
||||
}
|
||||
args = append(args, "SAMPLES", samples[0])
|
||||
}
|
||||
|
||||
+5
-5
@@ -2,7 +2,7 @@
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-standalone
|
||||
environment:
|
||||
@@ -23,7 +23,7 @@ services:
|
||||
- all
|
||||
|
||||
osscluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-osscluster
|
||||
environment:
|
||||
@@ -40,7 +40,7 @@ services:
|
||||
- all
|
||||
|
||||
sentinel-cluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-sentinel-cluster
|
||||
network_mode: "host"
|
||||
@@ -60,7 +60,7 @@ services:
|
||||
- all
|
||||
|
||||
sentinel:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-sentinel
|
||||
depends_on:
|
||||
@@ -84,7 +84,7 @@ services:
|
||||
- all
|
||||
|
||||
ring-cluster:
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.2.1-pre}
|
||||
image: ${CLIENT_LIBS_TEST_IMAGE:-redislabs/client-libs-test:8.4.0}
|
||||
platform: linux/amd64
|
||||
container_name: redis-ring-cluster
|
||||
environment:
|
||||
|
||||
+207
-45
@@ -52,34 +52,82 @@ type Error interface {
|
||||
var _ Error = proto.RedisError("")
|
||||
|
||||
func isContextError(err error) bool {
|
||||
switch err {
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
// Check for wrapped context errors using errors.Is
|
||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// isTimeoutError checks if an error is a timeout error, even if wrapped.
|
||||
// Returns (isTimeout, shouldRetryOnTimeout) where:
|
||||
// - isTimeout: true if the error is any kind of timeout error
|
||||
// - shouldRetryOnTimeout: true if Timeout() method returns true
|
||||
func isTimeoutError(err error) (isTimeout bool, hasTimeoutFlag bool) {
|
||||
// Check for timeoutError interface (works with wrapped errors)
|
||||
var te timeoutError
|
||||
if errors.As(err, &te) {
|
||||
return true, te.Timeout()
|
||||
}
|
||||
|
||||
// Check for net.Error specifically (common case for network timeouts)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
return true, netErr.Timeout()
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func shouldRetry(err error, retryTimeout bool) bool {
|
||||
switch err {
|
||||
case io.EOF, io.ErrUnexpectedEOF:
|
||||
return true
|
||||
case nil, context.Canceled, context.DeadlineExceeded:
|
||||
if err == nil {
|
||||
return false
|
||||
case pool.ErrPoolTimeout:
|
||||
}
|
||||
|
||||
// Check for EOF errors (works with wrapped errors)
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for context errors (works with wrapped errors)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check for pool timeout (works with wrapped errors)
|
||||
if errors.Is(err, pool.ErrPoolTimeout) {
|
||||
// connection pool timeout, increase retries. #3289
|
||||
return true
|
||||
}
|
||||
|
||||
if v, ok := err.(timeoutError); ok {
|
||||
if v.Timeout() {
|
||||
// Check for timeout errors (works with wrapped errors)
|
||||
if isTimeout, hasTimeoutFlag := isTimeoutError(err); isTimeout {
|
||||
if hasTimeoutFlag {
|
||||
return retryTimeout
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for typed Redis errors using errors.As (works with wrapped errors)
|
||||
if proto.IsMaxClientsError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsLoadingError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsReadOnlyError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsMasterDownError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsClusterDownError(err) {
|
||||
return true
|
||||
}
|
||||
if proto.IsTryAgainError(err) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Fallback to string checking for backward compatibility with plain errors
|
||||
s := err.Error()
|
||||
if s == "ERR max number of clients reached" {
|
||||
if strings.HasPrefix(s, "ERR max number of clients reached") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "LOADING ") {
|
||||
@@ -88,29 +136,42 @@ func shouldRetry(err error, retryTimeout bool) bool {
|
||||
if strings.HasPrefix(s, "READONLY ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "CLUSTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "TRYAGAIN ") {
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(s, "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isRedisError(err error) bool {
|
||||
_, ok := err.(proto.RedisError)
|
||||
return ok
|
||||
// Check if error implements the Error interface (works with wrapped errors)
|
||||
var redisErr Error
|
||||
if errors.As(err, &redisErr) {
|
||||
return true
|
||||
}
|
||||
// Also check for proto.RedisError specifically
|
||||
var protoRedisErr proto.RedisError
|
||||
return errors.As(err, &protoRedisErr)
|
||||
}
|
||||
|
||||
func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
switch err {
|
||||
case nil:
|
||||
if err == nil {
|
||||
return false
|
||||
case context.Canceled, context.DeadlineExceeded:
|
||||
}
|
||||
|
||||
// Check for context errors (works with wrapped errors)
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for pool timeout errors (works with wrapped errors)
|
||||
if errors.Is(err, pool.ErrConnUnusableTimeout) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -131,7 +192,9 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
}
|
||||
|
||||
if allowTimeout {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Check for network timeout errors (works with wrapped errors)
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -140,44 +203,143 @@ func isBadConn(err error, allowTimeout bool, addr string) bool {
|
||||
}
|
||||
|
||||
func isMovedError(err error) (moved bool, ask bool, addr string) {
|
||||
if !isRedisError(err) {
|
||||
return
|
||||
// Check for typed MovedError
|
||||
if movedErr, ok := proto.IsMovedError(err); ok {
|
||||
addr = movedErr.Addr()
|
||||
addr = internal.GetAddr(addr)
|
||||
return true, false, addr
|
||||
}
|
||||
|
||||
// Check for typed AskError
|
||||
if askErr, ok := proto.IsAskError(err); ok {
|
||||
addr = askErr.Addr()
|
||||
addr = internal.GetAddr(addr)
|
||||
return false, true, addr
|
||||
}
|
||||
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
switch {
|
||||
case strings.HasPrefix(s, "MOVED "):
|
||||
moved = true
|
||||
case strings.HasPrefix(s, "ASK "):
|
||||
ask = true
|
||||
default:
|
||||
return
|
||||
if strings.HasPrefix(s, "MOVED ") {
|
||||
// Parse: MOVED 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
addr = internal.GetAddr(parts[2])
|
||||
return true, false, addr
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(s, "ASK ") {
|
||||
// Parse: ASK 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
addr = internal.GetAddr(parts[2])
|
||||
return false, true, addr
|
||||
}
|
||||
}
|
||||
|
||||
ind := strings.LastIndex(s, " ")
|
||||
if ind == -1 {
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
addr = s[ind+1:]
|
||||
addr = internal.GetAddr(addr)
|
||||
return
|
||||
return false, false, ""
|
||||
}
|
||||
|
||||
func isLoadingError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "LOADING ")
|
||||
return proto.IsLoadingError(err)
|
||||
}
|
||||
|
||||
func isReadOnlyError(err error) bool {
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
return proto.IsReadOnlyError(err)
|
||||
}
|
||||
|
||||
func isMovedSameConnAddr(err error, addr string) bool {
|
||||
redisError := err.Error()
|
||||
if !strings.HasPrefix(redisError, "MOVED ") {
|
||||
return false
|
||||
if movedErr, ok := proto.IsMovedError(err); ok {
|
||||
return strings.HasSuffix(movedErr.Addr(), addr)
|
||||
}
|
||||
return strings.HasSuffix(redisError, " "+addr)
|
||||
return false
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Typed error checking functions for public use.
|
||||
// These functions work correctly even when errors are wrapped in hooks.
|
||||
|
||||
// IsLoadingError checks if an error is a Redis LOADING error, even if wrapped.
|
||||
// LOADING errors occur when Redis is loading the dataset in memory.
|
||||
func IsLoadingError(err error) bool {
|
||||
return proto.IsLoadingError(err)
|
||||
}
|
||||
|
||||
// IsReadOnlyError checks if an error is a Redis READONLY error, even if wrapped.
|
||||
// READONLY errors occur when trying to write to a read-only replica.
|
||||
func IsReadOnlyError(err error) bool {
|
||||
return proto.IsReadOnlyError(err)
|
||||
}
|
||||
|
||||
// IsClusterDownError checks if an error is a Redis CLUSTERDOWN error, even if wrapped.
|
||||
// CLUSTERDOWN errors occur when the cluster is down.
|
||||
func IsClusterDownError(err error) bool {
|
||||
return proto.IsClusterDownError(err)
|
||||
}
|
||||
|
||||
// IsTryAgainError checks if an error is a Redis TRYAGAIN error, even if wrapped.
|
||||
// TRYAGAIN errors occur when a command cannot be processed and should be retried.
|
||||
func IsTryAgainError(err error) bool {
|
||||
return proto.IsTryAgainError(err)
|
||||
}
|
||||
|
||||
// IsMasterDownError checks if an error is a Redis MASTERDOWN error, even if wrapped.
|
||||
// MASTERDOWN errors occur when the master is down.
|
||||
func IsMasterDownError(err error) bool {
|
||||
return proto.IsMasterDownError(err)
|
||||
}
|
||||
|
||||
// IsMaxClientsError checks if an error is a Redis max clients error, even if wrapped.
|
||||
// This error occurs when the maximum number of clients has been reached.
|
||||
func IsMaxClientsError(err error) bool {
|
||||
return proto.IsMaxClientsError(err)
|
||||
}
|
||||
|
||||
// IsMovedError checks if an error is a Redis MOVED error, even if wrapped.
|
||||
// MOVED errors occur in cluster mode when a key has been moved to a different node.
|
||||
// Returns the address of the node where the key has been moved and a boolean indicating if it's a MOVED error.
|
||||
func IsMovedError(err error) (addr string, ok bool) {
|
||||
if movedErr, isMovedErr := proto.IsMovedError(err); isMovedErr {
|
||||
return movedErr.Addr(), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsAskError checks if an error is a Redis ASK error, even if wrapped.
|
||||
// ASK errors occur in cluster mode when a key is being migrated and the client should ask another node.
|
||||
// Returns the address of the node to ask and a boolean indicating if it's an ASK error.
|
||||
func IsAskError(err error) (addr string, ok bool) {
|
||||
if askErr, isAskErr := proto.IsAskError(err); isAskErr {
|
||||
return askErr.Addr(), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is a Redis authentication error, even if wrapped.
|
||||
// Authentication errors occur when:
|
||||
// - NOAUTH: Redis requires authentication but none was provided
|
||||
// - WRONGPASS: Redis authentication failed due to incorrect password
|
||||
// - unauthenticated: Error returned when password changed
|
||||
func IsAuthError(err error) bool {
|
||||
return proto.IsAuthError(err)
|
||||
}
|
||||
|
||||
// IsPermissionError checks if an error is a Redis permission error, even if wrapped.
|
||||
// Permission errors (NOPERM) occur when a user does not have permission to execute a command.
|
||||
func IsPermissionError(err error) bool {
|
||||
return proto.IsPermissionError(err)
|
||||
}
|
||||
|
||||
// IsExecAbortError checks if an error is a Redis EXECABORT error, even if wrapped.
|
||||
// EXECABORT errors occur when a transaction is aborted.
|
||||
func IsExecAbortError(err error) bool {
|
||||
return proto.IsExecAbortError(err)
|
||||
}
|
||||
|
||||
// IsOOMError checks if an error is a Redis OOM (Out Of Memory) error, even if wrapped.
|
||||
// OOM errors occur when Redis is out of memory.
|
||||
func IsOOMError(err error) bool {
|
||||
return proto.IsOOMError(err)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
+4
-4
@@ -116,16 +116,16 @@ func (c cmdable) HMGet(ctx context.Context, key string, fields ...string) *Slice
|
||||
|
||||
// HSet accepts values in following formats:
|
||||
//
|
||||
// - HSet("myhash", "key1", "value1", "key2", "value2")
|
||||
// - HSet(ctx, "myhash", "key1", "value1", "key2", "value2")
|
||||
//
|
||||
// - HSet("myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
// - HSet(ctx, "myhash", []string{"key1", "value1", "key2", "value2"})
|
||||
//
|
||||
// - HSet("myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
// - HSet(ctx, "myhash", map[string]interface{}{"key1": "value1", "key2": "value2"})
|
||||
//
|
||||
// Playing struct With "redis" tag.
|
||||
// type MyHash struct { Key1 string `redis:"key1"`; Key2 int `redis:"key2"` }
|
||||
//
|
||||
// - HSet("myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
// - HSet(ctx, "myhash", MyHash{"value1", "value2"}) Warn: redis-server >= 4.0
|
||||
//
|
||||
// For struct, can be a structure pointer type, we only parse the field whose tag is redis.
|
||||
// if you don't want the field to be read, you can use the `redis:"-"` flag to ignore it,
|
||||
|
||||
Generated
Vendored
+100
@@ -0,0 +1,100 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ConnReAuthCredentialsListener is a credentials listener for a specific connection
|
||||
// that triggers re-authentication when credentials change.
|
||||
//
|
||||
// This listener implements the auth.CredentialsListener interface and is subscribed
|
||||
// to a StreamingCredentialsProvider. When new credentials are received via OnNext,
|
||||
// it marks the connection for re-authentication through the manager.
|
||||
//
|
||||
// The re-authentication is always performed asynchronously to avoid blocking the
|
||||
// credentials provider and to prevent potential deadlocks with the pool semaphore.
|
||||
// The actual re-auth happens when the connection is returned to the pool in an idle state.
|
||||
//
|
||||
// Lifecycle:
|
||||
// - Created during connection initialization via Manager.Listener()
|
||||
// - Subscribed to the StreamingCredentialsProvider
|
||||
// - Receives credential updates via OnNext()
|
||||
// - Cleaned up when connection is removed from pool via Manager.RemoveListener()
|
||||
type ConnReAuthCredentialsListener struct {
|
||||
// reAuth is the function to re-authenticate the connection with new credentials
|
||||
reAuth func(conn *pool.Conn, credentials auth.Credentials) error
|
||||
|
||||
// onErr is the function to call when re-authentication or acquisition fails
|
||||
onErr func(conn *pool.Conn, err error)
|
||||
|
||||
// conn is the connection this listener is associated with
|
||||
conn *pool.Conn
|
||||
|
||||
// manager is the streaming credentials manager for coordinating re-auth
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// OnNext is called when new credentials are received from the StreamingCredentialsProvider.
|
||||
//
|
||||
// This method marks the connection for asynchronous re-authentication. The actual
|
||||
// re-authentication happens in the background when the connection is returned to the
|
||||
// pool and is in an idle state.
|
||||
//
|
||||
// Asynchronous re-auth is used to:
|
||||
// - Avoid blocking the credentials provider's notification goroutine
|
||||
// - Prevent deadlocks with the pool's semaphore (especially with small pool sizes)
|
||||
// - Ensure re-auth happens when the connection is safe to use (not processing commands)
|
||||
//
|
||||
// The reAuthFn callback receives:
|
||||
// - nil if the connection was successfully acquired for re-auth
|
||||
// - error if acquisition timed out or failed
|
||||
//
|
||||
// Thread-safe: Called by the credentials provider's notification goroutine.
|
||||
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
|
||||
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Always use async reauth to avoid complex pool semaphore issues
|
||||
// The synchronous path can cause deadlocks in the pool's semaphore mechanism
|
||||
// when called from the Subscribe goroutine, especially with small pool sizes.
|
||||
// The connection pool hook will re-authenticate the connection when it is
|
||||
// returned to the pool in a clean, idle state.
|
||||
c.manager.MarkForReAuth(c.conn, func(err error) {
|
||||
// err is from connection acquisition (timeout, etc.)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
// err is from reauth command execution
|
||||
err = c.reAuth(c.conn, credentials)
|
||||
if err != nil {
|
||||
// Log the error
|
||||
c.OnError(err)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// OnError is called when an error occurs during credential streaming or re-authentication.
|
||||
//
|
||||
// This method can be called from:
|
||||
// - The StreamingCredentialsProvider when there's an error in the credentials stream
|
||||
// - The re-auth process when connection acquisition times out
|
||||
// - The re-auth process when the AUTH command fails
|
||||
//
|
||||
// The error is delegated to the onErr callback provided during listener creation.
|
||||
//
|
||||
// Thread-safe: Can be called from multiple goroutines (provider, re-auth worker).
|
||||
func (c *ConnReAuthCredentialsListener) OnError(err error) {
|
||||
if c.onErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.onErr(c.conn, err)
|
||||
}
|
||||
|
||||
// Ensure ConnReAuthCredentialsListener implements the CredentialsListener interface.
|
||||
var _ auth.CredentialsListener = (*ConnReAuthCredentialsListener)(nil)
|
||||
+77
@@ -0,0 +1,77 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
)
|
||||
|
||||
// CredentialsListeners is a thread-safe collection of credentials listeners
|
||||
// indexed by connection ID.
|
||||
//
|
||||
// This collection is used by the Manager to maintain a registry of listeners
|
||||
// for each connection in the pool. Listeners are reused when connections are
|
||||
// reinitialized (e.g., after a handoff) to avoid creating duplicate subscriptions
|
||||
// to the StreamingCredentialsProvider.
|
||||
//
|
||||
// The collection supports concurrent access from multiple goroutines during
|
||||
// connection initialization, credential updates, and connection removal.
|
||||
type CredentialsListeners struct {
|
||||
// listeners maps connection ID to credentials listener
|
||||
listeners map[uint64]auth.CredentialsListener
|
||||
|
||||
// lock protects concurrent access to the listeners map
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCredentialsListeners creates a new thread-safe credentials listeners collection.
|
||||
func NewCredentialsListeners() *CredentialsListeners {
|
||||
return &CredentialsListeners{
|
||||
listeners: make(map[uint64]auth.CredentialsListener),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds or updates a credentials listener for a connection.
|
||||
//
|
||||
// If a listener already exists for the connection ID, it is replaced.
|
||||
// This is safe because the old listener should have been unsubscribed
|
||||
// before the connection was reinitialized.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Add(connID uint64, listener auth.CredentialsListener) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
if c.listeners == nil {
|
||||
c.listeners = make(map[uint64]auth.CredentialsListener)
|
||||
}
|
||||
c.listeners[connID] = listener
|
||||
}
|
||||
|
||||
// Get retrieves the credentials listener for a connection.
|
||||
//
|
||||
// Returns:
|
||||
// - listener: The credentials listener for the connection, or nil if not found
|
||||
// - ok: true if a listener exists for the connection ID, false otherwise
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Get(connID uint64) (auth.CredentialsListener, bool) {
|
||||
c.lock.RLock()
|
||||
defer c.lock.RUnlock()
|
||||
if len(c.listeners) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
listener, ok := c.listeners[connID]
|
||||
return listener, ok
|
||||
}
|
||||
|
||||
// Remove removes the credentials listener for a connection.
|
||||
//
|
||||
// This is called when a connection is removed from the pool to prevent
|
||||
// memory leaks. If no listener exists for the connection ID, this is a no-op.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (c *CredentialsListeners) Remove(connID uint64) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.listeners, connID)
|
||||
}
|
||||
+137
@@ -0,0 +1,137 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// Manager coordinates streaming credentials and re-authentication for a connection pool.
|
||||
//
|
||||
// The manager is responsible for:
|
||||
// - Creating and managing per-connection credentials listeners
|
||||
// - Providing the pool hook for re-authentication
|
||||
// - Coordinating between credentials updates and pool operations
|
||||
//
|
||||
// When credentials change via a StreamingCredentialsProvider:
|
||||
// 1. The credentials listener (ConnReAuthCredentialsListener) receives the update
|
||||
// 2. It calls MarkForReAuth on the manager
|
||||
// 3. The manager delegates to the pool hook
|
||||
// 4. The pool hook schedules background re-authentication
|
||||
//
|
||||
// The manager maintains a registry of credentials listeners indexed by connection ID,
|
||||
// allowing listener reuse when connections are reinitialized (e.g., after handoff).
|
||||
type Manager struct {
|
||||
// credentialsListeners maps connection ID to credentials listener
|
||||
credentialsListeners *CredentialsListeners
|
||||
|
||||
// pool is the connection pool being managed
|
||||
pool pool.Pooler
|
||||
|
||||
// poolHookRef is the re-authentication pool hook
|
||||
poolHookRef *ReAuthPoolHook
|
||||
}
|
||||
|
||||
// NewManager creates a new streaming credentials manager.
|
||||
//
|
||||
// Parameters:
|
||||
// - pl: The connection pool to manage
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The manager creates a ReAuthPoolHook sized to match the pool size, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
|
||||
m := &Manager{
|
||||
pool: pl,
|
||||
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
|
||||
credentialsListeners: NewCredentialsListeners(),
|
||||
}
|
||||
m.poolHookRef.manager = m
|
||||
return m
|
||||
}
|
||||
|
||||
// PoolHook returns the pool hook for re-authentication.
|
||||
//
|
||||
// This hook should be registered with the connection pool to enable
|
||||
// automatic re-authentication when credentials change.
|
||||
func (m *Manager) PoolHook() pool.PoolHook {
|
||||
return m.poolHookRef
|
||||
}
|
||||
|
||||
// Listener returns or creates a credentials listener for a connection.
|
||||
//
|
||||
// This method is called during connection initialization to set up the
|
||||
// credentials listener. If a listener already exists for the connection ID
|
||||
// (e.g., after a handoff), it is reused.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to create/get a listener for
|
||||
// - reAuth: Function to re-authenticate the connection with new credentials
|
||||
// - onErr: Function to call when re-authentication fails
|
||||
//
|
||||
// Returns:
|
||||
// - auth.CredentialsListener: The listener to subscribe to the credentials provider
|
||||
// - error: Non-nil if poolCn is nil
|
||||
//
|
||||
// Note: The reAuth and onErr callbacks are captured once when the listener is
|
||||
// created and reused for the connection's lifetime. They should not change.
|
||||
//
|
||||
// Thread-safe: Can be called concurrently during connection initialization.
|
||||
func (m *Manager) Listener(
|
||||
poolCn *pool.Conn,
|
||||
reAuth func(*pool.Conn, auth.Credentials) error,
|
||||
onErr func(*pool.Conn, error),
|
||||
) (auth.CredentialsListener, error) {
|
||||
if poolCn == nil {
|
||||
return nil, errors.New("poolCn cannot be nil")
|
||||
}
|
||||
connID := poolCn.GetID()
|
||||
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
|
||||
// so we can get the old listener from the cache and use it.
|
||||
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
|
||||
listener, ok := m.credentialsListeners.Get(connID)
|
||||
if !ok || listener == nil {
|
||||
// Create new listener for this connection
|
||||
// Note: Callbacks (reAuth, onErr) are captured once and reused for the connection's lifetime
|
||||
newCredListener := &ConnReAuthCredentialsListener{
|
||||
conn: poolCn,
|
||||
reAuth: reAuth,
|
||||
onErr: onErr,
|
||||
manager: m,
|
||||
}
|
||||
|
||||
m.credentialsListeners.Add(connID, newCredListener)
|
||||
listener = newCredListener
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called by the credentials listener when new credentials are
|
||||
// received. It delegates to the pool hook to schedule background re-authentication.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolCn: The connection to re-authenticate
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Called by credentials listeners when credentials change.
|
||||
func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
|
||||
connID := poolCn.GetID()
|
||||
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
|
||||
}
|
||||
|
||||
// RemoveListener removes the credentials listener for a connection.
|
||||
//
|
||||
// This method is called by the pool hook's OnRemove to clean up listeners
|
||||
// when connections are removed from the pool.
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID whose listener should be removed
|
||||
//
|
||||
// Thread-safe: Called during connection removal.
|
||||
func (m *Manager) RemoveListener(connID uint64) {
|
||||
m.credentialsListeners.Remove(connID)
|
||||
}
|
||||
+241
@@ -0,0 +1,241 @@
|
||||
package streaming
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// ReAuthPoolHook is a pool hook that manages background re-authentication of connections
|
||||
// when credentials change via a streaming credentials provider.
|
||||
//
|
||||
// The hook uses a semaphore-based worker pool to limit concurrent re-authentication
|
||||
// operations and prevent pool exhaustion. When credentials change, connections are
|
||||
// marked for re-authentication and processed asynchronously in the background.
|
||||
//
|
||||
// The re-authentication process:
|
||||
// 1. OnPut: When a connection is returned to the pool, check if it needs re-auth
|
||||
// 2. If yes, schedule it for background processing (move from shouldReAuth to scheduledReAuth)
|
||||
// 3. A worker goroutine acquires the connection (waits until it's not in use)
|
||||
// 4. Executes the re-auth function while holding the connection
|
||||
// 5. Releases the connection back to the pool
|
||||
//
|
||||
// The hook ensures that:
|
||||
// - Only one re-auth operation runs per connection at a time
|
||||
// - Connections are not used for commands during re-authentication
|
||||
// - Re-auth operations timeout if they can't acquire the connection
|
||||
// - Resources are properly cleaned up on connection removal
|
||||
type ReAuthPoolHook struct {
|
||||
// shouldReAuth maps connection ID to re-auth function
|
||||
// Connections in this map need re-authentication but haven't been scheduled yet
|
||||
shouldReAuth map[uint64]func(error)
|
||||
shouldReAuthLock sync.RWMutex
|
||||
|
||||
// workers is a semaphore limiting concurrent re-auth operations
|
||||
// Initialized with poolSize tokens to prevent pool exhaustion
|
||||
// Uses FastSemaphore for better performance with eventual fairness
|
||||
workers *internal.FastSemaphore
|
||||
|
||||
// reAuthTimeout is the maximum time to wait for acquiring a connection for re-auth
|
||||
reAuthTimeout time.Duration
|
||||
|
||||
// scheduledReAuth maps connection ID to scheduled status
|
||||
// Connections in this map have a background worker attempting re-authentication
|
||||
scheduledReAuth map[uint64]bool
|
||||
scheduledLock sync.RWMutex
|
||||
|
||||
// manager is a back-reference for cleanup operations
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// NewReAuthPoolHook creates a new re-authentication pool hook.
|
||||
//
|
||||
// Parameters:
|
||||
// - poolSize: Maximum number of concurrent re-auth operations (typically matches pool size)
|
||||
// - reAuthTimeout: Maximum time to wait for acquiring a connection for re-authentication
|
||||
//
|
||||
// The poolSize parameter is used to initialize the worker semaphore, ensuring that
|
||||
// re-auth operations don't exhaust the connection pool.
|
||||
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
|
||||
return &ReAuthPoolHook{
|
||||
shouldReAuth: make(map[uint64]func(error)),
|
||||
scheduledReAuth: make(map[uint64]bool),
|
||||
workers: internal.NewFastSemaphore(int32(poolSize)),
|
||||
reAuthTimeout: reAuthTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// MarkForReAuth marks a connection for re-authentication.
|
||||
//
|
||||
// This method is called when credentials change and a connection needs to be
|
||||
// re-authenticated. The actual re-authentication happens asynchronously when
|
||||
// the connection is returned to the pool (in OnPut).
|
||||
//
|
||||
// Parameters:
|
||||
// - connID: The connection ID to mark for re-authentication
|
||||
// - reAuthFn: Function to call for re-authentication, receives error if acquisition fails
|
||||
//
|
||||
// Thread-safe: Can be called concurrently from multiple goroutines.
|
||||
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
|
||||
r.shouldReAuthLock.Lock()
|
||||
defer r.shouldReAuthLock.Unlock()
|
||||
r.shouldReAuth[connID] = reAuthFn
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication or has a scheduled
|
||||
// re-auth operation. If so, it rejects the connection (returns accept=false),
|
||||
// causing the pool to try another connection.
|
||||
//
|
||||
// Returns:
|
||||
// - accept: false if connection needs re-auth, true otherwise
|
||||
// - err: always nil (errors are not used in this hook)
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines getting connections.
|
||||
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.RLock()
|
||||
_, shouldReAuth := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
// This connection was marked for reauth while in the pool,
|
||||
// reject the connection
|
||||
if shouldReAuth {
|
||||
// simply reject the connection, it will be re-authenticated in OnPut
|
||||
return false, nil
|
||||
}
|
||||
r.scheduledLock.RLock()
|
||||
_, hasScheduled := r.scheduledReAuth[connID]
|
||||
r.scheduledLock.RUnlock()
|
||||
// has scheduled reauth, reject the connection
|
||||
if hasScheduled {
|
||||
// simply reject the connection, it currently has a reauth scheduled
|
||||
// and the worker is waiting for slot to execute the reauth
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
//
|
||||
// This hook checks if the connection needs re-authentication. If so, it schedules
|
||||
// a background goroutine to perform the re-auth asynchronously. The goroutine:
|
||||
// 1. Waits for a worker slot (semaphore)
|
||||
// 2. Acquires the connection (waits until not in use)
|
||||
// 3. Executes the re-auth function
|
||||
// 4. Releases the connection and worker slot
|
||||
//
|
||||
// The connection is always pooled (not removed) since re-auth happens in background.
|
||||
//
|
||||
// Returns:
|
||||
// - shouldPool: always true (connection stays in pool during background re-auth)
|
||||
// - shouldRemove: always false
|
||||
// - err: always nil
|
||||
//
|
||||
// Thread-safe: Called concurrently by multiple goroutines returning connections.
|
||||
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
|
||||
if conn == nil {
|
||||
// noop
|
||||
return true, false, nil
|
||||
}
|
||||
connID := conn.GetID()
|
||||
// Check if reauth is needed and get the function with proper locking
|
||||
r.shouldReAuthLock.RLock()
|
||||
reAuthFn, ok := r.shouldReAuth[connID]
|
||||
r.shouldReAuthLock.RUnlock()
|
||||
|
||||
if ok {
|
||||
// Acquire both locks to atomically move from shouldReAuth to scheduledReAuth
|
||||
// This prevents race conditions where OnGet might miss the transition
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
r.scheduledReAuth[connID] = true
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
go func() {
|
||||
r.workers.AcquireBlocking()
|
||||
// safety first
|
||||
if conn == nil || (conn != nil && conn.IsClosed()) {
|
||||
r.workers.Release()
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
// once again - safety first
|
||||
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
|
||||
}
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.workers.Release()
|
||||
}()
|
||||
|
||||
// Create timeout context for connection acquisition
|
||||
// This prevents indefinite waiting if the connection is stuck
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.reAuthTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Try to acquire the connection for re-authentication
|
||||
// We need to ensure the connection is IDLE (not IN_USE) before transitioning to UNUSABLE
|
||||
// This prevents re-authentication from interfering with active commands
|
||||
// Use AwaitAndTransition to wait for the connection to become IDLE
|
||||
stateMachine := conn.GetStateMachine()
|
||||
if stateMachine == nil {
|
||||
// No state machine - should not happen, but handle gracefully
|
||||
reAuthFn(pool.ErrConnUnusableTimeout)
|
||||
return
|
||||
}
|
||||
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := stateMachine.AwaitAndTransition(ctx, pool.ValidFromIdle(), pool.StateUnusable)
|
||||
if err != nil {
|
||||
// Timeout or other error occurred, cannot acquire connection
|
||||
reAuthFn(err)
|
||||
return
|
||||
}
|
||||
|
||||
// safety first
|
||||
if !conn.IsClosed() {
|
||||
// Successfully acquired the connection, perform reauth
|
||||
reAuthFn(nil)
|
||||
}
|
||||
|
||||
// Release the connection: transition from UNUSABLE back to IDLE
|
||||
stateMachine.Transition(pool.StateIdle)
|
||||
}()
|
||||
}
|
||||
|
||||
// the reauth will happen in background, as far as the pool is concerned:
|
||||
// pool the connection, don't remove it, no error
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
//
|
||||
// This hook cleans up all state associated with the connection:
|
||||
// - Removes from shouldReAuth map (pending re-auth)
|
||||
// - Removes from scheduledReAuth map (active re-auth)
|
||||
// - Removes credentials listener from manager
|
||||
//
|
||||
// This prevents memory leaks and ensures that removed connections don't have
|
||||
// lingering re-auth operations or listeners.
|
||||
//
|
||||
// Thread-safe: Called when connections are removed due to errors, timeouts, or pool closure.
|
||||
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
|
||||
connID := conn.GetID()
|
||||
r.shouldReAuthLock.Lock()
|
||||
r.scheduledLock.Lock()
|
||||
delete(r.scheduledReAuth, connID)
|
||||
delete(r.shouldReAuth, connID)
|
||||
r.scheduledLock.Unlock()
|
||||
r.shouldReAuthLock.Unlock()
|
||||
if r.manager != nil {
|
||||
r.manager.RemoveListener(connID)
|
||||
}
|
||||
}
|
||||
|
||||
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)
|
||||
+54
@@ -0,0 +1,54 @@
|
||||
// Package interfaces provides shared interfaces used by both the main redis package
|
||||
// and the maintnotifications upgrade package to avoid circular dependencies.
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NotificationProcessor is (most probably) a push.NotificationProcessor
|
||||
// forward declaration to avoid circular imports
|
||||
type NotificationProcessor interface {
|
||||
RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error
|
||||
UnregisterHandler(pushNotificationName string) error
|
||||
GetHandler(pushNotificationName string) interface{}
|
||||
}
|
||||
|
||||
// ClientInterface defines the interface that clients must implement for maintnotifications upgrades.
|
||||
type ClientInterface interface {
|
||||
// GetOptions returns the client options.
|
||||
GetOptions() OptionsInterface
|
||||
|
||||
// GetPushProcessor returns the client's push notification processor.
|
||||
GetPushProcessor() NotificationProcessor
|
||||
}
|
||||
|
||||
// OptionsInterface defines the interface for client options.
|
||||
// Uses an adapter pattern to avoid circular dependencies.
|
||||
type OptionsInterface interface {
|
||||
// GetReadTimeout returns the read timeout.
|
||||
GetReadTimeout() time.Duration
|
||||
|
||||
// GetWriteTimeout returns the write timeout.
|
||||
GetWriteTimeout() time.Duration
|
||||
|
||||
// GetNetwork returns the network type.
|
||||
GetNetwork() string
|
||||
|
||||
// GetAddr returns the connection address.
|
||||
GetAddr() string
|
||||
|
||||
// IsTLSEnabled returns true if TLS is enabled.
|
||||
IsTLSEnabled() bool
|
||||
|
||||
// GetProtocol returns the protocol version.
|
||||
GetProtocol() int
|
||||
|
||||
// GetPoolSize returns the connection pool size.
|
||||
GetPoolSize() int
|
||||
|
||||
// NewDialer returns a new dialer function for the connection.
|
||||
NewDialer() func(context.Context) (net.Conn, error)
|
||||
}
|
||||
+57
-4
@@ -7,20 +7,73 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// TODO (ned): Revisit logging
|
||||
// Add more standardized approach with log levels and configurability
|
||||
|
||||
type Logging interface {
|
||||
Printf(ctx context.Context, format string, v ...interface{})
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
type DefaultLogger struct {
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
func (l *DefaultLogger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
_ = l.log.Output(2, fmt.Sprintf(format, v...))
|
||||
}
|
||||
|
||||
func NewDefaultLogger() Logging {
|
||||
return &DefaultLogger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
}
|
||||
}
|
||||
|
||||
// Logger calls Output to print to the stderr.
|
||||
// Arguments are handled in the manner of fmt.Print.
|
||||
var Logger Logging = &logger{
|
||||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile),
|
||||
var Logger Logging = NewDefaultLogger()
|
||||
|
||||
var LogLevel LogLevelT = LogLevelError
|
||||
|
||||
// LogLevelT represents the logging level
|
||||
type LogLevelT int
|
||||
|
||||
// Log level constants for the entire go-redis library
|
||||
const (
|
||||
LogLevelError LogLevelT = iota // 0 - errors only
|
||||
LogLevelWarn // 1 - warnings and errors
|
||||
LogLevelInfo // 2 - info, warnings, and errors
|
||||
LogLevelDebug // 3 - debug, info, warnings, and errors
|
||||
)
|
||||
|
||||
// String returns the string representation of the log level
|
||||
func (l LogLevelT) String() string {
|
||||
switch l {
|
||||
case LogLevelError:
|
||||
return "ERROR"
|
||||
case LogLevelWarn:
|
||||
return "WARN"
|
||||
case LogLevelInfo:
|
||||
return "INFO"
|
||||
case LogLevelDebug:
|
||||
return "DEBUG"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the log level is valid
|
||||
func (l LogLevelT) IsValid() bool {
|
||||
return l >= LogLevelError && l <= LogLevelDebug
|
||||
}
|
||||
|
||||
func (l LogLevelT) WarnOrAbove() bool {
|
||||
return l >= LogLevelWarn
|
||||
}
|
||||
|
||||
func (l LogLevelT) InfoOrAbove() bool {
|
||||
return l >= LogLevelInfo
|
||||
}
|
||||
|
||||
func (l LogLevelT) DebugOrAbove() bool {
|
||||
return l >= LogLevelDebug
|
||||
}
|
||||
|
||||
Generated
Vendored
+625
@@ -0,0 +1,625 @@
|
||||
package logs
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
)
|
||||
|
||||
// appendJSONIfDebug appends JSON data to a message only if the global log level is Debug
|
||||
func appendJSONIfDebug(message string, data map[string]interface{}) string {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
jsonData, _ := json.Marshal(data)
|
||||
return fmt.Sprintf("%s %s", message, string(jsonData))
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
const (
|
||||
// ========================================
|
||||
// CIRCUIT_BREAKER.GO - Circuit breaker management
|
||||
// ========================================
|
||||
CircuitBreakerTransitioningToHalfOpenMessage = "circuit breaker transitioning to half-open"
|
||||
CircuitBreakerOpenedMessage = "circuit breaker opened"
|
||||
CircuitBreakerReopenedMessage = "circuit breaker reopened"
|
||||
CircuitBreakerClosedMessage = "circuit breaker closed"
|
||||
CircuitBreakerCleanupMessage = "circuit breaker cleanup"
|
||||
CircuitBreakerOpenMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// CONFIG.GO - Configuration and debug
|
||||
// ========================================
|
||||
DebugLoggingEnabledMessage = "debug logging enabled"
|
||||
ConfigDebugMessage = "config debug"
|
||||
|
||||
// ========================================
|
||||
// ERRORS.GO - Error message constants
|
||||
// ========================================
|
||||
InvalidRelaxedTimeoutErrorMessage = "relaxed timeout must be greater than 0"
|
||||
InvalidHandoffTimeoutErrorMessage = "handoff timeout must be greater than 0"
|
||||
InvalidHandoffWorkersErrorMessage = "MaxWorkers must be greater than or equal to 0"
|
||||
InvalidHandoffQueueSizeErrorMessage = "handoff queue size must be greater than 0"
|
||||
InvalidPostHandoffRelaxedDurationErrorMessage = "post-handoff relaxed duration must be greater than or equal to 0"
|
||||
InvalidEndpointTypeErrorMessage = "invalid endpoint type"
|
||||
InvalidMaintNotificationsErrorMessage = "invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')"
|
||||
InvalidHandoffRetriesErrorMessage = "MaxHandoffRetries must be between 1 and 10"
|
||||
InvalidClientErrorMessage = "invalid client type"
|
||||
InvalidNotificationErrorMessage = "invalid notification format"
|
||||
MaxHandoffRetriesReachedErrorMessage = "max handoff retries reached"
|
||||
HandoffQueueFullErrorMessage = "handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration"
|
||||
InvalidCircuitBreakerFailureThresholdErrorMessage = "circuit breaker failure threshold must be >= 1"
|
||||
InvalidCircuitBreakerResetTimeoutErrorMessage = "circuit breaker reset timeout must be >= 0"
|
||||
InvalidCircuitBreakerMaxRequestsErrorMessage = "circuit breaker max requests must be >= 1"
|
||||
ConnectionMarkedForHandoffErrorMessage = "connection marked for handoff"
|
||||
ConnectionInvalidHandoffStateErrorMessage = "connection is in invalid state for handoff"
|
||||
ShutdownErrorMessage = "shutdown"
|
||||
CircuitBreakerOpenErrorMessage = "circuit breaker is open, failing fast"
|
||||
|
||||
// ========================================
|
||||
// EXAMPLE_HOOKS.GO - Example metrics hooks
|
||||
// ========================================
|
||||
MetricsHookProcessingNotificationMessage = "metrics hook processing"
|
||||
MetricsHookRecordedErrorMessage = "metrics hook recorded error"
|
||||
|
||||
// ========================================
|
||||
// HANDOFF_WORKER.GO - Connection handoff processing
|
||||
// ========================================
|
||||
HandoffStartedMessage = "handoff started"
|
||||
HandoffFailedMessage = "handoff failed"
|
||||
ConnectionNotMarkedForHandoffMessage = "is not marked for handoff and has no retries"
|
||||
ConnectionNotMarkedForHandoffErrorMessage = "is not marked for handoff"
|
||||
HandoffRetryAttemptMessage = "Performing handoff"
|
||||
CannotQueueHandoffForRetryMessage = "can't queue handoff for retry"
|
||||
HandoffQueueFullMessage = "handoff queue is full"
|
||||
FailedToDialNewEndpointMessage = "failed to dial new endpoint"
|
||||
ApplyingRelaxedTimeoutDueToPostHandoffMessage = "applying relaxed timeout due to post-handoff"
|
||||
HandoffSuccessMessage = "handoff succeeded"
|
||||
RemovingConnectionFromPoolMessage = "removing connection from pool"
|
||||
NoPoolProvidedMessageCannotRemoveMessage = "no pool provided, cannot remove connection, closing it"
|
||||
WorkerExitingDueToShutdownMessage = "worker exiting due to shutdown"
|
||||
WorkerExitingDueToShutdownWhileProcessingMessage = "worker exiting due to shutdown while processing request"
|
||||
WorkerPanicRecoveredMessage = "worker panic recovered"
|
||||
WorkerExitingDueToInactivityTimeoutMessage = "worker exiting due to inactivity timeout"
|
||||
ReachedMaxHandoffRetriesMessage = "reached max handoff retries"
|
||||
|
||||
// ========================================
|
||||
// MANAGER.GO - Moving operation tracking and handler registration
|
||||
// ========================================
|
||||
DuplicateMovingOperationMessage = "duplicate MOVING operation ignored"
|
||||
TrackingMovingOperationMessage = "tracking MOVING operation"
|
||||
UntrackingMovingOperationMessage = "untracking MOVING operation"
|
||||
OperationNotTrackedMessage = "operation not tracked"
|
||||
FailedToRegisterHandlerMessage = "failed to register handler"
|
||||
|
||||
// ========================================
|
||||
// HOOKS.GO - Notification processing hooks
|
||||
// ========================================
|
||||
ProcessingNotificationMessage = "processing notification started"
|
||||
ProcessingNotificationFailedMessage = "proccessing notification failed"
|
||||
ProcessingNotificationSucceededMessage = "processing notification succeeded"
|
||||
|
||||
// ========================================
|
||||
// POOL_HOOK.GO - Pool connection management
|
||||
// ========================================
|
||||
FailedToQueueHandoffMessage = "failed to queue handoff"
|
||||
MarkedForHandoffMessage = "connection marked for handoff"
|
||||
|
||||
// ========================================
|
||||
// PUSH_NOTIFICATION_HANDLER.GO - Push notification validation and processing
|
||||
// ========================================
|
||||
InvalidNotificationFormatMessage = "invalid notification format"
|
||||
InvalidNotificationTypeFormatMessage = "invalid notification type format"
|
||||
InvalidSeqIDInMovingNotificationMessage = "invalid seqID in MOVING notification"
|
||||
InvalidTimeSInMovingNotificationMessage = "invalid timeS in MOVING notification"
|
||||
InvalidNewEndpointInMovingNotificationMessage = "invalid newEndpoint in MOVING notification"
|
||||
NoConnectionInHandlerContextMessage = "no connection in handler context"
|
||||
InvalidConnectionTypeInHandlerContextMessage = "invalid connection type in handler context"
|
||||
SchedulingHandoffToCurrentEndpointMessage = "scheduling handoff to current endpoint"
|
||||
RelaxedTimeoutDueToNotificationMessage = "applying relaxed timeout due to notification"
|
||||
UnrelaxedTimeoutMessage = "clearing relaxed timeout"
|
||||
ManagerNotInitializedMessage = "manager not initialized"
|
||||
FailedToMarkForHandoffMessage = "failed to mark connection for handoff"
|
||||
|
||||
// ========================================
|
||||
// used in pool/conn
|
||||
// ========================================
|
||||
UnrelaxedTimeoutAfterDeadlineMessage = "clearing relaxed timeout after deadline"
|
||||
)
|
||||
|
||||
func HandoffStarted(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffStartedMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffFailed(connID uint64, newEndpoint string, attempt int, maxAttempts int, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (attempt %d/%d): %v", connID, HandoffFailedMessage, newEndpoint, attempt, maxAttempts, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
"attempt": attempt,
|
||||
"maxAttempts": maxAttempts,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func HandoffSucceeded(connID uint64, newEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s", connID, HandoffSuccessMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": newEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Timeout-related log functions
|
||||
func RelaxedTimeoutDueToNotification(connID uint64, notificationType string, timeout interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s (%v)", connID, RelaxedTimeoutDueToNotificationMessage, notificationType, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeout(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func UnrelaxedTimeoutAfterDeadline(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, UnrelaxedTimeoutAfterDeadlineMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff queue and marking functions
|
||||
func HandoffQueueFull(queueLen, queueCap int) string {
|
||||
message := fmt.Sprintf("%s (%d/%d), cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration", HandoffQueueFullMessage, queueLen, queueCap)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"queueLen": queueLen,
|
||||
"queueCap": queueCap,
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToQueueHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToQueueHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToMarkForHandoff(connID uint64, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s: %v", connID, FailedToMarkForHandoffMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func FailedToDialNewEndpoint(connID uint64, endpoint string, err error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v", connID, FailedToDialNewEndpointMessage, endpoint, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ReachedMaxHandoffRetries(connID uint64, endpoint string, maxRetries int) string {
|
||||
message := fmt.Sprintf("conn[%d] %s to %s (max retries: %d)", connID, ReachedMaxHandoffRetriesMessage, endpoint, maxRetries)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"maxRetries": maxRetries,
|
||||
})
|
||||
}
|
||||
|
||||
// Notification processing functions
|
||||
func ProcessingNotification(connID uint64, seqID int64, notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] seqID[%d] %s %s: %v", connID, seqID, ProcessingNotificationMessage, notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationFailed(connID uint64, notificationType string, err error, notification interface{}) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s: %v - %v", connID, ProcessingNotificationFailedMessage, notificationType, err, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func ProcessingNotificationSucceeded(connID uint64, notificationType string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s %s", connID, ProcessingNotificationSucceededMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
// Moving operation tracking functions
|
||||
func DuplicateMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, DuplicateMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func TrackingMovingOperation(connID uint64, endpoint string, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s seqID[%d]", connID, TrackingMovingOperationMessage, endpoint, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func UntrackingMovingOperation(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, UntrackingMovingOperationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
func OperationNotTracked(connID uint64, seqID int64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s seqID[%d]", connID, OperationNotTrackedMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seqID": seqID,
|
||||
})
|
||||
}
|
||||
|
||||
// Connection pool functions
|
||||
func RemovingConnectionFromPool(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, RemovingConnectionFromPoolMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func NoPoolProvidedCannotRemove(connID uint64, reason error) string {
|
||||
message := fmt.Sprintf("conn[%d] %s due to: %v", connID, NoPoolProvidedMessageCannotRemoveMessage, reason)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"reason": reason.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker functions
|
||||
func CircuitBreakerOpen(connID uint64, endpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s for %s", connID, CircuitBreakerOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Additional handoff functions for specific cases
|
||||
func ConnectionNotMarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func ConnectionNotMarkedForHandoffError(connID uint64) string {
|
||||
return fmt.Sprintf("conn[%d] %s", connID, ConnectionNotMarkedForHandoffErrorMessage)
|
||||
}
|
||||
|
||||
func HandoffRetryAttempt(connID uint64, retries int, newEndpoint string, oldEndpoint string) string {
|
||||
message := fmt.Sprintf("conn[%d] Retry %d: %s to %s(was %s)", connID, retries, HandoffRetryAttemptMessage, newEndpoint, oldEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"retries": retries,
|
||||
"newEndpoint": newEndpoint,
|
||||
"oldEndpoint": oldEndpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CannotQueueHandoffForRetry(err error) string {
|
||||
message := fmt.Sprintf("%s: %v", CannotQueueHandoffForRetryMessage, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Validation and error functions
|
||||
func InvalidNotificationFormat(notification interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationFormatMessage, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNotificationTypeFormat(notificationType interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNotificationTypeFormatMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": fmt.Sprintf("%v", notificationType),
|
||||
})
|
||||
}
|
||||
|
||||
// InvalidNotification creates a log message for invalid notifications of any type
|
||||
func InvalidNotification(notificationType string, notification interface{}) string {
|
||||
message := fmt.Sprintf("invalid %s notification: %v", notificationType, notification)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"notification": fmt.Sprintf("%v", notification),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidSeqIDInMovingNotification(seqID interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidSeqIDInMovingNotificationMessage, seqID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"seqID": fmt.Sprintf("%v", seqID),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidTimeSInMovingNotification(timeS interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidTimeSInMovingNotificationMessage, timeS)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeS": fmt.Sprintf("%v", timeS),
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidNewEndpointInMovingNotification(newEndpoint interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", InvalidNewEndpointInMovingNotificationMessage, newEndpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"newEndpoint": fmt.Sprintf("%v", newEndpoint),
|
||||
})
|
||||
}
|
||||
|
||||
func NoConnectionInHandlerContext(notificationType string) string {
|
||||
message := fmt.Sprintf("%s for %s notification", NoConnectionInHandlerContextMessage, notificationType)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
})
|
||||
}
|
||||
|
||||
func InvalidConnectionTypeInHandlerContext(notificationType string, conn interface{}, handlerCtx interface{}) string {
|
||||
message := fmt.Sprintf("%s for %s notification - %T %#v", InvalidConnectionTypeInHandlerContextMessage, notificationType, conn, handlerCtx)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connType": fmt.Sprintf("%T", conn),
|
||||
})
|
||||
}
|
||||
|
||||
func SchedulingHandoffToCurrentEndpoint(connID uint64, seconds float64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s in %v seconds", connID, SchedulingHandoffToCurrentEndpointMessage, seconds)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"seconds": seconds,
|
||||
})
|
||||
}
|
||||
|
||||
func ManagerNotInitialized() string {
|
||||
return appendJSONIfDebug(ManagerNotInitializedMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func FailedToRegisterHandler(notificationType string, err error) string {
|
||||
message := fmt.Sprintf("%s for %s: %v", FailedToRegisterHandlerMessage, notificationType, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
func ShutdownError() string {
|
||||
return appendJSONIfDebug(ShutdownErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration validation error functions
|
||||
func InvalidRelaxedTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidRelaxedTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffWorkersError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffWorkersErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffQueueSizeError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffQueueSizeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidPostHandoffRelaxedDurationError() string {
|
||||
return appendJSONIfDebug(InvalidPostHandoffRelaxedDurationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidEndpointTypeError() string {
|
||||
return appendJSONIfDebug(InvalidEndpointTypeErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidMaintNotificationsError() string {
|
||||
return appendJSONIfDebug(InvalidMaintNotificationsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidHandoffRetriesError() string {
|
||||
return appendJSONIfDebug(InvalidHandoffRetriesErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidClientError() string {
|
||||
return appendJSONIfDebug(InvalidClientErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidNotificationError() string {
|
||||
return appendJSONIfDebug(InvalidNotificationErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func MaxHandoffRetriesReachedError() string {
|
||||
return appendJSONIfDebug(MaxHandoffRetriesReachedErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func HandoffQueueFullError() string {
|
||||
return appendJSONIfDebug(HandoffQueueFullErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerFailureThresholdError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerFailureThresholdErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerResetTimeoutError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerResetTimeoutErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func InvalidCircuitBreakerMaxRequestsError() string {
|
||||
return appendJSONIfDebug(InvalidCircuitBreakerMaxRequestsErrorMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
// Configuration and debug functions
|
||||
func DebugLoggingEnabled() string {
|
||||
return appendJSONIfDebug(DebugLoggingEnabledMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func ConfigDebug(config interface{}) string {
|
||||
message := fmt.Sprintf("%s: %+v", ConfigDebugMessage, config)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"config": fmt.Sprintf("%+v", config),
|
||||
})
|
||||
}
|
||||
|
||||
// Handoff worker functions
|
||||
func WorkerExitingDueToShutdown() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToShutdownWhileProcessing() string {
|
||||
return appendJSONIfDebug(WorkerExitingDueToShutdownWhileProcessingMessage, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func WorkerPanicRecovered(panicValue interface{}) string {
|
||||
message := fmt.Sprintf("%s: %v", WorkerPanicRecoveredMessage, panicValue)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"panic": fmt.Sprintf("%v", panicValue),
|
||||
})
|
||||
}
|
||||
|
||||
func WorkerExitingDueToInactivityTimeout(timeout interface{}) string {
|
||||
message := fmt.Sprintf("%s (%v)", WorkerExitingDueToInactivityTimeoutMessage, timeout)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
})
|
||||
}
|
||||
|
||||
func ApplyingRelaxedTimeoutDueToPostHandoff(connID uint64, timeout interface{}, until string) string {
|
||||
message := fmt.Sprintf("conn[%d] %s (%v) until %s", connID, ApplyingRelaxedTimeoutDueToPostHandoffMessage, timeout, until)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
"timeout": fmt.Sprintf("%v", timeout),
|
||||
"until": until,
|
||||
})
|
||||
}
|
||||
|
||||
// Example hooks functions
|
||||
func MetricsHookProcessingNotification(notificationType string, connID uint64) string {
|
||||
message := fmt.Sprintf("%s %s notification on conn[%d]", MetricsHookProcessingNotificationMessage, notificationType, connID)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
func MetricsHookRecordedError(notificationType string, connID uint64, err error) string {
|
||||
message := fmt.Sprintf("%s for %s notification on conn[%d]: %v", MetricsHookRecordedErrorMessage, notificationType, connID, err)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"notificationType": notificationType,
|
||||
"connID": connID,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
// Pool hook functions
|
||||
func MarkedForHandoff(connID uint64) string {
|
||||
message := fmt.Sprintf("conn[%d] %s", connID, MarkedForHandoffMessage)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"connID": connID,
|
||||
})
|
||||
}
|
||||
|
||||
// Circuit breaker additional functions
|
||||
func CircuitBreakerTransitioningToHalfOpen(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for %s", CircuitBreakerTransitioningToHalfOpenMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerOpened(endpoint string, failures int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d failures", CircuitBreakerOpenedMessage, endpoint, failures)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"failures": failures,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerReopened(endpoint string) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s due to failure in half-open state", CircuitBreakerReopenedMessage, endpoint)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerClosed(endpoint string, successes int64) string {
|
||||
message := fmt.Sprintf("%s for endpoint %s after %d successful requests", CircuitBreakerClosedMessage, endpoint, successes)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"endpoint": endpoint,
|
||||
"successes": successes,
|
||||
})
|
||||
}
|
||||
|
||||
func CircuitBreakerCleanup(removed int, total int) string {
|
||||
message := fmt.Sprintf("%s removed %d/%d entries", CircuitBreakerCleanupMessage, removed, total)
|
||||
return appendJSONIfDebug(message, map[string]interface{}{
|
||||
"removed": removed,
|
||||
"total": total,
|
||||
})
|
||||
}
|
||||
|
||||
// ExtractDataFromLogMessage extracts structured data from maintnotifications log messages
|
||||
// Returns a map containing the parsed key-value pairs from the structured data section
|
||||
// Example: "conn[123] handoff started to localhost:6379 {"connID":123,"endpoint":"localhost:6379"}"
|
||||
// Returns: map[string]interface{}{"connID": 123, "endpoint": "localhost:6379"}
|
||||
func ExtractDataFromLogMessage(logMessage string) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// Find the JSON data section at the end of the message
|
||||
re := regexp.MustCompile(`(\{.*\})$`)
|
||||
matches := re.FindStringSubmatch(logMessage)
|
||||
if len(matches) < 2 {
|
||||
return result
|
||||
}
|
||||
|
||||
jsonStr := matches[1]
|
||||
if jsonStr == "" {
|
||||
return result
|
||||
}
|
||||
|
||||
// Parse the JSON directly
|
||||
var jsonResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(jsonStr), &jsonResult); err == nil {
|
||||
return jsonResult
|
||||
}
|
||||
|
||||
// If JSON parsing fails, return empty map
|
||||
return result
|
||||
}
|
||||
+789
-21
@@ -1,28 +1,124 @@
|
||||
// Package pool implements the pool management
|
||||
package pool
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
)
|
||||
|
||||
var noDeadline = time.Time{}
|
||||
|
||||
// Preallocated errors for hot paths to avoid allocations
|
||||
var (
|
||||
errAlreadyMarkedForHandoff = errors.New("connection is already marked for handoff")
|
||||
errNotMarkedForHandoff = errors.New("connection was not marked for handoff")
|
||||
errHandoffStateChanged = errors.New("handoff state changed during marking")
|
||||
errConnectionNotAvailable = errors.New("redis: connection not available")
|
||||
errConnNotAvailableForWrite = errors.New("redis: connection not available for write operation")
|
||||
)
|
||||
|
||||
// getCachedTimeNs returns the current time in nanoseconds.
|
||||
// This function previously used a global cache updated by a background goroutine,
|
||||
// but that caused unnecessary CPU usage when the client was idle (ticker waking up
|
||||
// the scheduler every 50ms). We now use time.Now() directly, which is fast enough
|
||||
// on modern systems (vDSO on Linux) and only adds ~1-2% overhead in extreme
|
||||
// high-concurrency benchmarks while eliminating idle CPU usage.
|
||||
func getCachedTimeNs() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
|
||||
// GetCachedTimeNs returns the current time in nanoseconds.
|
||||
// Exported for use by other packages that need fast time access.
|
||||
func GetCachedTimeNs() int64 {
|
||||
return getCachedTimeNs()
|
||||
}
|
||||
|
||||
// Global atomic counter for connection IDs
|
||||
var connIDCounter uint64
|
||||
|
||||
// HandoffState represents the atomic state for connection handoffs
|
||||
// This struct is stored atomically to prevent race conditions between
|
||||
// checking handoff status and reading handoff parameters
|
||||
type HandoffState struct {
|
||||
ShouldHandoff bool // Whether connection should be handed off
|
||||
Endpoint string // New endpoint for handoff
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
}
|
||||
|
||||
// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value
|
||||
type atomicNetConn struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
// generateConnID generates a fast unique identifier for a connection with zero allocations
|
||||
func generateConnID() uint64 {
|
||||
return atomic.AddUint64(&connIDCounter, 1)
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
usedAt int64 // atomic
|
||||
netConn net.Conn
|
||||
// Connection identifier for unique tracking
|
||||
id uint64
|
||||
|
||||
usedAt atomic.Int64
|
||||
lastPutAt atomic.Int64
|
||||
|
||||
// Lock-free netConn access using atomic.Value
|
||||
// Contains *atomicNetConn wrapper, accessed atomically for better performance
|
||||
netConnAtomic atomic.Value // stores *atomicNetConn
|
||||
|
||||
rd *proto.Reader
|
||||
bw *bufio.Writer
|
||||
wr *proto.Writer
|
||||
|
||||
Inited bool
|
||||
// Lightweight mutex to protect reader operations during handoff
|
||||
// Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe
|
||||
readerMu sync.RWMutex
|
||||
|
||||
// State machine for connection state management
|
||||
// Replaces: usable, Inited, used
|
||||
// Provides thread-safe state transitions with FIFO waiting queue
|
||||
// States: CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
stateMachine *ConnStateMachine
|
||||
|
||||
// Handoff metadata - managed separately from state machine
|
||||
// These are atomic for lock-free access during handoff operations
|
||||
handoffStateAtomic atomic.Value // stores *HandoffState
|
||||
handoffRetriesAtomic atomic.Uint32 // retry counter
|
||||
|
||||
pooled bool
|
||||
pubsub bool
|
||||
closed atomic.Bool
|
||||
createdAt time.Time
|
||||
expiresAt time.Time
|
||||
|
||||
// maintenanceNotifications upgrade support: relaxed timeouts during migrations/failovers
|
||||
|
||||
// Using atomic operations for lock-free access to avoid mutex contention
|
||||
relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds
|
||||
relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch
|
||||
|
||||
// Counter to track multiple relaxed timeout setters if we have nested calls
|
||||
// will be decremented when ClearRelaxedTimeout is called or deadline is reached
|
||||
// if counter reaches 0, we clear the relaxed timeouts
|
||||
relaxedCounter atomic.Int32
|
||||
|
||||
// Connection initialization function for reconnections
|
||||
initConnFunc func(context.Context, *Conn) error
|
||||
|
||||
onClose func() error
|
||||
}
|
||||
@@ -32,9 +128,11 @@ func NewConn(netConn net.Conn) *Conn {
|
||||
}
|
||||
|
||||
func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn {
|
||||
now := time.Now()
|
||||
cn := &Conn{
|
||||
netConn: netConn,
|
||||
createdAt: time.Now(),
|
||||
createdAt: now,
|
||||
id: generateConnID(), // Generate unique ID for this connection
|
||||
stateMachine: NewConnStateMachine(),
|
||||
}
|
||||
|
||||
// Use specified buffer sizes, or fall back to 32KiB defaults if 0
|
||||
@@ -50,37 +148,656 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
|
||||
cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize)
|
||||
}
|
||||
|
||||
// Store netConn atomically for lock-free access using wrapper
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
|
||||
cn.wr = proto.NewWriter(cn.bw)
|
||||
cn.SetUsedAt(time.Now())
|
||||
cn.SetUsedAt(now)
|
||||
// Initialize handoff state atomically
|
||||
initialHandoffState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
}
|
||||
cn.handoffStateAtomic.Store(initialHandoffState)
|
||||
return cn
|
||||
}
|
||||
|
||||
func (cn *Conn) UsedAt() time.Time {
|
||||
unix := atomic.LoadInt64(&cn.usedAt)
|
||||
return time.Unix(unix, 0)
|
||||
return time.Unix(0, cn.usedAt.Load())
|
||||
}
|
||||
func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
cn.usedAt.Store(tm.UnixNano())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetUsedAt(tm time.Time) {
|
||||
atomic.StoreInt64(&cn.usedAt, tm.Unix())
|
||||
func (cn *Conn) UsedAtNs() int64 {
|
||||
return cn.usedAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetUsedAtNs(ns int64) {
|
||||
cn.usedAt.Store(ns)
|
||||
}
|
||||
|
||||
func (cn *Conn) LastPutAtNs() int64 {
|
||||
return cn.lastPutAt.Load()
|
||||
}
|
||||
func (cn *Conn) SetLastPutAtNs(ns int64) {
|
||||
cn.lastPutAt.Store(ns)
|
||||
}
|
||||
|
||||
// Backward-compatible wrapper methods for state machine
|
||||
// These maintain the existing API while using the new state machine internally
|
||||
|
||||
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
|
||||
//
|
||||
// This is used by background operations (handoff, re-auth) to acquire exclusive
|
||||
// access to a connection. The operation sets usable to false, preventing the pool
|
||||
// from returning the connection to clients.
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
//
|
||||
// Implementation note: This is a compatibility wrapper around the state machine.
|
||||
// It checks if the current state is "usable" (IDLE or IN_USE) and transitions accordingly.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
|
||||
currentState := cn.stateMachine.GetState()
|
||||
|
||||
// Check if current state matches the "old" usable value
|
||||
currentUsable := (currentState == StateIdle || currentState == StateInUse)
|
||||
if currentUsable != old {
|
||||
return false
|
||||
}
|
||||
|
||||
// If we're trying to set to the same value, succeed immediately
|
||||
if old == new {
|
||||
return true
|
||||
}
|
||||
|
||||
// Transition based on new value
|
||||
if new {
|
||||
// Trying to make usable - transition from UNUSABLE to IDLE
|
||||
// This should only work from UNUSABLE or INITIALIZING states
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromInitializingOrUnusable,
|
||||
StateIdle,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
// Trying to make unusable - transition from IDLE to UNUSABLE
|
||||
// This is typically for acquiring the connection for background operations
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(
|
||||
validFromIdle,
|
||||
StateUnusable,
|
||||
)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
|
||||
//
|
||||
// A connection is "usable" when it's in a stable state and can be returned to clients.
|
||||
// It becomes unusable during:
|
||||
// - Handoff operations (network connection replacement)
|
||||
// - Re-authentication (credential updates)
|
||||
// - Other background operations that need exclusive access
|
||||
//
|
||||
// Note: CREATED state is considered usable because new connections need to pass OnGet() hook
|
||||
// before initialization. The initialization happens after OnGet() in the client code.
|
||||
func (cn *Conn) IsUsable() bool {
|
||||
state := cn.stateMachine.GetState()
|
||||
// CREATED, IDLE, and IN_USE states are considered usable
|
||||
// CREATED: new connection, not yet initialized (will be initialized by client)
|
||||
// IDLE: initialized and ready to be acquired
|
||||
// IN_USE: usable but currently acquired by someone
|
||||
return state == StateCreated || state == StateIdle || state == StateInUse
|
||||
}
|
||||
|
||||
// SetUsable sets the usable flag for the connection (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This should be called to mark a connection as usable after initialization or
|
||||
// to release it after a background operation completes.
|
||||
//
|
||||
// Prefer CompareAndSwapUsable() when acquiring exclusive access to avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsable(usable bool) {
|
||||
if usable {
|
||||
// Transition to IDLE state (ready to be acquired)
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
} else {
|
||||
// Transition to UNUSABLE state (for background operations)
|
||||
cn.stateMachine.Transition(StateUnusable)
|
||||
}
|
||||
}
|
||||
|
||||
// IsInited returns true if the connection has been initialized.
|
||||
// This is a backward-compatible wrapper around the state machine.
|
||||
func (cn *Conn) IsInited() bool {
|
||||
state := cn.stateMachine.GetState()
|
||||
// Connection is initialized if it's in IDLE or any post-initialization state
|
||||
return state != StateCreated && state != StateInitializing && state != StateClosed
|
||||
}
|
||||
|
||||
// Used - State machine based implementation
|
||||
|
||||
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// This is the preferred method for acquiring a connection from the pool, as it
|
||||
// ensures that only one goroutine marks the connection as used.
|
||||
//
|
||||
// Implementation: Uses state machine transitions IDLE ⇄ IN_USE
|
||||
//
|
||||
// Returns true if the swap was successful (old value matched), false otherwise.
|
||||
// Deprecated: Use GetStateMachine().TryTransition() directly for better state management.
|
||||
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
|
||||
if old == new {
|
||||
// No change needed
|
||||
currentState := cn.stateMachine.GetState()
|
||||
currentUsed := (currentState == StateInUse)
|
||||
return currentUsed == old
|
||||
}
|
||||
|
||||
if !old && new {
|
||||
// Acquiring: IDLE → IN_USE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromCreatedOrIdle, StateInUse)
|
||||
return err == nil
|
||||
} else {
|
||||
// Releasing: IN_USE → IDLE
|
||||
// Use predefined slice to avoid allocation
|
||||
_, err := cn.stateMachine.TryTransition(validFromInUse, StateIdle)
|
||||
return err == nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsUsed returns true if the connection is currently in use (lock-free).
|
||||
//
|
||||
// Deprecated: Use GetStateMachine().GetState() == StateInUse directly for better clarity.
|
||||
// This method is kept for backwards compatibility.
|
||||
//
|
||||
// A connection is "used" when it has been retrieved from the pool and is
|
||||
// actively processing a command. Background operations (like re-auth) should
|
||||
// wait until the connection is not used before executing commands.
|
||||
func (cn *Conn) IsUsed() bool {
|
||||
return cn.stateMachine.GetState() == StateInUse
|
||||
}
|
||||
|
||||
// SetUsed sets the used flag for the connection (lock-free).
|
||||
//
|
||||
// This should be called when returning a connection to the pool (set to false)
|
||||
// or when a single-connection pool retrieves its connection (set to true).
|
||||
//
|
||||
// Prefer CompareAndSwapUsed() when acquiring from a multi-connection pool to
|
||||
// avoid race conditions.
|
||||
// Deprecated: Use GetStateMachine().Transition() directly for better state management.
|
||||
func (cn *Conn) SetUsed(val bool) {
|
||||
if val {
|
||||
cn.stateMachine.Transition(StateInUse)
|
||||
} else {
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
}
|
||||
|
||||
// getNetConn returns the current network connection using atomic load (lock-free).
|
||||
// This is the fast path for accessing netConn without mutex overhead.
|
||||
func (cn *Conn) getNetConn() net.Conn {
|
||||
if v := cn.netConnAtomic.Load(); v != nil {
|
||||
if wrapper, ok := v.(*atomicNetConn); ok {
|
||||
return wrapper.conn
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// setNetConn stores the network connection atomically (lock-free).
|
||||
// This is used for the fast path of connection replacement.
|
||||
func (cn *Conn) setNetConn(netConn net.Conn) {
|
||||
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
|
||||
}
|
||||
|
||||
// Handoff state management - atomic access to handoff metadata
|
||||
|
||||
// ShouldHandoff returns true if connection needs handoff (lock-free).
|
||||
func (cn *Conn) ShouldHandoff() bool {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).ShouldHandoff
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetHandoffEndpoint returns the new endpoint for handoff (lock-free).
|
||||
func (cn *Conn) GetHandoffEndpoint() string {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).Endpoint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free).
|
||||
func (cn *Conn) GetMovingSeqID() int64 {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
return v.(*HandoffState).SeqID
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetHandoffInfo returns all handoff information atomically (lock-free).
|
||||
// This method prevents race conditions by returning all handoff state in a single atomic operation.
|
||||
// Returns (shouldHandoff, endpoint, seqID).
|
||||
func (cn *Conn) GetHandoffInfo() (bool, string, int64) {
|
||||
if v := cn.handoffStateAtomic.Load(); v != nil {
|
||||
state := v.(*HandoffState)
|
||||
return state.ShouldHandoff, state.Endpoint, state.SeqID
|
||||
}
|
||||
return false, "", 0
|
||||
}
|
||||
|
||||
// HandoffRetries returns the current handoff retry count (lock-free).
|
||||
func (cn *Conn) HandoffRetries() int {
|
||||
return int(cn.handoffRetriesAtomic.Load())
|
||||
}
|
||||
|
||||
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).
|
||||
func (cn *Conn) IncrementAndGetHandoffRetries(n int) int {
|
||||
return int(cn.handoffRetriesAtomic.Add(uint32(n)))
|
||||
}
|
||||
|
||||
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
|
||||
func (cn *Conn) IsPooled() bool {
|
||||
return cn.pooled
|
||||
}
|
||||
|
||||
// IsPubSub returns true if the connection is used for PubSub.
|
||||
func (cn *Conn) IsPubSub() bool {
|
||||
return cn.pubsub
|
||||
}
|
||||
|
||||
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
|
||||
// These timeouts will be used for all subsequent commands until the deadline expires.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) {
|
||||
cn.relaxedCounter.Add(1)
|
||||
cn.relaxedReadTimeoutNs.Store(int64(readTimeout))
|
||||
cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout))
|
||||
}
|
||||
|
||||
// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline.
|
||||
// After the deadline, timeouts automatically revert to normal values.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) {
|
||||
cn.SetRelaxedTimeout(readTimeout, writeTimeout)
|
||||
cn.relaxedDeadlineNs.Store(deadline.UnixNano())
|
||||
}
|
||||
|
||||
// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior.
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) ClearRelaxedTimeout() {
|
||||
// Atomically decrement counter and check if we should clear
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
if newCount <= 0 && (deadlineNs == 0 || time.Now().UnixNano() >= deadlineNs) {
|
||||
// Use atomic load to get current value for CAS to avoid stale value race
|
||||
current := cn.relaxedCounter.Load()
|
||||
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) clearRelaxedTimeout() {
|
||||
cn.relaxedReadTimeoutNs.Store(0)
|
||||
cn.relaxedWriteTimeoutNs.Store(0)
|
||||
cn.relaxedDeadlineNs.Store(0)
|
||||
cn.relaxedCounter.Store(0)
|
||||
}
|
||||
|
||||
// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection.
|
||||
// This checks both the timeout values and the deadline (if set).
|
||||
// Uses atomic operations for lock-free access.
|
||||
func (cn *Conn) HasRelaxedTimeout() bool {
|
||||
// Fast path: no relaxed timeouts are set
|
||||
if cn.relaxedCounter.Load() <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// If no relaxed timeouts are set, return false
|
||||
if readTimeoutNs <= 0 && writeTimeoutNs <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, relaxed timeouts are active
|
||||
if deadlineNs == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If deadline is set, check if it's still in the future
|
||||
return time.Now().UnixNano() < deadlineNs
|
||||
}
|
||||
|
||||
// getEffectiveReadTimeout returns the timeout to use for read operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration {
|
||||
readTimeoutNs := cn.relaxedReadTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if readTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(readTimeoutNs)
|
||||
}
|
||||
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(readTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// getEffectiveWriteTimeout returns the timeout to use for write operations.
|
||||
// If relaxed timeout is set and not expired, it takes precedence over the provided timeout.
|
||||
// This method automatically clears expired relaxed timeouts using atomic operations.
|
||||
func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration {
|
||||
writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load()
|
||||
|
||||
// Fast path: no relaxed timeout set
|
||||
if writeTimeoutNs <= 0 {
|
||||
return normalTimeout
|
||||
}
|
||||
|
||||
deadlineNs := cn.relaxedDeadlineNs.Load()
|
||||
// If no deadline is set, use relaxed timeout
|
||||
if deadlineNs == 0 {
|
||||
return time.Duration(writeTimeoutNs)
|
||||
}
|
||||
|
||||
// Use cached time to avoid expensive syscall (max 50ms staleness is acceptable for timeout checks)
|
||||
nowNs := getCachedTimeNs()
|
||||
// Check if deadline has passed
|
||||
if nowNs < deadlineNs {
|
||||
// Deadline is in the future, use relaxed timeout
|
||||
return time.Duration(writeTimeoutNs)
|
||||
} else {
|
||||
// Deadline has passed, clear relaxed timeouts atomically and use normal timeout
|
||||
newCount := cn.relaxedCounter.Add(-1)
|
||||
if newCount <= 0 {
|
||||
internal.Logger.Printf(context.Background(), logs.UnrelaxedTimeoutAfterDeadline(cn.GetID()))
|
||||
cn.clearRelaxedTimeout()
|
||||
}
|
||||
return normalTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (cn *Conn) SetOnClose(fn func() error) {
|
||||
cn.onClose = fn
|
||||
}
|
||||
|
||||
// SetInitConnFunc sets the connection initialization function to be called on reconnections.
|
||||
func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) {
|
||||
cn.initConnFunc = fn
|
||||
}
|
||||
|
||||
// ExecuteInitConn runs the stored connection initialization function if available.
|
||||
func (cn *Conn) ExecuteInitConn(ctx context.Context) error {
|
||||
if cn.initConnFunc != nil {
|
||||
return cn.initConnFunc(ctx, cn)
|
||||
}
|
||||
return fmt.Errorf("redis: no initConnFunc set for conn[%d]", cn.GetID())
|
||||
}
|
||||
|
||||
func (cn *Conn) SetNetConn(netConn net.Conn) {
|
||||
cn.netConn = netConn
|
||||
// Store the new connection atomically first (lock-free)
|
||||
cn.setNetConn(netConn)
|
||||
// Protect reader reset operations to avoid data races
|
||||
// Use write lock since we're modifying the reader state
|
||||
cn.readerMu.Lock()
|
||||
cn.rd.Reset(netConn)
|
||||
cn.readerMu.Unlock()
|
||||
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
|
||||
// GetNetConn safely returns the current network connection using atomic load (lock-free).
|
||||
// This method is used by the pool for health checks and provides better performance.
|
||||
func (cn *Conn) GetNetConn() net.Conn {
|
||||
return cn.getNetConn()
|
||||
}
|
||||
|
||||
// SetNetConnAndInitConn replaces the underlying connection and executes the initialization.
|
||||
// This method ensures only one initialization can happen at a time by using atomic state transitions.
|
||||
// If another goroutine is currently initializing, this will wait for it to complete.
|
||||
func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error {
|
||||
// Wait for and transition to INITIALIZING state - this prevents concurrent initializations
|
||||
// Valid from states: CREATED (first init), IDLE (reconnect), UNUSABLE (handoff/reauth)
|
||||
// If another goroutine is initializing, we'll wait for it to finish
|
||||
// if the context has a deadline, use that, otherwise use the connection read (relaxed) timeout
|
||||
// which should be set during handoff. If it is not set, use a 5 second default
|
||||
deadline, ok := ctx.Deadline()
|
||||
if !ok {
|
||||
deadline = time.Now().Add(cn.getEffectiveReadTimeout(5 * time.Second))
|
||||
}
|
||||
waitCtx, cancel := context.WithDeadline(ctx, deadline)
|
||||
defer cancel()
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.AwaitAndTransition(
|
||||
waitCtx,
|
||||
validFromCreatedIdleOrUnusable,
|
||||
StateInitializing,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot initialize connection from state %s: %w", finalState, err)
|
||||
}
|
||||
|
||||
// Replace the underlying connection
|
||||
cn.SetNetConn(netConn)
|
||||
|
||||
// Execute initialization
|
||||
// NOTE: ExecuteInitConn (via baseClient.initConn) will transition to IDLE on success
|
||||
// or CLOSED on failure. We don't need to do it here.
|
||||
// NOTE: Initconn returns conn in IDLE state
|
||||
initErr := cn.ExecuteInitConn(ctx)
|
||||
if initErr != nil {
|
||||
// ExecuteInitConn already transitioned to CLOSED, just return the error
|
||||
return initErr
|
||||
}
|
||||
|
||||
// ExecuteInitConn already transitioned to IDLE
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkForHandoff marks the connection for handoff due to MOVING notification.
|
||||
// Returns an error if the connection is already marked for handoff.
|
||||
// Note: This only sets metadata - the connection state is not changed until OnPut.
|
||||
// This allows the current user to finish using the connection before handoff.
|
||||
func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error {
|
||||
// Check if already marked for handoff
|
||||
if cn.ShouldHandoff() {
|
||||
return errAlreadyMarkedForHandoff
|
||||
}
|
||||
|
||||
// Set handoff metadata atomically
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: true,
|
||||
Endpoint: newEndpoint,
|
||||
SeqID: seqID,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkQueuedForHandoff marks the connection as queued for handoff processing.
|
||||
// This makes the connection unusable until handoff completes.
|
||||
// This is called from OnPut hook, where the connection is typically in IN_USE state.
|
||||
// The pool will preserve the UNUSABLE state and not overwrite it with IDLE.
|
||||
func (cn *Conn) MarkQueuedForHandoff() error {
|
||||
// Get current handoff state
|
||||
currentState := cn.handoffStateAtomic.Load()
|
||||
if currentState == nil {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
state := currentState.(*HandoffState)
|
||||
if !state.ShouldHandoff {
|
||||
return errNotMarkedForHandoff
|
||||
}
|
||||
|
||||
// Create new state with ShouldHandoff=false but preserve endpoint and seqID
|
||||
// This prevents the connection from being queued multiple times while still
|
||||
// allowing the worker to access the handoff metadata
|
||||
newState := &HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: state.Endpoint, // Preserve endpoint for handoff processing
|
||||
SeqID: state.SeqID, // Preserve seqID for handoff processing
|
||||
}
|
||||
|
||||
// Atomic compare-and-swap to update state
|
||||
if !cn.handoffStateAtomic.CompareAndSwap(currentState, newState) {
|
||||
// State changed between load and CAS - retry or return error
|
||||
return errHandoffStateChanged
|
||||
}
|
||||
|
||||
// Transition to UNUSABLE from IN_USE (normal flow), IDLE (edge cases), or CREATED (tests/uninitialized)
|
||||
// The connection is typically in IN_USE state when OnPut is called (normal Put flow)
|
||||
// But in some edge cases or tests, it might be in IDLE or CREATED state
|
||||
// The pool will detect this state change and preserve it (not overwrite with IDLE)
|
||||
// Use predefined slice to avoid allocation
|
||||
finalState, err := cn.stateMachine.TryTransition(validFromCreatedInUseOrIdle, StateUnusable)
|
||||
if err != nil {
|
||||
// Check if already in UNUSABLE state (race condition or retry)
|
||||
// ShouldHandoff should be false now, but check just in case
|
||||
if finalState == StateUnusable && !cn.ShouldHandoff() {
|
||||
// Already unusable - this is fine, keep the new handoff state
|
||||
return nil
|
||||
}
|
||||
// Restore the original state if transition fails for other reasons
|
||||
cn.handoffStateAtomic.Store(currentState)
|
||||
return fmt.Errorf("failed to mark connection as unusable: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetID returns the unique identifier for this connection.
|
||||
func (cn *Conn) GetID() uint64 {
|
||||
return cn.id
|
||||
}
|
||||
|
||||
// GetStateMachine returns the connection's state machine for advanced state management.
|
||||
// This is primarily used by internal packages like maintnotifications for handoff processing.
|
||||
func (cn *Conn) GetStateMachine() *ConnStateMachine {
|
||||
return cn.stateMachine
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire the connection for use.
|
||||
// This is an optimized inline method for the hot path (Get operation).
|
||||
//
|
||||
// It tries to transition from IDLE -> IN_USE or CREATED -> CREATED.
|
||||
// Returns true if the connection was successfully acquired, false otherwise.
|
||||
// The CREATED->CREATED is done so we can keep the state correct for later
|
||||
// initialization of the connection in initConn.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast()
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// The IDLE->IN_USE and CREATED->CREATED transitions don't need
|
||||
// waiter notification, and benchmarks show 1-3% improvement. If the state machine ever
|
||||
// needs to notify waiters on these transitions, update this to use TryTransitionFast().
|
||||
func (cn *Conn) TryAcquire() bool {
|
||||
// The || operator short-circuits, so only 1 CAS in the common case
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateIdle), uint32(StateInUse)) ||
|
||||
cn.stateMachine.state.CompareAndSwap(uint32(StateCreated), uint32(StateCreated))
|
||||
}
|
||||
|
||||
// Release releases the connection back to the pool.
|
||||
// This is an optimized inline method for the hot path (Put operation).
|
||||
//
|
||||
// It tries to transition from IN_USE -> IDLE.
|
||||
// Returns true if the connection was successfully released, false otherwise.
|
||||
//
|
||||
// Performance: This is faster than calling GetStateMachine() + TryTransitionFast().
|
||||
//
|
||||
// NOTE: We directly access cn.stateMachine.state here instead of using the state machine's
|
||||
// methods. This breaks encapsulation but is necessary for performance.
|
||||
// If the state machine ever needs to notify waiters
|
||||
// on this transition, update this to use TryTransitionFast().
|
||||
func (cn *Conn) Release() bool {
|
||||
// Inline the hot path - single CAS operation
|
||||
return cn.stateMachine.state.CompareAndSwap(uint32(StateInUse), uint32(StateIdle))
|
||||
}
|
||||
|
||||
// ClearHandoffState clears the handoff state after successful handoff.
|
||||
// Makes the connection usable again.
|
||||
func (cn *Conn) ClearHandoffState() {
|
||||
// Clear handoff metadata
|
||||
cn.handoffStateAtomic.Store(&HandoffState{
|
||||
ShouldHandoff: false,
|
||||
Endpoint: "",
|
||||
SeqID: 0,
|
||||
})
|
||||
|
||||
// Reset retry counter
|
||||
cn.handoffRetriesAtomic.Store(0)
|
||||
|
||||
// Mark connection as usable again
|
||||
// Use state machine directly instead of deprecated SetUsable
|
||||
// probably done by initConn
|
||||
cn.stateMachine.Transition(StateIdle)
|
||||
}
|
||||
|
||||
// HasBufferedData safely checks if the connection has buffered data.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) HasBufferedData() bool {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
return cn.rd.Buffered() > 0
|
||||
}
|
||||
|
||||
// PeekReplyTypeSafe safely peeks at the reply type.
|
||||
// This method is used to avoid data races when checking for push notifications.
|
||||
func (cn *Conn) PeekReplyTypeSafe() (byte, error) {
|
||||
// Use read lock for concurrent access to reader state
|
||||
cn.readerMu.RLock()
|
||||
defer cn.readerMu.RUnlock()
|
||||
|
||||
if cn.rd.Buffered() <= 0 {
|
||||
return 0, fmt.Errorf("redis: can't peek reply type, no data available")
|
||||
}
|
||||
return cn.rd.PeekReplyType()
|
||||
}
|
||||
|
||||
func (cn *Conn) Write(b []byte) (int, error) {
|
||||
return cn.netConn.Write(b)
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Write(b)
|
||||
}
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) RemoteAddr() net.Addr {
|
||||
if cn.netConn != nil {
|
||||
return cn.netConn.RemoteAddr()
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -89,7 +806,16 @@ func (cn *Conn) WithReader(
|
||||
ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveReadTimeout(timeout)
|
||||
|
||||
// Get the connection directly from atomic storage
|
||||
netConn := cn.getNetConn()
|
||||
if netConn == nil {
|
||||
return errConnectionNotAvailable
|
||||
}
|
||||
|
||||
if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -100,13 +826,25 @@ func (cn *Conn) WithWriter(
|
||||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error,
|
||||
) error {
|
||||
if timeout >= 0 {
|
||||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil {
|
||||
return err
|
||||
// Use relaxed timeout if set, otherwise use provided timeout
|
||||
effectiveTimeout := cn.getEffectiveWriteTimeout(timeout)
|
||||
|
||||
// Set write deadline on the connection
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Connection is not available - return preallocated error
|
||||
return errConnNotAvailableForWrite
|
||||
}
|
||||
}
|
||||
|
||||
// Reset the buffered writer if needed, should not happen
|
||||
if cn.bw.Buffered() > 0 {
|
||||
cn.bw.Reset(cn.netConn)
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
cn.bw.Reset(netConn)
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(cn.wr); err != nil {
|
||||
@@ -116,17 +854,47 @@ func (cn *Conn) WithWriter(
|
||||
return cn.bw.Flush()
|
||||
}
|
||||
|
||||
func (cn *Conn) IsClosed() bool {
|
||||
return cn.closed.Load() || cn.stateMachine.GetState() == StateClosed
|
||||
}
|
||||
|
||||
func (cn *Conn) Close() error {
|
||||
cn.closed.Store(true)
|
||||
|
||||
// Transition to CLOSED state
|
||||
cn.stateMachine.Transition(StateClosed)
|
||||
|
||||
if cn.onClose != nil {
|
||||
// ignore error
|
||||
_ = cn.onClose()
|
||||
}
|
||||
return cn.netConn.Close()
|
||||
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return netConn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MaybeHasData tries to peek at the next byte in the socket without consuming it
|
||||
// This is used to check if there are push notifications available
|
||||
// Important: This will work on Linux, but not on Windows
|
||||
func (cn *Conn) MaybeHasData() bool {
|
||||
// Lock-free netConn access for better performance
|
||||
if netConn := cn.getNetConn(); netConn != nil {
|
||||
return maybeHasData(netConn)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deadline computes the effective deadline time based on context and timeout.
|
||||
// It updates the usedAt timestamp to now.
|
||||
// Uses cached time to avoid expensive syscall (max 50ms staleness is acceptable for deadline calculation).
|
||||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
|
||||
tm := time.Now()
|
||||
cn.SetUsedAt(tm)
|
||||
// Use cached time for deadline calculation (called 2x per command: read + write)
|
||||
nowNs := getCachedTimeNs()
|
||||
cn.SetUsedAtNs(nowNs)
|
||||
tm := time.Unix(0, nowNs)
|
||||
|
||||
if timeout > 0 {
|
||||
tm = tm.Add(timeout)
|
||||
|
||||
+11
-1
@@ -12,6 +12,9 @@ import (
|
||||
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
// connCheck checks if the connection is still alive and if there is data in the socket
|
||||
// it will try to peek at the next byte without consuming it since we may want to work with it
|
||||
// later on (e.g. push notifications)
|
||||
func connCheck(conn net.Conn) error {
|
||||
// Reset previous timeout.
|
||||
_ = conn.SetDeadline(time.Time{})
|
||||
@@ -29,7 +32,9 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
if err := rawConn.Read(func(fd uintptr) bool {
|
||||
var buf [1]byte
|
||||
n, err := syscall.Read(int(fd), buf[:])
|
||||
// Use MSG_PEEK to peek at data without consuming it
|
||||
n, _, err := syscall.Recvfrom(int(fd), buf[:], syscall.MSG_PEEK|syscall.MSG_DONTWAIT)
|
||||
|
||||
switch {
|
||||
case n == 0 && err == nil:
|
||||
sysErr = io.EOF
|
||||
@@ -47,3 +52,8 @@ func connCheck(conn net.Conn) error {
|
||||
|
||||
return sysErr
|
||||
}
|
||||
|
||||
// maybeHasData checks if there is data in the socket without consuming it
|
||||
func maybeHasData(conn net.Conn) bool {
|
||||
return connCheck(conn) == errUnexpectedRead
|
||||
}
|
||||
|
||||
+13
-2
@@ -2,8 +2,19 @@
|
||||
|
||||
package pool
|
||||
|
||||
import "net"
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func connCheck(conn net.Conn) error {
|
||||
// errUnexpectedRead is placeholder error variable for non-unix build constraints
|
||||
var errUnexpectedRead = errors.New("unexpected read from socket")
|
||||
|
||||
func connCheck(_ net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// since we can't check for data on the socket, we just assume there is some
|
||||
func maybeHasData(_ net.Conn) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
+343
@@ -0,0 +1,343 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// ConnState represents the connection state in the state machine.
|
||||
// States are designed to be lightweight and fast to check.
|
||||
//
|
||||
// State Transitions:
|
||||
//
|
||||
// CREATED → INITIALIZING → IDLE ⇄ IN_USE
|
||||
// ↓
|
||||
// UNUSABLE (handoff/reauth)
|
||||
// ↓
|
||||
// IDLE/CLOSED
|
||||
type ConnState uint32
|
||||
|
||||
const (
|
||||
// StateCreated - Connection just created, not yet initialized
|
||||
StateCreated ConnState = iota
|
||||
|
||||
// StateInitializing - Connection initialization in progress
|
||||
StateInitializing
|
||||
|
||||
// StateIdle - Connection initialized and idle in pool, ready to be acquired
|
||||
StateIdle
|
||||
|
||||
// StateInUse - Connection actively processing a command (retrieved from pool)
|
||||
StateInUse
|
||||
|
||||
// StateUnusable - Connection temporarily unusable due to background operation
|
||||
// (handoff, reauth, etc.). Cannot be acquired from pool.
|
||||
StateUnusable
|
||||
|
||||
// StateClosed - Connection closed
|
||||
StateClosed
|
||||
)
|
||||
|
||||
// Predefined state slices to avoid allocations in hot paths
|
||||
var (
|
||||
validFromInUse = []ConnState{StateInUse}
|
||||
validFromCreatedOrIdle = []ConnState{StateCreated, StateIdle}
|
||||
validFromCreatedInUseOrIdle = []ConnState{StateCreated, StateInUse, StateIdle}
|
||||
// For AwaitAndTransition calls
|
||||
validFromCreatedIdleOrUnusable = []ConnState{StateCreated, StateIdle, StateUnusable}
|
||||
validFromIdle = []ConnState{StateIdle}
|
||||
// For CompareAndSwapUsable
|
||||
validFromInitializingOrUnusable = []ConnState{StateInitializing, StateUnusable}
|
||||
)
|
||||
|
||||
// Accessor functions for predefined slices to avoid allocations in external packages
|
||||
// These return the same slice instance, so they're zero-allocation
|
||||
|
||||
// ValidFromIdle returns a predefined slice containing only StateIdle.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromIdle() []ConnState {
|
||||
return validFromIdle
|
||||
}
|
||||
|
||||
// ValidFromCreatedIdleOrUnusable returns a predefined slice for initialization transitions.
|
||||
// Use this to avoid allocations when calling AwaitAndTransition or TryTransition.
|
||||
func ValidFromCreatedIdleOrUnusable() []ConnState {
|
||||
return validFromCreatedIdleOrUnusable
|
||||
}
|
||||
|
||||
// String returns a human-readable string representation of the state.
|
||||
func (s ConnState) String() string {
|
||||
switch s {
|
||||
case StateCreated:
|
||||
return "CREATED"
|
||||
case StateInitializing:
|
||||
return "INITIALIZING"
|
||||
case StateIdle:
|
||||
return "IDLE"
|
||||
case StateInUse:
|
||||
return "IN_USE"
|
||||
case StateUnusable:
|
||||
return "UNUSABLE"
|
||||
case StateClosed:
|
||||
return "CLOSED"
|
||||
default:
|
||||
return fmt.Sprintf("UNKNOWN(%d)", s)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrInvalidStateTransition is returned when a state transition is not allowed
|
||||
ErrInvalidStateTransition = errors.New("invalid state transition")
|
||||
|
||||
// ErrStateMachineClosed is returned when operating on a closed state machine
|
||||
ErrStateMachineClosed = errors.New("state machine is closed")
|
||||
|
||||
// ErrTimeout is returned when a state transition times out
|
||||
ErrTimeout = errors.New("state transition timeout")
|
||||
)
|
||||
|
||||
// waiter represents a goroutine waiting for a state transition.
|
||||
// Designed for minimal allocations and fast processing.
|
||||
type waiter struct {
|
||||
validStates map[ConnState]struct{} // States we're waiting for
|
||||
targetState ConnState // State to transition to
|
||||
done chan error // Signaled when transition completes or times out
|
||||
}
|
||||
|
||||
// ConnStateMachine manages connection state transitions with FIFO waiting queue.
|
||||
// Optimized for:
|
||||
// - Lock-free reads (hot path)
|
||||
// - Minimal allocations
|
||||
// - Fast state transitions
|
||||
// - FIFO fairness for waiters
|
||||
// Note: Handoff metadata (endpoint, seqID, retries) is managed separately in the Conn struct.
|
||||
type ConnStateMachine struct {
|
||||
// Current state - atomic for lock-free reads
|
||||
state atomic.Uint32
|
||||
|
||||
// FIFO queue for waiters - only locked during waiter add/remove/notify
|
||||
mu sync.Mutex
|
||||
waiters *list.List // List of *waiter
|
||||
waiterCount atomic.Int32 // Fast lock-free check for waiters (avoids mutex in hot path)
|
||||
}
|
||||
|
||||
// NewConnStateMachine creates a new connection state machine.
|
||||
// Initial state is StateCreated.
|
||||
func NewConnStateMachine() *ConnStateMachine {
|
||||
sm := &ConnStateMachine{
|
||||
waiters: list.New(),
|
||||
}
|
||||
sm.state.Store(uint32(StateCreated))
|
||||
return sm
|
||||
}
|
||||
|
||||
// GetState returns the current state (lock-free read).
|
||||
// This is the hot path - optimized for zero allocations and minimal overhead.
|
||||
// Note: Zero allocations applies to state reads; converting the returned state to a string
|
||||
// (via String()) may allocate if the state is unknown.
|
||||
func (sm *ConnStateMachine) GetState() ConnState {
|
||||
return ConnState(sm.state.Load())
|
||||
}
|
||||
|
||||
// TryTransitionFast is an optimized version for the hot path (Get/Put operations).
|
||||
// It only handles simple state transitions without waiter notification.
|
||||
// This is safe because:
|
||||
// 1. Get/Put don't need to wait for state changes
|
||||
// 2. Background operations (handoff/reauth) use UNUSABLE state, which this won't match
|
||||
// 3. If a background operation is in progress (state is UNUSABLE), this fails fast
|
||||
//
|
||||
// Returns true if transition succeeded, false otherwise.
|
||||
// Use this for performance-critical paths where you don't need error details.
|
||||
//
|
||||
// Performance: Single CAS operation - as fast as the old atomic bool!
|
||||
// For multiple from states, use: sm.TryTransitionFast(State1, Target) || sm.TryTransitionFast(State2, Target)
|
||||
// The || operator short-circuits, so only 1 CAS is executed in the common case.
|
||||
func (sm *ConnStateMachine) TryTransitionFast(fromState, targetState ConnState) bool {
|
||||
return sm.state.CompareAndSwap(uint32(fromState), uint32(targetState))
|
||||
}
|
||||
|
||||
// TryTransition attempts an immediate state transition without waiting.
|
||||
// Returns the current state after the transition attempt and an error if the transition failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// This is faster than AwaitAndTransition when you don't need to wait.
|
||||
// Uses compare-and-swap to atomically transition, preventing concurrent transitions.
|
||||
// This method does NOT wait - it fails immediately if the transition cannot be performed.
|
||||
//
|
||||
// Performance: Zero allocations on success path (hot path).
|
||||
func (sm *ConnStateMachine) TryTransition(validFromStates []ConnState, targetState ConnState) (ConnState, error) {
|
||||
// Try each valid from state with CAS
|
||||
// This ensures only ONE goroutine can successfully transition at a time
|
||||
for _, fromState := range validFromStates {
|
||||
// Try to atomically swap from fromState to targetState
|
||||
// If successful, we won the race and can proceed
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
// Hot path optimization: only check for waiters if transition succeeded
|
||||
// This avoids atomic load on every Get/Put when no waiters exist
|
||||
if sm.waiterCount.Load() > 0 {
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
|
||||
// All CAS attempts failed - state is not valid for this transition
|
||||
// Return the current state so caller can decide what to do
|
||||
// Note: This error path allocates, but it's the exceptional case
|
||||
currentState := sm.GetState()
|
||||
return currentState, fmt.Errorf("%w: cannot transition from %s to %s (valid from: %v)",
|
||||
ErrInvalidStateTransition, currentState, targetState, validFromStates)
|
||||
}
|
||||
|
||||
// Transition unconditionally transitions to the target state.
|
||||
// Use with caution - prefer AwaitAndTransition or TryTransition for safety.
|
||||
// This is useful for error paths or when you know the transition is valid.
|
||||
func (sm *ConnStateMachine) Transition(targetState ConnState) {
|
||||
sm.state.Store(uint32(targetState))
|
||||
sm.notifyWaiters()
|
||||
}
|
||||
|
||||
// AwaitAndTransition waits for the connection to reach one of the valid states,
|
||||
// then atomically transitions to the target state.
|
||||
// Returns the current state after the transition attempt and an error if the operation failed.
|
||||
// The returned state is the CURRENT state (after the attempt), not the previous state.
|
||||
// Returns error if timeout expires or context is cancelled.
|
||||
//
|
||||
// This method implements FIFO fairness - the first caller to wait gets priority
|
||||
// when the state becomes available.
|
||||
//
|
||||
// Performance notes:
|
||||
// - If already in a valid state, this is very fast (no allocation, no waiting)
|
||||
// - If waiting is required, allocates one waiter struct and one channel
|
||||
func (sm *ConnStateMachine) AwaitAndTransition(
|
||||
ctx context.Context,
|
||||
validFromStates []ConnState,
|
||||
targetState ConnState,
|
||||
) (ConnState, error) {
|
||||
// Fast path: try immediate transition with CAS to prevent race conditions
|
||||
// BUT: only if there are no waiters in the queue (to maintain FIFO ordering)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
for _, fromState := range validFromStates {
|
||||
// Check if we're already in target state
|
||||
if fromState == targetState && sm.GetState() == targetState {
|
||||
return targetState, nil
|
||||
}
|
||||
|
||||
// Try to atomically swap from fromState to targetState
|
||||
if sm.state.CompareAndSwap(uint32(fromState), uint32(targetState)) {
|
||||
// Success! We transitioned atomically
|
||||
sm.notifyWaiters()
|
||||
return targetState, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast path failed - check if we should wait or fail
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if closed
|
||||
if currentState == StateClosed {
|
||||
return currentState, ErrStateMachineClosed
|
||||
}
|
||||
|
||||
// Slow path: need to wait for state change
|
||||
// Create waiter with valid states map for fast lookup
|
||||
validStatesMap := make(map[ConnState]struct{}, len(validFromStates))
|
||||
for _, s := range validFromStates {
|
||||
validStatesMap[s] = struct{}{}
|
||||
}
|
||||
|
||||
w := &waiter{
|
||||
validStates: validStatesMap,
|
||||
targetState: targetState,
|
||||
done: make(chan error, 1), // Buffered to avoid goroutine leak
|
||||
}
|
||||
|
||||
// Add to FIFO queue
|
||||
sm.mu.Lock()
|
||||
elem := sm.waiters.PushBack(w)
|
||||
sm.waiterCount.Add(1)
|
||||
sm.mu.Unlock()
|
||||
|
||||
// Wait for state change or timeout
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Timeout or cancellation - remove from queue
|
||||
sm.mu.Lock()
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
sm.mu.Unlock()
|
||||
return sm.GetState(), ctx.Err()
|
||||
case err := <-w.done:
|
||||
// Transition completed (or failed)
|
||||
// Note: waiterCount is decremented either in notifyWaiters (when the waiter is notified and removed)
|
||||
// or here (on timeout/cancellation).
|
||||
return sm.GetState(), err
|
||||
}
|
||||
}
|
||||
|
||||
// notifyWaiters checks if any waiters can proceed and notifies them in FIFO order.
|
||||
// This is called after every state transition.
|
||||
func (sm *ConnStateMachine) notifyWaiters() {
|
||||
// Fast path: check atomic counter without acquiring lock
|
||||
// This eliminates mutex overhead in the common case (no waiters)
|
||||
if sm.waiterCount.Load() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring lock (waiters might have been processed)
|
||||
if sm.waiters.Len() == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Process waiters in FIFO order until no more can be processed
|
||||
// We loop instead of recursing to avoid stack overflow and mutex issues
|
||||
for {
|
||||
processed := false
|
||||
|
||||
// Find the first waiter that can proceed
|
||||
for elem := sm.waiters.Front(); elem != nil; elem = elem.Next() {
|
||||
w := elem.Value.(*waiter)
|
||||
|
||||
// Read current state inside the loop to get the latest value
|
||||
currentState := sm.GetState()
|
||||
|
||||
// Check if current state is valid for this waiter
|
||||
if _, valid := w.validStates[currentState]; valid {
|
||||
// Remove from queue first
|
||||
sm.waiters.Remove(elem)
|
||||
sm.waiterCount.Add(-1)
|
||||
|
||||
// Use CAS to ensure state hasn't changed since we checked
|
||||
// This prevents race condition where another thread changes state
|
||||
// between our check and our transition
|
||||
if sm.state.CompareAndSwap(uint32(currentState), uint32(w.targetState)) {
|
||||
// Successfully transitioned - notify waiter
|
||||
w.done <- nil
|
||||
processed = true
|
||||
break
|
||||
} else {
|
||||
// State changed - re-add waiter to front of queue to maintain FIFO ordering
|
||||
// This waiter was first in line and should retain priority
|
||||
sm.waiters.PushFront(w)
|
||||
sm.waiterCount.Add(1)
|
||||
// Continue to next iteration to re-read state
|
||||
processed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we didn't process any waiter, we're done
|
||||
if !processed {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
+165
@@ -0,0 +1,165 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// PoolHook defines the interface for connection lifecycle hooks.
|
||||
type PoolHook interface {
|
||||
// OnGet is called when a connection is retrieved from the pool.
|
||||
// It can modify the connection or return an error to prevent its use.
|
||||
// The accept flag can be used to prevent the connection from being used.
|
||||
// On Accept = false the connection is rejected and returned to the pool.
|
||||
// The error can be used to prevent the connection from being used and returned to the pool.
|
||||
// On Errors, the connection is removed from the pool.
|
||||
// It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool)
|
||||
// The flag can be used for gathering metrics on pool hit/miss ratio.
|
||||
OnGet(ctx context.Context, conn *Conn, isNewConn bool) (accept bool, err error)
|
||||
|
||||
// OnPut is called when a connection is returned to the pool.
|
||||
// It returns whether the connection should be pooled and whether it should be removed.
|
||||
OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error)
|
||||
|
||||
// OnRemove is called when a connection is removed from the pool.
|
||||
// This happens when:
|
||||
// - Connection fails health check
|
||||
// - Connection exceeds max lifetime
|
||||
// - Pool is being closed
|
||||
// - Connection encounters an error
|
||||
// Implementations should clean up any per-connection state.
|
||||
// The reason parameter indicates why the connection was removed.
|
||||
OnRemove(ctx context.Context, conn *Conn, reason error)
|
||||
}
|
||||
|
||||
// PoolHookManager manages multiple pool hooks.
|
||||
type PoolHookManager struct {
|
||||
hooks []PoolHook
|
||||
hooksMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewPoolHookManager creates a new pool hook manager.
|
||||
func NewPoolHookManager() *PoolHookManager {
|
||||
return &PoolHookManager{
|
||||
hooks: make([]PoolHook, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddHook adds a pool hook to the manager.
|
||||
// Hooks are called in the order they were added.
|
||||
func (phm *PoolHookManager) AddHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
phm.hooks = append(phm.hooks, hook)
|
||||
}
|
||||
|
||||
// RemoveHook removes a pool hook from the manager.
|
||||
func (phm *PoolHookManager) RemoveHook(hook PoolHook) {
|
||||
phm.hooksMu.Lock()
|
||||
defer phm.hooksMu.Unlock()
|
||||
|
||||
for i, h := range phm.hooks {
|
||||
if h == hook {
|
||||
// Remove hook by swapping with last element and truncating
|
||||
phm.hooks[i] = phm.hooks[len(phm.hooks)-1]
|
||||
phm.hooks = phm.hooks[:len(phm.hooks)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessOnGet calls all OnGet hooks in order.
|
||||
// If any hook returns an error, processing stops and the error is returned.
|
||||
func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) (acceptConn bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
acceptConn, err := hook.OnGet(ctx, conn, isNewConn)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !acceptConn {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ProcessOnPut calls all OnPut hooks in order.
|
||||
// The first hook that returns shouldRemove=true or shouldPool=false will stop processing.
|
||||
func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
shouldPool = true // Default to pooling the connection
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn)
|
||||
|
||||
if hookErr != nil {
|
||||
return false, true, hookErr
|
||||
}
|
||||
|
||||
// If any hook says to remove or not pool, respect that decision
|
||||
if hookShouldRemove {
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
if !hookShouldPool {
|
||||
shouldPool = false
|
||||
}
|
||||
}
|
||||
|
||||
return shouldPool, false, nil
|
||||
}
|
||||
|
||||
// ProcessOnRemove calls all OnRemove hooks in order.
|
||||
func (phm *PoolHookManager) ProcessOnRemove(ctx context.Context, conn *Conn, reason error) {
|
||||
// Copy slice reference while holding lock (fast)
|
||||
phm.hooksMu.RLock()
|
||||
hooks := phm.hooks
|
||||
phm.hooksMu.RUnlock()
|
||||
|
||||
// Call hooks without holding lock (slow operations)
|
||||
for _, hook := range hooks {
|
||||
hook.OnRemove(ctx, conn, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHookCount returns the number of registered hooks (for testing).
|
||||
func (phm *PoolHookManager) GetHookCount() int {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
return len(phm.hooks)
|
||||
}
|
||||
|
||||
// GetHooks returns a copy of all registered hooks.
|
||||
func (phm *PoolHookManager) GetHooks() []PoolHook {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
hooks := make([]PoolHook, len(phm.hooks))
|
||||
copy(hooks, phm.hooks)
|
||||
return hooks
|
||||
}
|
||||
|
||||
// Clone creates a copy of the hook manager with the same hooks.
|
||||
// This is used for lock-free atomic updates of the hook manager.
|
||||
func (phm *PoolHookManager) Clone() *PoolHookManager {
|
||||
phm.hooksMu.RLock()
|
||||
defer phm.hooksMu.RUnlock()
|
||||
|
||||
newManager := &PoolHookManager{
|
||||
hooks: make([]PoolHook, len(phm.hooks)),
|
||||
}
|
||||
copy(newManager.hooks, phm.hooks)
|
||||
return newManager
|
||||
}
|
||||
+656
-148
File diff suppressed because it is too large
Load Diff
+50
-4
@@ -1,7 +1,13 @@
|
||||
package pool
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SingleConnPool is a pool that always returns the same connection.
|
||||
// Note: This pool is not thread-safe.
|
||||
// It is intended to be used by clients that need a single connection.
|
||||
type SingleConnPool struct {
|
||||
pool Pooler
|
||||
cn *Conn
|
||||
@@ -10,6 +16,12 @@ type SingleConnPool struct {
|
||||
|
||||
var _ Pooler = (*SingleConnPool)(nil)
|
||||
|
||||
// NewSingleConnPool creates a new single connection pool.
|
||||
// The pool will always return the same connection.
|
||||
// The pool will not:
|
||||
// - Close the connection
|
||||
// - Reconnect the connection
|
||||
// - Track the connection in any way
|
||||
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool {
|
||||
return &SingleConnPool{
|
||||
pool: pool,
|
||||
@@ -25,20 +37,47 @@ func (p *SingleConnPool) CloseConn(cn *Conn) error {
|
||||
return p.pool.CloseConn(cn)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) {
|
||||
func (p *SingleConnPool) Get(_ context.Context) (*Conn, error) {
|
||||
if p.stickyErr != nil {
|
||||
return nil, p.stickyErr
|
||||
}
|
||||
if p.cn == nil {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
// NOTE: SingleConnPool is NOT thread-safe by design and is used in special scenarios:
|
||||
// - During initialization (connection is in INITIALIZING state)
|
||||
// - During re-authentication (connection is in UNUSABLE state)
|
||||
// - For transactions (connection might be in various states)
|
||||
// We use SetUsed() which forces the transition, rather than TryTransition() which
|
||||
// would fail if the connection is not in IDLE/CREATED state.
|
||||
p.cn.SetUsed(true)
|
||||
p.cn.SetUsedAt(time.Now())
|
||||
return p.cn, nil
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {}
|
||||
func (p *SingleConnPool) Put(_ context.Context, cn *Conn) {
|
||||
if p.cn == nil {
|
||||
return
|
||||
}
|
||||
if p.cn != cn {
|
||||
return
|
||||
}
|
||||
p.cn.SetUsed(false)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
func (p *SingleConnPool) Remove(_ context.Context, cn *Conn, reason error) {
|
||||
cn.SetUsed(false)
|
||||
p.cn = nil
|
||||
p.stickyErr = reason
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for SingleConnPool
|
||||
// since SingleConnPool doesn't use a turn-based queue system.
|
||||
func (p *SingleConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) Close() error {
|
||||
p.cn = nil
|
||||
p.stickyErr = ErrClosed
|
||||
@@ -53,6 +92,13 @@ func (p *SingleConnPool) IdleLen() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for SingleConnPool.
|
||||
func (p *SingleConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *SingleConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *SingleConnPool) AddPoolHook(_ PoolHook) {}
|
||||
|
||||
func (p *SingleConnPool) RemovePoolHook(_ PoolHook) {}
|
||||
|
||||
+13
@@ -123,6 +123,12 @@ func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
|
||||
p.ch <- cn
|
||||
}
|
||||
|
||||
// RemoveWithoutTurn has the same behavior as Remove for StickyConnPool
|
||||
// since StickyConnPool doesn't use a turn-based queue system.
|
||||
func (p *StickyConnPool) RemoveWithoutTurn(ctx context.Context, cn *Conn, reason error) {
|
||||
p.Remove(ctx, cn, reason)
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) Close() error {
|
||||
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 {
|
||||
return nil
|
||||
@@ -196,6 +202,13 @@ func (p *StickyConnPool) IdleLen() int {
|
||||
return len(p.ch)
|
||||
}
|
||||
|
||||
// Size returns the maximum pool size, which is always 1 for StickyConnPool.
|
||||
func (p *StickyConnPool) Size() int { return 1 }
|
||||
|
||||
func (p *StickyConnPool) Stats() *Stats {
|
||||
return &Stats{}
|
||||
}
|
||||
|
||||
func (p *StickyConnPool) AddPoolHook(hook PoolHook) {}
|
||||
|
||||
func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {}
|
||||
|
||||
+80
@@ -0,0 +1,80 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type PubSubStats struct {
|
||||
Created uint32
|
||||
Untracked uint32
|
||||
Active uint32
|
||||
}
|
||||
|
||||
// PubSubPool manages a pool of PubSub connections.
|
||||
type PubSubPool struct {
|
||||
opt *Options
|
||||
netDialer func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// Map to track active PubSub connections
|
||||
activeConns sync.Map // map[uint64]*Conn (connID -> conn)
|
||||
closed atomic.Bool
|
||||
stats PubSubStats
|
||||
}
|
||||
|
||||
// NewPubSubPool implements a pool for PubSub connections.
|
||||
// It intentionally does not implement the Pooler interface
|
||||
func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool {
|
||||
return &PubSubPool{
|
||||
opt: opt,
|
||||
netDialer: netDialer,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) {
|
||||
if p.closed.Load() {
|
||||
return nil, ErrClosed
|
||||
}
|
||||
|
||||
netConn, err := p.netDialer(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize)
|
||||
cn.pubsub = true
|
||||
atomic.AddUint32(&p.stats.Created, 1)
|
||||
return cn, nil
|
||||
|
||||
}
|
||||
|
||||
func (p *PubSubPool) TrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, 1)
|
||||
p.activeConns.Store(cn.GetID(), cn)
|
||||
}
|
||||
|
||||
func (p *PubSubPool) UntrackConn(cn *Conn) {
|
||||
atomic.AddUint32(&p.stats.Active, ^uint32(0))
|
||||
atomic.AddUint32(&p.stats.Untracked, 1)
|
||||
p.activeConns.Delete(cn.GetID())
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Close() error {
|
||||
p.closed.Store(true)
|
||||
p.activeConns.Range(func(key, value interface{}) bool {
|
||||
cn := value.(*Conn)
|
||||
_ = cn.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PubSubPool) Stats() *PubSubStats {
|
||||
// load stats atomically
|
||||
return &PubSubStats{
|
||||
Created: atomic.LoadUint32(&p.stats.Created),
|
||||
Untracked: atomic.LoadUint32(&p.stats.Untracked),
|
||||
Active: atomic.LoadUint32(&p.stats.Active),
|
||||
}
|
||||
}
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type wantConn struct {
|
||||
mu sync.Mutex // protects ctx, done and sending of the result
|
||||
ctx context.Context // context for dial, cleared after delivered or canceled
|
||||
cancelCtx context.CancelFunc
|
||||
done bool // true after delivered or canceled
|
||||
result chan wantConnResult // channel to deliver connection or error
|
||||
}
|
||||
|
||||
// getCtxForDial returns context for dial or nil if connection was delivered or canceled.
|
||||
func (w *wantConn) getCtxForDial() context.Context {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
return w.ctx
|
||||
}
|
||||
|
||||
func (w *wantConn) tryDeliver(cn *Conn, err error) bool {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.done {
|
||||
return false
|
||||
}
|
||||
|
||||
w.done = true
|
||||
w.ctx = nil
|
||||
|
||||
w.result <- wantConnResult{cn: cn, err: err}
|
||||
close(w.result)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *wantConn) cancel() *Conn {
|
||||
w.mu.Lock()
|
||||
var cn *Conn
|
||||
if w.done {
|
||||
select {
|
||||
case result := <-w.result:
|
||||
cn = result.cn
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
close(w.result)
|
||||
}
|
||||
|
||||
w.done = true
|
||||
w.ctx = nil
|
||||
w.mu.Unlock()
|
||||
|
||||
return cn
|
||||
}
|
||||
|
||||
type wantConnResult struct {
|
||||
cn *Conn
|
||||
err error
|
||||
}
|
||||
|
||||
type wantConnQueue struct {
|
||||
mu sync.RWMutex
|
||||
items []*wantConn
|
||||
}
|
||||
|
||||
func newWantConnQueue() *wantConnQueue {
|
||||
return &wantConnQueue{
|
||||
items: make([]*wantConn, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *wantConnQueue) enqueue(w *wantConn) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
q.items = append(q.items, w)
|
||||
}
|
||||
|
||||
func (q *wantConnQueue) dequeue() (*wantConn, bool) {
|
||||
q.mu.Lock()
|
||||
defer q.mu.Unlock()
|
||||
|
||||
if len(q.items) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
item := q.items[0]
|
||||
q.items = q.items[1:]
|
||||
return item, true
|
||||
}
|
||||
+89
-2
@@ -50,7 +50,8 @@ func (e RedisError) Error() string { return string(e) }
|
||||
func (RedisError) RedisError() {}
|
||||
|
||||
func ParseErrorReply(line []byte) error {
|
||||
return RedisError(line[1:])
|
||||
msg := string(line[1:])
|
||||
return parseTypedRedisError(msg)
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
@@ -99,6 +100,92 @@ func (r *Reader) PeekReplyType() (byte, error) {
|
||||
return b[0], nil
|
||||
}
|
||||
|
||||
func (r *Reader) PeekPushNotificationName() (string, error) {
|
||||
// "prime" the buffer by peeking at the next byte
|
||||
c, err := r.Peek(1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if c[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, next reply is not a push notification")
|
||||
}
|
||||
|
||||
// peek 36 bytes at most, should be enough to read the push notification name
|
||||
toPeek := 36
|
||||
buffered := r.Buffered()
|
||||
if buffered == 0 {
|
||||
return "", fmt.Errorf("redis: can't peek push notification name, no data available")
|
||||
}
|
||||
if buffered < toPeek {
|
||||
toPeek = buffered
|
||||
}
|
||||
buf, err := r.rd.Peek(toPeek)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if buf[0] != RespPush {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
if len(buf) < 3 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
|
||||
// remove push notification type
|
||||
buf = buf[1:]
|
||||
// remove first line - e.g. >2\r\n
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification: %q", buf)
|
||||
}
|
||||
// next line should be $<length><string>\r\n or +<length><string>\r\n
|
||||
// should have the type of the push notification name and it's length
|
||||
if buf[0] != RespString && buf[0] != RespStatus {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
typeOfName := buf[0]
|
||||
// remove the type of the push notification name
|
||||
buf = buf[1:]
|
||||
if typeOfName == RespString {
|
||||
// remove the length of the string
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[i+2:]
|
||||
break
|
||||
} else {
|
||||
if buf[i] < '0' || buf[i] > '9' {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(buf) < 2 {
|
||||
return "", fmt.Errorf("redis: can't parse push notification name: %q", buf)
|
||||
}
|
||||
// keep only the notification name
|
||||
for i := 0; i < len(buf)-1; i++ {
|
||||
if buf[i] == '\r' && buf[i+1] == '\n' {
|
||||
buf = buf[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return util.BytesToString(buf), nil
|
||||
}
|
||||
|
||||
// ReadLine Return a valid reply, it will check the protocol or redis error,
|
||||
// and discard the attribute type.
|
||||
func (r *Reader) ReadLine() ([]byte, error) {
|
||||
@@ -115,7 +202,7 @@ func (r *Reader) ReadLine() ([]byte, error) {
|
||||
var blobErr string
|
||||
blobErr, err = r.readStringReply(line)
|
||||
if err == nil {
|
||||
err = RedisError(blobErr)
|
||||
err = parseTypedRedisError(blobErr)
|
||||
}
|
||||
return nil, err
|
||||
case RespAttr:
|
||||
|
||||
+488
@@ -0,0 +1,488 @@
|
||||
package proto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Typed Redis errors for better error handling with wrapping support.
|
||||
// These errors maintain backward compatibility by keeping the same error messages.
|
||||
|
||||
// LoadingError is returned when Redis is loading the dataset in memory.
|
||||
type LoadingError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *LoadingError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *LoadingError) RedisError() {}
|
||||
|
||||
// NewLoadingError creates a new LoadingError with the given message.
|
||||
func NewLoadingError(msg string) *LoadingError {
|
||||
return &LoadingError{msg: msg}
|
||||
}
|
||||
|
||||
// ReadOnlyError is returned when trying to write to a read-only replica.
|
||||
type ReadOnlyError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ReadOnlyError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ReadOnlyError) RedisError() {}
|
||||
|
||||
// NewReadOnlyError creates a new ReadOnlyError with the given message.
|
||||
func NewReadOnlyError(msg string) *ReadOnlyError {
|
||||
return &ReadOnlyError{msg: msg}
|
||||
}
|
||||
|
||||
// MovedError is returned when a key has been moved to a different node in a cluster.
|
||||
type MovedError struct {
|
||||
msg string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (e *MovedError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MovedError) RedisError() {}
|
||||
|
||||
// Addr returns the address of the node where the key has been moved.
|
||||
func (e *MovedError) Addr() string {
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// NewMovedError creates a new MovedError with the given message and address.
|
||||
func NewMovedError(msg string, addr string) *MovedError {
|
||||
return &MovedError{msg: msg, addr: addr}
|
||||
}
|
||||
|
||||
// AskError is returned when a key is being migrated and the client should ask another node.
|
||||
type AskError struct {
|
||||
msg string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (e *AskError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *AskError) RedisError() {}
|
||||
|
||||
// Addr returns the address of the node to ask.
|
||||
func (e *AskError) Addr() string {
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// NewAskError creates a new AskError with the given message and address.
|
||||
func NewAskError(msg string, addr string) *AskError {
|
||||
return &AskError{msg: msg, addr: addr}
|
||||
}
|
||||
|
||||
// ClusterDownError is returned when the cluster is down.
|
||||
type ClusterDownError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ClusterDownError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ClusterDownError) RedisError() {}
|
||||
|
||||
// NewClusterDownError creates a new ClusterDownError with the given message.
|
||||
func NewClusterDownError(msg string) *ClusterDownError {
|
||||
return &ClusterDownError{msg: msg}
|
||||
}
|
||||
|
||||
// TryAgainError is returned when a command cannot be processed and should be retried.
|
||||
type TryAgainError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *TryAgainError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *TryAgainError) RedisError() {}
|
||||
|
||||
// NewTryAgainError creates a new TryAgainError with the given message.
|
||||
func NewTryAgainError(msg string) *TryAgainError {
|
||||
return &TryAgainError{msg: msg}
|
||||
}
|
||||
|
||||
// MasterDownError is returned when the master is down.
|
||||
type MasterDownError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *MasterDownError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MasterDownError) RedisError() {}
|
||||
|
||||
// NewMasterDownError creates a new MasterDownError with the given message.
|
||||
func NewMasterDownError(msg string) *MasterDownError {
|
||||
return &MasterDownError{msg: msg}
|
||||
}
|
||||
|
||||
// MaxClientsError is returned when the maximum number of clients has been reached.
|
||||
type MaxClientsError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *MaxClientsError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *MaxClientsError) RedisError() {}
|
||||
|
||||
// NewMaxClientsError creates a new MaxClientsError with the given message.
|
||||
func NewMaxClientsError(msg string) *MaxClientsError {
|
||||
return &MaxClientsError{msg: msg}
|
||||
}
|
||||
|
||||
// AuthError is returned when authentication fails.
|
||||
type AuthError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *AuthError) RedisError() {}
|
||||
|
||||
// NewAuthError creates a new AuthError with the given message.
|
||||
func NewAuthError(msg string) *AuthError {
|
||||
return &AuthError{msg: msg}
|
||||
}
|
||||
|
||||
// PermissionError is returned when a user lacks required permissions.
|
||||
type PermissionError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *PermissionError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *PermissionError) RedisError() {}
|
||||
|
||||
// NewPermissionError creates a new PermissionError with the given message.
|
||||
func NewPermissionError(msg string) *PermissionError {
|
||||
return &PermissionError{msg: msg}
|
||||
}
|
||||
|
||||
// ExecAbortError is returned when a transaction is aborted.
|
||||
type ExecAbortError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *ExecAbortError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *ExecAbortError) RedisError() {}
|
||||
|
||||
// NewExecAbortError creates a new ExecAbortError with the given message.
|
||||
func NewExecAbortError(msg string) *ExecAbortError {
|
||||
return &ExecAbortError{msg: msg}
|
||||
}
|
||||
|
||||
// OOMError is returned when Redis is out of memory.
|
||||
type OOMError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *OOMError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *OOMError) RedisError() {}
|
||||
|
||||
// NewOOMError creates a new OOMError with the given message.
|
||||
func NewOOMError(msg string) *OOMError {
|
||||
return &OOMError{msg: msg}
|
||||
}
|
||||
|
||||
// parseTypedRedisError parses a Redis error message and returns a typed error if applicable.
|
||||
// This function maintains backward compatibility by keeping the same error messages.
|
||||
func parseTypedRedisError(msg string) error {
|
||||
// Check for specific error patterns and return typed errors
|
||||
switch {
|
||||
case strings.HasPrefix(msg, "LOADING "):
|
||||
return NewLoadingError(msg)
|
||||
case strings.HasPrefix(msg, "READONLY "):
|
||||
return NewReadOnlyError(msg)
|
||||
case strings.HasPrefix(msg, "MOVED "):
|
||||
// Extract address from "MOVED <slot> <addr>"
|
||||
addr := extractAddr(msg)
|
||||
return NewMovedError(msg, addr)
|
||||
case strings.HasPrefix(msg, "ASK "):
|
||||
// Extract address from "ASK <slot> <addr>"
|
||||
addr := extractAddr(msg)
|
||||
return NewAskError(msg, addr)
|
||||
case strings.HasPrefix(msg, "CLUSTERDOWN "):
|
||||
return NewClusterDownError(msg)
|
||||
case strings.HasPrefix(msg, "TRYAGAIN "):
|
||||
return NewTryAgainError(msg)
|
||||
case strings.HasPrefix(msg, "MASTERDOWN "):
|
||||
return NewMasterDownError(msg)
|
||||
case msg == "ERR max number of clients reached":
|
||||
return NewMaxClientsError(msg)
|
||||
case strings.HasPrefix(msg, "NOAUTH "), strings.HasPrefix(msg, "WRONGPASS "), strings.Contains(msg, "unauthenticated"):
|
||||
return NewAuthError(msg)
|
||||
case strings.HasPrefix(msg, "NOPERM "):
|
||||
return NewPermissionError(msg)
|
||||
case strings.HasPrefix(msg, "EXECABORT "):
|
||||
return NewExecAbortError(msg)
|
||||
case strings.HasPrefix(msg, "OOM "):
|
||||
return NewOOMError(msg)
|
||||
default:
|
||||
// Return generic RedisError for unknown error types
|
||||
return RedisError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// extractAddr extracts the address from MOVED/ASK error messages.
|
||||
// Format: "MOVED <slot> <addr>" or "ASK <slot> <addr>"
|
||||
func extractAddr(msg string) string {
|
||||
ind := strings.LastIndex(msg, " ")
|
||||
if ind == -1 {
|
||||
return ""
|
||||
}
|
||||
return msg[ind+1:]
|
||||
}
|
||||
|
||||
// IsLoadingError checks if an error is a LoadingError, even if wrapped.
|
||||
func IsLoadingError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var loadingErr *LoadingError
|
||||
if errors.As(err, &loadingErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with LOADING prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "LOADING ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "LOADING ")
|
||||
}
|
||||
|
||||
// IsReadOnlyError checks if an error is a ReadOnlyError, even if wrapped.
|
||||
func IsReadOnlyError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var readOnlyErr *ReadOnlyError
|
||||
if errors.As(err, &readOnlyErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with READONLY prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "READONLY ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "READONLY ")
|
||||
}
|
||||
|
||||
// IsMovedError checks if an error is a MovedError, even if wrapped.
|
||||
// Returns the error and a boolean indicating if it's a MovedError.
|
||||
func IsMovedError(err error) (*MovedError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var movedErr *MovedError
|
||||
if errors.As(err, &movedErr) {
|
||||
return movedErr, true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
if strings.HasPrefix(s, "MOVED ") {
|
||||
// Parse: MOVED 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
return &MovedError{msg: s, addr: parts[2]}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsAskError checks if an error is an AskError, even if wrapped.
|
||||
// Returns the error and a boolean indicating if it's an AskError.
|
||||
func IsAskError(err error) (*AskError, bool) {
|
||||
if err == nil {
|
||||
return nil, false
|
||||
}
|
||||
var askErr *AskError
|
||||
if errors.As(err, &askErr) {
|
||||
return askErr, true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
if strings.HasPrefix(s, "ASK ") {
|
||||
// Parse: ASK 3999 127.0.0.1:6381
|
||||
parts := strings.Split(s, " ")
|
||||
if len(parts) == 3 {
|
||||
return &AskError{msg: s, addr: parts[2]}, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// IsClusterDownError checks if an error is a ClusterDownError, even if wrapped.
|
||||
func IsClusterDownError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var clusterDownErr *ClusterDownError
|
||||
if errors.As(err, &clusterDownErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with CLUSTERDOWN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "CLUSTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "CLUSTERDOWN ")
|
||||
}
|
||||
|
||||
// IsTryAgainError checks if an error is a TryAgainError, even if wrapped.
|
||||
func IsTryAgainError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var tryAgainErr *TryAgainError
|
||||
if errors.As(err, &tryAgainErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with TRYAGAIN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "TRYAGAIN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "TRYAGAIN ")
|
||||
}
|
||||
|
||||
// IsMasterDownError checks if an error is a MasterDownError, even if wrapped.
|
||||
func IsMasterDownError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var masterDownErr *MasterDownError
|
||||
if errors.As(err, &masterDownErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with MASTERDOWN prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "MASTERDOWN ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "MASTERDOWN ")
|
||||
}
|
||||
|
||||
// IsMaxClientsError checks if an error is a MaxClientsError, even if wrapped.
|
||||
func IsMaxClientsError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var maxClientsErr *MaxClientsError
|
||||
if errors.As(err, &maxClientsErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with max clients prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "ERR max number of clients reached") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "ERR max number of clients reached")
|
||||
}
|
||||
|
||||
// IsAuthError checks if an error is an AuthError, even if wrapped.
|
||||
func IsAuthError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var authErr *AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with auth error prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) {
|
||||
s := redisErr.Error()
|
||||
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
s := err.Error()
|
||||
return strings.HasPrefix(s, "NOAUTH ") || strings.HasPrefix(s, "WRONGPASS ") || strings.Contains(s, "unauthenticated")
|
||||
}
|
||||
|
||||
// IsPermissionError checks if an error is a PermissionError, even if wrapped.
|
||||
func IsPermissionError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var permErr *PermissionError
|
||||
if errors.As(err, &permErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with NOPERM prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "NOPERM ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "NOPERM ")
|
||||
}
|
||||
|
||||
// IsExecAbortError checks if an error is an ExecAbortError, even if wrapped.
|
||||
func IsExecAbortError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var execAbortErr *ExecAbortError
|
||||
if errors.As(err, &execAbortErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with EXECABORT prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "EXECABORT ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "EXECABORT ")
|
||||
}
|
||||
|
||||
// IsOOMError checks if an error is an OOMError, even if wrapped.
|
||||
func IsOOMError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var oomErr *OOMError
|
||||
if errors.As(err, &oomErr) {
|
||||
return true
|
||||
}
|
||||
// Check if wrapped error is a RedisError with OOM prefix
|
||||
var redisErr RedisError
|
||||
if errors.As(err, &redisErr) && strings.HasPrefix(redisErr.Error(), "OOM ") {
|
||||
return true
|
||||
}
|
||||
// Fallback to string checking for backward compatibility
|
||||
return strings.HasPrefix(err.Error(), "OOM ")
|
||||
}
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
package internal
|
||||
|
||||
const RedisNull = "<nil>"
|
||||
+193
@@ -0,0 +1,193 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var semTimers = sync.Pool{
|
||||
New: func() interface{} {
|
||||
t := time.NewTimer(time.Hour)
|
||||
t.Stop()
|
||||
return t
|
||||
},
|
||||
}
|
||||
|
||||
// FastSemaphore is a channel-based semaphore optimized for performance.
|
||||
// It uses a fast path that avoids timer allocation when tokens are available.
|
||||
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
|
||||
// Closing the semaphore unblocks all waiting goroutines.
|
||||
//
|
||||
// Performance: ~30 ns/op with zero allocations on fast path.
|
||||
// Fairness: Eventual fairness (no starvation) but not strict FIFO.
|
||||
type FastSemaphore struct {
|
||||
tokens chan struct{}
|
||||
max int32
|
||||
}
|
||||
|
||||
// NewFastSemaphore creates a new fast semaphore with the given capacity.
|
||||
func NewFastSemaphore(capacity int32) *FastSemaphore {
|
||||
ch := make(chan struct{}, capacity)
|
||||
// Pre-fill with tokens
|
||||
for i := int32(0); i < capacity; i++ {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
return &FastSemaphore{
|
||||
tokens: ch,
|
||||
max: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a token without blocking.
|
||||
// Returns true if successful, false if no tokens available.
|
||||
func (s *FastSemaphore) TryAcquire() bool {
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire acquires a token, blocking if necessary until one is available.
|
||||
// Returns an error if the context is cancelled or the timeout expires.
|
||||
// Uses a fast path to avoid timer allocation when tokens are immediately available.
|
||||
func (s *FastSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
|
||||
// Check context first
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Try fast path first (no timer needed)
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
// Slow path: need to wait with timeout
|
||||
timer := semTimers.Get().(*time.Timer)
|
||||
defer semTimers.Put(timer)
|
||||
timer.Reset(timeout)
|
||||
|
||||
select {
|
||||
case <-s.tokens:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return timeoutErr
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
|
||||
func (s *FastSemaphore) AcquireBlocking() {
|
||||
<-s.tokens
|
||||
}
|
||||
|
||||
// Release releases a token back to the semaphore.
|
||||
func (s *FastSemaphore) Release() {
|
||||
s.tokens <- struct{}{}
|
||||
}
|
||||
|
||||
// Close closes the semaphore, unblocking all waiting goroutines.
|
||||
// After close, all Acquire calls will receive a closed channel signal.
|
||||
func (s *FastSemaphore) Close() {
|
||||
close(s.tokens)
|
||||
}
|
||||
|
||||
// Len returns the current number of acquired tokens.
|
||||
func (s *FastSemaphore) Len() int32 {
|
||||
return s.max - int32(len(s.tokens))
|
||||
}
|
||||
|
||||
// FIFOSemaphore is a channel-based semaphore with strict FIFO ordering.
|
||||
// Unlike FastSemaphore, this guarantees that threads are served in the exact order they call Acquire().
|
||||
// The channel is pre-filled with tokens: Acquire = receive, Release = send.
|
||||
// Closing the semaphore unblocks all waiting goroutines.
|
||||
//
|
||||
// Performance: ~115 ns/op with zero allocations (slower than FastSemaphore due to timer allocation).
|
||||
// Fairness: Strict FIFO ordering guaranteed by Go runtime.
|
||||
type FIFOSemaphore struct {
|
||||
tokens chan struct{}
|
||||
max int32
|
||||
}
|
||||
|
||||
// NewFIFOSemaphore creates a new FIFO semaphore with the given capacity.
|
||||
func NewFIFOSemaphore(capacity int32) *FIFOSemaphore {
|
||||
ch := make(chan struct{}, capacity)
|
||||
// Pre-fill with tokens
|
||||
for i := int32(0); i < capacity; i++ {
|
||||
ch <- struct{}{}
|
||||
}
|
||||
return &FIFOSemaphore{
|
||||
tokens: ch,
|
||||
max: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
// TryAcquire attempts to acquire a token without blocking.
|
||||
// Returns true if successful, false if no tokens available.
|
||||
func (s *FIFOSemaphore) TryAcquire() bool {
|
||||
select {
|
||||
case <-s.tokens:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire acquires a token, blocking if necessary until one is available.
|
||||
// Returns an error if the context is cancelled or the timeout expires.
|
||||
// Always uses timer to guarantee FIFO ordering (no fast path).
|
||||
func (s *FIFOSemaphore) Acquire(ctx context.Context, timeout time.Duration, timeoutErr error) error {
|
||||
// No fast path - always use timer to guarantee FIFO
|
||||
timer := semTimers.Get().(*time.Timer)
|
||||
defer semTimers.Put(timer)
|
||||
timer.Reset(timeout)
|
||||
|
||||
select {
|
||||
case <-s.tokens:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return timeoutErr
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireBlocking acquires a token, blocking indefinitely until one is available.
|
||||
func (s *FIFOSemaphore) AcquireBlocking() {
|
||||
<-s.tokens
|
||||
}
|
||||
|
||||
// Release releases a token back to the semaphore.
|
||||
func (s *FIFOSemaphore) Release() {
|
||||
s.tokens <- struct{}{}
|
||||
}
|
||||
|
||||
// Close closes the semaphore, unblocking all waiting goroutines.
|
||||
// After close, all Acquire calls will receive a closed channel signal.
|
||||
func (s *FIFOSemaphore) Close() {
|
||||
close(s.tokens)
|
||||
}
|
||||
|
||||
// Len returns the current number of acquired tokens.
|
||||
func (s *FIFOSemaphore) Len() int32 {
|
||||
return s.max - int32(len(s.tokens))
|
||||
}
|
||||
+11
@@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 {
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur.
|
||||
func SafeIntToInt32(value int, fieldName string) (int32, error) {
|
||||
if value > math.MaxInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32)
|
||||
}
|
||||
if value < math.MinInt32 {
|
||||
return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32)
|
||||
}
|
||||
return int32(value), nil
|
||||
}
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
package util
|
||||
|
||||
// Max returns the maximum of two integers
|
||||
func Max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Min returns the minimum of two integers
|
||||
func Min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
+218
@@ -0,0 +1,218 @@
|
||||
# Maintenance Notifications - FEATURES
|
||||
|
||||
## Overview
|
||||
|
||||
The Maintenance Notifications feature enables seamless Redis connection handoffs during cluster maintenance operations without dropping active connections. This feature leverages Redis RESP3 push notifications to provide zero-downtime maintenance for Redis Enterprise and compatible Redis deployments.
|
||||
|
||||
## Important
|
||||
|
||||
Using Maintenance Notifications may affect the read and write timeouts by relaxing them during maintenance operations.
|
||||
This is necessary to prevent false failures due to increased latency during handoffs. The relaxed timeouts are automatically applied and removed as needed.
|
||||
|
||||
## Key Features
|
||||
|
||||
### Seamless Connection Handoffs
|
||||
- **Zero-Downtime Maintenance**: Automatically handles connection transitions during cluster operations
|
||||
- **Active Operation Preservation**: Transfers in-flight operations to new connections without interruption
|
||||
- **Graceful Degradation**: Falls back to standard reconnection if handoff fails
|
||||
|
||||
### Push Notification Support
|
||||
Supports all Redis Enterprise maintenance notification types:
|
||||
- **MOVING** - Slot moving to a new node
|
||||
- **MIGRATING** - Slot in migration state
|
||||
- **MIGRATED** - Migration completed
|
||||
- **FAILING_OVER** - Node failing over
|
||||
- **FAILED_OVER** - Failover completed
|
||||
|
||||
### Circuit Breaker Pattern
|
||||
- **Endpoint-Specific Failure Tracking**: Prevents repeated connection attempts to failing endpoints
|
||||
- **Automatic Recovery Testing**: Half-open state allows gradual recovery validation
|
||||
- **Configurable Thresholds**: Customize failure thresholds and reset timeouts
|
||||
|
||||
### Flexible Configuration
|
||||
- **Auto-Detection Mode**: Automatically detects server support for maintenance notifications
|
||||
- **Multiple Endpoint Types**: Support for internal/external IP/FQDN endpoint resolution
|
||||
- **Auto-Scaling Workers**: Automatically sizes worker pool based on connection pool size
|
||||
- **Timeout Management**: Separate timeouts for relaxed (during maintenance) and normal operations
|
||||
|
||||
### Extensible Hook System
|
||||
- **Pre/Post Processing Hooks**: Monitor and customize notification handling
|
||||
- **Built-in Hooks**: Logging and metrics collection hooks included
|
||||
- **Custom Hook Support**: Implement custom business logic around maintenance events
|
||||
|
||||
### Comprehensive Monitoring
|
||||
- **Metrics Collection**: Track notification counts, processing times, and error rates
|
||||
- **Circuit Breaker Stats**: Monitor endpoint health and circuit breaker states
|
||||
- **Operation Tracking**: Track active handoff operations and their lifecycle
|
||||
|
||||
## Architecture Highlights
|
||||
|
||||
### Event-Driven Handoff System
|
||||
- **Asynchronous Processing**: Non-blocking handoff operations using worker pool pattern
|
||||
- **Queue-Based Architecture**: Configurable queue size with auto-scaling support
|
||||
- **Retry Mechanism**: Configurable retry attempts with exponential backoff
|
||||
|
||||
### Connection Pool Integration
|
||||
- **Pool Hook Interface**: Seamless integration with go-redis connection pool
|
||||
- **Connection State Management**: Atomic flags for connection usability tracking
|
||||
- **Graceful Shutdown**: Ensures all in-flight handoffs complete before shutdown
|
||||
|
||||
### Thread-Safe Design
|
||||
- **Lock-Free Operations**: Atomic operations for high-performance state tracking
|
||||
- **Concurrent-Safe Maps**: sync.Map for tracking active operations
|
||||
- **Minimal Lock Contention**: Read-write locks only where necessary
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Operation Modes
|
||||
- **`ModeDisabled`**: Maintenance notifications completely disabled
|
||||
- **`ModeEnabled`**: Forcefully enabled (fails if server doesn't support)
|
||||
- **`ModeAuto`**: Auto-detect server support (recommended default)
|
||||
|
||||
### Endpoint Types
|
||||
- **`EndpointTypeAuto`**: Auto-detect based on current connection
|
||||
- **`EndpointTypeInternalIP`**: Use internal IP addresses
|
||||
- **`EndpointTypeInternalFQDN`**: Use internal fully qualified domain names
|
||||
- **`EndpointTypeExternalIP`**: Use external IP addresses
|
||||
- **`EndpointTypeExternalFQDN`**: Use external fully qualified domain names
|
||||
- **`EndpointTypeNone`**: No endpoint (reconnect with current configuration)
|
||||
|
||||
### Timeout Configuration
|
||||
- **`RelaxedTimeout`**: Extended timeout during maintenance operations (default: 10s)
|
||||
- **`HandoffTimeout`**: Maximum time for handoff completion (default: 15s)
|
||||
- **`PostHandoffRelaxedDuration`**: Relaxed period after handoff (default: 2×RelaxedTimeout)
|
||||
|
||||
### Worker Pool Configuration
|
||||
- **`MaxWorkers`**: Maximum concurrent handoff workers (auto-calculated if 0)
|
||||
- **`HandoffQueueSize`**: Handoff queue capacity (auto-calculated if 0)
|
||||
- **`MaxHandoffRetries`**: Maximum retry attempts for failed handoffs (default: 3)
|
||||
|
||||
### Circuit Breaker Configuration
|
||||
- **`CircuitBreakerFailureThreshold`**: Failures before opening circuit (default: 5)
|
||||
- **`CircuitBreakerResetTimeout`**: Time before testing recovery (default: 60s)
|
||||
- **`CircuitBreakerMaxRequests`**: Max requests in half-open state (default: 3)
|
||||
|
||||
## Auto-Scaling Formulas
|
||||
|
||||
### Worker Pool Sizing
|
||||
When `MaxWorkers = 0` (auto-calculate):
|
||||
```
|
||||
MaxWorkers = min(PoolSize/2, max(10, PoolSize/3))
|
||||
```
|
||||
|
||||
### Queue Sizing
|
||||
When `HandoffQueueSize = 0` (auto-calculate):
|
||||
```
|
||||
QueueSize = max(20 × MaxWorkers, PoolSize)
|
||||
Capped by: min(MaxActiveConns + 1, 5 × PoolSize)
|
||||
```
|
||||
|
||||
### Examples
|
||||
- **Pool Size 100**: 33 workers, 660 queue (capped at 500)
|
||||
- **Pool Size 100 + MaxActiveConns 150**: 33 workers, 151 queue
|
||||
- **Pool Size 50**: 16 workers, 320 queue (capped at 250)
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
### Throughput
|
||||
- **Non-Blocking Handoffs**: Client operations continue during handoffs
|
||||
- **Concurrent Processing**: Multiple handoffs processed in parallel
|
||||
- **Minimal Overhead**: Lock-free atomic operations for state tracking
|
||||
|
||||
### Latency
|
||||
- **Relaxed Timeouts**: Extended timeouts during maintenance prevent false failures
|
||||
- **Fast Path**: Connections not undergoing handoff have zero overhead
|
||||
- **Graceful Degradation**: Failed handoffs fall back to standard reconnection
|
||||
|
||||
### Resource Usage
|
||||
- **Memory Efficient**: Bounded queue sizes prevent memory exhaustion
|
||||
- **Worker Pool**: Fixed worker count prevents goroutine explosion
|
||||
- **Connection Reuse**: Handoff reuses existing connection objects
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
- Comprehensive unit test coverage for all components
|
||||
- Mock-based testing for isolation
|
||||
- Concurrent operation testing
|
||||
|
||||
### Integration Tests
|
||||
- Pool integration tests with real connection handoffs
|
||||
- Circuit breaker behavior validation
|
||||
- Hook system integration testing
|
||||
|
||||
### E2E Tests
|
||||
- Real Redis Enterprise cluster testing
|
||||
- Multiple scenario coverage (timeouts, endpoint types, stress tests)
|
||||
- Fault injection testing
|
||||
- TLS configuration testing
|
||||
|
||||
## Compatibility
|
||||
|
||||
### Requirements
|
||||
- **Redis Protocol**: RESP3 required for push notifications
|
||||
- **Redis Version**: Redis Enterprise or compatible Redis with maintenance notifications
|
||||
- **Go Version**: Go 1.18+ (uses generics and atomic types)
|
||||
|
||||
### Client Support
|
||||
#### Currently Supported
|
||||
- **Standalone Client** (`redis.NewClient`)
|
||||
|
||||
#### Planned Support
|
||||
- **Cluster Client** (not yet supported)
|
||||
|
||||
#### Will Not Support
|
||||
- **Failover Client** (no planned support)
|
||||
- **Ring Client** (no planned support)
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### Enabling Maintenance Notifications
|
||||
|
||||
**Before:**
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 2, // RESP2
|
||||
})
|
||||
```
|
||||
|
||||
**After:**
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeAuto,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
### Adding Monitoring
|
||||
|
||||
```go
|
||||
// Get the manager from the client
|
||||
manager := client.GetMaintNotificationsManager()
|
||||
if manager != nil {
|
||||
// Add logging hook
|
||||
loggingHook := maintnotifications.NewLoggingHook(2) // Info level
|
||||
manager.AddNotificationHook(loggingHook)
|
||||
|
||||
// Add metrics hook
|
||||
metricsHook := maintnotifications.NewMetricsHook()
|
||||
manager.AddNotificationHook(metricsHook)
|
||||
}
|
||||
```
|
||||
|
||||
## Known Limitations
|
||||
|
||||
1. **Standalone Only**: Currently only supported in standalone Redis clients
|
||||
2. **RESP3 Required**: Push notifications require RESP3 protocol
|
||||
3. **Server Support**: Requires Redis Enterprise or compatible Redis with maintenance notifications
|
||||
4. **Single Connection Commands**: Some commands (MULTI/EXEC, WATCH) may need special handling
|
||||
5. **No Failover/Ring Client Support**: Failover and Ring clients are not supported and there are no plans to add support
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Cluster client support
|
||||
- Enhanced metrics and observability
|
||||
+67
@@ -0,0 +1,67 @@
|
||||
# Maintenance Notifications
|
||||
|
||||
Seamless Redis connection handoffs during cluster maintenance operations without dropping connections.
|
||||
|
||||
## ⚠️ **Important Note**
|
||||
**Maintenance notifications are currently supported only in standalone Redis clients.** Cluster clients (ClusterClient, FailoverClient, etc.) do not yet support this functionality.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: "localhost:6379",
|
||||
Protocol: 3, // RESP3 required
|
||||
MaintNotificationsConfig: &maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeEnabled,
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Modes
|
||||
|
||||
- **`ModeDisabled`** - Maintenance notifications disabled
|
||||
- **`ModeEnabled`** - Forcefully enabled (fails if server doesn't support)
|
||||
- **`ModeAuto`** - Auto-detect server support (default)
|
||||
|
||||
## Configuration
|
||||
|
||||
```go
|
||||
&maintnotifications.Config{
|
||||
Mode: maintnotifications.ModeAuto,
|
||||
EndpointType: maintnotifications.EndpointTypeAuto,
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxHandoffRetries: 3,
|
||||
MaxWorkers: 0, // Auto-calculated
|
||||
HandoffQueueSize: 0, // Auto-calculated
|
||||
PostHandoffRelaxedDuration: 0, // 2 * RelaxedTimeout
|
||||
}
|
||||
```
|
||||
|
||||
### Endpoint Types
|
||||
|
||||
- **`EndpointTypeAuto`** - Auto-detect based on connection (default)
|
||||
- **`EndpointTypeInternalIP`** - Internal IP address
|
||||
- **`EndpointTypeInternalFQDN`** - Internal FQDN
|
||||
- **`EndpointTypeExternalIP`** - External IP address
|
||||
- **`EndpointTypeExternalFQDN`** - External FQDN
|
||||
- **`EndpointTypeNone`** - No endpoint (reconnect with current config)
|
||||
|
||||
### Auto-Scaling
|
||||
|
||||
**Workers**: `min(PoolSize/2, max(10, PoolSize/3))` when auto-calculated
|
||||
**Queue**: `max(20×Workers, PoolSize)` capped by `MaxActiveConns+1` or `5×PoolSize`
|
||||
|
||||
**Examples:**
|
||||
- Pool 100: 33 workers, 660 queue (capped at 500)
|
||||
- Pool 100 + MaxActiveConns 150: 33 workers, 151 queue
|
||||
|
||||
## How It Works
|
||||
|
||||
1. Redis sends push notifications about cluster maintenance operations
|
||||
2. Client creates new connections to updated endpoints
|
||||
3. Active operations transfer to new connections
|
||||
4. Old connections close gracefully
|
||||
|
||||
|
||||
## For more information, see [FEATURES](FEATURES.md)
|
||||
+353
@@ -0,0 +1,353 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// CircuitBreakerState represents the state of a circuit breaker
|
||||
type CircuitBreakerState int32
|
||||
|
||||
const (
|
||||
// CircuitBreakerClosed - normal operation, requests allowed
|
||||
CircuitBreakerClosed CircuitBreakerState = iota
|
||||
// CircuitBreakerOpen - failing fast, requests rejected
|
||||
CircuitBreakerOpen
|
||||
// CircuitBreakerHalfOpen - testing if service recovered
|
||||
CircuitBreakerHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitBreakerState) String() string {
|
||||
switch s {
|
||||
case CircuitBreakerClosed:
|
||||
return "closed"
|
||||
case CircuitBreakerOpen:
|
||||
return "open"
|
||||
case CircuitBreakerHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker implements the circuit breaker pattern for endpoint-specific failure handling
|
||||
type CircuitBreaker struct {
|
||||
// Configuration
|
||||
failureThreshold int // Number of failures before opening
|
||||
resetTimeout time.Duration // How long to stay open before testing
|
||||
maxRequests int // Max requests allowed in half-open state
|
||||
|
||||
// State tracking (atomic for lock-free access)
|
||||
state atomic.Int32 // CircuitBreakerState
|
||||
failures atomic.Int64 // Current failure count
|
||||
successes atomic.Int64 // Success count in half-open state
|
||||
requests atomic.Int64 // Request count in half-open state
|
||||
lastFailureTime atomic.Int64 // Unix timestamp of last failure
|
||||
lastSuccessTime atomic.Int64 // Unix timestamp of last success
|
||||
|
||||
// Endpoint identification
|
||||
endpoint string
|
||||
config *Config
|
||||
}
|
||||
|
||||
// newCircuitBreaker creates a new circuit breaker for an endpoint
|
||||
func newCircuitBreaker(endpoint string, config *Config) *CircuitBreaker {
|
||||
// Use configuration values with sensible defaults
|
||||
failureThreshold := 5
|
||||
resetTimeout := 60 * time.Second
|
||||
maxRequests := 3
|
||||
|
||||
if config != nil {
|
||||
failureThreshold = config.CircuitBreakerFailureThreshold
|
||||
resetTimeout = config.CircuitBreakerResetTimeout
|
||||
maxRequests = config.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
return &CircuitBreaker{
|
||||
failureThreshold: failureThreshold,
|
||||
resetTimeout: resetTimeout,
|
||||
maxRequests: maxRequests,
|
||||
endpoint: endpoint,
|
||||
config: config,
|
||||
state: atomic.Int32{}, // Defaults to CircuitBreakerClosed (0)
|
||||
}
|
||||
}
|
||||
|
||||
// IsOpen returns true if the circuit breaker is open (rejecting requests)
|
||||
func (cb *CircuitBreaker) IsOpen() bool {
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
return state == CircuitBreakerOpen
|
||||
}
|
||||
|
||||
// shouldAttemptReset checks if enough time has passed to attempt reset
|
||||
func (cb *CircuitBreaker) shouldAttemptReset() bool {
|
||||
lastFailure := time.Unix(cb.lastFailureTime.Load(), 0)
|
||||
return time.Since(lastFailure) >= cb.resetTimeout
|
||||
}
|
||||
|
||||
// Execute runs the given function with circuit breaker protection
|
||||
func (cb *CircuitBreaker) Execute(fn func() error) error {
|
||||
// Single atomic state load for consistency
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerOpen:
|
||||
if cb.shouldAttemptReset() {
|
||||
// Attempt transition to half-open
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerOpen), int32(CircuitBreakerHalfOpen)) {
|
||||
cb.requests.Store(0)
|
||||
cb.successes.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerTransitioningToHalfOpen(cb.endpoint))
|
||||
}
|
||||
// Fall through to half-open logic
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
} else {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
fallthrough
|
||||
case CircuitBreakerHalfOpen:
|
||||
requests := cb.requests.Add(1)
|
||||
if requests > int64(cb.maxRequests) {
|
||||
cb.requests.Add(-1) // Revert the increment
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the function with consistent state
|
||||
err := fn()
|
||||
|
||||
if err != nil {
|
||||
cb.recordFailure()
|
||||
return err
|
||||
}
|
||||
|
||||
cb.recordSuccess()
|
||||
return nil
|
||||
}
|
||||
|
||||
// recordFailure records a failure and potentially opens the circuit
|
||||
func (cb *CircuitBreaker) recordFailure() {
|
||||
cb.lastFailureTime.Store(time.Now().Unix())
|
||||
failures := cb.failures.Add(1)
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
if failures >= int64(cb.failureThreshold) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerClosed), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerOpened(cb.endpoint, failures))
|
||||
}
|
||||
}
|
||||
}
|
||||
case CircuitBreakerHalfOpen:
|
||||
// Any failure in half-open state immediately opens the circuit
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerOpen)) {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerReopened(cb.endpoint))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recordSuccess records a success and potentially closes the circuit
|
||||
func (cb *CircuitBreaker) recordSuccess() {
|
||||
cb.lastSuccessTime.Store(time.Now().Unix())
|
||||
|
||||
state := CircuitBreakerState(cb.state.Load())
|
||||
|
||||
switch state {
|
||||
case CircuitBreakerClosed:
|
||||
// Reset failure count on success in closed state
|
||||
cb.failures.Store(0)
|
||||
case CircuitBreakerHalfOpen:
|
||||
successes := cb.successes.Add(1)
|
||||
|
||||
// If we've had enough successful requests, close the circuit
|
||||
if successes >= int64(cb.maxRequests) {
|
||||
if cb.state.CompareAndSwap(int32(CircuitBreakerHalfOpen), int32(CircuitBreakerClosed)) {
|
||||
cb.failures.Store(0)
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerClosed(cb.endpoint, successes))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetState returns the current state of the circuit breaker
|
||||
func (cb *CircuitBreaker) GetState() CircuitBreakerState {
|
||||
return CircuitBreakerState(cb.state.Load())
|
||||
}
|
||||
|
||||
// GetStats returns current statistics for monitoring
|
||||
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
|
||||
return CircuitBreakerStats{
|
||||
Endpoint: cb.endpoint,
|
||||
State: cb.GetState(),
|
||||
Failures: cb.failures.Load(),
|
||||
Successes: cb.successes.Load(),
|
||||
Requests: cb.requests.Load(),
|
||||
LastFailureTime: time.Unix(cb.lastFailureTime.Load(), 0),
|
||||
LastSuccessTime: time.Unix(cb.lastSuccessTime.Load(), 0),
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerStats provides statistics about a circuit breaker
|
||||
type CircuitBreakerStats struct {
|
||||
Endpoint string
|
||||
State CircuitBreakerState
|
||||
Failures int64
|
||||
Successes int64
|
||||
Requests int64
|
||||
LastFailureTime time.Time
|
||||
LastSuccessTime time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerEntry wraps a circuit breaker with access tracking
|
||||
type CircuitBreakerEntry struct {
|
||||
breaker *CircuitBreaker
|
||||
lastAccess atomic.Int64 // Unix timestamp
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// CircuitBreakerManager manages circuit breakers for multiple endpoints
|
||||
type CircuitBreakerManager struct {
|
||||
breakers sync.Map // map[string]*CircuitBreakerEntry
|
||||
config *Config
|
||||
cleanupStop chan struct{}
|
||||
cleanupMu sync.Mutex
|
||||
lastCleanup atomic.Int64 // Unix timestamp
|
||||
}
|
||||
|
||||
// newCircuitBreakerManager creates a new circuit breaker manager
|
||||
func newCircuitBreakerManager(config *Config) *CircuitBreakerManager {
|
||||
cbm := &CircuitBreakerManager{
|
||||
config: config,
|
||||
cleanupStop: make(chan struct{}),
|
||||
}
|
||||
cbm.lastCleanup.Store(time.Now().Unix())
|
||||
|
||||
// Start background cleanup goroutine
|
||||
go cbm.cleanupLoop()
|
||||
|
||||
return cbm
|
||||
}
|
||||
|
||||
// GetCircuitBreaker returns the circuit breaker for an endpoint, creating it if necessary
|
||||
func (cbm *CircuitBreakerManager) GetCircuitBreaker(endpoint string) *CircuitBreaker {
|
||||
now := time.Now().Unix()
|
||||
|
||||
if entry, ok := cbm.breakers.Load(endpoint); ok {
|
||||
cbEntry := entry.(*CircuitBreakerEntry)
|
||||
cbEntry.lastAccess.Store(now)
|
||||
return cbEntry.breaker
|
||||
}
|
||||
|
||||
// Create new circuit breaker with metadata
|
||||
newBreaker := newCircuitBreaker(endpoint, cbm.config)
|
||||
newEntry := &CircuitBreakerEntry{
|
||||
breaker: newBreaker,
|
||||
created: time.Now(),
|
||||
}
|
||||
newEntry.lastAccess.Store(now)
|
||||
|
||||
actual, _ := cbm.breakers.LoadOrStore(endpoint, newEntry)
|
||||
return actual.(*CircuitBreakerEntry).breaker
|
||||
}
|
||||
|
||||
// GetAllStats returns statistics for all circuit breakers
|
||||
func (cbm *CircuitBreakerManager) GetAllStats() []CircuitBreakerStats {
|
||||
var stats []CircuitBreakerStats
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
stats = append(stats, entry.breaker.GetStats())
|
||||
return true
|
||||
})
|
||||
return stats
|
||||
}
|
||||
|
||||
// cleanupLoop runs background cleanup of unused circuit breakers
|
||||
func (cbm *CircuitBreakerManager) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute) // Cleanup every 5 minutes
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cbm.cleanup()
|
||||
case <-cbm.cleanupStop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes circuit breakers that haven't been accessed recently
|
||||
func (cbm *CircuitBreakerManager) cleanup() {
|
||||
// Prevent concurrent cleanups
|
||||
if !cbm.cleanupMu.TryLock() {
|
||||
return
|
||||
}
|
||||
defer cbm.cleanupMu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-30 * time.Minute).Unix() // 30 minute TTL
|
||||
|
||||
var toDelete []string
|
||||
count := 0
|
||||
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
endpoint := key.(string)
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
|
||||
count++
|
||||
|
||||
// Remove if not accessed recently
|
||||
if entry.lastAccess.Load() < cutoff {
|
||||
toDelete = append(toDelete, endpoint)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Delete expired entries
|
||||
for _, endpoint := range toDelete {
|
||||
cbm.breakers.Delete(endpoint)
|
||||
}
|
||||
|
||||
// Log cleanup results
|
||||
if len(toDelete) > 0 && internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CircuitBreakerCleanup(len(toDelete), count))
|
||||
}
|
||||
|
||||
cbm.lastCleanup.Store(now.Unix())
|
||||
}
|
||||
|
||||
// Shutdown stops the cleanup goroutine
|
||||
func (cbm *CircuitBreakerManager) Shutdown() {
|
||||
close(cbm.cleanupStop)
|
||||
}
|
||||
|
||||
// Reset resets all circuit breakers (useful for testing)
|
||||
func (cbm *CircuitBreakerManager) Reset() {
|
||||
cbm.breakers.Range(func(key, value interface{}) bool {
|
||||
entry := value.(*CircuitBreakerEntry)
|
||||
breaker := entry.breaker
|
||||
breaker.state.Store(int32(CircuitBreakerClosed))
|
||||
breaker.failures.Store(0)
|
||||
breaker.successes.Store(0)
|
||||
breaker.requests.Store(0)
|
||||
breaker.lastFailureTime.Store(0)
|
||||
breaker.lastSuccessTime.Store(0)
|
||||
return true
|
||||
})
|
||||
}
|
||||
+458
@@ -0,0 +1,458 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
)
|
||||
|
||||
// Mode represents the maintenance notifications mode
|
||||
type Mode string
|
||||
|
||||
// Constants for maintenance push notifications modes
|
||||
const (
|
||||
ModeDisabled Mode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command
|
||||
ModeEnabled Mode = "enabled" // Client forcefully sends command, interrupts connection on error
|
||||
ModeAuto Mode = "auto" // Client tries to send command, disables feature on error
|
||||
)
|
||||
|
||||
// IsValid returns true if the maintenance notifications mode is valid
|
||||
func (m Mode) IsValid() bool {
|
||||
switch m {
|
||||
case ModeDisabled, ModeEnabled, ModeAuto:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the mode
|
||||
func (m Mode) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
||||
// EndpointType represents the type of endpoint to request in MOVING notifications
|
||||
type EndpointType string
|
||||
|
||||
// Constants for endpoint types
|
||||
const (
|
||||
EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection
|
||||
EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address
|
||||
EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN
|
||||
EndpointTypeExternalIP EndpointType = "external-ip" // External IP address
|
||||
EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN
|
||||
EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config)
|
||||
)
|
||||
|
||||
// IsValid returns true if the endpoint type is valid
|
||||
func (e EndpointType) IsValid() bool {
|
||||
switch e {
|
||||
case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of the endpoint type
|
||||
func (e EndpointType) String() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// Config provides configuration options for maintenance notifications
|
||||
type Config struct {
|
||||
// Mode controls how client maintenance notifications are handled.
|
||||
// Valid values: ModeDisabled, ModeEnabled, ModeAuto
|
||||
// Default: ModeAuto
|
||||
Mode Mode
|
||||
|
||||
// EndpointType specifies the type of endpoint to request in MOVING notifications.
|
||||
// Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN,
|
||||
// EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone
|
||||
// Default: EndpointTypeAuto
|
||||
EndpointType EndpointType
|
||||
|
||||
// RelaxedTimeout is the concrete timeout value to use during
|
||||
// MIGRATING/FAILING_OVER states to accommodate increased latency.
|
||||
// This applies to both read and write timeouts.
|
||||
// Default: 10 seconds
|
||||
RelaxedTimeout time.Duration
|
||||
|
||||
// HandoffTimeout is the maximum time to wait for connection handoff to complete.
|
||||
// If handoff takes longer than this, the old connection will be forcibly closed.
|
||||
// Default: 15 seconds (matches server-side eviction timeout)
|
||||
HandoffTimeout time.Duration
|
||||
|
||||
// MaxWorkers is the maximum number of worker goroutines for processing handoff requests.
|
||||
// Workers are created on-demand and automatically cleaned up when idle.
|
||||
// If zero, defaults to min(10, PoolSize/2) to handle bursts effectively.
|
||||
// If explicitly set, enforces minimum of PoolSize/2
|
||||
//
|
||||
// Default: min(PoolSize/2, max(10, PoolSize/3)), Minimum when set: PoolSize/2
|
||||
MaxWorkers int
|
||||
|
||||
// HandoffQueueSize is the size of the buffered channel used to queue handoff requests.
|
||||
// If the queue is full, new handoff requests will be rejected.
|
||||
// Scales with both worker count and pool size for better burst handling.
|
||||
//
|
||||
// Default: max(20×MaxWorkers, PoolSize), capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
// When set: minimum 200, capped by MaxActiveConns+1 (if set) or 5×PoolSize
|
||||
HandoffQueueSize int
|
||||
|
||||
// PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection
|
||||
// after a handoff completes. This provides additional resilience during cluster transitions.
|
||||
// Default: 2 * RelaxedTimeout
|
||||
PostHandoffRelaxedDuration time.Duration
|
||||
|
||||
// Circuit breaker configuration for endpoint failure handling
|
||||
// CircuitBreakerFailureThreshold is the number of failures before opening the circuit.
|
||||
// Default: 5
|
||||
CircuitBreakerFailureThreshold int
|
||||
|
||||
// CircuitBreakerResetTimeout is how long to wait before testing if the endpoint recovered.
|
||||
// Default: 60 seconds
|
||||
CircuitBreakerResetTimeout time.Duration
|
||||
|
||||
// CircuitBreakerMaxRequests is the maximum number of requests allowed in half-open state.
|
||||
// Default: 3
|
||||
CircuitBreakerMaxRequests int
|
||||
|
||||
// MaxHandoffRetries is the maximum number of times to retry a failed handoff.
|
||||
// After this many retries, the connection will be removed from the pool.
|
||||
// Default: 3
|
||||
MaxHandoffRetries int
|
||||
}
|
||||
|
||||
func (c *Config) IsEnabled() bool {
|
||||
return c != nil && c.Mode != ModeDisabled
|
||||
}
|
||||
|
||||
// DefaultConfig returns a Config with sensible defaults.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Mode: ModeAuto, // Enable by default for Redis Cloud
|
||||
EndpointType: EndpointTypeAuto, // Auto-detect based on connection
|
||||
RelaxedTimeout: 10 * time.Second,
|
||||
HandoffTimeout: 15 * time.Second,
|
||||
MaxWorkers: 0, // Auto-calculated based on pool size
|
||||
HandoffQueueSize: 0, // Auto-calculated based on max workers
|
||||
PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: 5,
|
||||
CircuitBreakerResetTimeout: 60 * time.Second,
|
||||
CircuitBreakerMaxRequests: 3,
|
||||
|
||||
// Connection Handoff Configuration
|
||||
MaxHandoffRetries: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the configuration is valid.
|
||||
func (c *Config) Validate() error {
|
||||
if c.RelaxedTimeout <= 0 {
|
||||
return ErrInvalidRelaxedTimeout
|
||||
}
|
||||
if c.HandoffTimeout <= 0 {
|
||||
return ErrInvalidHandoffTimeout
|
||||
}
|
||||
// Validate worker configuration
|
||||
// Allow 0 for auto-calculation, but negative values are invalid
|
||||
if c.MaxWorkers < 0 {
|
||||
return ErrInvalidHandoffWorkers
|
||||
}
|
||||
// HandoffQueueSize validation - allow 0 for auto-calculation
|
||||
if c.HandoffQueueSize < 0 {
|
||||
return ErrInvalidHandoffQueueSize
|
||||
}
|
||||
if c.PostHandoffRelaxedDuration < 0 {
|
||||
return ErrInvalidPostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Circuit breaker validation
|
||||
if c.CircuitBreakerFailureThreshold < 1 {
|
||||
return ErrInvalidCircuitBreakerFailureThreshold
|
||||
}
|
||||
if c.CircuitBreakerResetTimeout < 0 {
|
||||
return ErrInvalidCircuitBreakerResetTimeout
|
||||
}
|
||||
if c.CircuitBreakerMaxRequests < 1 {
|
||||
return ErrInvalidCircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
// Validate Mode (maintenance notifications mode)
|
||||
if !c.Mode.IsValid() {
|
||||
return ErrInvalidMaintNotifications
|
||||
}
|
||||
|
||||
// Validate EndpointType
|
||||
if !c.EndpointType.IsValid() {
|
||||
return ErrInvalidEndpointType
|
||||
}
|
||||
|
||||
// Validate configuration fields
|
||||
if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 {
|
||||
return ErrInvalidHandoffRetries
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyDefaults applies default values to any zero-value fields in the configuration.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaults() *Config {
|
||||
return c.ApplyDefaultsWithPoolSize(0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size to calculate worker defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config {
|
||||
return c.ApplyDefaultsWithPoolConfig(poolSize, 0)
|
||||
}
|
||||
|
||||
// ApplyDefaultsWithPoolConfig applies default values to any zero-value fields in the configuration,
|
||||
// using the provided pool size and max active connections to calculate worker and queue defaults.
|
||||
// This ensures that partially configured structs get sensible defaults for missing fields.
|
||||
func (c *Config) ApplyDefaultsWithPoolConfig(poolSize int, maxActiveConns int) *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
defaults := DefaultConfig()
|
||||
result := &Config{}
|
||||
|
||||
// Apply defaults for enum fields (empty/zero means not set)
|
||||
result.Mode = defaults.Mode
|
||||
if c.Mode != "" {
|
||||
result.Mode = c.Mode
|
||||
}
|
||||
|
||||
result.EndpointType = defaults.EndpointType
|
||||
if c.EndpointType != "" {
|
||||
result.EndpointType = c.EndpointType
|
||||
}
|
||||
|
||||
// Apply defaults for duration fields (zero means not set)
|
||||
result.RelaxedTimeout = defaults.RelaxedTimeout
|
||||
if c.RelaxedTimeout > 0 {
|
||||
result.RelaxedTimeout = c.RelaxedTimeout
|
||||
}
|
||||
|
||||
result.HandoffTimeout = defaults.HandoffTimeout
|
||||
if c.HandoffTimeout > 0 {
|
||||
result.HandoffTimeout = c.HandoffTimeout
|
||||
}
|
||||
|
||||
// Copy worker configuration
|
||||
result.MaxWorkers = c.MaxWorkers
|
||||
|
||||
// Apply worker defaults based on pool size
|
||||
result.applyWorkerDefaults(poolSize)
|
||||
|
||||
// Apply queue size defaults with new scaling approach
|
||||
// Default: max(20x workers, PoolSize), capped by maxActiveConns or 5x pool size
|
||||
workerBasedSize := result.MaxWorkers * 20
|
||||
poolBasedSize := poolSize
|
||||
result.HandoffQueueSize = util.Max(workerBasedSize, poolBasedSize)
|
||||
if c.HandoffQueueSize > 0 {
|
||||
// When explicitly set: enforce minimum of 200
|
||||
result.HandoffQueueSize = util.Max(200, c.HandoffQueueSize)
|
||||
}
|
||||
|
||||
// Cap queue size: use maxActiveConns+1 if set, otherwise 5x pool size
|
||||
var queueCap int
|
||||
if maxActiveConns > 0 {
|
||||
queueCap = maxActiveConns + 1
|
||||
// Ensure queue cap is at least 2 for very small maxActiveConns
|
||||
if queueCap < 2 {
|
||||
queueCap = 2
|
||||
}
|
||||
} else {
|
||||
queueCap = poolSize * 5
|
||||
}
|
||||
result.HandoffQueueSize = util.Min(result.HandoffQueueSize, queueCap)
|
||||
|
||||
// Ensure minimum queue size of 2 (fallback for very small pools)
|
||||
if result.HandoffQueueSize < 2 {
|
||||
result.HandoffQueueSize = 2
|
||||
}
|
||||
|
||||
result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2
|
||||
if c.PostHandoffRelaxedDuration > 0 {
|
||||
result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration
|
||||
}
|
||||
|
||||
// Apply defaults for configuration fields
|
||||
result.MaxHandoffRetries = defaults.MaxHandoffRetries
|
||||
if c.MaxHandoffRetries > 0 {
|
||||
result.MaxHandoffRetries = c.MaxHandoffRetries
|
||||
}
|
||||
|
||||
// Circuit breaker configuration
|
||||
result.CircuitBreakerFailureThreshold = defaults.CircuitBreakerFailureThreshold
|
||||
if c.CircuitBreakerFailureThreshold > 0 {
|
||||
result.CircuitBreakerFailureThreshold = c.CircuitBreakerFailureThreshold
|
||||
}
|
||||
|
||||
result.CircuitBreakerResetTimeout = defaults.CircuitBreakerResetTimeout
|
||||
if c.CircuitBreakerResetTimeout > 0 {
|
||||
result.CircuitBreakerResetTimeout = c.CircuitBreakerResetTimeout
|
||||
}
|
||||
|
||||
result.CircuitBreakerMaxRequests = defaults.CircuitBreakerMaxRequests
|
||||
if c.CircuitBreakerMaxRequests > 0 {
|
||||
result.CircuitBreakerMaxRequests = c.CircuitBreakerMaxRequests
|
||||
}
|
||||
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.DebugLoggingEnabled())
|
||||
internal.Logger.Printf(context.Background(), logs.ConfigDebug(result))
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
if c == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
|
||||
return &Config{
|
||||
Mode: c.Mode,
|
||||
EndpointType: c.EndpointType,
|
||||
RelaxedTimeout: c.RelaxedTimeout,
|
||||
HandoffTimeout: c.HandoffTimeout,
|
||||
MaxWorkers: c.MaxWorkers,
|
||||
HandoffQueueSize: c.HandoffQueueSize,
|
||||
PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration,
|
||||
|
||||
// Circuit breaker configuration
|
||||
CircuitBreakerFailureThreshold: c.CircuitBreakerFailureThreshold,
|
||||
CircuitBreakerResetTimeout: c.CircuitBreakerResetTimeout,
|
||||
CircuitBreakerMaxRequests: c.CircuitBreakerMaxRequests,
|
||||
|
||||
// Configuration fields
|
||||
MaxHandoffRetries: c.MaxHandoffRetries,
|
||||
}
|
||||
}
|
||||
|
||||
// applyWorkerDefaults calculates and applies worker defaults based on pool size
|
||||
func (c *Config) applyWorkerDefaults(poolSize int) {
|
||||
// Calculate defaults based on pool size
|
||||
if poolSize <= 0 {
|
||||
poolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
|
||||
// When not set: min(poolSize/2, max(10, poolSize/3)) - balanced scaling approach
|
||||
originalMaxWorkers := c.MaxWorkers
|
||||
c.MaxWorkers = util.Min(poolSize/2, util.Max(10, poolSize/3))
|
||||
if originalMaxWorkers != 0 {
|
||||
// When explicitly set: max(poolSize/2, set_value) - ensure at least poolSize/2 workers
|
||||
c.MaxWorkers = util.Max(poolSize/2, originalMaxWorkers)
|
||||
}
|
||||
|
||||
// Ensure minimum of 1 worker (fallback for very small pools)
|
||||
if c.MaxWorkers < 1 {
|
||||
c.MaxWorkers = 1
|
||||
}
|
||||
}
|
||||
|
||||
// DetectEndpointType automatically detects the appropriate endpoint type
|
||||
// based on the connection address and TLS configuration.
|
||||
//
|
||||
// For IP addresses:
|
||||
// - If TLS is enabled: requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// For hostnames:
|
||||
// - If TLS is enabled: always requests FQDN for proper certificate validation
|
||||
// - If TLS is disabled: requests IP for better performance
|
||||
//
|
||||
// Internal vs External detection:
|
||||
// - For IPs: uses private IP range detection
|
||||
// - For hostnames: uses heuristics based on common internal naming patterns
|
||||
func DetectEndpointType(addr string, tlsEnabled bool) EndpointType {
|
||||
// Extract host from "host:port" format
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr // Assume no port
|
||||
}
|
||||
|
||||
// Check if the host is an IP address or hostname
|
||||
ip := net.ParseIP(host)
|
||||
isIPAddress := ip != nil
|
||||
var endpointType EndpointType
|
||||
|
||||
if isIPAddress {
|
||||
// Address is an IP - determine if it's private or public
|
||||
isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast()
|
||||
|
||||
if tlsEnabled {
|
||||
// TLS with IP addresses - still prefer FQDN for certificate validation
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
} else {
|
||||
// No TLS - can use IP addresses directly
|
||||
if isPrivate {
|
||||
endpointType = EndpointTypeInternalIP
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalIP
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Address is a hostname
|
||||
isInternalHostname := isInternalHostname(host)
|
||||
if isInternalHostname {
|
||||
endpointType = EndpointTypeInternalFQDN
|
||||
} else {
|
||||
endpointType = EndpointTypeExternalFQDN
|
||||
}
|
||||
}
|
||||
|
||||
return endpointType
|
||||
}
|
||||
|
||||
// isInternalHostname determines if a hostname appears to be internal/private.
|
||||
// This is a heuristic based on common naming patterns.
|
||||
func isInternalHostname(hostname string) bool {
|
||||
// Convert to lowercase for comparison
|
||||
hostname = strings.ToLower(hostname)
|
||||
|
||||
// Common internal hostname patterns
|
||||
internalPatterns := []string{
|
||||
"localhost",
|
||||
".local",
|
||||
".internal",
|
||||
".corp",
|
||||
".lan",
|
||||
".intranet",
|
||||
".private",
|
||||
}
|
||||
|
||||
// Check for exact match or suffix match
|
||||
for _, pattern := range internalPatterns {
|
||||
if hostname == pattern || strings.HasSuffix(hostname, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.)
|
||||
// If hostname doesn't contain dots, it's likely internal
|
||||
if !strings.Contains(hostname, ".") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Default to external for fully qualified domain names
|
||||
return false
|
||||
}
|
||||
+76
@@ -0,0 +1,76 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
)
|
||||
|
||||
// Configuration errors
|
||||
var (
|
||||
ErrInvalidRelaxedTimeout = errors.New(logs.InvalidRelaxedTimeoutError())
|
||||
ErrInvalidHandoffTimeout = errors.New(logs.InvalidHandoffTimeoutError())
|
||||
ErrInvalidHandoffWorkers = errors.New(logs.InvalidHandoffWorkersError())
|
||||
ErrInvalidHandoffQueueSize = errors.New(logs.InvalidHandoffQueueSizeError())
|
||||
ErrInvalidPostHandoffRelaxedDuration = errors.New(logs.InvalidPostHandoffRelaxedDurationError())
|
||||
ErrInvalidEndpointType = errors.New(logs.InvalidEndpointTypeError())
|
||||
ErrInvalidMaintNotifications = errors.New(logs.InvalidMaintNotificationsError())
|
||||
ErrMaxHandoffRetriesReached = errors.New(logs.MaxHandoffRetriesReachedError())
|
||||
|
||||
// Configuration validation errors
|
||||
|
||||
// ErrInvalidHandoffRetries is returned when the number of handoff retries is invalid
|
||||
ErrInvalidHandoffRetries = errors.New(logs.InvalidHandoffRetriesError())
|
||||
)
|
||||
|
||||
// Integration errors
|
||||
var (
|
||||
// ErrInvalidClient is returned when the client does not support push notifications
|
||||
ErrInvalidClient = errors.New(logs.InvalidClientError())
|
||||
)
|
||||
|
||||
// Handoff errors
|
||||
var (
|
||||
// ErrHandoffQueueFull is returned when the handoff queue is full
|
||||
ErrHandoffQueueFull = errors.New(logs.HandoffQueueFullError())
|
||||
)
|
||||
|
||||
// Notification errors
|
||||
var (
|
||||
// ErrInvalidNotification is returned when a notification is in an invalid format
|
||||
ErrInvalidNotification = errors.New(logs.InvalidNotificationError())
|
||||
)
|
||||
|
||||
// connection handoff errors
|
||||
var (
|
||||
// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoff = errors.New(logs.ConnectionMarkedForHandoffErrorMessage)
|
||||
// ErrConnectionMarkedForHandoffWithState is returned when a connection is marked for handoff
|
||||
// and should not be used until the handoff is complete
|
||||
ErrConnectionMarkedForHandoffWithState = errors.New(logs.ConnectionMarkedForHandoffErrorMessage + " with state")
|
||||
// ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff
|
||||
ErrConnectionInvalidHandoffState = errors.New(logs.ConnectionInvalidHandoffStateErrorMessage)
|
||||
)
|
||||
|
||||
// shutdown errors
|
||||
var (
|
||||
// ErrShutdown is returned when the maintnotifications manager is shutdown
|
||||
ErrShutdown = errors.New(logs.ShutdownError())
|
||||
)
|
||||
|
||||
// circuit breaker errors
|
||||
var (
|
||||
// ErrCircuitBreakerOpen is returned when the circuit breaker is open
|
||||
ErrCircuitBreakerOpen = errors.New(logs.CircuitBreakerOpenErrorMessage)
|
||||
)
|
||||
|
||||
// circuit breaker configuration errors
|
||||
var (
|
||||
// ErrInvalidCircuitBreakerFailureThreshold is returned when the circuit breaker failure threshold is invalid
|
||||
ErrInvalidCircuitBreakerFailureThreshold = errors.New(logs.InvalidCircuitBreakerFailureThresholdError())
|
||||
// ErrInvalidCircuitBreakerResetTimeout is returned when the circuit breaker reset timeout is invalid
|
||||
ErrInvalidCircuitBreakerResetTimeout = errors.New(logs.InvalidCircuitBreakerResetTimeoutError())
|
||||
// ErrInvalidCircuitBreakerMaxRequests is returned when the circuit breaker max requests is invalid
|
||||
ErrInvalidCircuitBreakerMaxRequests = errors.New(logs.InvalidCircuitBreakerMaxRequestsError())
|
||||
)
|
||||
+101
@@ -0,0 +1,101 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
startTimeKey contextKey = "maint_notif_start_time"
|
||||
)
|
||||
|
||||
// MetricsHook collects metrics about notification processing.
|
||||
type MetricsHook struct {
|
||||
NotificationCounts map[string]int64
|
||||
ProcessingTimes map[string]time.Duration
|
||||
ErrorCounts map[string]int64
|
||||
HandoffCounts int64 // Total handoffs initiated
|
||||
HandoffSuccesses int64 // Successful handoffs
|
||||
HandoffFailures int64 // Failed handoffs
|
||||
}
|
||||
|
||||
// NewMetricsHook creates a new metrics collection hook.
|
||||
func NewMetricsHook() *MetricsHook {
|
||||
return &MetricsHook{
|
||||
NotificationCounts: make(map[string]int64),
|
||||
ProcessingTimes: make(map[string]time.Duration),
|
||||
ErrorCounts: make(map[string]int64),
|
||||
}
|
||||
}
|
||||
|
||||
// PreHook records the start time for processing metrics.
|
||||
func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
mh.NotificationCounts[notificationType]++
|
||||
|
||||
// Log connection information if available
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookProcessingNotification(notificationType, conn.GetID()))
|
||||
}
|
||||
|
||||
// Store start time in context for duration calculation
|
||||
startTime := time.Now()
|
||||
_ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further
|
||||
|
||||
return notification, true
|
||||
}
|
||||
|
||||
// PostHook records processing completion and any errors.
|
||||
func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
// Calculate processing duration
|
||||
if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok {
|
||||
duration := time.Since(startTime)
|
||||
mh.ProcessingTimes[notificationType] = duration
|
||||
}
|
||||
|
||||
// Record errors
|
||||
if result != nil {
|
||||
mh.ErrorCounts[notificationType]++
|
||||
|
||||
// Log error details with connection information
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
internal.Logger.Printf(ctx, logs.MetricsHookRecordedError(notificationType, conn.GetID(), result))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetMetrics returns a summary of collected metrics.
|
||||
func (mh *MetricsHook) GetMetrics() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"notification_counts": mh.NotificationCounts,
|
||||
"processing_times": mh.ProcessingTimes,
|
||||
"error_counts": mh.ErrorCounts,
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleCircuitBreakerMonitor demonstrates how to monitor circuit breaker status
|
||||
func ExampleCircuitBreakerMonitor(poolHook *PoolHook) {
|
||||
// Get circuit breaker statistics
|
||||
stats := poolHook.GetCircuitBreakerStats()
|
||||
|
||||
for _, stat := range stats {
|
||||
fmt.Printf("Circuit Breaker for %s:\n", stat.Endpoint)
|
||||
fmt.Printf(" State: %s\n", stat.State)
|
||||
fmt.Printf(" Failures: %d\n", stat.Failures)
|
||||
fmt.Printf(" Last Failure: %v\n", stat.LastFailureTime)
|
||||
fmt.Printf(" Last Success: %v\n", stat.LastSuccessTime)
|
||||
|
||||
// Alert if circuit breaker is open
|
||||
if stat.State.String() == "open" {
|
||||
fmt.Printf(" ⚠️ ALERT: Circuit breaker is OPEN for %s\n", stat.Endpoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
+512
@@ -0,0 +1,512 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// handoffWorkerManager manages background workers and queue for connection handoffs
|
||||
type handoffWorkerManager struct {
|
||||
// Event-driven handoff support
|
||||
handoffQueue chan HandoffRequest // Queue for handoff requests
|
||||
shutdown chan struct{} // Shutdown signal
|
||||
shutdownOnce sync.Once // Ensure clean shutdown
|
||||
workerWg sync.WaitGroup // Track worker goroutines
|
||||
|
||||
// On-demand worker management
|
||||
maxWorkers int
|
||||
activeWorkers atomic.Int32
|
||||
workerTimeout time.Duration // How long workers wait for work before exiting
|
||||
workersScaling atomic.Bool
|
||||
|
||||
// Simple state tracking
|
||||
pending sync.Map // map[uint64]int64 (connID -> seqID)
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Pool hook reference for handoff processing
|
||||
poolHook *PoolHook
|
||||
|
||||
// Circuit breaker manager for endpoint failure handling
|
||||
circuitBreakerManager *CircuitBreakerManager
|
||||
}
|
||||
|
||||
// newHandoffWorkerManager creates a new handoff worker manager
|
||||
func newHandoffWorkerManager(config *Config, poolHook *PoolHook) *handoffWorkerManager {
|
||||
return &handoffWorkerManager{
|
||||
handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize),
|
||||
shutdown: make(chan struct{}),
|
||||
maxWorkers: config.MaxWorkers,
|
||||
activeWorkers: atomic.Int32{}, // Start with no workers - create on demand
|
||||
workerTimeout: 15 * time.Second, // Workers exit after 15s of inactivity
|
||||
config: config,
|
||||
poolHook: poolHook,
|
||||
circuitBreakerManager: newCircuitBreakerManager(config),
|
||||
}
|
||||
}
|
||||
|
||||
// getCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (hwm *handoffWorkerManager) getCurrentWorkers() int {
|
||||
return int(hwm.activeWorkers.Load())
|
||||
}
|
||||
|
||||
// getPendingMap returns the pending map for testing purposes
|
||||
func (hwm *handoffWorkerManager) getPendingMap() *sync.Map {
|
||||
return &hwm.pending
|
||||
}
|
||||
|
||||
// getMaxWorkers returns the max workers for testing purposes
|
||||
func (hwm *handoffWorkerManager) getMaxWorkers() int {
|
||||
return hwm.maxWorkers
|
||||
}
|
||||
|
||||
// getHandoffQueue returns the handoff queue for testing purposes
|
||||
func (hwm *handoffWorkerManager) getHandoffQueue() chan HandoffRequest {
|
||||
return hwm.handoffQueue
|
||||
}
|
||||
|
||||
// getCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (hwm *handoffWorkerManager) getCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return hwm.circuitBreakerManager.GetAllStats()
|
||||
}
|
||||
|
||||
// resetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (hwm *handoffWorkerManager) resetCircuitBreakers() {
|
||||
hwm.circuitBreakerManager.Reset()
|
||||
}
|
||||
|
||||
// isHandoffPending returns true if the given connection has a pending handoff
|
||||
func (hwm *handoffWorkerManager) isHandoffPending(conn *pool.Conn) bool {
|
||||
_, pending := hwm.pending.Load(conn.GetID())
|
||||
return pending
|
||||
}
|
||||
|
||||
// ensureWorkerAvailable ensures at least one worker is available to process requests
|
||||
// Creates a new worker if needed and under the max limit
|
||||
func (hwm *handoffWorkerManager) ensureWorkerAvailable() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return
|
||||
default:
|
||||
if hwm.workersScaling.CompareAndSwap(false, true) {
|
||||
defer hwm.workersScaling.Store(false)
|
||||
// Check if we need a new worker
|
||||
currentWorkers := hwm.activeWorkers.Load()
|
||||
workersWas := currentWorkers
|
||||
for currentWorkers < int32(hwm.maxWorkers) {
|
||||
hwm.workerWg.Add(1)
|
||||
go hwm.onDemandWorker()
|
||||
currentWorkers++
|
||||
}
|
||||
// workersWas is always <= currentWorkers
|
||||
// currentWorkers will be maxWorkers, but if we have a worker that was closed
|
||||
// while we were creating new workers, just add the difference between
|
||||
// the currentWorkers and the number of workers we observed initially (i.e. the number of workers we created)
|
||||
hwm.activeWorkers.Add(currentWorkers - workersWas)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onDemandWorker processes handoff requests and exits when idle
|
||||
func (hwm *handoffWorkerManager) onDemandWorker() {
|
||||
defer func() {
|
||||
// Handle panics to ensure proper cleanup
|
||||
if r := recover(); r != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerPanicRecovered(r))
|
||||
}
|
||||
|
||||
// Decrement active worker count when exiting
|
||||
hwm.activeWorkers.Add(-1)
|
||||
hwm.workerWg.Done()
|
||||
}()
|
||||
|
||||
// Create reusable timer to prevent timer leaks
|
||||
timer := time.NewTimer(hwm.workerTimeout)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
// Reset timer for next iteration
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
timer.Reset(hwm.workerTimeout)
|
||||
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdown())
|
||||
}
|
||||
return
|
||||
case <-timer.C:
|
||||
// Worker has been idle for too long, exit to save resources
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToInactivityTimeout(hwm.workerTimeout))
|
||||
}
|
||||
return
|
||||
case request := <-hwm.handoffQueue:
|
||||
// Check for shutdown before processing
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.WorkerExitingDueToShutdownWhileProcessing())
|
||||
}
|
||||
// Clean up the request before exiting
|
||||
hwm.pending.Delete(request.ConnID)
|
||||
return
|
||||
default:
|
||||
// Process the request
|
||||
hwm.processHandoffRequest(request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processHandoffRequest processes a single handoff request
|
||||
func (hwm *handoffWorkerManager) processHandoffRequest(request HandoffRequest) {
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffStarted(request.Conn.GetID(), request.Endpoint))
|
||||
}
|
||||
|
||||
// Create a context with handoff timeout from config
|
||||
handoffTimeout := 15 * time.Second // Default timeout
|
||||
if hwm.config != nil && hwm.config.HandoffTimeout > 0 {
|
||||
handoffTimeout = hwm.config.HandoffTimeout
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Create a context that also respects the shutdown signal
|
||||
shutdownCtx, shutdownCancel := context.WithCancel(ctx)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Monitor shutdown signal in a separate goroutine
|
||||
go func() {
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
shutdownCancel()
|
||||
case <-shutdownCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Perform the handoff with cancellable context
|
||||
shouldRetry, err := hwm.performConnectionHandoff(shutdownCtx, request.Conn)
|
||||
minRetryBackoff := 500 * time.Millisecond
|
||||
if err != nil {
|
||||
if shouldRetry {
|
||||
now := time.Now()
|
||||
deadline, ok := shutdownCtx.Deadline()
|
||||
thirdOfTimeout := handoffTimeout / 3
|
||||
if !ok || deadline.Before(now) {
|
||||
// wait half the timeout before retrying if no deadline or deadline has passed
|
||||
deadline = now.Add(thirdOfTimeout)
|
||||
}
|
||||
afterTime := deadline.Sub(now)
|
||||
if afterTime < minRetryBackoff {
|
||||
afterTime = minRetryBackoff
|
||||
}
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
// Get current retry count for better logging
|
||||
currentRetries := request.Conn.HandoffRetries()
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffFailed(request.ConnID, request.Endpoint, currentRetries, maxRetries, err))
|
||||
}
|
||||
// Schedule retry - keep connection in pending map until retry is queued
|
||||
time.AfterFunc(afterTime, func() {
|
||||
if err := hwm.queueHandoff(request.Conn); err != nil {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.CannotQueueHandoffForRetry(err))
|
||||
}
|
||||
// Failed to queue retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
hwm.closeConnFromRequest(context.Background(), request, err)
|
||||
} else {
|
||||
// Successfully queued retry - remove from pending (will be re-added by queueHandoff)
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
})
|
||||
return
|
||||
} else {
|
||||
// Won't retry - remove from pending and close connection
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
go hwm.closeConnFromRequest(ctx, request, err)
|
||||
}
|
||||
|
||||
// Clear handoff state if not returned for retry
|
||||
seqID := request.Conn.GetMovingSeqID()
|
||||
connID := request.Conn.GetID()
|
||||
if hwm.poolHook.operationsManager != nil {
|
||||
hwm.poolHook.operationsManager.UntrackOperationWithConnID(seqID, connID)
|
||||
}
|
||||
} else {
|
||||
// Success - remove from pending map
|
||||
hwm.pending.Delete(request.Conn.GetID())
|
||||
}
|
||||
}
|
||||
|
||||
// queueHandoff queues a handoff request for processing
|
||||
// if err is returned, connection will be removed from pool
|
||||
func (hwm *handoffWorkerManager) queueHandoff(conn *pool.Conn) error {
|
||||
// Get handoff info atomically to prevent race conditions
|
||||
shouldHandoff, endpoint, seqID := conn.GetHandoffInfo()
|
||||
|
||||
// on retries the connection will not be marked for handoff, but it will have retries > 0
|
||||
// if shouldHandoff is false and retries is 0, then we are not retrying and not do a handoff
|
||||
if !shouldHandoff && conn.HandoffRetries() == 0 {
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ConnectionNotMarkedForHandoff(conn.GetID()))
|
||||
}
|
||||
return errors.New(logs.ConnectionNotMarkedForHandoffError(conn.GetID()))
|
||||
}
|
||||
|
||||
// Create handoff request with atomically retrieved data
|
||||
request := HandoffRequest{
|
||||
Conn: conn,
|
||||
ConnID: conn.GetID(),
|
||||
Endpoint: endpoint,
|
||||
SeqID: seqID,
|
||||
Pool: hwm.poolHook.pool, // Include pool for connection removal on failure
|
||||
}
|
||||
|
||||
select {
|
||||
// priority to shutdown
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
default:
|
||||
select {
|
||||
case <-hwm.shutdown:
|
||||
return ErrShutdown
|
||||
case hwm.handoffQueue <- request:
|
||||
// Store in pending map
|
||||
hwm.pending.Store(request.ConnID, request.SeqID)
|
||||
// Ensure we have a worker to process this request
|
||||
hwm.ensureWorkerAvailable()
|
||||
return nil
|
||||
case <-time.After(100 * time.Millisecond): // give workers a chance to process
|
||||
// Queue is full - log and attempt scaling
|
||||
queueLen := len(hwm.handoffQueue)
|
||||
queueCap := cap(hwm.handoffQueue)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.HandoffQueueFull(queueLen, queueCap))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have workers available to handle the load
|
||||
hwm.ensureWorkerAvailable()
|
||||
return ErrHandoffQueueFull
|
||||
}
|
||||
|
||||
// shutdownWorkers gracefully shuts down the worker manager, waiting for workers to complete
|
||||
func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
|
||||
hwm.shutdownOnce.Do(func() {
|
||||
close(hwm.shutdown)
|
||||
// workers will exit when they finish their current request
|
||||
|
||||
// Shutdown circuit breaker manager cleanup goroutine
|
||||
if hwm.circuitBreakerManager != nil {
|
||||
hwm.circuitBreakerManager.Shutdown()
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for workers to complete
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
hwm.workerWg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// performConnectionHandoff performs the actual connection handoff
|
||||
// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached
|
||||
func (hwm *handoffWorkerManager) performConnectionHandoff(ctx context.Context, conn *pool.Conn) (shouldRetry bool, err error) {
|
||||
// Clear handoff state after successful handoff
|
||||
connID := conn.GetID()
|
||||
|
||||
newEndpoint := conn.GetHandoffEndpoint()
|
||||
if newEndpoint == "" {
|
||||
return false, ErrConnectionInvalidHandoffState
|
||||
}
|
||||
|
||||
// Use circuit breaker to protect against failing endpoints
|
||||
circuitBreaker := hwm.circuitBreakerManager.GetCircuitBreaker(newEndpoint)
|
||||
|
||||
// Check if circuit breaker is open before attempting handoff
|
||||
if circuitBreaker.IsOpen() {
|
||||
internal.Logger.Printf(ctx, logs.CircuitBreakerOpen(connID, newEndpoint))
|
||||
return false, ErrCircuitBreakerOpen // Don't retry when circuit breaker is open
|
||||
}
|
||||
|
||||
// Perform the handoff
|
||||
shouldRetry, err = hwm.performHandoffInternal(ctx, conn, newEndpoint, connID)
|
||||
|
||||
// Update circuit breaker based on result
|
||||
if err != nil {
|
||||
// Only track dial/network errors in circuit breaker, not initialization errors
|
||||
if shouldRetry {
|
||||
circuitBreaker.recordFailure()
|
||||
}
|
||||
return shouldRetry, err
|
||||
}
|
||||
|
||||
// Success - record in circuit breaker
|
||||
circuitBreaker.recordSuccess()
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// performHandoffInternal performs the actual handoff logic (extracted for circuit breaker integration)
|
||||
func (hwm *handoffWorkerManager) performHandoffInternal(
|
||||
ctx context.Context,
|
||||
conn *pool.Conn,
|
||||
newEndpoint string,
|
||||
connID uint64,
|
||||
) (shouldRetry bool, err error) {
|
||||
retries := conn.IncrementAndGetHandoffRetries(1)
|
||||
internal.Logger.Printf(ctx, logs.HandoffRetryAttempt(connID, retries, newEndpoint, conn.RemoteAddr().String()))
|
||||
maxRetries := 3 // Default fallback
|
||||
if hwm.config != nil {
|
||||
maxRetries = hwm.config.MaxHandoffRetries
|
||||
}
|
||||
|
||||
if retries > maxRetries {
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.ReachedMaxHandoffRetries(connID, newEndpoint, maxRetries))
|
||||
}
|
||||
// won't retry on ErrMaxHandoffRetriesReached
|
||||
return false, ErrMaxHandoffRetriesReached
|
||||
}
|
||||
|
||||
// Create endpoint-specific dialer
|
||||
endpointDialer := hwm.createEndpointDialer(newEndpoint)
|
||||
|
||||
// Create new connection to the new endpoint
|
||||
newNetConn, err := endpointDialer(ctx)
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, logs.FailedToDialNewEndpoint(connID, newEndpoint, err))
|
||||
// will retry
|
||||
// Maybe a network error - retry after a delay
|
||||
return true, err
|
||||
}
|
||||
|
||||
// Get the old connection
|
||||
oldConn := conn.GetNetConn()
|
||||
|
||||
// Apply relaxed timeout to the new connection for the configured post-handoff duration
|
||||
// This gives the new connection more time to handle operations during cluster transition
|
||||
// Setting this here (before initing the connection) ensures that the connection is going
|
||||
// to use the relaxed timeout for the first operation (auth/ACL select)
|
||||
if hwm.config != nil && hwm.config.PostHandoffRelaxedDuration > 0 {
|
||||
relaxedTimeout := hwm.config.RelaxedTimeout
|
||||
// Set relaxed timeout with deadline - no background goroutine needed
|
||||
deadline := time.Now().Add(hwm.config.PostHandoffRelaxedDuration)
|
||||
conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline)
|
||||
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(context.Background(), logs.ApplyingRelaxedTimeoutDueToPostHandoff(connID, relaxedTimeout, deadline.Format("15:04:05.000")))
|
||||
}
|
||||
}
|
||||
|
||||
// Replace the connection and execute initialization
|
||||
err = conn.SetNetConnAndInitConn(ctx, newNetConn)
|
||||
if err != nil {
|
||||
// won't retry
|
||||
// Initialization failed - remove the connection
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if oldConn != nil {
|
||||
oldConn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Clear handoff state will:
|
||||
// - set the connection as usable again
|
||||
// - clear the handoff state (shouldHandoff, endpoint, seqID)
|
||||
// - reset the handoff retries to 0
|
||||
// Note: Theoretically there may be a short window where the connection is in the pool
|
||||
// and IDLE (initConn completed) but still has handoff state set.
|
||||
conn.ClearHandoffState()
|
||||
internal.Logger.Printf(ctx, logs.HandoffSucceeded(connID, newEndpoint))
|
||||
|
||||
// successfully completed the handoff, no retry needed and no error
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// createEndpointDialer creates a dialer function that connects to a specific endpoint
|
||||
func (hwm *handoffWorkerManager) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) {
|
||||
return func(ctx context.Context) (net.Conn, error) {
|
||||
// Parse endpoint to extract host and port
|
||||
host, port, err := net.SplitHostPort(endpoint)
|
||||
if err != nil {
|
||||
// If no port specified, assume default Redis port
|
||||
host = endpoint
|
||||
if port == "" {
|
||||
port = "6379"
|
||||
}
|
||||
}
|
||||
|
||||
// Use the base dialer to connect to the new endpoint
|
||||
return hwm.poolHook.baseDialer(ctx, hwm.poolHook.network, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
|
||||
// closeConnFromRequest closes the connection and logs the reason
|
||||
func (hwm *handoffWorkerManager) closeConnFromRequest(ctx context.Context, request HandoffRequest, err error) {
|
||||
pooler := request.Pool
|
||||
conn := request.Conn
|
||||
|
||||
// Clear handoff state before closing
|
||||
conn.ClearHandoffState()
|
||||
|
||||
if pooler != nil {
|
||||
// Use RemoveWithoutTurn instead of Remove to avoid freeing a turn that we don't have.
|
||||
// The handoff worker doesn't call Get(), so it doesn't have a turn to free.
|
||||
// Remove() is meant to be called after Get() and frees a turn.
|
||||
// RemoveWithoutTurn() removes and closes the connection without affecting the queue.
|
||||
pooler.RemoveWithoutTurn(ctx, conn, err)
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RemovingConnectionFromPool(conn.GetID(), err))
|
||||
}
|
||||
} else {
|
||||
err := conn.Close() // Close the connection if no pool provided
|
||||
if err != nil {
|
||||
internal.Logger.Printf(ctx, "redis: failed to close connection: %v", err)
|
||||
}
|
||||
if internal.LogLevel.WarnOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.NoPoolProvidedCannotRemove(conn.GetID(), err))
|
||||
}
|
||||
}
|
||||
}
|
||||
+60
@@ -0,0 +1,60 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// LoggingHook is an example hook implementation that logs all notifications.
|
||||
type LoggingHook struct {
|
||||
LogLevel int // 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
}
|
||||
|
||||
// PreHook logs the notification before processing and allows modification.
|
||||
func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
if lh.LogLevel >= 2 { // Info level
|
||||
// Log the notification type and content
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
seqID := int64(0)
|
||||
if slices.Contains(maintenanceNotificationTypes, notificationType) {
|
||||
// seqID is the second element in the notification array
|
||||
if len(notification) > 1 {
|
||||
if parsedSeqID, ok := notification[1].(int64); !ok {
|
||||
seqID = 0
|
||||
} else {
|
||||
seqID = parsedSeqID
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotification(connID, seqID, notificationType, notification))
|
||||
}
|
||||
return notification, true // Continue processing with unmodified notification
|
||||
}
|
||||
|
||||
// PostHook logs the result after processing.
|
||||
func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
connID := uint64(0)
|
||||
if conn, ok := notificationCtx.Conn.(*pool.Conn); ok {
|
||||
connID = conn.GetID()
|
||||
}
|
||||
if result != nil && lh.LogLevel >= 1 { // Warning level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationFailed(connID, notificationType, result, notification))
|
||||
} else if lh.LogLevel >= 3 { // Debug level
|
||||
internal.Logger.Printf(ctx, logs.ProcessingNotificationSucceeded(connID, notificationType))
|
||||
}
|
||||
}
|
||||
|
||||
// NewLoggingHook creates a new logging hook with the specified log level.
|
||||
// Log levels: 0=Error, 1=Warn, 2=Info, 3=Debug
|
||||
func NewLoggingHook(logLevel int) *LoggingHook {
|
||||
return &LoggingHook{LogLevel: logLevel}
|
||||
}
|
||||
+320
@@ -0,0 +1,320 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/interfaces"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Push notification type constants for maintenance
|
||||
const (
|
||||
NotificationMoving = "MOVING"
|
||||
NotificationMigrating = "MIGRATING"
|
||||
NotificationMigrated = "MIGRATED"
|
||||
NotificationFailingOver = "FAILING_OVER"
|
||||
NotificationFailedOver = "FAILED_OVER"
|
||||
)
|
||||
|
||||
// maintenanceNotificationTypes contains all notification types that maintenance handles
|
||||
var maintenanceNotificationTypes = []string{
|
||||
NotificationMoving,
|
||||
NotificationMigrating,
|
||||
NotificationMigrated,
|
||||
NotificationFailingOver,
|
||||
NotificationFailedOver,
|
||||
}
|
||||
|
||||
// NotificationHook is called before and after notification processing
|
||||
// PreHook can modify the notification and return false to skip processing
|
||||
// PostHook is called after successful processing
|
||||
type NotificationHook interface {
|
||||
PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool)
|
||||
PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error)
|
||||
}
|
||||
|
||||
// MovingOperationKey provides a unique key for tracking MOVING operations
|
||||
// that combines sequence ID with connection identifier to handle duplicate
|
||||
// sequence IDs across multiple connections to the same node.
|
||||
type MovingOperationKey struct {
|
||||
SeqID int64 // Sequence ID from MOVING notification
|
||||
ConnID uint64 // Unique connection identifier
|
||||
}
|
||||
|
||||
// String returns a string representation of the key for debugging
|
||||
func (k MovingOperationKey) String() string {
|
||||
return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID)
|
||||
}
|
||||
|
||||
// Manager provides a simplified upgrade functionality with hooks and atomic state.
|
||||
type Manager struct {
|
||||
client interfaces.ClientInterface
|
||||
config *Config
|
||||
options interfaces.OptionsInterface
|
||||
pool pool.Pooler
|
||||
|
||||
// MOVING operation tracking - using sync.Map for better concurrent performance
|
||||
activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation
|
||||
|
||||
// Atomic state tracking - no locks needed for state queries
|
||||
activeOperationCount atomic.Int64 // Number of active operations
|
||||
closed atomic.Bool // Manager closed state
|
||||
|
||||
// Notification hooks for extensibility
|
||||
hooks []NotificationHook
|
||||
hooksMu sync.RWMutex // Protects hooks slice
|
||||
poolHooksRef *PoolHook
|
||||
}
|
||||
|
||||
// MovingOperation tracks an active MOVING operation.
|
||||
type MovingOperation struct {
|
||||
SeqID int64
|
||||
NewEndpoint string
|
||||
StartTime time.Time
|
||||
Deadline time.Time
|
||||
}
|
||||
|
||||
// NewManager creates a new simplified manager.
|
||||
func NewManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*Manager, error) {
|
||||
if client == nil {
|
||||
return nil, ErrInvalidClient
|
||||
}
|
||||
|
||||
hm := &Manager{
|
||||
client: client,
|
||||
pool: pool,
|
||||
options: client.GetOptions(),
|
||||
config: config.Clone(),
|
||||
hooks: make([]NotificationHook, 0),
|
||||
}
|
||||
|
||||
// Set up push notification handling
|
||||
if err := hm.setupPushNotifications(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hm, nil
|
||||
}
|
||||
|
||||
// GetPoolHook creates a pool hook with a custom dialer.
|
||||
func (hm *Manager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) {
|
||||
poolHook := hm.createPoolHook(baseDialer)
|
||||
hm.pool.AddPoolHook(poolHook)
|
||||
}
|
||||
|
||||
// setupPushNotifications sets up push notification handling by registering with the client's processor.
|
||||
func (hm *Manager) setupPushNotifications() error {
|
||||
processor := hm.client.GetPushProcessor()
|
||||
if processor == nil {
|
||||
return ErrInvalidClient // Client doesn't support push notifications
|
||||
}
|
||||
|
||||
// Create our notification handler
|
||||
handler := &NotificationHandler{manager: hm, operationsManager: hm}
|
||||
|
||||
// Register handlers for all upgrade notifications with the client's processor
|
||||
for _, notificationType := range maintenanceNotificationTypes {
|
||||
if err := processor.RegisterHandler(notificationType, handler, true); err != nil {
|
||||
return errors.New(logs.FailedToRegisterHandler(notificationType, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Create MOVING operation record
|
||||
movingOp := &MovingOperation{
|
||||
SeqID: seqID,
|
||||
NewEndpoint: newEndpoint,
|
||||
StartTime: time.Now(),
|
||||
Deadline: deadline,
|
||||
}
|
||||
|
||||
// Use LoadOrStore for atomic check-and-set operation
|
||||
if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded {
|
||||
// Duplicate MOVING notification, ignore
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.DuplicateMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.TrackingMovingOperation(connID, newEndpoint, seqID))
|
||||
}
|
||||
|
||||
// Increment active operation count atomically
|
||||
hm.activeOperationCount.Add(1)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID.
|
||||
func (hm *Manager) UntrackOperationWithConnID(seqID int64, connID uint64) {
|
||||
// Create composite key
|
||||
key := MovingOperationKey{
|
||||
SeqID: seqID,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
// Remove from active operations atomically
|
||||
if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.UntrackingMovingOperation(connID, seqID))
|
||||
}
|
||||
// Decrement active operation count only if operation existed
|
||||
hm.activeOperationCount.Add(-1)
|
||||
} else {
|
||||
if internal.LogLevel.DebugOrAbove() { // Debug level
|
||||
internal.Logger.Printf(context.Background(), logs.OperationNotTracked(connID, seqID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveMovingOperations returns active operations with composite keys.
|
||||
// WARNING: This method creates a new map and copies all operations on every call.
|
||||
// Use sparingly, especially in hot paths or high-frequency logging.
|
||||
func (hm *Manager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation {
|
||||
result := make(map[MovingOperationKey]*MovingOperation)
|
||||
|
||||
// Iterate over sync.Map to build result
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
k := key.(MovingOperationKey)
|
||||
op := value.(*MovingOperation)
|
||||
|
||||
// Create a copy to avoid sharing references
|
||||
result[k] = &MovingOperation{
|
||||
SeqID: op.SeqID,
|
||||
NewEndpoint: op.NewEndpoint,
|
||||
StartTime: op.StartTime,
|
||||
Deadline: op.Deadline,
|
||||
}
|
||||
return true // Continue iteration
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// IsHandoffInProgress returns true if any handoff is in progress.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) IsHandoffInProgress() bool {
|
||||
return hm.activeOperationCount.Load() > 0
|
||||
}
|
||||
|
||||
// GetActiveOperationCount returns the number of active operations.
|
||||
// Uses atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetActiveOperationCount() int64 {
|
||||
return hm.activeOperationCount.Load()
|
||||
}
|
||||
|
||||
// Close closes the manager.
|
||||
func (hm *Manager) Close() error {
|
||||
// Use atomic operation for thread-safe close check
|
||||
if !hm.closed.CompareAndSwap(false, true) {
|
||||
return nil // Already closed
|
||||
}
|
||||
|
||||
// Shutdown the pool hook if it exists
|
||||
if hm.poolHooksRef != nil {
|
||||
// Use a timeout to prevent hanging indefinitely
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := hm.poolHooksRef.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
// was not able to close pool hook, keep closed state false
|
||||
hm.closed.Store(false)
|
||||
return err
|
||||
}
|
||||
// Remove the pool hook from the pool
|
||||
if hm.pool != nil {
|
||||
hm.pool.RemovePoolHook(hm.poolHooksRef)
|
||||
}
|
||||
}
|
||||
|
||||
// Clear all active operations
|
||||
hm.activeMovingOps.Range(func(key, value interface{}) bool {
|
||||
hm.activeMovingOps.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
// Reset counter
|
||||
hm.activeOperationCount.Store(0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetState returns current state using atomic counter for lock-free operation.
|
||||
func (hm *Manager) GetState() State {
|
||||
if hm.activeOperationCount.Load() > 0 {
|
||||
return StateMoving
|
||||
}
|
||||
return StateIdle
|
||||
}
|
||||
|
||||
// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing.
|
||||
func (hm *Manager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
currentNotification := notification
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification)
|
||||
if !shouldContinue {
|
||||
return modifiedNotification, false
|
||||
}
|
||||
currentNotification = modifiedNotification
|
||||
}
|
||||
|
||||
return currentNotification, true
|
||||
}
|
||||
|
||||
// processPostHooks calls all post-hooks with the processing result.
|
||||
func (hm *Manager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) {
|
||||
hm.hooksMu.RLock()
|
||||
defer hm.hooksMu.RUnlock()
|
||||
|
||||
for _, hook := range hm.hooks {
|
||||
hook.PostHook(ctx, notificationCtx, notificationType, notification, result)
|
||||
}
|
||||
}
|
||||
|
||||
// createPoolHook creates a pool hook with this manager already set.
|
||||
func (hm *Manager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook {
|
||||
if hm.poolHooksRef != nil {
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
// Get pool size from client options for better worker defaults
|
||||
poolSize := 0
|
||||
if hm.options != nil {
|
||||
poolSize = hm.options.GetPoolSize()
|
||||
}
|
||||
|
||||
hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize)
|
||||
hm.poolHooksRef.SetPool(hm.pool)
|
||||
|
||||
return hm.poolHooksRef
|
||||
}
|
||||
|
||||
func (hm *Manager) AddNotificationHook(notificationHook NotificationHook) {
|
||||
hm.hooksMu.Lock()
|
||||
defer hm.hooksMu.Unlock()
|
||||
hm.hooks = append(hm.hooks, notificationHook)
|
||||
}
|
||||
+182
@@ -0,0 +1,182 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
)
|
||||
|
||||
// OperationsManagerInterface defines the interface for completing handoff operations
|
||||
type OperationsManagerInterface interface {
|
||||
TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error
|
||||
UntrackOperationWithConnID(seqID int64, connID uint64)
|
||||
}
|
||||
|
||||
// HandoffRequest represents a request to handoff a connection to a new endpoint
|
||||
type HandoffRequest struct {
|
||||
Conn *pool.Conn
|
||||
ConnID uint64 // Unique connection identifier
|
||||
Endpoint string
|
||||
SeqID int64
|
||||
Pool pool.Pooler // Pool to remove connection from on failure
|
||||
}
|
||||
|
||||
// PoolHook implements pool.PoolHook for Redis-specific connection handling
|
||||
// with maintenance notifications support.
|
||||
type PoolHook struct {
|
||||
// Base dialer for creating connections to new endpoints during handoffs
|
||||
// args are network and address
|
||||
baseDialer func(context.Context, string, string) (net.Conn, error)
|
||||
|
||||
// Network type (e.g., "tcp", "unix")
|
||||
network string
|
||||
|
||||
// Worker manager for background handoff processing
|
||||
workerManager *handoffWorkerManager
|
||||
|
||||
// Configuration for the maintenance notifications
|
||||
config *Config
|
||||
|
||||
// Operations manager interface for operation completion tracking
|
||||
operationsManager OperationsManagerInterface
|
||||
|
||||
// Pool interface for removing connections on handoff failure
|
||||
pool pool.Pooler
|
||||
}
|
||||
|
||||
// NewPoolHook creates a new pool hook
|
||||
func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface) *PoolHook {
|
||||
return NewPoolHookWithPoolSize(baseDialer, network, config, operationsManager, 0)
|
||||
}
|
||||
|
||||
// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults
|
||||
func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, operationsManager OperationsManagerInterface, poolSize int) *PoolHook {
|
||||
// Apply defaults if config is nil or has zero values
|
||||
if config == nil {
|
||||
config = config.ApplyDefaultsWithPoolSize(poolSize)
|
||||
}
|
||||
|
||||
ph := &PoolHook{
|
||||
// baseDialer is used to create connections to new endpoints during handoffs
|
||||
baseDialer: baseDialer,
|
||||
network: network,
|
||||
config: config,
|
||||
operationsManager: operationsManager,
|
||||
}
|
||||
|
||||
// Create worker manager
|
||||
ph.workerManager = newHandoffWorkerManager(config, ph)
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
// SetPool sets the pool interface for removing connections on handoff failure
|
||||
func (ph *PoolHook) SetPool(pooler pool.Pooler) {
|
||||
ph.pool = pooler
|
||||
}
|
||||
|
||||
// GetCurrentWorkers returns the current number of active workers (for testing)
|
||||
func (ph *PoolHook) GetCurrentWorkers() int {
|
||||
return ph.workerManager.getCurrentWorkers()
|
||||
}
|
||||
|
||||
// IsHandoffPending returns true if the given connection has a pending handoff
|
||||
func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool {
|
||||
return ph.workerManager.isHandoffPending(conn)
|
||||
}
|
||||
|
||||
// GetPendingMap returns the pending map for testing purposes
|
||||
func (ph *PoolHook) GetPendingMap() *sync.Map {
|
||||
return ph.workerManager.getPendingMap()
|
||||
}
|
||||
|
||||
// GetMaxWorkers returns the max workers for testing purposes
|
||||
func (ph *PoolHook) GetMaxWorkers() int {
|
||||
return ph.workerManager.getMaxWorkers()
|
||||
}
|
||||
|
||||
// GetHandoffQueue returns the handoff queue for testing purposes
|
||||
func (ph *PoolHook) GetHandoffQueue() chan HandoffRequest {
|
||||
return ph.workerManager.getHandoffQueue()
|
||||
}
|
||||
|
||||
// GetCircuitBreakerStats returns circuit breaker statistics for monitoring
|
||||
func (ph *PoolHook) GetCircuitBreakerStats() []CircuitBreakerStats {
|
||||
return ph.workerManager.getCircuitBreakerStats()
|
||||
}
|
||||
|
||||
// ResetCircuitBreakers resets all circuit breakers (useful for testing)
|
||||
func (ph *PoolHook) ResetCircuitBreakers() {
|
||||
ph.workerManager.resetCircuitBreakers()
|
||||
}
|
||||
|
||||
// OnGet is called when a connection is retrieved from the pool
|
||||
func (ph *PoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
|
||||
// Check if connection is marked for handoff
|
||||
// This prevents using connections that have received MOVING notifications
|
||||
if conn.ShouldHandoff() {
|
||||
return false, ErrConnectionMarkedForHandoffWithState
|
||||
}
|
||||
|
||||
// Check if connection is usable (not in UNUSABLE or CLOSED state)
|
||||
// This ensures we don't return connections that are currently being handed off or re-authenticated.
|
||||
if !conn.IsUsable() {
|
||||
return false, ErrConnectionMarkedForHandoff
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// OnPut is called when a connection is returned to the pool
|
||||
func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) {
|
||||
// first check if we should handoff for faster rejection
|
||||
if !conn.ShouldHandoff() {
|
||||
// Default behavior (no handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
// check pending handoff to not queue the same connection twice
|
||||
if ph.workerManager.isHandoffPending(conn) {
|
||||
// Default behavior (pending handoff): pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := ph.workerManager.queueHandoff(conn); err != nil {
|
||||
// Failed to queue handoff, remove the connection
|
||||
internal.Logger.Printf(ctx, logs.FailedToQueueHandoff(conn.GetID(), err))
|
||||
// Don't pool, remove connection, no error to caller
|
||||
return false, true, nil
|
||||
}
|
||||
|
||||
// Check if handoff was already processed by a worker before we can mark it as queued
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was already processed - this is normal and the connection should be pooled
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
if err := conn.MarkQueuedForHandoff(); err != nil {
|
||||
// If marking fails, check if handoff was processed in the meantime
|
||||
if !conn.ShouldHandoff() {
|
||||
// Handoff was processed - this is normal, pool the connection
|
||||
return true, false, nil
|
||||
}
|
||||
// Other error - remove the connection
|
||||
return false, true, nil
|
||||
}
|
||||
internal.Logger.Printf(ctx, logs.MarkedForHandoff(conn.GetID()))
|
||||
return true, false, nil
|
||||
}
|
||||
|
||||
func (ph *PoolHook) OnRemove(_ context.Context, _ *pool.Conn, _ error) {
|
||||
// Not used
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the processor, waiting for workers to complete
|
||||
func (ph *PoolHook) Shutdown(ctx context.Context) error {
|
||||
return ph.workerManager.shutdownWorkers(ctx)
|
||||
}
|
||||
Generated
Vendored
+282
@@ -0,0 +1,282 @@
|
||||
package maintnotifications
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9/internal"
|
||||
"github.com/redis/go-redis/v9/internal/maintnotifications/logs"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// NotificationHandler handles push notifications for the simplified manager.
|
||||
type NotificationHandler struct {
|
||||
manager *Manager
|
||||
operationsManager OperationsManagerInterface
|
||||
}
|
||||
|
||||
// HandlePushNotification processes push notifications with hook support.
|
||||
func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) == 0 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationFormat(notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
notificationType, ok := notification[0].(string)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotificationTypeFormat(notification[0]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Process pre-hooks - they can modify the notification or skip processing
|
||||
modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification)
|
||||
if !shouldContinue {
|
||||
return nil // Hooks decided to skip processing
|
||||
}
|
||||
|
||||
var err error
|
||||
switch notificationType {
|
||||
case NotificationMoving:
|
||||
err = snh.handleMoving(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrating:
|
||||
err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationMigrated:
|
||||
err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailingOver:
|
||||
err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification)
|
||||
case NotificationFailedOver:
|
||||
err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification)
|
||||
default:
|
||||
// Ignore other notification types (e.g., pub/sub messages)
|
||||
err = nil
|
||||
}
|
||||
|
||||
// Process post-hooks with the result
|
||||
snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// handleMoving processes MOVING notifications.
|
||||
// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff
|
||||
func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
if len(notification) < 3 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MOVING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
seqID, ok := notification[1].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidSeqIDInMovingNotification(notification[1]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Extract timeS
|
||||
timeS, ok := notification[2].(int64)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidTimeSInMovingNotification(notification[2]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
newEndpoint := ""
|
||||
if len(notification) > 3 {
|
||||
// Extract new endpoint
|
||||
newEndpoint, ok = notification[3].(string)
|
||||
if !ok {
|
||||
stringified := fmt.Sprintf("%v", notification[3])
|
||||
// this could be <nil> which is valid
|
||||
if notification[3] == nil || stringified == internal.RedisNull {
|
||||
newEndpoint = ""
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNewEndpointInMovingNotification(notification[3]))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the connection that received this notification
|
||||
conn := handlerCtx.Conn
|
||||
if conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MOVING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Type assert to get the underlying pool connection
|
||||
var poolConn *pool.Conn
|
||||
if pc, ok := conn.(*pool.Conn); ok {
|
||||
poolConn = pc
|
||||
} else {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MOVING", conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// If the connection is closed or not pooled, we can ignore the notification
|
||||
// this connection won't be remembered by the pool and will be garbage collected
|
||||
// Keep pubsub connections around since they are not pooled but are long-lived
|
||||
// and should be allowed to handoff (the pubsub instance will reconnect and change
|
||||
// the underlying *pool.Conn)
|
||||
if (poolConn.IsClosed() || !poolConn.IsPooled()) && !poolConn.IsPubSub() {
|
||||
return nil
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(time.Duration(timeS) * time.Second)
|
||||
// If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds
|
||||
if newEndpoint == "" || newEndpoint == internal.RedisNull {
|
||||
if internal.LogLevel.DebugOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.SchedulingHandoffToCurrentEndpoint(poolConn.GetID(), float64(timeS)/2))
|
||||
}
|
||||
// same as current endpoint
|
||||
newEndpoint = snh.manager.options.GetAddr()
|
||||
// delay the handoff for timeS/2 seconds to the same endpoint
|
||||
// do this in a goroutine to avoid blocking the notification handler
|
||||
// NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff
|
||||
// and there should be no possibility of a race condition or double handoff.
|
||||
time.AfterFunc(time.Duration(timeS/2)*time.Second, func() {
|
||||
if poolConn == nil || poolConn.IsClosed() {
|
||||
return
|
||||
}
|
||||
if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil {
|
||||
// Log error but don't fail the goroutine - use background context since original may be cancelled
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(poolConn.GetID(), err))
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline)
|
||||
}
|
||||
|
||||
func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error {
|
||||
if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil {
|
||||
internal.Logger.Printf(context.Background(), logs.FailedToMarkForHandoff(conn.GetID(), err))
|
||||
// Connection is already marked for handoff, which is acceptable
|
||||
// This can happen if multiple MOVING notifications are received for the same connection
|
||||
return nil
|
||||
}
|
||||
// Optionally track in m
|
||||
if snh.operationsManager != nil {
|
||||
connID := conn.GetID()
|
||||
// Track the operation (ignore errors since this is optional)
|
||||
_ = snh.operationsManager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID)
|
||||
} else {
|
||||
return errors.New(logs.ManagerNotInitialized())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrating processes MIGRATING notifications.
|
||||
func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATING notifications indicate that a connection is about to be migrated
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATING", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATING"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATING", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(conn.GetID(), "MIGRATING", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMigrated processes MIGRATED notifications.
|
||||
func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// MIGRATED notifications indicate that a connection migration has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("MIGRATED", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("MIGRATED"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("MIGRATED", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailingOver processes FAILING_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILING_OVER notifications indicate that a connection is about to failover
|
||||
// Apply relaxed timeouts to the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILING_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILING_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILING_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Apply relaxed timeout to this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.RelaxedTimeoutDueToNotification(connID, "FAILING_OVER", snh.manager.config.RelaxedTimeout))
|
||||
}
|
||||
conn.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFailedOver processes FAILED_OVER notifications.
|
||||
func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error {
|
||||
// FAILED_OVER notifications indicate that a connection failover has completed
|
||||
// Restore normal timeouts for the specific connection that received this notification
|
||||
if len(notification) < 2 {
|
||||
internal.Logger.Printf(ctx, logs.InvalidNotification("FAILED_OVER", notification))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
if handlerCtx.Conn == nil {
|
||||
internal.Logger.Printf(ctx, logs.NoConnectionInHandlerContext("FAILED_OVER"))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
conn, ok := handlerCtx.Conn.(*pool.Conn)
|
||||
if !ok {
|
||||
internal.Logger.Printf(ctx, logs.InvalidConnectionTypeInHandlerContext("FAILED_OVER", handlerCtx.Conn, handlerCtx))
|
||||
return ErrInvalidNotification
|
||||
}
|
||||
|
||||
// Clear relaxed timeout for this specific connection
|
||||
if internal.LogLevel.InfoOrAbove() {
|
||||
connID := conn.GetID()
|
||||
internal.Logger.Printf(ctx, logs.UnrelaxedTimeout(connID))
|
||||
}
|
||||
conn.ClearRelaxedTimeout()
|
||||
return nil
|
||||
}
|
||||
+24
@@ -0,0 +1,24 @@
|
||||
package maintnotifications
|
||||
|
||||
// State represents the current state of a maintenance operation
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateIdle indicates no upgrade is in progress
|
||||
StateIdle State = iota
|
||||
|
||||
// StateHandoff indicates a connection handoff is in progress
|
||||
StateMoving
|
||||
)
|
||||
|
||||
// String returns a string representation of the state.
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateMoving:
|
||||
return "moving"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
+149
-14
@@ -16,6 +16,9 @@ import (
|
||||
"github.com/redis/go-redis/v9/auth"
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/util"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
// Limiter is the interface of a rate limiter or a circuit breaker.
|
||||
@@ -31,7 +34,6 @@ type Limiter interface {
|
||||
|
||||
// Options keeps the settings to set up redis connection.
|
||||
type Options struct {
|
||||
|
||||
// Network type, either tcp or unix.
|
||||
//
|
||||
// default: is tcp.
|
||||
@@ -109,6 +111,16 @@ type Options struct {
|
||||
// default: 5 seconds
|
||||
DialTimeout time.Duration
|
||||
|
||||
// DialerRetries is the maximum number of retry attempts when dialing fails.
|
||||
//
|
||||
// default: 5
|
||||
DialerRetries int
|
||||
|
||||
// DialerRetryTimeout is the backoff duration between retry attempts.
|
||||
//
|
||||
// default: 100 milliseconds
|
||||
DialerRetryTimeout time.Duration
|
||||
|
||||
// ReadTimeout for socket reads. If reached, commands will fail
|
||||
// with a timeout instead of blocking. Supported values:
|
||||
//
|
||||
@@ -152,6 +164,7 @@ type Options struct {
|
||||
//
|
||||
// Note that FIFO has slightly higher overhead compared to LIFO,
|
||||
// but it helps closing idle connections faster reducing the pool size.
|
||||
// default: false
|
||||
PoolFIFO bool
|
||||
|
||||
// PoolSize is the base number of socket connections.
|
||||
@@ -162,6 +175,10 @@ type Options struct {
|
||||
// default: 10 * runtime.GOMAXPROCS(0)
|
||||
PoolSize int
|
||||
|
||||
// MaxConcurrentDials is the maximum number of concurrent connection creation goroutines.
|
||||
// If <= 0, defaults to PoolSize. If > PoolSize, it will be capped at PoolSize.
|
||||
MaxConcurrentDials int
|
||||
|
||||
// PoolTimeout is the amount of time client waits for connection if all connections
|
||||
// are busy before returning an error.
|
||||
//
|
||||
@@ -232,10 +249,24 @@ type Options struct {
|
||||
// When unstable mode is enabled, the client will use RESP3 protocol and only be able to use RawResult
|
||||
UnstableResp3 bool
|
||||
|
||||
// Push notifications are always enabled for RESP3 connections (Protocol: 3)
|
||||
// and are not available for RESP2 connections. No configuration option is needed.
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications are in "auto" mode and will be enabled if the server supports it.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *Options) init() {
|
||||
@@ -255,12 +286,23 @@ func (opt *Options) init() {
|
||||
if opt.DialTimeout == 0 {
|
||||
opt.DialTimeout = 5 * time.Second
|
||||
}
|
||||
if opt.DialerRetries == 0 {
|
||||
opt.DialerRetries = 5
|
||||
}
|
||||
if opt.DialerRetryTimeout == 0 {
|
||||
opt.DialerRetryTimeout = 100 * time.Millisecond
|
||||
}
|
||||
if opt.Dialer == nil {
|
||||
opt.Dialer = NewDialer(opt)
|
||||
}
|
||||
if opt.PoolSize == 0 {
|
||||
opt.PoolSize = 10 * runtime.GOMAXPROCS(0)
|
||||
}
|
||||
if opt.MaxConcurrentDials <= 0 {
|
||||
opt.MaxConcurrentDials = opt.PoolSize
|
||||
} else if opt.MaxConcurrentDials > opt.PoolSize {
|
||||
opt.MaxConcurrentDials = opt.PoolSize
|
||||
}
|
||||
if opt.ReadBufferSize == 0 {
|
||||
opt.ReadBufferSize = proto.DefaultBufferSize
|
||||
}
|
||||
@@ -312,13 +354,40 @@ func (opt *Options) init() {
|
||||
case 0:
|
||||
opt.MaxRetryBackoff = 512 * time.Millisecond
|
||||
}
|
||||
|
||||
if opt.FailingTimeoutSeconds == 0 {
|
||||
opt.FailingTimeoutSeconds = 15
|
||||
}
|
||||
|
||||
opt.MaintNotificationsConfig = opt.MaintNotificationsConfig.ApplyDefaultsWithPoolConfig(opt.PoolSize, opt.MaxActiveConns)
|
||||
|
||||
// auto-detect endpoint type if not specified
|
||||
endpointType := opt.MaintNotificationsConfig.EndpointType
|
||||
if endpointType == "" || endpointType == maintnotifications.EndpointTypeAuto {
|
||||
// Auto-detect endpoint type if not specified
|
||||
endpointType = maintnotifications.DetectEndpointType(opt.Addr, opt.TLSConfig != nil)
|
||||
}
|
||||
opt.MaintNotificationsConfig.EndpointType = endpointType
|
||||
}
|
||||
|
||||
func (opt *Options) clone() *Options {
|
||||
clone := *opt
|
||||
|
||||
// Deep clone MaintNotificationsConfig to avoid sharing between clients
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
clone.MaintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &clone
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) {
|
||||
return NewDialer(opt)
|
||||
}
|
||||
|
||||
// NewDialer returns a function that will be used as the default dialer
|
||||
// when none is specified in Options.Dialer.
|
||||
func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) {
|
||||
@@ -565,6 +634,7 @@ func setupConnParams(u *url.URL, o *Options) (*Options, error) {
|
||||
o.MinIdleConns = q.int("min_idle_conns")
|
||||
o.MaxIdleConns = q.int("max_idle_conns")
|
||||
o.MaxActiveConns = q.int("max_active_conns")
|
||||
o.MaxConcurrentDials = q.int("max_concurrent_dials")
|
||||
if q.has("conn_max_idle_time") {
|
||||
o.ConnMaxIdleTime = q.duration("conn_max_idle_time")
|
||||
} else {
|
||||
@@ -604,21 +674,86 @@ func getUserPassword(u *url.URL) (string, string) {
|
||||
func newConnPool(
|
||||
opt *Options,
|
||||
dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) *pool.ConnPool {
|
||||
) (*pool.ConnPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewConnPool(&pool.Options{
|
||||
Dialer: func(ctx context.Context) (net.Conn, error) {
|
||||
return dialer(ctx, opt.Network, opt.Addr)
|
||||
},
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: opt.PoolSize,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
MinIdleConns: opt.MinIdleConns,
|
||||
MaxIdleConns: opt.MaxIdleConns,
|
||||
MaxActiveConns: opt.MaxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
})
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
MaxConcurrentDials: opt.MaxConcurrentDials,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: opt.ReadBufferSize,
|
||||
WriteBufferSize: opt.WriteBufferSize,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}), nil
|
||||
}
|
||||
|
||||
func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error),
|
||||
) (*pool.PubSubPool, error) {
|
||||
poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pool.NewPubSubPool(&pool.Options{
|
||||
PoolFIFO: opt.PoolFIFO,
|
||||
PoolSize: poolSize,
|
||||
MaxConcurrentDials: opt.MaxConcurrentDials,
|
||||
PoolTimeout: opt.PoolTimeout,
|
||||
DialTimeout: opt.DialTimeout,
|
||||
DialerRetries: opt.DialerRetries,
|
||||
DialerRetryTimeout: opt.DialerRetryTimeout,
|
||||
MinIdleConns: minIdleConns,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxActiveConns: maxActiveConns,
|
||||
ConnMaxIdleTime: opt.ConnMaxIdleTime,
|
||||
ConnMaxLifetime: opt.ConnMaxLifetime,
|
||||
ReadBufferSize: 32 * 1024,
|
||||
WriteBufferSize: 32 * 1024,
|
||||
PushNotificationsEnabled: opt.Protocol == 3,
|
||||
}, dialer), nil
|
||||
}
|
||||
|
||||
+67
-12
@@ -20,6 +20,8 @@ import (
|
||||
"github.com/redis/go-redis/v9/internal/pool"
|
||||
"github.com/redis/go-redis/v9/internal/proto"
|
||||
"github.com/redis/go-redis/v9/internal/rand"
|
||||
"github.com/redis/go-redis/v9/maintnotifications"
|
||||
"github.com/redis/go-redis/v9/push"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -38,6 +40,7 @@ type ClusterOptions struct {
|
||||
ClientName string
|
||||
|
||||
// NewClient creates a cluster node client with provided name and options.
|
||||
// If NewClient is set by the user, the user is responsible for handling maintnotifications upgrades and push notifications.
|
||||
NewClient func(opt *Options) *Client
|
||||
|
||||
// The maximum number of retries before giving up. Command is retried
|
||||
@@ -74,6 +77,10 @@ type ClusterOptions struct {
|
||||
CredentialsProviderContext func(ctx context.Context) (username string, password string, err error)
|
||||
StreamingCredentialsProvider auth.StreamingCredentialsProvider
|
||||
|
||||
// MaxRetries is the maximum number of retries before giving up.
|
||||
// For ClusterClient, retries are disabled by default (set to -1),
|
||||
// because the cluster client handles all kinds of retries internally.
|
||||
// This is intentional and differs from the standalone Options default.
|
||||
MaxRetries int
|
||||
MinRetryBackoff time.Duration
|
||||
MaxRetryBackoff time.Duration
|
||||
@@ -125,10 +132,22 @@ type ClusterOptions struct {
|
||||
// UnstableResp3 enables Unstable mode for Redis Search module with RESP3.
|
||||
UnstableResp3 bool
|
||||
|
||||
// PushNotificationProcessor is the processor for handling push notifications.
|
||||
// If nil, a default processor will be created for RESP3 connections.
|
||||
PushNotificationProcessor push.NotificationProcessor
|
||||
|
||||
// FailingTimeoutSeconds is the timeout in seconds for marking a cluster node as failing.
|
||||
// When a node is marked as failing, it will be avoided for this duration.
|
||||
// Default is 15 seconds.
|
||||
FailingTimeoutSeconds int
|
||||
|
||||
// MaintNotificationsConfig provides custom configuration for maintnotifications upgrades.
|
||||
// When MaintNotificationsConfig.Mode is not "disabled", the client will handle
|
||||
// cluster upgrade notifications gracefully and manage connection/pool state
|
||||
// transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications.
|
||||
// If nil, maintnotifications upgrades are in "auto" mode and will be enabled if the server supports it.
|
||||
// The ClusterClient does not directly work with maintnotifications, it is up to the clients in the Nodes map to work with maintnotifications.
|
||||
MaintNotificationsConfig *maintnotifications.Config
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) init() {
|
||||
@@ -319,6 +338,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er
|
||||
}
|
||||
|
||||
func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// Clone MaintNotificationsConfig to avoid sharing between cluster node clients
|
||||
var maintNotificationsConfig *maintnotifications.Config
|
||||
if opt.MaintNotificationsConfig != nil {
|
||||
configClone := *opt.MaintNotificationsConfig
|
||||
maintNotificationsConfig = &configClone
|
||||
}
|
||||
|
||||
return &Options{
|
||||
ClientName: opt.ClientName,
|
||||
Dialer: opt.Dialer,
|
||||
@@ -360,8 +386,10 @@ func (opt *ClusterOptions) clientOptions() *Options {
|
||||
// much use for ClusterSlots config). This means we cannot execute the
|
||||
// READONLY command against that node -- setting readOnly to false in such
|
||||
// situations in the options below will prevent that from happening.
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
readOnly: opt.ReadOnly && opt.ClusterSlots == nil,
|
||||
UnstableResp3: opt.UnstableResp3,
|
||||
MaintNotificationsConfig: maintNotificationsConfig,
|
||||
PushNotificationProcessor: opt.PushNotificationProcessor,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1664,7 +1692,7 @@ func (c *ClusterClient) processTxPipelineNode(
|
||||
}
|
||||
|
||||
func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
ctx context.Context, _ *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
|
||||
) error {
|
||||
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
|
||||
return writeCmds(wr, cmds)
|
||||
@@ -1682,7 +1710,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
trimmedCmds := cmds[1 : len(cmds)-1]
|
||||
|
||||
if err := c.txPipelineReadQueued(
|
||||
ctx, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
ctx, node, cn, rd, statusCmd, trimmedCmds, failedCmds,
|
||||
); err != nil {
|
||||
setCmdsErr(cmds, err)
|
||||
|
||||
@@ -1694,23 +1722,37 @@ func (c *ClusterClient) processTxPipelineNodeConn(
|
||||
return err
|
||||
}
|
||||
|
||||
return pipelineReadCmds(rd, trimmedCmds)
|
||||
return node.Client.pipelineReadCmds(ctx, cn, rd, trimmedCmds)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClusterClient) txPipelineReadQueued(
|
||||
ctx context.Context,
|
||||
node *clusterNode,
|
||||
cn *pool.Conn,
|
||||
rd *proto.Reader,
|
||||
statusCmd *StatusCmd,
|
||||
cmds []Cmder,
|
||||
failedCmds *cmdsMap,
|
||||
) error {
|
||||
// Parse queued replies.
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
if err := statusCmd.readReply(rd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
err := statusCmd.readReply(rd)
|
||||
if err != nil {
|
||||
if c.checkMovedErr(ctx, cmd, err, failedCmds) {
|
||||
@@ -1724,6 +1766,12 @@ func (c *ClusterClient) txPipelineReadQueued(
|
||||
}
|
||||
}
|
||||
|
||||
// To be sure there are no buffered push notifications, we process them before reading the reply
|
||||
if err := node.Client.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil {
|
||||
// Log the error but don't fail the command execution
|
||||
// Push notification processing errors shouldn't break normal Redis operations
|
||||
internal.Logger.Printf(ctx, "push: error processing pending notifications before reading reply: %v", err)
|
||||
}
|
||||
// Parse number of replies.
|
||||
line, err := rd.ReadLine()
|
||||
if err != nil {
|
||||
@@ -1829,12 +1877,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s
|
||||
return err
|
||||
}
|
||||
|
||||
// maintenance notifications won't work here for now
|
||||
func (c *ClusterClient) pubSub() *PubSub {
|
||||
var node *clusterNode
|
||||
pubsub := &PubSub{
|
||||
opt: c.opt.clientOptions(),
|
||||
|
||||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) {
|
||||
newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) {
|
||||
if node != nil {
|
||||
panic("node != nil")
|
||||
}
|
||||
@@ -1868,18 +1916,25 @@ func (c *ClusterClient) pubSub() *PubSub {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cn, err := node.Client.newConn(context.TODO())
|
||||
cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels)
|
||||
if err != nil {
|
||||
node = nil
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// will return nil if already initialized
|
||||
err = node.Client.initConn(ctx, cn)
|
||||
if err != nil {
|
||||
_ = cn.Close()
|
||||
node = nil
|
||||
return nil, err
|
||||
}
|
||||
node.Client.pubSubPool.TrackConn(cn)
|
||||
return cn, nil
|
||||
},
|
||||
closeConn: func(cn *pool.Conn) error {
|
||||
err := node.Client.connPool.CloseConn(cn)
|
||||
// Untrack connection from PubSubPool
|
||||
node.Client.pubSubPool.UntrackConn(cn)
|
||||
err := cn.Close()
|
||||
node = nil
|
||||
return err
|
||||
},
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user