mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
72e2b682bb
The v1.0.14 fix replaced one contended sync.RWMutex (RefreshCoordinator.
refreshMutex) with sync.Map. Production showed the same death-spiral
signature recurring ~2 hours later — same shape, different mutex:
65 goroutines stuck on a sync.(*RWMutex).Lock at one address, pod
pinned at 1000m CPU, identical Yaegi runCfg/reflect.Value.Call stack
pattern. The mutex was RefreshCoordinator.attemptsMutex.
Generalising: under Yaegi (interpreted Go for traefik plugins), any
per-request global mutex acquisition is a latent serialization point.
reflect.Value.Call dispatch on a held lock turns a microsecond
critical section into a multi-millisecond one, and on a GOMAXPROCS=1
pod the queue is unbounded.
This commit removes every per-request global mutex on the hot path:
1. RefreshCoordinator.attemptsMutex (sync.RWMutex)
sessionRefreshAttempts: map -> sync.Map.
refreshAttemptTracker: all fields atomic (int32, int64 UnixNano,
cooldownEndNano == 0 as the not-in-cooldown sentinel, replacing
the inCooldown bool).
isInCooldown / recordRefreshAttempt / recordRefreshSuccess /
recordRefreshFailure all become lock-free. Cooldown entry uses
CompareAndSwapInt64 so only one goroutine logs the transition.
2. RefreshCircuitBreaker.mutex (sync.RWMutex)
lastFailureTime / lastSuccessTime -> atomic.Int64 UnixNano.
state and failures already atomic.
AllowRequest / RecordSuccess / RecordFailure now pure atomic ops.
3. TraefikOidc.firstRequestMutex (sync.Mutex)
firstRequestReceived bool -> firstRequestStarted int32.
metadataRefreshStarted bool -> metadataRefreshStartedAtomic int32.
ServeHTTP bootstrap path uses CompareAndSwapInt32 — fires once,
zero steady-state cost. Previously the mutex was acquired on
every non-health request forever.
4. TraefikOidc.metadataRetryMutex (sync.Mutex)
lastMetadataRetryTime time.Time -> lastMetadataRetryNano int64.
The 30-second retry throttle is now a CAS on lastMetadataRetryNano.
cleanupStaleEntries iterates via sync.Map.Range; eviction is a
CompareAndDelete by pointer identity so a tracker freshly re-used by
a concurrent caller is not lost.
Empirical evidence (3 specialist-agent analysis of the v1.0.14 spike,
profiles in /tmp/traefik-spike-1779511683/):
* mutex profile: 97% delay in sync.(*Mutex).Unlock via
HTTPHandlerSwitcher -> accesslog -> metrics -> backoff.RetryNotify
* 65 stuck goroutines at one RWMutex address (0x40022eb648),
identical Yaegi CFG pointer, all on rc.attemptsMutex via
recordRefreshAttempt + isInCooldown
* traffic driver: long-lived in-cluster Go-http-client doing
~5.4 req/s POST embeddings via OIDC cookie session → same
sessionID → contention all funnels to one tracker entry
Yaegi support for sync/atomic confirmed at
github.com/traefik/yaegi@v0.16.1/stdlib/go1_22_sync_atomic.go:
AddInt32/Int64, LoadInt32/Int64, StoreInt32/Int64,
CompareAndSwapInt32/Int64 all exposed via reflect.ValueOf. Yaegi
dispatches each call through reflect.Value.Call to the COMPILED
atomic.* function, which executes a single hardware CAS/LOCK-XADD
instruction. Each atomic op still pays Yaegi dispatch cost but
cannot block — no queueing, no death spiral.
Trade-off acknowledged: v1.0.15 issues ~6-8 atomic/sync.Map ops per
leader-path request vs the 4 mutex ops of v1.0.14. Under low
contention this is a modest CPU bump. Under high contention it's
an unbounded → bounded transformation. Net win.
All tests pass with -race; golangci-lint clean.
813 lines
28 KiB
Go
813 lines
28 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Helper builders
|
|
// =============================================================================
|
|
|
|
// makeBearerJWT constructs a JWT with explicit header + claims for tests.
|
|
// Signature is opaque (b64("signature")) — bearer tests don't exercise the
|
|
// real cryptographic verifier; verification is bypassed via tokenCache pre-
|
|
// seed so the bearer pipeline under test sees a "verified" token.
|
|
func makeBearerJWT(t *testing.T, header, claims map[string]interface{}) string {
|
|
t.Helper()
|
|
hb, err := json.Marshal(header)
|
|
if err != nil {
|
|
t.Fatalf("marshal header: %v", err)
|
|
}
|
|
cb, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Fatalf("marshal claims: %v", err)
|
|
}
|
|
return fmt.Sprintf("%s.%s.%s",
|
|
base64.RawURLEncoding.EncodeToString(hb),
|
|
base64.RawURLEncoding.EncodeToString(cb),
|
|
base64.RawURLEncoding.EncodeToString([]byte("signature")),
|
|
)
|
|
}
|
|
|
|
// defaultBearerHeader produces the standard RS256+kid header used in tests.
|
|
func defaultBearerHeader() map[string]interface{} {
|
|
return map[string]interface{}{"alg": "RS256", "kid": "test-kid"}
|
|
}
|
|
|
|
// defaultBearerClaims produces a baseline access-token claim set. Tests
|
|
// shallow-clone and override fields as needed.
|
|
func defaultBearerClaims() map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"iss": "https://issuer.example.com",
|
|
"aud": "https://api.example.com",
|
|
"sub": "service-account-1",
|
|
"scope": "api:read api:write",
|
|
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
|
"iat": float64(time.Now().Unix()),
|
|
}
|
|
}
|
|
|
|
// makeBearerOIDC constructs a TraefikOidc wired for bearer auth tests. The
|
|
// real verifyTokenWithOpts pipeline is short-circuited via tokenCache pre-
|
|
// seed: any token Set into t.tokenCache returns nil from VerifyToken,
|
|
// letting tests exercise the post-verify bearer logic (classifier, identifier,
|
|
// throttle, header forwarding) without standing up JWKs.
|
|
func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
|
t.Helper()
|
|
sm := createTestSessionManager(t)
|
|
oidc := &TraefikOidc{
|
|
next: next,
|
|
logger: NewLogger("error"),
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: sm,
|
|
firstRequestStarted: 1,
|
|
metadataRefreshStartedAtomic: 1,
|
|
issuerURL: "https://issuer.example.com",
|
|
audience: "https://api.example.com",
|
|
clientID: "https://api.example.com",
|
|
tokenCache: NewTokenCache(),
|
|
excludedURLs: map[string]struct{}{"/favicon.ico": {}},
|
|
allowedRolesAndGroups: map[string]struct{}{},
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 1000),
|
|
ctx: context.Background(),
|
|
enableBearerAuth: true,
|
|
stripAuthorizationHeader: true,
|
|
bearerEmitWWWAuthenticate: true,
|
|
bearerOverridesCookie: false,
|
|
bearerIdentifierClaim: "sub",
|
|
maxIdentifierLength: 256,
|
|
maxTokenAge: 24 * time.Hour,
|
|
bearerFailureThreshold: 20,
|
|
bearerFailureWindow: 60 * time.Second,
|
|
bearerFailurePenalty: 60 * time.Second,
|
|
bearerFailureTracker: newBearerFailureTracker(20, 60*time.Second, 60*time.Second),
|
|
}
|
|
oidc.extractClaimsFunc = extractClaims
|
|
close(oidc.initComplete)
|
|
return oidc
|
|
}
|
|
|
|
// seedVerified pre-populates the tokenCache so verifyTokenWithOpts short-
|
|
// circuits to nil for the given token. Mirrors the production fast-return
|
|
// path at token_manager.go for previously-verified tokens.
|
|
func seedVerified(t *testing.T, oidc *TraefikOidc, token string, claims map[string]interface{}) {
|
|
t.Helper()
|
|
if oidc.tokenCache == nil {
|
|
oidc.tokenCache = NewTokenCache()
|
|
}
|
|
oidc.tokenCache.Set(token, claims, time.Hour)
|
|
}
|
|
|
|
// =============================================================================
|
|
// Unit tests — small helpers
|
|
// =============================================================================
|
|
|
|
func TestDetectBearerToken(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
name string
|
|
header string
|
|
want string
|
|
ok bool
|
|
}{
|
|
{"missing header", "", "", false},
|
|
{"basic auth", "Basic abc", "", false},
|
|
{"bearer with token", "Bearer abc.def.ghi", "abc.def.ghi", true},
|
|
{"lowercase bearer", "bearer abc.def.ghi", "abc.def.ghi", true},
|
|
{"mixed case", "BeArEr abc.def.ghi", "abc.def.ghi", true},
|
|
{"empty token after prefix", "Bearer ", "", false},
|
|
{"bearer no space", "Bearerabc", "", false},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
if tc.header != "" {
|
|
req.Header.Set("Authorization", tc.header)
|
|
}
|
|
got, ok := detectBearerToken(req)
|
|
if ok != tc.ok || got != tc.want {
|
|
t.Fatalf("got=(%q, %v), want=(%q, %v)", got, ok, tc.want, tc.ok)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseBearerJOSEHeader(t *testing.T) {
|
|
t.Parallel()
|
|
mk := func(t *testing.T, h map[string]interface{}) string {
|
|
return makeBearerJWT(t, h, map[string]interface{}{"sub": "x"})
|
|
}
|
|
cases := []struct {
|
|
header map[string]interface{}
|
|
name string
|
|
wantErr bool
|
|
}{
|
|
{name: "valid RS256", header: map[string]interface{}{"alg": "RS256", "kid": "k1"}, wantErr: false},
|
|
{name: "valid ES512", header: map[string]interface{}{"alg": "ES512", "kid": "abc-_.="}, wantErr: false},
|
|
{name: "alg=none rejected", header: map[string]interface{}{"alg": "none", "kid": "k1"}, wantErr: true},
|
|
{name: "alg=HS256 rejected", header: map[string]interface{}{"alg": "HS256", "kid": "k1"}, wantErr: true},
|
|
{name: "missing kid", header: map[string]interface{}{"alg": "RS256"}, wantErr: true},
|
|
{name: "kid too long", header: map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}, wantErr: true},
|
|
{name: "kid bad chars", header: map[string]interface{}{"alg": "RS256", "kid": "evil/../etc/passwd"}, wantErr: true},
|
|
{name: "kid with space", header: map[string]interface{}{"alg": "RS256", "kid": "key one"}, wantErr: true},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
token := mk(t, tc.header)
|
|
err := parseBearerJOSEHeader(token)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSanitiseBearerIdentifier(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
name string
|
|
in string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{"normal sub", "service-account-1", "service-account-1", false},
|
|
{"email-like", "alice@example.com", "alice@example.com", false},
|
|
{"trim whitespace", " abc ", "abc", false},
|
|
{"empty", "", "", true},
|
|
{"only whitespace", " ", "", true},
|
|
{"control char (newline)", "alice\nbob", "", true},
|
|
{"control char (CR)", "alice\rbob", "", true},
|
|
{"control char (NUL)", "alice\x00bob", "", true},
|
|
{"bidi override", "alice\u202ebob", "", true},
|
|
{"bidi isolate", "alice\u2066bob", "", true},
|
|
{"comma delimiter", "alice,bob", "", true},
|
|
{"semicolon delimiter", "alice;bob", "", true},
|
|
{"equals delimiter", "alice=bob", "", true},
|
|
{"over length", strings.Repeat("a", 257), "", true},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := sanitizeBearerIdentifier(tc.in, 256)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
|
}
|
|
if !tc.wantErr && got != tc.want {
|
|
t.Fatalf("got=%q want=%q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResolveBearerIdentifier(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
claims map[string]interface{}
|
|
name string
|
|
claim string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{name: "default sub", claims: map[string]interface{}{"sub": "abc"}, claim: "", want: "abc"},
|
|
{name: "explicit sub", claims: map[string]interface{}{"sub": "abc"}, claim: "sub", want: "abc"},
|
|
{name: "custom client_id claim", claims: map[string]interface{}{"client_id": "svc"}, claim: "client_id", want: "svc"},
|
|
{name: "missing claim", claims: map[string]interface{}{"other": "x"}, claim: "sub", wantErr: true},
|
|
{name: "non-string claim", claims: map[string]interface{}{"sub": 123}, claim: "sub", wantErr: true},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
got, err := resolveBearerIdentifier(tc.claims, tc.claim)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
|
}
|
|
if !tc.wantErr && got != tc.want {
|
|
t.Fatalf("got=%q want=%q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEnforceMultiAudienceAzp(t *testing.T) {
|
|
t.Parallel()
|
|
const cid = "https://api.example.com"
|
|
cases := []struct {
|
|
claims map[string]interface{}
|
|
name string
|
|
wantErr bool
|
|
}{
|
|
{name: "single string aud", claims: map[string]interface{}{"aud": "x"}, wantErr: false},
|
|
{name: "single element array", claims: map[string]interface{}{"aud": []interface{}{"x"}}, wantErr: false},
|
|
{name: "multi-aud with matching azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": cid}, wantErr: false},
|
|
{name: "multi-aud missing azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}}, wantErr: true},
|
|
{name: "multi-aud empty azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": ""}, wantErr: true},
|
|
{name: "multi-aud wrong azp", claims: map[string]interface{}{"aud": []interface{}{"a", "b"}, "azp": "other"}, wantErr: true},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := enforceMultiAudienceAzp(tc.claims, cid)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestEnforceIatAge(t *testing.T) {
|
|
t.Parallel()
|
|
now := time.Now()
|
|
cases := []struct {
|
|
name string
|
|
iat float64
|
|
maxAge time.Duration
|
|
wantErr bool
|
|
}{
|
|
{name: "fresh", iat: float64(now.Unix()), maxAge: time.Hour, wantErr: false},
|
|
{name: "23h59m old, max 24h", iat: float64(now.Add(-23*time.Hour - 59*time.Minute).Unix()), maxAge: 24 * time.Hour, wantErr: false},
|
|
{name: "25h old, max 24h", iat: float64(now.Add(-25 * time.Hour).Unix()), maxAge: 24 * time.Hour, wantErr: true},
|
|
{name: "1970 token", iat: float64(0), maxAge: 24 * time.Hour, wantErr: true},
|
|
{name: "maxAge disabled (0)", iat: float64(0), maxAge: 0, wantErr: false},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := enforceIatAge(map[string]interface{}{"iat": tc.iat}, tc.maxAge)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Fatalf("err=%v wantErr=%v", err, tc.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBearerFailureTracker(t *testing.T) {
|
|
t.Parallel()
|
|
tr := newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
|
|
const ip = "10.0.0.1"
|
|
// Below threshold: not blocked.
|
|
for i := 0; i < 2; i++ {
|
|
tr.recordFailure(ip)
|
|
if b, _ := tr.blocked(ip); b {
|
|
t.Fatalf("blocked too early after %d failures", i+1)
|
|
}
|
|
}
|
|
// Threshold reached: blocked.
|
|
tr.recordFailure(ip)
|
|
if b, retry := tr.blocked(ip); !b || retry <= 0 {
|
|
t.Fatalf("expected blocked with positive retry, got=%v retry=%v", b, retry)
|
|
}
|
|
// Success clears the counter.
|
|
tr.recordSuccess(ip)
|
|
if b, _ := tr.blocked(ip); b {
|
|
t.Fatalf("expected unblocked after success")
|
|
}
|
|
// Other IPs are unaffected.
|
|
if b, _ := tr.blocked("10.0.0.2"); b {
|
|
t.Fatalf("unrelated IP should not be blocked")
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Integration tests — full ServeHTTP via the bearer pipeline
|
|
// =============================================================================
|
|
|
|
func TestServeHTTP_Bearer_HappyPath(t *testing.T) {
|
|
t.Parallel()
|
|
var nextCalled atomic.Bool
|
|
var capturedHeaders http.Header
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextCalled.Store(true)
|
|
capturedHeaders = r.Header.Clone()
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if !nextCalled.Load() {
|
|
t.Fatalf("expected next handler to run; got status=%d body=%q", rw.Code, rw.Body.String())
|
|
}
|
|
if rw.Code != http.StatusOK {
|
|
t.Fatalf("status=%d, want 200", rw.Code)
|
|
}
|
|
if got := capturedHeaders.Get("X-Forwarded-User"); got != "service-account-1" {
|
|
t.Fatalf("X-Forwarded-User=%q, want service-account-1", got)
|
|
}
|
|
if got := capturedHeaders.Get("Authorization"); got != "" {
|
|
t.Fatalf("Authorization should be stripped, got=%q", got)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_StripAuthDisabled(t *testing.T) {
|
|
t.Parallel()
|
|
var capturedAuth string
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedAuth = r.Header.Get("Authorization")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.stripAuthorizationHeader = false
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if !strings.HasPrefix(capturedAuth, "Bearer ") {
|
|
t.Fatalf("expected Authorization to be forwarded, got=%q", capturedAuth)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_RejectIDToken(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for ID token rejection")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
// ID-token shape: nonce claim present and no scope. detectTokenType
|
|
// returns true.
|
|
claims := map[string]interface{}{
|
|
"iss": "https://issuer.example.com",
|
|
"aud": "https://api.example.com",
|
|
"sub": "user-1",
|
|
"nonce": "n-0S6_WzA2Mj",
|
|
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
|
"iat": float64(time.Now().Unix()),
|
|
}
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
if wa := rw.Header().Get("WWW-Authenticate"); !strings.Contains(wa, `error="invalid_token"`) {
|
|
t.Fatalf("expected WWW-Authenticate invalid_token, got=%q", wa)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_AlgNoneRejected(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for alg=none")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
header := map[string]interface{}{"alg": "none", "kid": "k1"}
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, header, claims)
|
|
// Even if we pre-seeded the cache, the early alg pin runs FIRST.
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_KidTooLongRejected(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for oversized kid")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
header := map[string]interface{}{"alg": "RS256", "kid": strings.Repeat("a", bearerKidMaxLen+1)}
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, header, claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_MultiAudRequiresAzp(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for multi-aud without azp")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
|
|
delete(claims, "azp")
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_MultiAudWithAzpAccepted(t *testing.T) {
|
|
t.Parallel()
|
|
var nextCalled atomic.Bool
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextCalled.Store(true)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
claims["aud"] = []interface{}{"https://api.example.com", "https://other.example.com"}
|
|
claims["azp"] = oidc.clientID
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusOK || !nextCalled.Load() {
|
|
t.Fatalf("expected 200 + next called; got status=%d called=%v", rw.Code, nextCalled.Load())
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_IatTooOldRejected(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for old iat")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
claims["iat"] = float64(time.Now().Add(-25 * time.Hour).Unix())
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_IdentifierWithBidiRejected(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for bidi identifier")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
claims["sub"] = "alice\u202ebob"
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_ReplayRegression(t *testing.T) {
|
|
t.Parallel()
|
|
var successCount atomic.Int32
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
successCount.Add(1)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
claims["jti"] = "regression-jti"
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
if rw.Code != http.StatusOK {
|
|
t.Fatalf("iteration %d: status=%d, want 200", i, rw.Code)
|
|
}
|
|
}
|
|
if successCount.Load() != 100 {
|
|
t.Fatalf("successCount=%d, want 100", successCount.Load())
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_ThrottleTrips429(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run during throttle test")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.bearerFailureTracker = newBearerFailureTracker(3, 60*time.Second, 60*time.Second)
|
|
|
|
// Send malformed bearers from the same RemoteAddr until threshold trips.
|
|
send := func() *httptest.ResponseRecorder {
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.RemoteAddr = "10.0.0.5:1234"
|
|
req.Header.Set("Authorization", "Bearer not-a-jwt")
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
return rw
|
|
}
|
|
for i := 0; i < 3; i++ {
|
|
rw := send()
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("pre-throttle iteration %d: status=%d, want 401", i, rw.Code)
|
|
}
|
|
}
|
|
// 4th request: throttled.
|
|
rw := send()
|
|
if rw.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected 429 after threshold, got %d", rw.Code)
|
|
}
|
|
if ra := rw.Header().Get("Retry-After"); ra == "" {
|
|
t.Fatalf("expected Retry-After header on 429")
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_ExcludedURLStripsAuth(t *testing.T) {
|
|
t.Parallel()
|
|
var capturedAuth string
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedAuth = r.Header.Get("Authorization")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.excludedURLs = map[string]struct{}{"/favicon.ico": {}}
|
|
|
|
req := httptest.NewRequest("GET", "/favicon.ico", nil)
|
|
req.Header.Set("Authorization", "Bearer abc.def.ghi")
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if rw.Code != http.StatusOK {
|
|
t.Fatalf("excluded path should pass; got %d", rw.Code)
|
|
}
|
|
if capturedAuth != "" {
|
|
t.Fatalf("Authorization must be stripped on excluded paths, got=%q", capturedAuth)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_RolesGate(t *testing.T) {
|
|
t.Parallel()
|
|
cases := []struct {
|
|
name string
|
|
rolesClaim []interface{}
|
|
want int
|
|
}{
|
|
{name: "matching role", rolesClaim: []interface{}{"admin"}, want: http.StatusOK},
|
|
{name: "no matching role", rolesClaim: []interface{}{"viewer"}, want: http.StatusForbidden},
|
|
}
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.allowedRolesAndGroups = map[string]struct{}{"admin": {}}
|
|
oidc.roleClaimName = "roles"
|
|
claims := defaultBearerClaims()
|
|
claims["roles"] = tc.rolesClaim
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
if rw.Code != tc.want {
|
|
t.Fatalf("status=%d, want %d", rw.Code, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_CookieWinsByDefault(t *testing.T) {
|
|
t.Parallel()
|
|
// Both cookie and bearer present: cookie path runs (which will redirect
|
|
// to /authorize since the cookie is empty/unauthenticated).
|
|
var nextCalled atomic.Bool
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextCalled.Store(true)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
prefix := oidc.sessionManager.GetCookiePrefix()
|
|
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
// Cookie path consumed the request; bearer was ignored. Since the
|
|
// cookie is empty, the cookie path will either 302 to /authorize or
|
|
// return 401 — in either case, next must NOT be called.
|
|
if nextCalled.Load() {
|
|
t.Fatalf("next must not be called when bearer is ignored due to cookie precedence")
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_BearerOverridesCookie(t *testing.T) {
|
|
t.Parallel()
|
|
var nextCalled atomic.Bool
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextCalled.Store(true)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.bearerOverridesCookie = true
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
prefix := oidc.sessionManager.GetCookiePrefix()
|
|
req.AddCookie(&http.Cookie{Name: prefix + "main", Value: "irrelevant"})
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
|
|
if !nextCalled.Load() || rw.Code != http.StatusOK {
|
|
t.Fatalf("expected bearer to win with override; status=%d called=%v", rw.Code, nextCalled.Load())
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_OversizedToken(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for oversized token")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
huge := strings.Repeat("a", AccessTokenConfig.MaxLength+1)
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+huge)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_MalformedJWT(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Fatalf("next must not run for malformed JWT")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer not.jwt") // 1 dot
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
if rw.Code != http.StatusUnauthorized {
|
|
t.Fatalf("status=%d, want 401", rw.Code)
|
|
}
|
|
}
|
|
|
|
func TestServeHTTP_Bearer_FeatureOffPassesThrough(t *testing.T) {
|
|
t.Parallel()
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Should not be reached: cookie path runs and (with no session)
|
|
// will redirect or 401. We assert no panic / next not called.
|
|
t.Fatalf("next must not run when bearer is off and no valid session exists")
|
|
})
|
|
oidc := makeBearerOIDC(t, next)
|
|
oidc.enableBearerAuth = false
|
|
claims := defaultBearerClaims()
|
|
token := makeBearerJWT(t, defaultBearerHeader(), claims)
|
|
seedVerified(t, oidc, token, claims)
|
|
req := httptest.NewRequest("GET", "/api/work", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rw := httptest.NewRecorder()
|
|
oidc.ServeHTTP(rw, req)
|
|
// Expect non-200: either 302 to /authorize or 401. The point is the
|
|
// bearer pipeline didn't run.
|
|
if rw.Code == http.StatusOK {
|
|
t.Fatalf("expected non-200 when bearer is off; got %d", rw.Code)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Startup validation tests
|
|
// =============================================================================
|
|
|
|
func TestStartupValidation_BearerRequiresAudience(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := CreateConfig()
|
|
cfg.ProviderURL = "https://issuer.example.com"
|
|
cfg.ClientID = "id"
|
|
cfg.ClientSecret = "secret"
|
|
cfg.CallbackURL = "/oauth/callback"
|
|
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
cfg.EnableBearerAuth = true
|
|
cfg.Audience = ""
|
|
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
|
|
if err == nil || !strings.Contains(err.Error(), "requires Audience") {
|
|
t.Fatalf("expected audience-required error, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStartupValidation_BearerRejectsEmailIdentifier(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := CreateConfig()
|
|
cfg.ProviderURL = "https://issuer.example.com"
|
|
cfg.ClientID = "id"
|
|
cfg.ClientSecret = "secret"
|
|
cfg.CallbackURL = "/oauth/callback"
|
|
cfg.SessionEncryptionKey = "0123456789abcdef0123456789abcdef0123456789abcdef"
|
|
cfg.EnableBearerAuth = true
|
|
cfg.Audience = "https://api.example.com"
|
|
cfg.BearerIdentifierClaim = "email"
|
|
_, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), cfg, "bearer-test")
|
|
if err == nil || !strings.Contains(err.Error(), "bearerIdentifierClaim=\"email\"") {
|
|
t.Fatalf("expected email-identifier rejection, got %v", err)
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Principal invariants
|
|
// =============================================================================
|
|
|
|
func TestBuildPrincipalFromSession_NoIdentifier(t *testing.T) {
|
|
t.Parallel()
|
|
oidc := &TraefikOidc{logger: NewLogger("error")}
|
|
if p := oidc.buildPrincipalFromSession(nil); p != nil {
|
|
t.Fatalf("nil session must produce nil principal")
|
|
}
|
|
}
|