mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
546ceb949c
* fix(security): encrypt session cookies + fail closed on invalid config
Batch 1 of security audit remediation (ranks 1, 2, 6).
- session.go: derive independent HMAC + AES-256 keys via stdlib HKDF-SHA256
and build the gorilla cookie store with both, so session cookies are now
encrypted, not merely signed. The single-key store previously left OIDC
access/refresh/ID tokens recoverable from raw cookie bytes. Cookie format
changes, so existing sessions are invalidated on deploy (one-time re-login).
- main.go: call config.Validate() at construction and error out on failure,
instead of silently substituting a public hardcoded encryption key for
empty/short keys (which allowed session forgery). The yaegi analyzer
passes via .traefik.yml testData.
- settings.go: isValidSecureURL permits plaintext HTTP for loopback hosts
only (RFC 8252); remote providers must still use HTTPS.
- tests: complete configs that did not satisfy Validate(); add regression
tests in security_audit_fixes_test.go.
Configs below documented minimums (rateLimit < 10, key < 32 chars) are now
rejected at startup (fail closed).
* fix(security): validate discovered OIDC endpoints + pin introspection host
Batch 2 of security audit remediation (ranks 3, 4).
- url_helpers.go: add validateDiscoveredEndpoint, an SSRF screen for endpoints
taken from the provider discovery document (jwks_uri, token, authorization,
revocation, end_session, introspection, registration). Blocks link-local
(cloud metadata 169.254.169.254), multicast, unspecified and private
addresses (unless allowPrivateIPAddresses); blocks loopback unless the
configured providerURL is itself loopback (dev/test). Cross-domain JWKS
hosts (e.g. Google) stay allowed. Add sameHost helper.
- main.go: updateMetadataEndpoints screens every discovered endpoint and
blanks any that fail (fail closed downstream). The introspection endpoint
carries the client secret via HTTP Basic, so it is additionally pinned to
the providerURL host to stop a poisoned discovery document exfiltrating the
secret to an attacker-controlled host.
- tests: regression tests for the SSRF guard and the host pin.
* fix(security): close open redirects + anchor excluded-URL matching
Batch 3 of security audit remediation (ranks 5, 14, 15).
- auth_flow.go: run the stored incoming path through normalizeLogoutPath
before using it as the post-login redirect, so //evil.com and /\evil.com
payloads become host-relative (open-redirect, rank 5).
- url_helpers.go: excluded-URL matching is anchored at a natural boundary
(exact, sub-path "/", or file extension "."), so excluding "/public" no
longer also bypasses auth on "/publicsecret"; "/favicon" still matches
"/favicon.ico" (rank 14).
- internal/utils: X-Forwarded-Host is sanitized (first value only; reject
CRLF/whitespace/multi-value) before building redirect URLs (rank 15).
- helpers.go: the logout redirect used when there is no provider end-session
endpoint is host-relative, never an absolute URL derived from the
client-controllable request host (logout open-redirect, rank 15).
- tests: update two logout cases that asserted the old absolute redirect;
add regression tests.
* fix(security): reject unverified Azure tokens; fix transport TLS reuse
Batch 4 of security audit remediation (ranks 7, 11).
- token_validation_rs.go: an Azure nonce-bearing access token that cannot be
cryptographically verified no longer returns "authenticated" when there is
no ID token to corroborate it; it refreshes (if possible) or forces
re-authentication instead of failing open (rank 7).
- http_client_pool.go: the at-limit transport-reuse path now takes the write
lock before mutating refCount (fixes a data race) and only reuses a
transport whose TLS settings (CA pool + InsecureSkipVerify) match the
caller's, never one with a different trust store; if none matches it returns
nil so the caller falls back to a verifying default transport (rank 11).
- tests: add a transport-pool TLS-isolation regression test.
* fix(security): stop logging templated header values (token leak)
Batch 5 of security audit remediation (rank 16).
middleware.go: templated downstream headers commonly carry the access token
(e.g. "Authorization: Bearer {{.AccessToken}}"). The debug log line printed
the full header value, leaking credentials into logs. Log the header name and
byte length instead.
* fix(security): cache-key collision, cache-config divergence, fleet cleanup
Batch 6 of security audit remediation (ranks 9, 10, 12).
- token_manager.go: detectTokenType keys its cache on a SHA-256 hash of the
full token instead of the first 32 chars (which are only the base64url JWT
header). Distinct tokens sharing alg+kid no longer collide and get
mis-classified (rank 10).
- cache_manager.go: the process-global cache manager is initialized once and
shared across plugin instances; it now logs a loud warning when a later
instance requests a different explicit Redis backend that is silently
ignored, surfacing the cross-instance state-isolation hazard (rank 9).
- singleton_resources.go / main.go / utilities.go: track a process-global live
instance count; the shared singleton-token-cleanup task is stopped only when
the LAST instance shuts down, so one instance's Close() (e.g. a config reload)
no longer kills cleanup for surviving instances (rank 12).
- tests: update TestDetectTokenTypeCaching for the new key; add regression tests.
* fix(security): bound introspection cache + cookie lifetime to config
Batch 7 of security audit remediation (ranks 8, 13).
- token_introspection.go: when requireTokenIntrospection is enabled, cap the
positive introspection-result cache at 30s (instead of 5m) so a token
revoked at the provider stops passing within ~30s, matching the operator's
near-real-time revocation expectation (rank 8).
- session.go: bind the cookie store's MaxAge to the configured sessionMaxAge,
so the cookie codec's cryptographic timestamp validity is no longer fixed at
gorilla's 30-day default; a stolen cookie is valid only for the configured
session lifetime (rank 13).
- tests: add a cookie-lifetime regression test.
* fix(security): low-severity hardening (cache, DoS caps, PKCE, throttle)
Batch 8 of security audit remediation — low severity
(ranks 24, 25, 27, 29, 31, 36, 37, 41, 45, 46, 49).
- universal_cache.go: updateLocalCache updates an existing key in place instead
of orphaning its LRU element and double-counting currentSize/currentMemory
(rank 36 — the only production-reachable bug in this batch).
- jwk.go / metadata_cache.go / token_introspection.go: bound response bodies
with io.LimitReader (1 MiB) to prevent memory exhaustion from a hostile or
buggy provider (ranks 24, 25).
- jwk.go: skip JWKs not usable for signature verification (use != sig, or
key_ops without "verify") when building the key set (rank 49).
- auth_flow.go: fail closed at the callback when PKCE is enabled but the code
verifier is missing, instead of silently dropping it (rank 27).
- utilities.go / main.go: match allowedUserDomains case-insensitively (rank 31).
- bearer_auth.go: a single success no longer wipes an active per-IP penalty;
the counter resets only when no penalty is in effect (rank 29).
- main.go: handle (not discard) the NewSessionManager error (rank 37).
- error_recovery.go: take a write lock in isServiceDegraded (it deletes from a
map); compare retryable-error substrings case-insensitively (ranks 45, 46).
- singleton_resources.go: bind the generic-cache cleanup goroutine to the
resource-manager shutdown channel so it cannot outlive its owner (rank 41).
- tests: update the bearer throttle test to the corrected penalty semantics.
* fix(security): header sanitization, issuer pinning, fail-closed paths
Batch 9 of security audit remediation (ranks 18, 19, 20, 21, 22, 30, 33, 34).
- middleware.go / bearer_auth.go: sanitize claim-derived values on the cookie
auth path before injecting them into downstream headers. Drop group/role and
identifier values containing control chars, bidi-override runes, or the
, ; = delimiters (a comma would inject phantom entries into X-User-Groups);
reject control/bidi/over-length in rendered templated header output (but
permit , ; = in free-form values such as a bearer token). The bearer path
already sanitized; the cookie path did not (ranks 33, 34).
- main.go / metadata_cache.go: pin the discovered issuer to the configured
provider host (sameHost) and refuse/never-cache a mismatch, so a poisoned
discovery document cannot redefine the JWT trust anchor (ranks 21, 22).
- token_introspection.go: when a distinct API audience is configured, fail
closed on a missing or mismatched introspection audience; aud parsed as
string-or-array per RFC 7662 (rank 19).
- logout.go: front-channel logout requires a matching issuer; an empty iss is
rejected (blocks unauthenticated forced-logout via a known sid) (rank 30).
- token_validation_rs.go: an opaque access token with no ID token and no
successful introspection fails closed (re-auth) instead of authenticating
(ranks 18, 20).
- tests: realistic same-host provider mocks; regression tests for the header
sanitization distinction and the fail-closed paths.
* chore(security): remove unwired dead code with latent footguns
Batch 10 of security audit remediation — delete confirmed-dead, unwired
subsystems (ranks 26, 35, 50). None had a production caller (grep-verified);
removal eliminates the latent footguns and ~2.1k lines of dead code.
- token_validator.go (deleted): an unused *TokenValidator whose validateJWT set
Valid=true with NO signature verification — a severe footgun if ever wired
(rank 50). The wired RS-aware validators are unaffected.
- security_monitoring.go (deleted): an unused *SecurityMonitor / ExtractClientIP
that trusted spoofable X-Forwarded-For / X-Real-IP. The live bearer throttle
uses clientIPForBearer (RemoteAddr-only), unchanged (rank 35).
- dynamic_client_registration.go: removed the RFC 7592 management methods
(Update/Read/DeleteClientRegistration) that dereferenced an attacker-
influenced RegistrationClientURI with the registration token attached and no
HTTPS/SSRF gate, and had no callers. The wired RFC 7591 RegisterClient and
credential-store helpers are kept (rank 26).
- tests: removed the tests covering the deleted code.
* chore: add Makefile with yaegi load validation
No Makefile existed. The new `yaegi-validate` target interprets the plugin
under the yaegi interpreter the same way Traefik loads it, catching yaegi-only
incompatibilities (unsupported stdlib symbols, reflection edge cases) that the
native `go build` / `go test` toolchain does not. Importing the plugin forces
yaegi to interpret every file plus its vendored deps; CreateConfig + New
exercise the instantiation path.
- cmd/yaegicheck/main.go: the load driver, marked //go:build ignore so it is
excluded from `go build ./...` (avoids VCS-stamping a main binary, which
fails in git-worktree layouts) yet is run explicitly by yaegi.
- Makefile: build / fmt / vet / lint / test / vendor / yaegi-validate / check
targets; `make check` runs vet + tests + yaegi-validate.
Verified: `make yaegi-validate` passes on this branch — the HKDF cookie
encryption, net-based endpoint validation, and claim sanitizers all interpret
and instantiate cleanly under yaegi.
* ci: bump workflow Go toolchain to 1.25; pin yaegi-validate to v0.16.1
Traefik v3.7.1 (the deployed version) is built with `go 1.25.0`, so the PR and
release workflows now use Go 1.25.x to match the toolchain Traefik uses.
Important distinction: the CI Go version is the build TOOLCHAIN. The plugin's
actual interpreter-compatibility ceiling is the yaegi version Traefik bundles
(v0.16.1, which declares go 1.21 and ships a ~Go 1.22 stdlib symbol surface),
NOT the CI Go version. That ceiling is enforced by `make yaegi-validate` plus
the go.mod language directive — e.g. it is why HKDF is hand-rolled with
hmac+sha256 rather than Go 1.24's crypto/hkdf, which yaegi v0.16.1 lacks.
Also pin Makefile YAEGI_VERSION to v0.16.1 (what Traefik v3.7.1 vendors) so
yaegi-validate exercises the real deployed interpreter instead of @latest,
which could pass on a newer yaegi that supports symbols the deployed one does
not.
* docs: align README/CONFIGURATION with branch behavior changes
- excludedURLs: documented as segment/extension-boundary matching (was
"prefix-matched") — "/public" no longer also matches "/publicsecret" (rank 14).
- Front-channel logout now requires a matching `iss`; requests without one are
rejected with 400 (rank 30).
- Add an "Upgrading from an earlier release" note: session cookies are now
AES-256 encrypted with lifetime tracking sessionMaxAge (one-time re-login on
upgrade), and invalid configuration (rateLimit < 10, key < 32 bytes, missing
callbackURL, non-HTTPS remote providerURL) now fails closed at startup.
* fix: remove staticcheck-flagged unused functions; wire staticcheck into make check
CI Static Analysis (standalone staticcheck) failed with U1000 "unused":
- dynamic_client_registration.go: deleteCredentialsFromStore — its only caller
was the RFC 7592 DeleteClientRegistration removed in the dead-code batch.
- token_test.go: createTestJWTSimple — its only callers were the TokenValidator
tests removed in the same batch.
Both confirmed to have zero remaining callers and removed. build / vet /
go test ./... / staticcheck ./... all green.
The pre-commit hook runs golangci-lint, but CI runs standalone staticcheck
(which flags U1000). Add a `staticcheck` Makefile target and include it in
`make check` so this class of finding is caught locally before push.
* fix(test): stabilize flaky TestWorkerPool_TaskPanic
tasksFailed is incremented in the worker's deferred recover(), which runs after the panicking task's own defer wg.Done(). wg.Wait() could therefore return before the failure was recorded, so reading the counter immediately raced and flaked on slow CI runners. Poll until the failure lands (2s budget) instead. Verified 200x plain + 50x under -race/GOMAXPROCS=1.
4849 lines
156 KiB
Go
4849 lines
156 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math/big"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
// TestSuite holds common test data and setup
|
|
type TestSuite struct {
|
|
t *testing.T
|
|
rsaPrivateKey *rsa.PrivateKey
|
|
rsaPublicKey *rsa.PublicKey
|
|
ecPrivateKey *ecdsa.PrivateKey
|
|
tOidc *TraefikOidc
|
|
mockJWKCache *MockJWKCache
|
|
sessionManager *SessionManager
|
|
// utf *UnifiedTestFramework // Removed - consolidated test framework
|
|
token string
|
|
}
|
|
|
|
// NewTestSuite creates a new test suite with automatic cleanup
|
|
func NewTestSuite(t *testing.T) *TestSuite {
|
|
ts := &TestSuite{
|
|
t: t,
|
|
// utf: NewUnifiedTestFramework(t), // Removed
|
|
}
|
|
return ts
|
|
}
|
|
|
|
// Setup initializes the test suite
|
|
func (ts *TestSuite) Setup() {
|
|
// Initialize unified test framework if not already done
|
|
// Unified test framework removed - using direct cleanup
|
|
|
|
var err error
|
|
ts.rsaPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to generate RSA key: %v", err)
|
|
}
|
|
ts.rsaPublicKey = &ts.rsaPrivateKey.PublicKey
|
|
|
|
// Generate EC key for EC key tests
|
|
ts.ecPrivateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to generate EC key: %v", err)
|
|
}
|
|
|
|
// Create a JWK for the RSA public key
|
|
jwk := JWK{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Alg: "RS256",
|
|
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
|
|
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
|
|
}
|
|
jwks := &JWKSet{
|
|
Keys: []JWK{jwk},
|
|
}
|
|
|
|
// Create a mock JWKCache
|
|
ts.mockJWKCache = &MockJWKCache{
|
|
JWKS: jwks,
|
|
Err: nil,
|
|
}
|
|
|
|
// Create a test JWT token signed with the RSA private key
|
|
// Create timestamps with proper clock skew
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
|
|
ts.token, err = createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
if err != nil {
|
|
ts.t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
logger := NewLogger("info")
|
|
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
|
|
// Create WaitGroup for the OIDC instance
|
|
goroutineWG := &sync.WaitGroup{}
|
|
|
|
// Initialize caches properly
|
|
tokenBlacklist := NewCache()
|
|
tokenCacheInternal := NewCache()
|
|
tokenCache := &TokenCache{}
|
|
if tokenCache.cache == nil {
|
|
// Type assert to get the underlying UniversalCache
|
|
if wrapper, ok := tokenCacheInternal.(*CacheInterfaceWrapper); ok {
|
|
tokenCache.cache = wrapper.cache
|
|
}
|
|
}
|
|
|
|
// Common TraefikOidc instance
|
|
ts.tOidc = &TraefikOidc{
|
|
issuerURL: "https://test-issuer.com",
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
roleClaimName: "roles", // Set default for backward compatibility
|
|
groupClaimName: "groups", // Set default for backward compatibility
|
|
userIdentifierClaim: "email", // Set default for backward compatibility
|
|
jwkCache: ts.mockJWKCache,
|
|
jwksURL: "https://test-jwks-url.com",
|
|
revocationURL: "https://revocation-endpoint.com",
|
|
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
|
tokenBlacklist: tokenBlacklist,
|
|
tokenCache: tokenCache,
|
|
logger: logger,
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
excludedURLs: map[string]struct{}{"/favicon": {}, "/health": {}},
|
|
httpClient: &http.Client{Timeout: 10 * time.Second},
|
|
// Explicitly set paths as New() is bypassed
|
|
redirURLPath: "/callback", // Assume default callback path for tests
|
|
logoutURLPath: "/callback/logout", // Assume default logout path for tests
|
|
tokenURL: "https://test-issuer.com/token", // Explicitly set for refresh tests
|
|
extractClaimsFunc: extractClaims,
|
|
initComplete: make(chan struct{}),
|
|
sessionManager: ts.sessionManager,
|
|
goroutineWG: goroutineWG,
|
|
ctx: context.Background(),
|
|
tokenCleanupStopChan: make(chan struct{}),
|
|
metadataRefreshStopChan: make(chan struct{}),
|
|
}
|
|
close(ts.tOidc.initComplete)
|
|
ts.tOidc.tokenVerifier = ts.tOidc
|
|
ts.tOidc.jwtVerifier = ts.tOidc
|
|
// Set default mock exchanger
|
|
ts.tOidc.tokenExchanger = &MockTokenExchanger{
|
|
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
// Default mock behavior for code exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token, // Use the valid token from setup
|
|
AccessToken: ts.token,
|
|
RefreshToken: "default-refresh-token",
|
|
ExpiresIn: 3600,
|
|
}, nil
|
|
},
|
|
RefreshTokenFunc: func(refreshToken string) (*TokenResponse, error) {
|
|
// Default mock behavior for refresh (can be overridden in tests)
|
|
return nil, fmt.Errorf("default mock: refresh not expected")
|
|
},
|
|
RevokeTokenFunc: func(token, tokenType string) error {
|
|
// Default mock behavior for revoke
|
|
return nil
|
|
},
|
|
}
|
|
|
|
// OIDC instance created
|
|
|
|
// Register cleanup
|
|
ts.t.Cleanup(func() {
|
|
if ts.tOidc.tokenBlacklist != nil {
|
|
ts.tOidc.tokenBlacklist.Close()
|
|
}
|
|
if ts.tOidc.tokenCache != nil && ts.tOidc.tokenCache.cache != nil {
|
|
ts.tOidc.tokenCache.cache.Close()
|
|
}
|
|
})
|
|
}
|
|
|
|
// Helper function exchangeCodeForTokenFunc removed as it's unused after refactoring to TokenExchanger interface.
|
|
|
|
// MockJWKCache implements JWKCacheInterface
|
|
type MockJWKCache struct {
|
|
Err error
|
|
JWKS *JWKSet
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// Close is a no-op for the mock.
|
|
func (m *MockJWKCache) Close() {
|
|
// No operation needed for the mock.
|
|
}
|
|
|
|
func (m *MockJWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return m.JWKS, m.Err
|
|
}
|
|
|
|
func (m *MockJWKCache) GetPublicKey(ctx context.Context, jwksURL, kid string, httpClient *http.Client) (crypto.PublicKey, error) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
if m.Err != nil {
|
|
return nil, m.Err
|
|
}
|
|
if m.JWKS == nil {
|
|
return nil, fmt.Errorf("JWKS is nil")
|
|
}
|
|
for i := range m.JWKS.Keys {
|
|
k := &m.JWKS.Keys[i]
|
|
if k.Kid != kid {
|
|
continue
|
|
}
|
|
switch k.Kty {
|
|
case "RSA":
|
|
return k.ToRSAPublicKey()
|
|
case "EC":
|
|
return k.ToECDSAPublicKey()
|
|
default:
|
|
return nil, fmt.Errorf("unsupported key type: %s", k.Kty)
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("no matching public key found for kid: %s", kid)
|
|
}
|
|
|
|
func (m *MockJWKCache) Cleanup() {
|
|
// Mock cleanup is a no-op - we don't want to destroy the mock JWKS data
|
|
// Real cleanup is for expired entries, not resetting all data
|
|
}
|
|
|
|
// MockTokenVerifier implements TokenVerifier for testing, allowing interception of VerifyToken calls.
|
|
type MockTokenVerifier struct {
|
|
VerifyFunc func(token string) error
|
|
}
|
|
|
|
func (m *MockTokenVerifier) VerifyToken(token string) error {
|
|
if m.VerifyFunc != nil {
|
|
return m.VerifyFunc(token)
|
|
}
|
|
return fmt.Errorf("VerifyFunc not implemented in mock")
|
|
}
|
|
|
|
// MockTokenExchanger implements TokenExchanger for testing
|
|
type MockTokenExchanger struct {
|
|
ExchangeCodeFunc func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
|
|
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
|
|
RevokeTokenFunc func(token, tokenType string) error
|
|
}
|
|
|
|
func (m *MockTokenExchanger) ExchangeCodeForToken(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
if m.ExchangeCodeFunc != nil {
|
|
return m.ExchangeCodeFunc(ctx, grantType, codeOrToken, redirectURL, codeVerifier)
|
|
}
|
|
return nil, fmt.Errorf("ExchangeCodeFunc not implemented in mock")
|
|
}
|
|
|
|
func (m *MockTokenExchanger) GetNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
|
if m.RefreshTokenFunc != nil {
|
|
return m.RefreshTokenFunc(refreshToken)
|
|
}
|
|
return nil, fmt.Errorf("RefreshTokenFunc not implemented in mock")
|
|
}
|
|
|
|
func (m *MockTokenExchanger) RevokeTokenWithProvider(token, tokenType string) error {
|
|
if m.RevokeTokenFunc != nil {
|
|
return m.RevokeTokenFunc(token, tokenType)
|
|
}
|
|
return fmt.Errorf("RevokeTokenFunc not implemented in mock")
|
|
}
|
|
|
|
// Helper function to check if a token is a test token
|
|
func isTestToken(token string) bool {
|
|
// Parse the token without verification to check if it's a test token
|
|
claims, err := extractClaims(token)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
// Check if the issuer is our test issuer
|
|
if iss, ok := claims["iss"].(string); ok {
|
|
return iss == "https://test-issuer.com"
|
|
}
|
|
|
|
// Check if audience is our test client
|
|
if aud, ok := claims["aud"].(string); ok {
|
|
return aud == "test-client-id"
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
// Helper function to create a new valid token for refresh tests using test suite
|
|
func (ts *TestSuite) createNewValidToken() string {
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Add(-2 * time.Minute).Unix()
|
|
nbf := now.Add(-2 * time.Minute).Unix()
|
|
|
|
token, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": generateRandomString(16),
|
|
})
|
|
|
|
return token
|
|
}
|
|
|
|
// Helper function to create a JWT token
|
|
func createTestJWT(privateKey *rsa.PrivateKey, alg, kid string, claims map[string]interface{}) (string, error) {
|
|
header := map[string]interface{}{
|
|
"alg": alg,
|
|
"kid": kid,
|
|
"typ": "JWT",
|
|
}
|
|
headerJSON, err := json.Marshal(header)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
headerEncoded := base64.RawURLEncoding.EncodeToString(headerJSON)
|
|
|
|
claimsJSON, err := json.Marshal(claims)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
claimsEncoded := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
signedContent := headerEncoded + "." + claimsEncoded
|
|
|
|
// Select the appropriate hash function based on algorithm
|
|
var hashFunc crypto.Hash
|
|
switch alg {
|
|
case "RS256", "PS256":
|
|
hashFunc = crypto.SHA256
|
|
case "RS384", "PS384":
|
|
hashFunc = crypto.SHA384
|
|
case "RS512", "PS512":
|
|
hashFunc = crypto.SHA512
|
|
default:
|
|
return "", fmt.Errorf("unsupported algorithm: %s", alg)
|
|
}
|
|
|
|
hasher := hashFunc.New()
|
|
hasher.Write([]byte(signedContent))
|
|
hashed := hasher.Sum(nil)
|
|
|
|
var signatureBytes []byte
|
|
|
|
// Use appropriate signing method based on algorithm
|
|
if strings.HasPrefix(alg, "RS") {
|
|
// PKCS1v15 signing for RS* algorithms
|
|
signatureBytes, err = rsa.SignPKCS1v15(rand.Reader, privateKey, hashFunc, hashed)
|
|
} else if strings.HasPrefix(alg, "PS") {
|
|
// PSS signing for PS* algorithms
|
|
signatureBytes, err = rsa.SignPSS(rand.Reader, privateKey, hashFunc, hashed, nil)
|
|
} else {
|
|
return "", fmt.Errorf("unsupported RSA algorithm: %s", alg)
|
|
}
|
|
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
signatureEncoded := base64.RawURLEncoding.EncodeToString(signatureBytes)
|
|
|
|
token := signedContent + "." + signatureEncoded
|
|
|
|
return token, nil
|
|
}
|
|
|
|
func bigIntToBytes(i *big.Int) []byte {
|
|
return i.Bytes()
|
|
}
|
|
|
|
// TestVerifyToken tests the VerifyToken method
|
|
func TestVerifyToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
blacklist bool
|
|
rateLimit bool
|
|
cacheToken bool
|
|
expectedError bool
|
|
}{
|
|
{
|
|
name: "Valid token",
|
|
token: ts.token,
|
|
expectedError: false,
|
|
},
|
|
{
|
|
name: "Invalid token signature",
|
|
token: ts.token + "invalid",
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Blacklisted token",
|
|
token: ts.token,
|
|
blacklist: true,
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Rate limit exceeded",
|
|
token: ts.token,
|
|
rateLimit: true,
|
|
expectedError: true,
|
|
},
|
|
{
|
|
name: "Token in cache",
|
|
token: ts.token,
|
|
cacheToken: true,
|
|
expectedError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache for each test
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
// Clear the token cache instead of creating a new one (it's a singleton)
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
ts.tOidc.tokenCache.Clear()
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
|
|
// Set up the test case
|
|
if tc.blacklist {
|
|
// Use Set with a duration. Value 'true' is arbitrary.
|
|
ts.tOidc.tokenBlacklist.Set(tc.token, true, 1*time.Hour)
|
|
}
|
|
|
|
if tc.rateLimit {
|
|
// Exceed rate limit
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
|
|
}
|
|
|
|
if tc.cacheToken {
|
|
// Use more realistic claims for cached token
|
|
ts.tOidc.tokenCache.Set(tc.token, map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"sub": "test-subject",
|
|
"exp": float64(time.Now().Add(1 * time.Hour).Unix()),
|
|
"jti": generateRandomString(16), // Add a JTI claim to prevent replay detection
|
|
}, time.Minute)
|
|
|
|
// Verify the token is actually in the cache
|
|
if claims, exists := ts.tOidc.tokenCache.Get(tc.token); exists {
|
|
t.Logf("Token found in cache with claims: %v", claims)
|
|
} else {
|
|
t.Logf("Token NOT found in cache despite cacheToken=true")
|
|
}
|
|
}
|
|
|
|
err := ts.tOidc.VerifyToken(tc.token)
|
|
if tc.expectedError && err == nil {
|
|
t.Errorf("Test %s: expected error but got nil", tc.name)
|
|
}
|
|
if !tc.expectedError && err != nil {
|
|
t.Errorf("Test %s: expected no error but got %v", tc.name, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeHTTP tests the ServeHTTP method
|
|
func TestServeHTTP(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
})
|
|
ts.tOidc.next = nextHandler
|
|
ts.tOidc.name = "test"
|
|
|
|
// Helper to create an expired token
|
|
createExpiredToken := func() string {
|
|
exp := time.Now().Add(-1 * time.Hour).Unix() // Expired 1 hour ago
|
|
iat := time.Now().Add(-2 * time.Hour).Unix()
|
|
nbf := time.Now().Add(-2 * time.Hour).Unix()
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce-expired", // Different nonce for clarity
|
|
"jti": generateRandomString(16),
|
|
})
|
|
return expiredToken
|
|
}
|
|
|
|
tests := []struct {
|
|
sessionValues map[interface{}]interface{}
|
|
setupSession func(*SessionData)
|
|
mockRefreshTokenFunc func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error)
|
|
assertSessionAfterRequest func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager)
|
|
requestHeaders map[string]string
|
|
name string
|
|
requestPath string
|
|
expectedBody string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Excluded URL",
|
|
requestPath: "/favicon.ico",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Unauthenticated request (no refresh token) to protected URL",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Ensure no tokens are set
|
|
session.SetAuthenticated(false)
|
|
session.SetAccessToken("")
|
|
session.SetRefreshToken("")
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC as there's no refresh token
|
|
},
|
|
{
|
|
name: "Unauthenticated request (with refresh token) to protected URL - Expect Refresh Attempt",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(false) // Not authenticated
|
|
session.SetAccessToken("") // No access token
|
|
session.SetRefreshToken("valid-refresh-token-for-unauth-test") // BUT has refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token-for-unauth-test" {
|
|
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := ts.createNewValidToken() // Use helper from TestServeHTTP
|
|
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-unauth", ExpiresIn: 3600}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect OK after successful refresh
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Unauthenticated request (with refresh token) to protected URL - Refresh Fails",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(false) // Not authenticated
|
|
session.SetAccessToken("") // No access token
|
|
session.SetRefreshToken("invalid-refresh-token-for-unauth-test") // Invalid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid")
|
|
}
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC after failed refresh
|
|
},
|
|
{
|
|
name: "Authenticated request to protected URL (Valid Token)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
// Generate a fresh valid token for this test case to avoid replay issues
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: "OK",
|
|
},
|
|
// This test case remains valid as the logic should still attempt refresh when expired token + refresh token exist
|
|
{
|
|
name: "Authenticated request with expired token and successful refresh",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// NOTE: isUserAuthenticated now returns authenticated=false if access token is expired,
|
|
// even if session.SetAuthenticated(true) was called.
|
|
// We rely on needsRefresh=true and the presence of the refresh token to trigger the refresh attempt.
|
|
session.SetAuthenticated(true) // Set flag initially, though isUserAuthenticated will override based on token
|
|
session.SetUserIdentifier("user@example.com")
|
|
// Create an expired token for this test
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
|
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAccessToken(expiredToken) // Set expired token
|
|
session.SetRefreshToken("valid-refresh-token") // Set valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token" {
|
|
return nil, fmt.Errorf("mock error: expected 'valid-refresh-token', got '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := ts.createNewValidToken()
|
|
return &TokenResponse{
|
|
IDToken: newToken, // Return new valid token
|
|
AccessToken: newToken, // Often the same as ID token in tests
|
|
RefreshToken: "new-refresh-token",
|
|
ExpiresIn: 3600,
|
|
}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success after refresh
|
|
expectedBody: "OK",
|
|
assertSessionAfterRequest: func(t *testing.T, rr *httptest.ResponseRecorder, req *http.Request, sessionManager *SessionManager) {
|
|
// Create a new request to read the cookies set by the response recorder
|
|
reqForCookieRead := httptest.NewRequest("GET", "/protected", nil)
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
reqForCookieRead.AddCookie(cookie)
|
|
}
|
|
// Get session based on response cookies
|
|
session, err := sessionManager.GetSession(reqForCookieRead)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session after request: %v", err)
|
|
}
|
|
// Assert new tokens are in the session
|
|
if session.GetAccessToken() == "" || session.GetAccessToken() == createExpiredToken() {
|
|
t.Errorf("Expected access token to be updated in session, but it was empty or still the expired one")
|
|
}
|
|
if session.GetRefreshToken() != "new-refresh-token" {
|
|
t.Errorf("Expected refresh token to be updated to 'new-refresh-token', got '%s'", session.GetRefreshToken())
|
|
}
|
|
// Also check authenticated flag is now true
|
|
if !session.GetAuthenticated() {
|
|
t.Errorf("Expected session to be marked authenticated after successful refresh")
|
|
}
|
|
},
|
|
},
|
|
// This test case remains valid as the logic should still return 401 for API clients on refresh failure
|
|
{
|
|
name: "Logout URL",
|
|
requestPath: "/callback/logout", // Match the default logout path set in TestSuite.Setup
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@example.com",
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect after logout
|
|
expectedBody: "",
|
|
// No specific session assertion needed for logout redirect itself
|
|
},
|
|
{
|
|
name: "Authenticated request with expired token and FAILED refresh (Accept: JSON)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true) // Set flag initially
|
|
session.SetUserIdentifier("user@example.com")
|
|
// Create an expired token for this test
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
|
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAccessToken(expiredToken) // Expired access token
|
|
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
|
}
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "application/json",
|
|
},
|
|
expectedStatus: http.StatusUnauthorized, // Expect 401 for API client after failed refresh attempt
|
|
expectedBody: `{"error":"Unauthorized","error_description":"Token refresh failed","status_code":401}`,
|
|
},
|
|
// This test case remains valid as the logic should still redirect browser clients on refresh failure
|
|
{
|
|
name: "Authenticated request with expired token and FAILED refresh (Accept: HTML)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true) // Set flag initially
|
|
session.SetUserIdentifier("user@example.com")
|
|
// Create an expired token for this test
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
|
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAccessToken(expiredToken) // Expired access token
|
|
session.SetRefreshToken("valid-refresh-token") // Valid refresh token
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
// Simulate failed refresh
|
|
return nil, fmt.Errorf("mock error: refresh token invalid or provider down")
|
|
}
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "text/html", // Browser client
|
|
},
|
|
expectedStatus: http.StatusFound, // Expect redirect to OIDC for browser client after failed refresh attempt
|
|
},
|
|
// This test case remains valid as proactive refresh should still be attempted
|
|
{
|
|
name: "Authenticated request with token nearing expiry (needs refresh)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Create token expiring soon (e.g., 30s, within default 60s grace period)
|
|
exp := time.Now().Add(30 * time.Second).Unix()
|
|
iat := time.Now().Add(-1 * time.Minute).Unix()
|
|
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
|
nearExpiryToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
|
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
session.SetAccessToken(nearExpiryToken)
|
|
session.SetRefreshToken("valid-refresh-token-for-near-expiry") // Refresh token MUST exist for proactive refresh
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
if refreshToken != "valid-refresh-token-for-near-expiry" {
|
|
return nil, fmt.Errorf("mock error: unexpected refresh token '%s'", refreshToken)
|
|
}
|
|
// Simulate successful refresh
|
|
newToken := ts.createNewValidToken()
|
|
return &TokenResponse{IDToken: newToken, AccessToken: newToken, RefreshToken: "new-refresh-token-near-expiry", ExpiresIn: 3600}, nil
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success after proactive refresh
|
|
expectedBody: "OK",
|
|
},
|
|
// This test case remains valid as no refresh should be attempted
|
|
{
|
|
name: "Authenticated request with token valid (outside grace period)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
// Create token expiring later (e.g., 10 mins, outside default 60s grace period)
|
|
exp := time.Now().Add(10 * time.Minute).Unix()
|
|
iat := time.Now().Add(-1 * time.Minute).Unix()
|
|
nbf := time.Now().Add(-1 * time.Minute).Unix()
|
|
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": exp, "iat": iat, "nbf": nbf,
|
|
"sub": "test-subject", "email": "user@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
session.SetAccessToken(validToken)
|
|
session.SetIDToken(validToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("should-not-be-used-refresh-token")
|
|
},
|
|
mockRefreshTokenFunc: func(originalFunc func(refreshToken string) (*TokenResponse, error)) func(refreshToken string) (*TokenResponse, error) {
|
|
// This should NOT be called
|
|
return func(refreshToken string) (*TokenResponse, error) {
|
|
t.Errorf("Refresh token function was called unexpectedly for valid token outside grace period")
|
|
return nil, fmt.Errorf("refresh should not have been attempted")
|
|
}
|
|
},
|
|
expectedStatus: http.StatusOK, // Expect success, no refresh needed
|
|
expectedBody: "OK",
|
|
},
|
|
{
|
|
name: "Disallowed Domain (Accept: JSON)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "application/json",
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
expectedBody: `{"error":"Forbidden","error_description":"Access denied: You are not authorized to access this resource. To log out, visit: /callback/logout","status_code":403}`,
|
|
},
|
|
{
|
|
name: "Disallowed Domain (Accept: HTML)",
|
|
requestPath: "/protected",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@disallowed.com") // Use disallowed domain
|
|
// Generate a fresh valid token for this test case
|
|
freshToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
"iat": time.Now().Unix(), "nbf": time.Now().Unix(), "sub": "test-subject", "email": "user@disallowed.com", // Match email
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
session.SetAccessToken(freshToken)
|
|
session.SetIDToken(freshToken) // Ensure ID token is also set
|
|
session.SetRefreshToken("valid-refresh-token")
|
|
},
|
|
requestHeaders: map[string]string{
|
|
"Accept": "text/html",
|
|
},
|
|
expectedStatus: http.StatusForbidden, // Still Forbidden, but HTML response
|
|
expectedBody: "", // Body check is harder for HTML, focus on status and content-type
|
|
},
|
|
}
|
|
|
|
// Configure allowed domains for domain restriction tests
|
|
// This allows example.com but not disallowed.com
|
|
ts.tOidc.allowedUserDomains = map[string]struct{}{
|
|
"example.com": {},
|
|
}
|
|
|
|
// Use mock JWK cache to enable proper token verification
|
|
ts.tOidc.jwkCache = ts.mockJWKCache
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache for each test to prevent token replay detection errors
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
|
|
// Reset the global replayCache to prevent "token replay detected" errors
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Store original tokenVerifier to restore later
|
|
origTokenVerifier := ts.tOidc.tokenVerifier
|
|
|
|
// Create a mock tokenVerifier that clears the replay cache before verification
|
|
// This prevents replay detection when the same token is verified multiple times within a test
|
|
mockTokenVerifier := &MockTokenVerifier{
|
|
VerifyFunc: func(token string) error {
|
|
// Clear replay cache before token verification
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// For test tokens, perform basic validation without JWKS dependency
|
|
if isTestToken(token) {
|
|
// Parse the token to check basic validity and expiration
|
|
claims, err := extractClaims(token)
|
|
if err != nil {
|
|
return fmt.Errorf("token parsing failed: %v", err)
|
|
}
|
|
|
|
// Check token expiration
|
|
if exp, ok := claims["exp"].(float64); ok {
|
|
if time.Now().Unix() > int64(exp) {
|
|
return fmt.Errorf("token has expired")
|
|
}
|
|
}
|
|
|
|
// Token is valid for test purposes - also cache the claims like the real verifier would
|
|
if ts.tOidc.tokenCache != nil {
|
|
ts.tOidc.tokenCache.Set(token, claims, time.Hour)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// For non-test tokens, call the original verifier
|
|
if origTokenVerifier != nil {
|
|
return origTokenVerifier.VerifyToken(token)
|
|
}
|
|
return fmt.Errorf("original token verifier is nil")
|
|
},
|
|
}
|
|
|
|
// Replace tokenVerifier with our mock
|
|
ts.tOidc.tokenVerifier = mockTokenVerifier
|
|
|
|
// Restore original tokenVerifier after test
|
|
defer func() {
|
|
ts.tOidc.tokenVerifier = origTokenVerifier
|
|
}()
|
|
|
|
req := httptest.NewRequest("GET", tc.requestPath, nil)
|
|
// Set common headers needed by the logic (determineScheme, determineHost)
|
|
req.Header.Set("X-Forwarded-Proto", "http") // Or https if testing that
|
|
req.Header.Set("X-Forwarded-Host", "testhost.com")
|
|
req.Host = "testhost.com" // Also set Host header
|
|
// Set request headers from test case
|
|
if tc.requestHeaders != nil {
|
|
for key, value := range tc.requestHeaders {
|
|
req.Header.Set(key, value)
|
|
}
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Setup session if needed
|
|
session, err := ts.tOidc.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Test %s: Failed to get initial session: %v", tc.name, err)
|
|
}
|
|
if tc.setupSession != nil {
|
|
tc.setupSession(session)
|
|
// Save session to recorder to get cookies
|
|
saveRecorder := httptest.NewRecorder()
|
|
if err := session.Save(req, saveRecorder); err != nil {
|
|
t.Fatalf("Test %s: Failed to save initial session: %v", tc.name, err)
|
|
}
|
|
// Copy cookies from save recorder to the actual request
|
|
for _, cookie := range saveRecorder.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
}
|
|
|
|
// Mocking setup for TokenExchanger
|
|
originalExchanger := ts.tOidc.tokenExchanger // Store original
|
|
mockExchanger, isMock := originalExchanger.(*MockTokenExchanger)
|
|
if !isMock {
|
|
// This case should ideally not happen if Setup correctly assigns the mock,
|
|
// but handle it defensively.
|
|
t.Logf("Warning: Default exchanger was not the mock. Creating a temporary mock.")
|
|
mockExchanger = &MockTokenExchanger{
|
|
ExchangeCodeFunc: originalExchanger.ExchangeCodeForToken,
|
|
RefreshTokenFunc: originalExchanger.GetNewTokenWithRefreshToken,
|
|
RevokeTokenFunc: originalExchanger.RevokeTokenWithProvider,
|
|
}
|
|
ts.tOidc.tokenExchanger = mockExchanger // Temporarily assign mock
|
|
}
|
|
|
|
// Override specific mock methods if needed for the test case
|
|
originalMockRefreshFunc := mockExchanger.RefreshTokenFunc // Store current mock func
|
|
if tc.mockRefreshTokenFunc != nil {
|
|
// Assign the test case specific mock function
|
|
mockExchanger.RefreshTokenFunc = tc.mockRefreshTokenFunc(originalExchanger.GetNewTokenWithRefreshToken)
|
|
}
|
|
|
|
// Call ServeHTTP
|
|
ts.tOidc.ServeHTTP(rr, req)
|
|
|
|
// Restore original exchanger and mock function state
|
|
ts.tOidc.tokenExchanger = originalExchanger
|
|
if tc.mockRefreshTokenFunc != nil && mockExchanger != nil {
|
|
// Restore the previous mock function if we overrode it
|
|
mockExchanger.RefreshTokenFunc = originalMockRefreshFunc
|
|
}
|
|
|
|
// Check response status
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Test %s: Expected status %d, got %d. Body: %s", tc.name, tc.expectedStatus, rr.Code, rr.Body.String())
|
|
}
|
|
|
|
// Check response body if expected
|
|
// Check response body if expected (handle JSON vs HTML)
|
|
if tc.expectedBody != "" {
|
|
// For JSON, compare directly
|
|
if strings.Contains(rr.Header().Get("Content-Type"), "application/json") {
|
|
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
|
t.Errorf("Test %s: Expected JSON body %q, got %q", tc.name, tc.expectedBody, body)
|
|
}
|
|
} else if tc.expectedBody == "OK" { // Simple check for the "OK" body from next handler
|
|
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
|
t.Errorf("Test %s: Expected body %q, got %q", tc.name, tc.expectedBody, body)
|
|
}
|
|
}
|
|
// Add more sophisticated HTML body checks if needed
|
|
}
|
|
|
|
// Perform post-request session assertions if defined
|
|
if tc.assertSessionAfterRequest != nil {
|
|
tc.assertSessionAfterRequest(t, rr, req, ts.tOidc.sessionManager)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJWKToPEM(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
jwk *JWK
|
|
name string
|
|
errorContains string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Unsupported Key Type",
|
|
jwk: &JWK{
|
|
Kty: "unsupported",
|
|
Kid: "test-key-id",
|
|
},
|
|
expectError: true,
|
|
errorContains: "unsupported key type",
|
|
},
|
|
{
|
|
name: "EC Key",
|
|
jwk: &JWK{
|
|
Kty: "EC",
|
|
Kid: "test-ec-key-id",
|
|
Crv: "P-256",
|
|
X: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.X.Bytes()),
|
|
Y: base64.RawURLEncoding.EncodeToString(ts.ecPrivateKey.PublicKey.Y.Bytes()),
|
|
},
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
pemBytes, err := jwkToPEM(tc.jwk)
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if len(pemBytes) == 0 {
|
|
t.Error("PEM bytes should not be empty")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseJWT(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
errorContains string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Invalid Format",
|
|
token: "invalid.jwt.token",
|
|
expectError: true,
|
|
errorContains: "invalid JWT format",
|
|
},
|
|
{
|
|
name: "Valid Token",
|
|
token: ts.token,
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := parseJWT(tc.token)
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error, got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJWTVerify_MissingClaims(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
jwt := &JWT{
|
|
Header: map[string]interface{}{
|
|
"alg": "RS256",
|
|
"kid": "test-key-id",
|
|
},
|
|
Claims: map[string]interface{}{
|
|
// Missing 'iss', 'aud', 'exp', 'iat', 'sub'
|
|
},
|
|
}
|
|
|
|
err := jwt.Verify("https://test-issuer.com", "test-client-id")
|
|
if err == nil {
|
|
t.Error("Expected error for missing claims, got nil")
|
|
}
|
|
}
|
|
|
|
func TestHandleCallback(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
redirectURL := "http://example.com/"
|
|
|
|
tests := []struct {
|
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
|
sessionSetupFunc func(*SessionData)
|
|
name string
|
|
queryParams string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Success",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusFound,
|
|
},
|
|
{
|
|
name: "Missing Code",
|
|
queryParams: "",
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Exchange Code Error",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return nil, fmt.Errorf("exchange code error")
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing ID Token",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Disallowed Email",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Generate a unique token for this test case to avoid replay issues
|
|
// Use claims relevant to this test (disallowed email)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
disallowedToken, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject-disallowed",
|
|
"email": "user@disallowed.com", // The disallowed email for this test
|
|
"nonce": "test-nonce", // Match the nonce set in sessionSetupFunc
|
|
"jti": generateRandomString(16), // Unique JTI
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create disallowed token for test: %w", err)
|
|
}
|
|
return &TokenResponse{
|
|
IDToken: disallowedToken,
|
|
RefreshToken: "test-refresh-token-disallowed",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "Invalid State Parameter",
|
|
queryParams: "?code=test-code&state=invalid-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Nonce Mismatch",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing Nonce in Claims",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
// Missing nonce
|
|
}, nil
|
|
},
|
|
sessionSetupFunc: func(session *SessionData) {
|
|
session.SetCSRF("test-csrf-token")
|
|
session.SetNonce("test-nonce")
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
// Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Clear the global replay cache before each test run
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Explicitly clear the shared blacklist at the start of each sub-test
|
|
// to ensure no state leaks, even though we expect the local one to be used.
|
|
// Note: This line might be redundant now that the verifier is local, but keep for safety.
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
|
|
// Create a new instance for each test to avoid state carryover
|
|
instanceExtractClaimsFunc := tc.extractClaimsFunc
|
|
if instanceExtractClaimsFunc == nil {
|
|
instanceExtractClaimsFunc = extractClaims // Default to the real function if not provided by test case
|
|
}
|
|
tOidc := &TraefikOidc{
|
|
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
|
logger: logger,
|
|
userIdentifierClaim: "email", // Required for claim extraction
|
|
// exchangeCodeForTokenFunc: tc.exchangeCodeForToken, // Removed field
|
|
extractClaimsFunc: instanceExtractClaimsFunc, // Use the potentially defaulted function
|
|
tokenVerifier: nil, // Will be set to self below
|
|
jwtVerifier: nil, // Temporarily nil, will be set below
|
|
sessionManager: sessionManager,
|
|
tokenExchanger: &MockTokenExchanger{ // Create a new mock exchanger for this specific test run
|
|
ExchangeCodeFunc: func(ctx context.Context, grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error) {
|
|
// Wrap the test case function to match the required signature
|
|
if tc.exchangeCodeForToken != nil {
|
|
// Only call if the test case provided a function
|
|
return tc.exchangeCodeForToken(codeOrToken, redirectURL, codeVerifier)
|
|
}
|
|
// Provide a default behavior or error if no mock was provided for this test case
|
|
return nil, fmt.Errorf("mock ExchangeCodeFunc not implemented for this test case")
|
|
},
|
|
// Keep other mock funcs nil or provide defaults if needed by other parts of handleCallback
|
|
},
|
|
tokenCache: NewTokenCache(), // Initialize token cache
|
|
limiter: rate.NewLimiter(rate.Inf, 0), // Initialize rate limiter
|
|
tokenBlacklist: NewCache(), // Initialize token blacklist cache
|
|
|
|
// Add potentially missing fields based on New() comparison
|
|
clientID: ts.tOidc.clientID,
|
|
audience: ts.tOidc.clientID,
|
|
issuerURL: ts.tOidc.issuerURL,
|
|
jwkCache: ts.tOidc.jwkCache, // Use the mock cache from TestSuite
|
|
httpClient: ts.tOidc.httpClient,
|
|
initComplete: make(chan struct{}), // Initialize the channel
|
|
// Setting other fields like paths, enablePKCE etc. if needed
|
|
}
|
|
tOidc.tokenVerifier = tOidc // Point tokenVerifier to the local instance NOW
|
|
tOidc.jwtVerifier = tOidc // Point jwtVerifier to the local instance NOW
|
|
close(tOidc.initComplete) // Mark this test instance as initialized
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest("GET", "/callback"+tc.queryParams, nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Create session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
if tc.sessionSetupFunc != nil {
|
|
tc.sessionSetupFunc(session)
|
|
}
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder for the actual test
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Call handleCallback
|
|
tOidc.handleCallback(rr, req, redirectURL)
|
|
|
|
// Check response
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsAllowedDomain(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
allowedDomains map[string]struct{}
|
|
allowedUsers map[string]struct{}
|
|
name string
|
|
email string
|
|
expectedLogOutput string
|
|
allowed bool
|
|
}{
|
|
{
|
|
name: "Allowed domain",
|
|
email: "user@example.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Disallowed domain",
|
|
email: "user@notallowed.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Invalid email",
|
|
email: "invalid-email",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Specific user is allowed regardless of domain",
|
|
email: "specific.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Case-insensitive email matching for specific user",
|
|
email: "Specific.User@otherdomain.com", // Mixed case
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}}, // Lowercase
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Only allowed users configured (no domains)",
|
|
email: "specific.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{}, // Empty allowed domains
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "User not in allowed list when only specific users configured",
|
|
email: "other.user@otherdomain.com",
|
|
allowedDomains: map[string]struct{}{}, // Empty allowed domains
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "No restrictions (both empty)",
|
|
email: "anyone@anydomain.com",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure TraefikOidc instance for this test case
|
|
tOidc := ts.tOidc
|
|
tOidc.allowedUserDomains = tc.allowedDomains
|
|
tOidc.allowedUsers = tc.allowedUsers
|
|
|
|
allowed := tOidc.isAllowedDomain(tc.email)
|
|
if allowed != tc.allowed {
|
|
t.Errorf("Expected allowed=%v, got %v", tc.allowed, allowed)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsAllowedUser(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
allowedDomains map[string]struct{}
|
|
allowedUsers map[string]struct{}
|
|
userIdentifierClaim string
|
|
name string
|
|
userIdentifier string
|
|
allowed bool
|
|
}{
|
|
// Email-based identification (default behavior)
|
|
{
|
|
name: "Email identifier - allowed domain",
|
|
userIdentifier: "user@example.com",
|
|
userIdentifierClaim: "email",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Email identifier - disallowed domain",
|
|
userIdentifier: "user@notallowed.com",
|
|
userIdentifierClaim: "email",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Email identifier - specific user allowed",
|
|
userIdentifier: "specific.user@otherdomain.com",
|
|
userIdentifierClaim: "email",
|
|
allowedDomains: map[string]struct{}{"example.com": {}},
|
|
allowedUsers: map[string]struct{}{"specific.user@otherdomain.com": {}},
|
|
allowed: true,
|
|
},
|
|
|
|
// Non-email identifier (sub claim - for Azure AD users without email)
|
|
{
|
|
name: "Sub identifier - allowed in allowedUsers",
|
|
userIdentifier: "abc12345-6789-0abc-def0-123456789abc",
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Sub identifier - not in allowedUsers",
|
|
userIdentifier: "xyz-not-allowed-user",
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}},
|
|
allowed: false,
|
|
},
|
|
{
|
|
name: "Sub identifier - allowedDomains ignored for non-email",
|
|
userIdentifier: "user-id-12345",
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{"example.com": {}}, // Should be ignored
|
|
allowedUsers: map[string]struct{}{"user-id-12345": {}},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Sub identifier - no restrictions allows all",
|
|
userIdentifier: "any-user-id",
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{},
|
|
allowed: true,
|
|
},
|
|
{
|
|
name: "Sub identifier - case insensitive matching",
|
|
userIdentifier: "ABC12345-6789-0ABC-DEF0-123456789ABC", // Uppercase
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{"abc12345-6789-0abc-def0-123456789abc": {}}, // Lowercase
|
|
allowed: true,
|
|
},
|
|
|
|
// OID claim (Azure AD object ID)
|
|
{
|
|
name: "OID identifier - allowed user",
|
|
userIdentifier: "oid-12345-67890",
|
|
userIdentifierClaim: "oid",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{"oid-12345-67890": {}},
|
|
allowed: true,
|
|
},
|
|
|
|
// UPN claim (Azure AD User Principal Name)
|
|
{
|
|
name: "UPN identifier - allowed user (looks like email but use sub logic)",
|
|
userIdentifier: "user@tenant.onmicrosoft.com",
|
|
userIdentifierClaim: "upn",
|
|
allowedDomains: map[string]struct{}{"example.com": {}}, // Different domain, should be ignored
|
|
allowedUsers: map[string]struct{}{"user@tenant.onmicrosoft.com": {}},
|
|
allowed: true,
|
|
},
|
|
|
|
// Edge cases
|
|
{
|
|
name: "Empty identifier - not allowed",
|
|
userIdentifier: "",
|
|
userIdentifierClaim: "sub",
|
|
allowedDomains: map[string]struct{}{},
|
|
allowedUsers: map[string]struct{}{"some-user": {}},
|
|
allowed: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure TraefikOidc instance for this test case
|
|
tOidc := ts.tOidc
|
|
tOidc.allowedUserDomains = tc.allowedDomains
|
|
tOidc.allowedUsers = tc.allowedUsers
|
|
tOidc.userIdentifierClaim = tc.userIdentifierClaim
|
|
|
|
allowed := tOidc.isAllowedUser(tc.userIdentifier)
|
|
if allowed != tc.allowed {
|
|
t.Errorf("Expected allowed=%v, got %v for userIdentifier=%q with claim=%q",
|
|
tc.allowed, allowed, tc.userIdentifier, tc.userIdentifierClaim)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUserIdentifierClaimExtraction(t *testing.T) {
|
|
// Test that the correct claim is extracted based on userIdentifierClaim config
|
|
tests := []struct {
|
|
name string
|
|
userIdentifierClaim string
|
|
claims map[string]interface{}
|
|
expectedIdentifier string
|
|
shouldFallbackToSub bool
|
|
}{
|
|
{
|
|
name: "Email claim extraction (default)",
|
|
userIdentifierClaim: "email",
|
|
claims: map[string]interface{}{
|
|
"sub": "user-sub-id",
|
|
"email": "user@example.com",
|
|
},
|
|
expectedIdentifier: "user@example.com",
|
|
shouldFallbackToSub: false,
|
|
},
|
|
{
|
|
name: "Sub claim extraction",
|
|
userIdentifierClaim: "sub",
|
|
claims: map[string]interface{}{
|
|
"sub": "user-sub-id",
|
|
"email": "user@example.com",
|
|
},
|
|
expectedIdentifier: "user-sub-id",
|
|
shouldFallbackToSub: false,
|
|
},
|
|
{
|
|
name: "OID claim extraction (Azure AD)",
|
|
userIdentifierClaim: "oid",
|
|
claims: map[string]interface{}{
|
|
"sub": "user-sub-id",
|
|
"email": "user@example.com",
|
|
"oid": "azure-object-id",
|
|
},
|
|
expectedIdentifier: "azure-object-id",
|
|
shouldFallbackToSub: false,
|
|
},
|
|
{
|
|
name: "UPN claim extraction (Azure AD)",
|
|
userIdentifierClaim: "upn",
|
|
claims: map[string]interface{}{
|
|
"sub": "user-sub-id",
|
|
"upn": "user@tenant.onmicrosoft.com",
|
|
},
|
|
expectedIdentifier: "user@tenant.onmicrosoft.com",
|
|
shouldFallbackToSub: false,
|
|
},
|
|
{
|
|
name: "Fallback to sub when configured claim is missing",
|
|
userIdentifierClaim: "email",
|
|
claims: map[string]interface{}{
|
|
"sub": "fallback-sub-id",
|
|
// email is missing
|
|
},
|
|
expectedIdentifier: "fallback-sub-id",
|
|
shouldFallbackToSub: true,
|
|
},
|
|
{
|
|
name: "preferred_username claim extraction",
|
|
userIdentifierClaim: "preferred_username",
|
|
claims: map[string]interface{}{
|
|
"sub": "user-sub-id",
|
|
"preferred_username": "jdoe",
|
|
},
|
|
expectedIdentifier: "jdoe",
|
|
shouldFallbackToSub: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Extract user identifier using the same logic as auth_flow.go
|
|
userIdentifier, _ := tc.claims[tc.userIdentifierClaim].(string)
|
|
usedFallback := false
|
|
|
|
if userIdentifier == "" && tc.userIdentifierClaim != "sub" {
|
|
userIdentifier, _ = tc.claims["sub"].(string)
|
|
usedFallback = true
|
|
}
|
|
|
|
if userIdentifier != tc.expectedIdentifier {
|
|
t.Errorf("Expected identifier %q, got %q", tc.expectedIdentifier, userIdentifier)
|
|
}
|
|
|
|
if usedFallback != tc.shouldFallbackToSub {
|
|
t.Errorf("Expected fallback=%v, got %v", tc.shouldFallbackToSub, usedFallback)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOIDCHandler(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
ts.token = "valid.jwt.token"
|
|
|
|
tests := []struct {
|
|
exchangeCodeForToken func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error)
|
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
|
sessionSetupFunc func(session *sessions.Session)
|
|
name string
|
|
queryParams string
|
|
expectedStatus int
|
|
blacklist bool
|
|
rateLimit bool
|
|
cacheToken bool
|
|
}{
|
|
{
|
|
name: "Missing Code",
|
|
queryParams: "",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims with invalid nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusInternalServerError,
|
|
},
|
|
{
|
|
name: "Missing Nonce in Claims",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims without nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Invalid State Parameter",
|
|
queryParams: "?code=test-code&state=invalid-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "Nonce Mismatch",
|
|
queryParams: "?code=test-code&state=test-csrf-token",
|
|
sessionSetupFunc: func(session *sessions.Session) {
|
|
// Set CSRF and nonce values in session
|
|
session.Values["csrf"] = "test-csrf-token"
|
|
session.Values["nonce"] = "test-nonce"
|
|
},
|
|
exchangeCodeForToken: func(code string, redirectURL string, codeVerifier string) (*TokenResponse, error) {
|
|
// Simulate token exchange
|
|
return &TokenResponse{
|
|
IDToken: ts.token,
|
|
RefreshToken: "test-refresh-token",
|
|
}, nil
|
|
},
|
|
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
|
|
// Simulate extraction of claims with mismatched nonce
|
|
return map[string]interface{}{
|
|
"email": "user@example.com",
|
|
"nonce": "invalid-nonce",
|
|
}, nil
|
|
},
|
|
expectedStatus: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
// Capture range variable
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Reset token blacklist and cache
|
|
ts.tOidc.tokenBlacklist = NewCache() // Use generic cache for blacklist
|
|
ts.tOidc.tokenCache = NewTokenCache()
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
|
|
|
|
// Set up the test case
|
|
if tc.blacklist {
|
|
// Use Set with a duration. Value 'true' is arbitrary.
|
|
ts.tOidc.tokenBlacklist.Set(ts.token, true, 1*time.Hour)
|
|
}
|
|
|
|
if tc.rateLimit {
|
|
// Exceed rate limit
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
|
|
}
|
|
|
|
if tc.cacheToken {
|
|
// Cache the token with dummy claims
|
|
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
|
|
"empty": "claim",
|
|
}, 60)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestHandleLogout tests the logout functionality
|
|
func TestHandleLogout(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create mock revocation endpoint server
|
|
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
// Verify the required parameters are present
|
|
if r.Form.Get("token") == "" {
|
|
t.Error("Missing token parameter")
|
|
}
|
|
if r.Form.Get("token_type_hint") == "" {
|
|
t.Error("Missing token_type_hint parameter")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
defer mockRevocationServer.Close()
|
|
|
|
tests := []struct {
|
|
setupSession func(*SessionData)
|
|
name string
|
|
endSessionURL string
|
|
expectedURL string
|
|
host string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "Successful logout with end session endpoint",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken(ValidAccessToken)
|
|
session.SetIDToken(ValidIDToken)
|
|
session.SetRefreshToken(ValidRefreshToken)
|
|
},
|
|
endSessionURL: "https://provider/end-session",
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "https://provider/end-session?id_token_hint=" + url.QueryEscape(ValidIDToken) + "&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Successful logout without end session endpoint",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken(ValidAccessToken)
|
|
session.SetIDToken(ValidIDToken)
|
|
session.SetRefreshToken(ValidRefreshToken)
|
|
},
|
|
endSessionURL: "",
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "/",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Logout with empty session",
|
|
setupSession: func(session *SessionData) {},
|
|
expectedStatus: http.StatusFound,
|
|
expectedURL: "/",
|
|
host: "test-host",
|
|
},
|
|
{
|
|
name: "Logout with invalid end session URL",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetAccessToken(ValidAccessToken)
|
|
session.SetIDToken(ValidIDToken)
|
|
session.SetRefreshToken(ValidRefreshToken)
|
|
},
|
|
endSessionURL: ":\\invalid-url",
|
|
expectedStatus: http.StatusInternalServerError,
|
|
host: "test-host",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
tOidc := &TraefikOidc{
|
|
revocationURL: mockRevocationServer.URL,
|
|
endSessionURL: tc.endSessionURL,
|
|
logger: logger,
|
|
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
|
httpClient: &http.Client{},
|
|
clientID: "test-client-id",
|
|
audience: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
tokenCache: NewTokenCache(),
|
|
forceHTTPS: false,
|
|
sessionManager: sessionManager,
|
|
}
|
|
|
|
// Create request with proper headers
|
|
req := httptest.NewRequest("GET", "/logout", nil)
|
|
req.Header.Set("Host", tc.host)
|
|
|
|
// Create a response recorder
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Get a session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
if tc.setupSession != nil {
|
|
tc.setupSession(session)
|
|
}
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Handle logout
|
|
tOidc.handleLogout(rr, req)
|
|
|
|
// Check response
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
|
|
if tc.expectedURL != "" {
|
|
location := rr.Header().Get("Location")
|
|
if location != tc.expectedURL {
|
|
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
|
|
}
|
|
}
|
|
|
|
// Verify session is cleared
|
|
updatedSession, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get updated session: %v", err)
|
|
}
|
|
|
|
// Verify tokens are cleared
|
|
if token := updatedSession.GetAccessToken(); token != "" {
|
|
t.Error("Access token not cleared")
|
|
}
|
|
if token := updatedSession.GetRefreshToken(); token != "" {
|
|
t.Error("Refresh token not cleared")
|
|
}
|
|
if updatedSession.GetAuthenticated() {
|
|
t.Error("Session still marked as authenticated")
|
|
}
|
|
|
|
// Check token blacklist
|
|
if token := session.GetAccessToken(); token != "" {
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Access token was not blacklisted in cache")
|
|
}
|
|
}
|
|
if token := session.GetRefreshToken(); token != "" {
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Refresh token was not blacklisted in cache")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRevokeTokenWithProvider tests the token revocation with provider
|
|
func TestRevokeTokenWithProvider(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
tokenType string
|
|
statusCode int
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Successful token revocation",
|
|
token: "valid-token",
|
|
tokenType: "refresh_token",
|
|
statusCode: http.StatusOK,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Failed token revocation",
|
|
token: "invalid-token",
|
|
tokenType: "refresh_token",
|
|
statusCode: http.StatusBadRequest,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create test server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify request method and content type
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
|
|
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
|
|
}
|
|
|
|
// Verify form values
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
if got := r.Form.Get("token"); got != tc.token {
|
|
t.Errorf("Expected token %s, got %s", tc.token, got)
|
|
}
|
|
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
|
|
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
|
|
}
|
|
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
|
|
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
|
|
}
|
|
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
|
|
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
|
|
}
|
|
|
|
w.WriteHeader(tc.statusCode)
|
|
}))
|
|
defer server.Close()
|
|
|
|
// Set revocation URL to test server
|
|
ts.tOidc.revocationURL = server.URL
|
|
|
|
// Test token revocation
|
|
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
|
|
if tc.expectError && err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
if !tc.expectError && err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestRevokeToken tests the token revocation functionality
|
|
func TestRevokeToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
token := "test.token.with.claims"
|
|
claims := map[string]interface{}{
|
|
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
|
}
|
|
|
|
// Test token revocation
|
|
t.Run("Token revocation", func(t *testing.T) {
|
|
// Create a new instance for this specific test
|
|
tOidc := &TraefikOidc{
|
|
tokenBlacklist: NewCache(), // Use generic cache for blacklist
|
|
tokenCache: NewTokenCache(),
|
|
logger: NewLogger("info"), // Initialize the logger
|
|
}
|
|
|
|
// Cache the token
|
|
tOidc.tokenCache.Set(token, claims, time.Hour)
|
|
|
|
// Revoke the token
|
|
tOidc.RevokeToken(token)
|
|
|
|
// Verify token was removed from cache
|
|
if _, exists := tOidc.tokenCache.Get(token); exists {
|
|
t.Error("Token was not removed from cache")
|
|
}
|
|
|
|
// Verify token was added to blacklist cache
|
|
if _, exists := tOidc.tokenBlacklist.Get(token); !exists {
|
|
t.Error("Token was not added to blacklist")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestBuildLogoutURL(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
endSessionURL string
|
|
idToken string
|
|
postLogoutRedirect string
|
|
expectedURL string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Valid URL",
|
|
endSessionURL: "https://provider/end-session",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Invalid URL",
|
|
endSessionURL: "://invalid-url",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "URL with existing query parameters",
|
|
endSessionURL: "https://provider/end-session?existing=param",
|
|
idToken: "test.id.token",
|
|
postLogoutRedirect: "http://example.com/",
|
|
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if url != tc.expectedURL {
|
|
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestHandleExpiredToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
setupSession func(*SessionData)
|
|
expectedPath string
|
|
}{
|
|
{
|
|
name: "Basic expired token",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
// Create an expired token for this test
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
|
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAccessToken(expiredToken)
|
|
session.SetUserIdentifier("test@example.com")
|
|
},
|
|
expectedPath: "/original/path",
|
|
},
|
|
{
|
|
name: "Session with additional values",
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
// Create an expired token for this test
|
|
expiredToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com", "aud": "test-client-id", "exp": time.Now().Add(-1 * time.Hour).Unix(),
|
|
"iat": time.Now().Add(-2 * time.Hour).Unix(), "nbf": time.Now().Add(-2 * time.Hour).Unix(),
|
|
"sub": "test-subject", "email": "test@example.com", "jti": generateRandomString(16),
|
|
})
|
|
session.SetAccessToken(expiredToken)
|
|
session.mainSession.Values["custom_value"] = "should-be-cleared"
|
|
},
|
|
expectedPath: "/another/path",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
logger := NewLogger("info")
|
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, "", "", 0, logger)
|
|
|
|
tOidc := &TraefikOidc{
|
|
sessionManager: sessionManager,
|
|
logger: logger,
|
|
tokenVerifier: ts.tOidc.tokenVerifier,
|
|
jwtVerifier: ts.tOidc.jwtVerifier,
|
|
initComplete: make(chan struct{}),
|
|
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
|
http.Redirect(rw, req, "/login", http.StatusFound)
|
|
},
|
|
}
|
|
close(tOidc.initComplete)
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", tc.expectedPath, nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Get session
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Setup session data
|
|
tc.setupSession(session)
|
|
|
|
// Handle expired token
|
|
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
|
|
|
|
// Get the updated session to verify changes
|
|
updatedSession, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get updated session: %v", err)
|
|
}
|
|
|
|
// Verify main session values
|
|
if updatedSession.GetCSRF() == "" {
|
|
t.Error("CSRF token not set")
|
|
}
|
|
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
|
|
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
|
|
}
|
|
if updatedSession.GetNonce() == "" {
|
|
t.Error("Nonce not set")
|
|
}
|
|
|
|
// Verify tokens are cleared
|
|
if token := updatedSession.GetAccessToken(); token != "" {
|
|
t.Error("Access token not cleared")
|
|
}
|
|
if token := updatedSession.GetRefreshToken(); token != "" {
|
|
t.Error("Refresh token not cleared")
|
|
}
|
|
|
|
// Verify redirect status
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Add this new test function
|
|
func TestExtractGroupsAndRoles(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claims map[string]interface{}
|
|
expectGroups []string
|
|
expectRoles []string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Valid groups and roles",
|
|
claims: map[string]interface{}{
|
|
"groups": []interface{}{"group1", "group2"},
|
|
"roles": []interface{}{"role1", "role2"},
|
|
},
|
|
expectGroups: []string{"group1", "group2"},
|
|
expectRoles: []string{"role1", "role2"},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Empty groups and roles",
|
|
claims: map[string]interface{}{
|
|
"groups": []interface{}{},
|
|
"roles": []interface{}{},
|
|
},
|
|
expectGroups: []string{},
|
|
expectRoles: []string{},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Invalid groups format",
|
|
claims: map[string]interface{}{
|
|
"groups": "not-an-array",
|
|
"roles": []interface{}{"role1"},
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create a test token with the claims
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
|
|
// Compare groups
|
|
if !stringSliceEqual(groups, tc.expectGroups) {
|
|
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
|
|
}
|
|
|
|
// Compare roles
|
|
if !stringSliceEqual(roles, tc.expectRoles) {
|
|
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
|
|
// can be created and initialized properly for different routes
|
|
func TestMultipleMiddlewareInstances(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping test in short mode")
|
|
}
|
|
|
|
// Create mock provider metadata server. Issuer + endpoints must share the
|
|
// host with ProviderURL (the httptest server), otherwise the discovery doc
|
|
// is rejected as poisoned (audit ranks 21/22). Derive them from the server.
|
|
var mockServer *httptest.Server
|
|
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
metadata := ProviderMetadata{
|
|
Issuer: mockServer.URL,
|
|
AuthURL: mockServer.URL + "/auth",
|
|
TokenURL: mockServer.URL + "/token",
|
|
JWKSURL: mockServer.URL + "/jwks",
|
|
RevokeURL: mockServer.URL + "/revoke",
|
|
EndSessionURL: mockServer.URL + "/end-session",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
// Create base config
|
|
config := &Config{
|
|
ProviderURL: mockServer.URL,
|
|
ClientID: "test-client",
|
|
ClientSecret: "test-secret",
|
|
CallbackURL: "/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
RateLimit: 100,
|
|
}
|
|
|
|
// Create multiple middleware instances
|
|
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
|
|
var middlewares []*TraefikOidc
|
|
|
|
for _, route := range routes {
|
|
config.CallbackURL = route + "/callback"
|
|
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config, "test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
|
|
}
|
|
|
|
// Type assert to access internal fields
|
|
if m, ok := middleware.(*TraefikOidc); ok {
|
|
middlewares = append(middlewares, m)
|
|
} else {
|
|
t.Fatalf("Middleware is not of type *TraefikOidc")
|
|
}
|
|
}
|
|
|
|
// Clean up all middleware instances to prevent goroutine leaks
|
|
defer func() {
|
|
for i, m := range middlewares {
|
|
if err := m.Close(); err != nil {
|
|
t.Errorf("Failed to close middleware instance %d: %v", i, err)
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Wait for all instances to initialize
|
|
for i, m := range middlewares {
|
|
select {
|
|
case <-m.initComplete:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Middleware instance %d failed to initialize", i)
|
|
}
|
|
|
|
// Verify each instance has its own unique configuration. Issuer is now
|
|
// pinned to the provider host (audit ranks 21/22), so it equals the
|
|
// mock server URL rather than a fixed literal.
|
|
if m.issuerURL != mockServer.URL {
|
|
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, mockServer.URL, m.issuerURL)
|
|
}
|
|
if m.authURL != mockServer.URL+"/auth" {
|
|
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, mockServer.URL+"/auth", m.authURL)
|
|
}
|
|
if m.tokenURL != mockServer.URL+"/token" {
|
|
t.Errorf("Instance %d: Expected token URL %s, got %s", i, mockServer.URL+"/token", m.tokenURL)
|
|
}
|
|
if m.jwksURL != mockServer.URL+"/jwks" {
|
|
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, mockServer.URL+"/jwks", m.jwksURL)
|
|
}
|
|
if m.redirURLPath != routes[i]+"/callback" {
|
|
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
|
|
}
|
|
}
|
|
|
|
// Test that each instance can handle requests independently
|
|
for i, m := range middlewares {
|
|
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rr, req)
|
|
|
|
// Should redirect (302) to the auth flow since not authenticated. The
|
|
// absolute auth URL is not asserted here: with issuer pinning (audit
|
|
// ranks 21/22) the discovery host equals the httptest server host,
|
|
// which is loopback, so buildAuthURL's SSRF guard legitimately refuses
|
|
// to emit a loopback authorization URL in this test environment. The
|
|
// per-instance auth/token/jwks/issuer URLs were already verified above;
|
|
// here we only confirm each instance independently triggers a redirect.
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestMultiRealmMetadataRefreshIsolation verifies that multiple middleware instances
|
|
// with different provider URLs (e.g., different Keycloak realms) get separate
|
|
// metadata refresh tasks. This addresses the issue reported in PR #88.
|
|
func TestMultiRealmMetadataRefreshIsolation(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping test in short mode")
|
|
}
|
|
|
|
// Create two mock provider metadata servers simulating different Keycloak realms
|
|
// Issuer + endpoints must share the host with each realm's ProviderURL
|
|
// (the httptest server), otherwise the discovery doc is rejected as
|
|
// poisoned (audit ranks 21/22). Keep the distinguishing /realms/realmN
|
|
// path so the per-realm isolation assertions below still hold, but base
|
|
// the host on the server URL — which is exactly what a same-host Keycloak
|
|
// deployment looks like.
|
|
var realm1Server *httptest.Server
|
|
realm1Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
base := realm1Server.URL + "/realms/realm1"
|
|
metadata := ProviderMetadata{
|
|
Issuer: base,
|
|
AuthURL: base + "/protocol/openid-connect/auth",
|
|
TokenURL: base + "/protocol/openid-connect/token",
|
|
JWKSURL: base + "/protocol/openid-connect/certs",
|
|
EndSessionURL: base + "/protocol/openid-connect/logout",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
}))
|
|
defer realm1Server.Close()
|
|
|
|
var realm2Server *httptest.Server
|
|
realm2Server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
base := realm2Server.URL + "/realms/realm2"
|
|
metadata := ProviderMetadata{
|
|
Issuer: base,
|
|
AuthURL: base + "/protocol/openid-connect/auth",
|
|
TokenURL: base + "/protocol/openid-connect/token",
|
|
JWKSURL: base + "/protocol/openid-connect/certs",
|
|
EndSessionURL: base + "/protocol/openid-connect/logout",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
}))
|
|
defer realm2Server.Close()
|
|
|
|
// Config for realm1
|
|
config1 := &Config{
|
|
ProviderURL: realm1Server.URL,
|
|
ClientID: "realm1-client",
|
|
ClientSecret: "realm1-secret",
|
|
CallbackURL: "/realm1/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
CookiePrefix: "_oidc_realm1_",
|
|
RateLimit: 100,
|
|
}
|
|
|
|
// Config for realm2
|
|
config2 := &Config{
|
|
ProviderURL: realm2Server.URL,
|
|
ClientID: "realm2-client",
|
|
ClientSecret: "realm2-secret",
|
|
CallbackURL: "/realm2/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
CookiePrefix: "_oidc_realm2_",
|
|
RateLimit: 100,
|
|
}
|
|
|
|
// Create middleware instances for both realms
|
|
middleware1, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config1, "realm1-middleware")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware for realm1: %v", err)
|
|
}
|
|
|
|
middleware2, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config2, "realm2-middleware")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware for realm2: %v", err)
|
|
}
|
|
|
|
m1, ok1 := middleware1.(*TraefikOidc)
|
|
m2, ok2 := middleware2.(*TraefikOidc)
|
|
if !ok1 || !ok2 {
|
|
t.Fatalf("Middleware is not of type *TraefikOidc")
|
|
}
|
|
|
|
// Clean up middleware instances
|
|
defer func() {
|
|
if err := m1.Close(); err != nil {
|
|
t.Errorf("Failed to close realm1 middleware: %v", err)
|
|
}
|
|
if err := m2.Close(); err != nil {
|
|
t.Errorf("Failed to close realm2 middleware: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Wait for both instances to initialize
|
|
select {
|
|
case <-m1.initComplete:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Realm1 middleware failed to initialize")
|
|
}
|
|
|
|
select {
|
|
case <-m2.initComplete:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Realm2 middleware failed to initialize")
|
|
}
|
|
|
|
// Verify each instance has the correct issuer URL from their respective realms
|
|
if !strings.Contains(m1.issuerURL, "realm1") {
|
|
t.Errorf("Realm1 middleware expected issuer with realm1, got %s", m1.issuerURL)
|
|
}
|
|
if !strings.Contains(m2.issuerURL, "realm2") {
|
|
t.Errorf("Realm2 middleware expected issuer with realm2, got %s", m2.issuerURL)
|
|
}
|
|
|
|
// Verify provider URLs are different
|
|
if m1.providerURL == m2.providerURL {
|
|
t.Errorf("Both middlewares should have different provider URLs, got same: %s", m1.providerURL)
|
|
}
|
|
|
|
// Test that each middleware can handle requests independently
|
|
req1 := httptest.NewRequest("GET", "/realm1/protected", nil)
|
|
rr1 := httptest.NewRecorder()
|
|
m1.ServeHTTP(rr1, req1)
|
|
|
|
req2 := httptest.NewRequest("GET", "/realm2/protected", nil)
|
|
rr2 := httptest.NewRecorder()
|
|
m2.ServeHTTP(rr2, req2)
|
|
|
|
// Both should redirect to their respective auth URLs
|
|
if rr1.Code != http.StatusFound {
|
|
t.Errorf("Realm1: Expected redirect status %d, got %d", http.StatusFound, rr1.Code)
|
|
}
|
|
if rr2.Code != http.StatusFound {
|
|
t.Errorf("Realm2: Expected redirect status %d, got %d", http.StatusFound, rr2.Code)
|
|
}
|
|
|
|
location1 := rr1.Header().Get("Location")
|
|
location2 := rr2.Header().Get("Location")
|
|
|
|
if !strings.Contains(location1, "realm1") {
|
|
t.Errorf("Realm1: Expected redirect to realm1 auth URL, got %s", location1)
|
|
}
|
|
if !strings.Contains(location2, "realm2") {
|
|
t.Errorf("Realm2: Expected redirect to realm2 auth URL, got %s", location2)
|
|
}
|
|
}
|
|
|
|
// TestMetadataRecoveryOnProviderFailure verifies that the middleware automatically
|
|
// recovers when the OIDC provider becomes available after initial failure.
|
|
func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping test in short mode")
|
|
}
|
|
|
|
// Track whether the provider is "available"
|
|
providerAvailable := false
|
|
var mu sync.Mutex
|
|
|
|
// Create mock provider that initially fails, then becomes available.
|
|
// Issuer + endpoints must share the host with ProviderURL (audit ranks
|
|
// 21/22), so derive them from the server URL.
|
|
var mockServer *httptest.Server
|
|
mockServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
mu.Lock()
|
|
available := providerAvailable
|
|
mu.Unlock()
|
|
|
|
if !available {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
if r.URL.Path == "/.well-known/openid-configuration" {
|
|
metadata := ProviderMetadata{
|
|
Issuer: mockServer.URL,
|
|
AuthURL: mockServer.URL + "/auth",
|
|
TokenURL: mockServer.URL + "/token",
|
|
JWKSURL: mockServer.URL + "/jwks",
|
|
EndSessionURL: mockServer.URL + "/logout",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusNotFound)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
config := &Config{
|
|
ProviderURL: mockServer.URL,
|
|
ClientID: "test-client",
|
|
ClientSecret: "test-secret",
|
|
CallbackURL: "/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
RateLimit: 100,
|
|
}
|
|
|
|
// Create middleware while provider is unavailable
|
|
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config, "test-recovery")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware: %v", err)
|
|
}
|
|
|
|
m, ok := middleware.(*TraefikOidc)
|
|
if !ok {
|
|
t.Fatalf("Middleware is not of type *TraefikOidc")
|
|
}
|
|
defer m.Close()
|
|
|
|
// Wait for initial initialization to complete (it should fail)
|
|
select {
|
|
case <-m.initComplete:
|
|
case <-time.After(15 * time.Second):
|
|
t.Fatal("Initialization did not complete in time")
|
|
}
|
|
|
|
// Verify initial state - should be in failed state (no issuerURL)
|
|
m.metadataMu.RLock()
|
|
initialIssuer := m.issuerURL
|
|
m.metadataMu.RUnlock()
|
|
|
|
if initialIssuer != "" {
|
|
t.Errorf("Expected empty issuerURL after failed init, got: %s", initialIssuer)
|
|
}
|
|
|
|
// First request should get 503
|
|
req1 := httptest.NewRequest("GET", "/protected", nil)
|
|
rr1 := httptest.NewRecorder()
|
|
m.ServeHTTP(rr1, req1)
|
|
|
|
if rr1.Code != http.StatusServiceUnavailable {
|
|
t.Errorf("Expected 503 when provider unavailable, got %d", rr1.Code)
|
|
}
|
|
|
|
// Now make the provider available
|
|
mu.Lock()
|
|
providerAvailable = true
|
|
mu.Unlock()
|
|
|
|
// Reset the retry timer to allow immediate retry. The field is atomic
|
|
// now, so no lock is needed.
|
|
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
|
|
|
|
// Second request should trigger recovery attempt
|
|
req2 := httptest.NewRequest("GET", "/protected", nil)
|
|
rr2 := httptest.NewRecorder()
|
|
m.ServeHTTP(rr2, req2)
|
|
|
|
// Give the async recovery a moment to complete
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// Check if recovery happened
|
|
m.metadataMu.RLock()
|
|
recoveredIssuer := m.issuerURL
|
|
m.metadataMu.RUnlock()
|
|
|
|
if recoveredIssuer == "" {
|
|
t.Error("Expected issuerURL to be recovered after provider became available")
|
|
}
|
|
|
|
// Third request should succeed (redirect to auth, not 503)
|
|
req3 := httptest.NewRequest("GET", "/protected", nil)
|
|
rr3 := httptest.NewRecorder()
|
|
m.ServeHTTP(rr3, req3)
|
|
|
|
if rr3.Code == http.StatusServiceUnavailable {
|
|
t.Errorf("Expected redirect after recovery, still got 503")
|
|
}
|
|
|
|
t.Logf("Recovery test: initial_issuer=%q, recovered_issuer=%q, final_status=%d",
|
|
initialIssuer, recoveredIssuer, rr3.Code)
|
|
}
|
|
|
|
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create consistent timestamps for all test cases
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
nbf := now.Add(-2 * time.Minute).Unix() // Account for clock skew
|
|
|
|
tests := []struct {
|
|
allowedRolesAndGroups map[string]struct{}
|
|
claims map[string]interface{}
|
|
setupSession func(*SessionData)
|
|
expectedHeaders map[string]string
|
|
name string
|
|
expectedStatus int
|
|
}{
|
|
{
|
|
name: "User with allowed role",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"admin": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"admin", "user"},
|
|
"groups": []interface{}{"group1"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "admin,user",
|
|
"X-User-Groups": "group1",
|
|
},
|
|
},
|
|
{
|
|
name: "User with allowed group",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"allowed-group": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"allowed-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "user",
|
|
"X-User-Groups": "allowed-group",
|
|
},
|
|
},
|
|
{
|
|
name: "User without allowed roles or groups",
|
|
allowedRolesAndGroups: map[string]struct{}{
|
|
"admin": {},
|
|
"allowed-group": {},
|
|
},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"regular-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusForbidden,
|
|
},
|
|
{
|
|
name: "No role/group restrictions",
|
|
allowedRolesAndGroups: map[string]struct{}{},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"roles": []interface{}{"user"},
|
|
"groups": []interface{}{"regular-group"},
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{
|
|
"X-User-Roles": "user",
|
|
"X-User-Groups": "regular-group",
|
|
},
|
|
},
|
|
{
|
|
name: "Claims without roles and groups",
|
|
allowedRolesAndGroups: map[string]struct{}{},
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"jti": generateRandomString(16),
|
|
},
|
|
setupSession: func(session *SessionData) {
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier("user@example.com")
|
|
},
|
|
expectedStatus: http.StatusOK,
|
|
expectedHeaders: map[string]string{},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create token with claims
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test token: %v", err)
|
|
}
|
|
|
|
// Create test handler
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Configure OIDC middleware
|
|
tOidc := ts.tOidc
|
|
tOidc.next = nextHandler
|
|
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", "/protected", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Set up session
|
|
session, err := tOidc.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
tc.setupSession(session)
|
|
session.SetAccessToken(token)
|
|
|
|
if err := session.Save(req, rr); err != nil {
|
|
t.Fatalf("Failed to save session: %v", err)
|
|
}
|
|
|
|
// Copy cookies to the new request
|
|
for _, cookie := range rr.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
// Reset response recorder
|
|
rr = httptest.NewRecorder()
|
|
|
|
// Serve request
|
|
tOidc.ServeHTTP(rr, req)
|
|
|
|
// Check status code
|
|
if rr.Code != tc.expectedStatus {
|
|
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
|
}
|
|
|
|
// Check headers if status is OK
|
|
if tc.expectedStatus == http.StatusOK {
|
|
for header, expectedValue := range tc.expectedHeaders {
|
|
if value := req.Header.Get(header); value != expectedValue {
|
|
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Helper function to compare string slices
|
|
func stringSliceEqual(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// TestExchangeTokensWithRedirects tests the token exchange process with redirects
|
|
func TestExchangeTokensWithRedirects(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
setupServer func() *httptest.Server
|
|
name string
|
|
errorContains string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Successful token exchange with redirects",
|
|
setupServer: func() *httptest.Server {
|
|
redirectCount := 0
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if redirectCount < 3 {
|
|
// Set a cookie before redirecting
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: fmt.Sprintf("redirect-cookie-%d", redirectCount),
|
|
Value: "test-value",
|
|
})
|
|
redirectCount++
|
|
w.Header().Set("Location", r.URL.String())
|
|
w.WriteHeader(http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// Verify all cookies from previous redirects are present
|
|
cookies := r.Cookies()
|
|
if len(cookies) != 3 {
|
|
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
|
}
|
|
for i := range 3 {
|
|
found := false
|
|
expectedName := fmt.Sprintf("redirect-cookie-%d", i)
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == expectedName {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Cookie %s not found", expectedName)
|
|
}
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Too many redirects",
|
|
setupServer: func() *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Location", r.URL.String())
|
|
w.WriteHeader(http.StatusFound)
|
|
}))
|
|
},
|
|
expectError: true,
|
|
errorContains: "stopped after 50 redirects",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := tc.setupServer()
|
|
defer server.Close()
|
|
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.tokenURL = server.URL
|
|
|
|
// Test token exchange
|
|
response, err := tOidc.exchangeTokens(context.Background(), "authorization_code", "test-code", "http://callback", "test-code-verifier")
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Error("Expected error but got nil")
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing %q, got %q", tc.errorContains, err.Error())
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if response == nil {
|
|
t.Error("Expected token response but got nil")
|
|
} else if response.IDToken != "test.id.token" {
|
|
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBuildAuthURL tests the buildAuthURL function with various URL scenarios
|
|
func TestBuildAuthURL(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
authURL string
|
|
issuerURL string
|
|
redirectURL string
|
|
state string
|
|
nonce string
|
|
codeChallenge string
|
|
expectedPrefix string
|
|
enablePKCE bool
|
|
checkPKCE bool
|
|
}{
|
|
{
|
|
name: "Absolute Auth URL",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "Relative Auth URL",
|
|
authURL: "/oidc/auth",
|
|
issuerURL: "https://logto.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://logto.example.com/oidc/auth?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "Relative Auth URL with Different Issuer",
|
|
authURL: "/sign-in",
|
|
issuerURL: "https://auth.example.com:8443",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com:8443/sign-in?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "With PKCE Enabled",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: true,
|
|
codeChallenge: "test-code-challenge",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: true,
|
|
},
|
|
{
|
|
name: "With PKCE Enabled but No Challenge",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: true,
|
|
codeChallenge: "",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
{
|
|
name: "With PKCE Disabled but Challenge Provided",
|
|
authURL: "https://auth.example.com/oauth/authorize",
|
|
issuerURL: "https://auth.example.com",
|
|
redirectURL: "https://app.example.com/callback",
|
|
state: "test-state",
|
|
nonce: "test-nonce",
|
|
enablePKCE: false,
|
|
codeChallenge: "test-code-challenge",
|
|
expectedPrefix: "https://auth.example.com/oauth/authorize?",
|
|
checkPKCE: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.authURL = tc.authURL
|
|
tOidc.issuerURL = tc.issuerURL
|
|
tOidc.enablePKCE = tc.enablePKCE
|
|
|
|
// Call buildAuthURL with code challenge
|
|
result := tOidc.buildAuthURL(tc.redirectURL, tc.state, tc.nonce, tc.codeChallenge)
|
|
|
|
// Verify the URL starts with the expected prefix
|
|
if !strings.HasPrefix(result, tc.expectedPrefix) {
|
|
t.Errorf("Expected URL to start with %q, got %q", tc.expectedPrefix, result)
|
|
}
|
|
|
|
// Parse the resulting URL to verify query parameters
|
|
parsedURL, err := url.Parse(result)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse resulting URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
expectedParams := map[string]string{
|
|
"client_id": tOidc.clientID,
|
|
"response_type": "code",
|
|
"redirect_uri": tc.redirectURL,
|
|
"state": tc.state,
|
|
"nonce": tc.nonce,
|
|
}
|
|
|
|
for key, expected := range expectedParams {
|
|
if got := query.Get(key); got != expected {
|
|
t.Errorf("Expected %s=%q, got %q", key, expected, got)
|
|
}
|
|
}
|
|
|
|
// Verify PKCE parameters
|
|
if tc.checkPKCE {
|
|
if got := query.Get("code_challenge"); got != tc.codeChallenge {
|
|
t.Errorf("Expected code_challenge=%q, got %q", tc.codeChallenge, got)
|
|
}
|
|
if got := query.Get("code_challenge_method"); got != "S256" {
|
|
t.Errorf("Expected code_challenge_method=%q, got %q", "S256", got)
|
|
}
|
|
} else {
|
|
if got := query.Get("code_challenge"); got != "" {
|
|
t.Errorf("Expected no code_challenge, but got %q", got)
|
|
}
|
|
if got := query.Get("code_challenge_method"); got != "" {
|
|
t.Errorf("Expected no code_challenge_method, but got %q", got)
|
|
}
|
|
}
|
|
|
|
// Verify scopes are present and correct
|
|
if len(tOidc.scopes) > 0 {
|
|
expectedScopes := strings.Join(tOidc.scopes, " ")
|
|
if got := query.Get("scope"); got != expectedScopes {
|
|
t.Errorf("Expected scope=%q, got %q", expectedScopes, got)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestExchangeCodeForToken tests the exchangeCodeForToken function with PKCE support
|
|
func TestExchangeCodeForToken(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
setupMock func(t *testing.T) *httptest.Server
|
|
name string
|
|
codeVerifier string
|
|
enablePKCE bool
|
|
}{
|
|
{
|
|
name: "With PKCE Enabled and Code Verifier",
|
|
enablePKCE: true,
|
|
codeVerifier: "test-code-verifier",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "test-code-verifier" {
|
|
t.Errorf("Expected code_verifier=test-code-verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
{
|
|
name: "With PKCE Disabled but Code Verifier Provided",
|
|
enablePKCE: false,
|
|
codeVerifier: "test-code-verifier",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is NOT included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
{
|
|
name: "With PKCE Enabled but No Code Verifier",
|
|
enablePKCE: true,
|
|
codeVerifier: "",
|
|
setupMock: func(t *testing.T) *httptest.Server {
|
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
t.Fatalf("Failed to parse form: %v", err)
|
|
}
|
|
|
|
// Verify code_verifier is NOT included
|
|
if codeVerifier := r.Form.Get("code_verifier"); codeVerifier != "" {
|
|
t.Errorf("Expected no code_verifier, got %s", codeVerifier)
|
|
}
|
|
|
|
// Return successful token response
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(TokenResponse{
|
|
IDToken: "test.id.token",
|
|
AccessToken: "test-access-token",
|
|
TokenType: "Bearer",
|
|
ExpiresIn: 3600,
|
|
RefreshToken: "test-refresh-token",
|
|
})
|
|
}))
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
server := tc.setupMock(t)
|
|
defer server.Close()
|
|
|
|
// Configure the test instance
|
|
tOidc := ts.tOidc
|
|
tOidc.tokenURL = server.URL
|
|
tOidc.enablePKCE = tc.enablePKCE
|
|
|
|
// Test exchangeCodeForToken
|
|
response, err := tOidc.exchangeCodeForToken("test-code", "http://callback", tc.codeVerifier)
|
|
if err != nil {
|
|
t.Errorf("Unexpected error: %v", err)
|
|
}
|
|
if response == nil {
|
|
t.Error("Expected token response but got nil")
|
|
} else if response.IDToken != "test.id.token" {
|
|
t.Errorf("Expected ID token %q, got %q", "test.id.token", response.IDToken)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDefaultInitiateAuthentication_PreservesQueryParameters tests that defaultInitiateAuthentication preserves query parameters in the incoming path.
|
|
func TestDefaultInitiateAuthentication_PreservesQueryParameters(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Create a request with query parameters
|
|
req := httptest.NewRequest("GET", "/protected/resource?param1=value1¶m2=value2", nil)
|
|
responseRecorder := httptest.NewRecorder()
|
|
|
|
// Get session
|
|
session, err := ts.sessionManager.GetSession(req)
|
|
if err != nil {
|
|
t.Fatalf("Failed to get session: %v", err)
|
|
}
|
|
|
|
// Call defaultInitiateAuthentication
|
|
redirectURL := "http://example.com/callback"
|
|
ts.tOidc.defaultInitiateAuthentication(responseRecorder, req, session, redirectURL)
|
|
|
|
// Verify that the incoming path includes query parameters
|
|
incomingPath := session.GetIncomingPath()
|
|
expectedPath := "/protected/resource?param1=value1¶m2=value2"
|
|
if incomingPath != expectedPath {
|
|
t.Errorf("Expected incoming path to be '%s', got '%s'", expectedPath, incomingPath)
|
|
}
|
|
}
|
|
|
|
// TestVerifyTimeConstraint tests the time constraint verification logic with separate past/future skew tolerances.
|
|
func TestVerifyTimeConstraint(t *testing.T) {
|
|
// Define tolerances used in jwt.go (ensure they match)
|
|
toleranceFuture := 2 * time.Minute
|
|
tolerancePast := 10 * time.Second
|
|
|
|
now := time.Now()
|
|
|
|
tests := []struct {
|
|
name string
|
|
claimTime time.Time
|
|
claimName string
|
|
futureCheck bool // true for exp, false for iat/nbf
|
|
expectError bool
|
|
}{
|
|
// Expiration (future=true, tolerance=2min)
|
|
{
|
|
name: "EXP: Valid (expires in 1 min)",
|
|
claimTime: now.Add(1 * time.Minute),
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "EXP: Expired (expired 3 min ago)",
|
|
claimTime: now.Add(-3 * time.Minute), // Outside 2min tolerance
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "EXP: Valid (expired 1 min ago, within 2min tolerance)",
|
|
claimTime: now.Add(-1 * time.Minute), // Inside 2min tolerance
|
|
claimName: "exp",
|
|
futureCheck: true,
|
|
expectError: false, // Should be allowed due to future tolerance
|
|
},
|
|
|
|
// Issued At (future=false, tolerance=10s)
|
|
{
|
|
name: "IAT: Valid (issued 1 min ago)",
|
|
claimTime: now.Add(-1 * time.Minute),
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "IAT: Invalid (issued 15 sec in future)",
|
|
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: true, // "token used before issued"
|
|
},
|
|
{
|
|
name: "IAT: Valid (issued 5 sec in future, within 10s tolerance)",
|
|
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
|
claimName: "iat",
|
|
futureCheck: false,
|
|
expectError: false, // Should be allowed due to past tolerance
|
|
},
|
|
|
|
// Not Before (future=false, tolerance=10s)
|
|
{
|
|
name: "NBF: Valid (active 1 min ago)",
|
|
claimTime: now.Add(-1 * time.Minute),
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "NBF: Invalid (active in 15 sec)",
|
|
claimTime: now.Add(15 * time.Second), // Outside 10s past tolerance
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: true, // "token not yet valid"
|
|
},
|
|
{
|
|
name: "NBF: Valid (active in 5 sec, within 10s tolerance)",
|
|
claimTime: now.Add(5 * time.Second), // Inside 10s past tolerance
|
|
claimName: "nbf",
|
|
futureCheck: false,
|
|
expectError: false, // Should be allowed due to past tolerance
|
|
},
|
|
}
|
|
|
|
// Temporarily adjust global tolerances for test consistency, then restore
|
|
originalFutureTolerance := ClockSkewToleranceFuture
|
|
originalPastTolerance := ClockSkewTolerancePast
|
|
ClockSkewToleranceFuture = toleranceFuture
|
|
ClockSkewTolerancePast = tolerancePast
|
|
defer func() {
|
|
ClockSkewToleranceFuture = originalFutureTolerance
|
|
ClockSkewTolerancePast = originalPastTolerance
|
|
}()
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Convert claim time to float64 unix timestamp
|
|
unixTime := float64(tc.claimTime.Unix()) + float64(tc.claimTime.Nanosecond())/1e9
|
|
|
|
var err error
|
|
// Call the specific verification function which uses verifyTimeConstraint
|
|
if tc.claimName == "exp" {
|
|
err = verifyExpiration(unixTime)
|
|
} else if tc.claimName == "iat" {
|
|
err = verifyIssuedAt(unixTime)
|
|
} else if tc.claimName == "nbf" {
|
|
err = verifyNotBefore(unixTime)
|
|
} else {
|
|
t.Fatalf("Unknown claim name in test setup: %s", tc.claimName)
|
|
}
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error for claim %s at time %v (now=%v), but got nil", tc.claimName, tc.claimTime, now)
|
|
} else {
|
|
t.Logf("Got expected error: %v", err) // Log the error for confirmation
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Expected no error for claim %s at time %v (now=%v), but got: %v", tc.claimName, tc.claimTime, now, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
} // Add missing closing brace for TestVerifyTimeConstraint
|
|
|
|
// ===== JWT REPLAY DETECTION TESTS =====
|
|
// These tests ensure the replay detection fix works correctly and prevents regressions
|
|
|
|
// TestJWTVerifyWithSkipReplayCheck tests the new skipReplayCheck parameter functionality
|
|
func TestJWTVerifyWithSkipReplayCheck(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache before test
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT: %v", err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
errorContains string
|
|
skipReplayCheck bool
|
|
firstCall bool
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "First verification with skipReplayCheck=false should succeed",
|
|
skipReplayCheck: false,
|
|
firstCall: true,
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Second verification with skipReplayCheck=false should fail (replay detected)",
|
|
skipReplayCheck: false,
|
|
firstCall: false,
|
|
expectError: true,
|
|
errorContains: "token replay detected",
|
|
},
|
|
{
|
|
name: "Verification with skipReplayCheck=true should always succeed",
|
|
skipReplayCheck: true,
|
|
firstCall: false, // Even on subsequent calls
|
|
expectError: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
if tc.firstCall {
|
|
// Clear replay cache for first call tests
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
}
|
|
|
|
err := jwt.Verify("https://test-issuer.com", "test-client-id", tc.skipReplayCheck)
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error containing '%s', but got nil", tc.errorContains)
|
|
} else if !strings.Contains(err.Error(), tc.errorContains) {
|
|
t.Errorf("Expected error containing '%s', got '%v'", tc.errorContains, err)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Expected no error, but got: %v", err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestJWTVerifyBackwardCompatibility tests that calls without the skipReplayCheck parameter default to replay checking
|
|
func TestJWTVerifyBackwardCompatibility(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT: %v", err)
|
|
}
|
|
|
|
// First call with old signature (no skipReplayCheck parameter) should succeed
|
|
err = jwt.Verify("https://test-issuer.com", "test-client-id")
|
|
if err != nil {
|
|
t.Errorf("First verification should succeed, got: %v", err)
|
|
}
|
|
|
|
// Second call with old signature should fail due to replay detection
|
|
err = jwt.Verify("https://test-issuer.com", "test-client-id")
|
|
if err == nil {
|
|
t.Error("Second verification should fail due to replay detection")
|
|
} else if !strings.Contains(err.Error(), "token replay detected") {
|
|
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestTokenReplayDetectionFalsePositiveFix tests the specific scenario that was causing false positives
|
|
func TestTokenReplayDetectionFalsePositiveFix(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Simulate the authentication flow that was causing false positives:
|
|
// 1. Initial authentication adds JTI to cache
|
|
// 2. Subsequent request validation should not trigger false positive
|
|
|
|
// Step 1: Initial authentication (this would add JTI to cache)
|
|
jwt1, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT for initial auth: %v", err)
|
|
}
|
|
|
|
err = jwt1.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
|
|
if err != nil {
|
|
t.Fatalf("Initial authentication should succeed: %v", err)
|
|
}
|
|
|
|
// Step 2: Subsequent request validation (this should skip replay check to avoid false positive)
|
|
jwt2, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT for subsequent request: %v", err)
|
|
}
|
|
|
|
err = jwt2.Verify("https://test-issuer.com", "test-client-id", true) // Skip replay check
|
|
if err != nil {
|
|
t.Errorf("Subsequent request validation should succeed with skipReplayCheck=true: %v", err)
|
|
}
|
|
|
|
// Step 3: Verify that actual replay attacks are still detected
|
|
jwt3, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
|
|
}
|
|
|
|
err = jwt3.Verify("https://test-issuer.com", "test-client-id", false) // Normal replay check
|
|
if err == nil {
|
|
t.Error("Actual replay attack should be detected when skipReplayCheck=false")
|
|
} else if !strings.Contains(err.Error(), "token replay detected") {
|
|
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestAuthenticationFlowReplayDetection tests the complete authentication flow
|
|
func TestAuthenticationFlowReplayDetection(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Test the complete flow:
|
|
// 1. Initial authentication (should add JTI to cache)
|
|
// 2. Multiple subsequent requests (should not trigger false positives)
|
|
// 3. Actual replay attack from different source (should be detected)
|
|
|
|
// Step 1: Initial authentication
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Fatalf("Initial authentication should succeed: %v", err)
|
|
}
|
|
|
|
// Verify JTI is in cache (use shardedReplayCache which is the actual cache used)
|
|
exists := shardedReplayCache.Exists(jti)
|
|
if !exists {
|
|
t.Error("JTI should be added to replay cache during initial authentication")
|
|
}
|
|
|
|
// Step 2: Subsequent requests (simulate normal request processing)
|
|
// These should use the token cache and skip replay detection
|
|
for i := range 3 {
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Errorf("Subsequent request %d should succeed: %v", i+1, err)
|
|
}
|
|
}
|
|
|
|
// Step 3: Simulate actual replay attack by directly calling JWT.Verify with replay check
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT for replay attack test: %v", err)
|
|
}
|
|
|
|
err = jwt.Verify("https://test-issuer.com", "test-client-id", false) // Force replay check
|
|
if err == nil {
|
|
t.Error("Actual replay attack should be detected")
|
|
} else if !strings.Contains(err.Error(), "token replay detected") {
|
|
t.Errorf("Expected 'token replay detected' error, got: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestActualReplayAttackDetection ensures real replay attacks are still properly detected
|
|
func TestActualReplayAttackDetection(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT: %v", err)
|
|
}
|
|
|
|
// First verification should succeed
|
|
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
|
if err != nil {
|
|
t.Fatalf("First verification should succeed: %v", err)
|
|
}
|
|
|
|
// Simulate different types of replay attacks
|
|
replayTests := []struct {
|
|
name string
|
|
description string
|
|
}{
|
|
{
|
|
name: "Direct replay attack",
|
|
description: "Same token used again with replay checking enabled",
|
|
},
|
|
{
|
|
name: "Replay from different source",
|
|
description: "Token intercepted and replayed by attacker",
|
|
},
|
|
}
|
|
|
|
for _, rt := range replayTests {
|
|
t.Run(rt.name, func(t *testing.T) {
|
|
// Parse token again (simulating replay)
|
|
replayJWT, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT for replay test: %v", err)
|
|
}
|
|
|
|
// Attempt replay with normal replay checking
|
|
err = replayJWT.Verify("https://test-issuer.com", "test-client-id", false)
|
|
if err == nil {
|
|
t.Errorf("Replay attack should be detected for: %s", rt.description)
|
|
} else if !strings.Contains(err.Error(), "token replay detected") {
|
|
t.Errorf("Expected 'token replay detected' error for %s, got: %v", rt.description, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestConcurrentTokenValidation tests thread safety of replay detection
|
|
func TestConcurrentTokenValidation(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Configure rate limiter to allow more requests for concurrent testing
|
|
ts.tOidc.limiter = rate.NewLimiter(rate.Limit(1000), 1000) // Allow 1000 requests per second with burst of 1000
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create multiple tokens with unique JTIs
|
|
var tokens []string
|
|
var jtis []string
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
for i := range 10 {
|
|
jti := generateRandomString(16)
|
|
jtis = append(jtis, jti)
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT %d: %v", i, err)
|
|
}
|
|
tokens = append(tokens, token)
|
|
}
|
|
|
|
// Test concurrent validation
|
|
const numGoroutines = 20
|
|
const numIterations = 5
|
|
|
|
results := make(chan error, numGoroutines*numIterations)
|
|
|
|
for g := range numGoroutines {
|
|
go func(goroutineID int) {
|
|
for i := range numIterations {
|
|
tokenIndex := (goroutineID + i) % len(tokens)
|
|
token := tokens[tokenIndex]
|
|
|
|
// First validation should succeed
|
|
err := ts.tOidc.VerifyToken(token)
|
|
results <- err
|
|
|
|
// Subsequent validation with same token should also succeed (uses cache)
|
|
err = ts.tOidc.VerifyToken(token)
|
|
results <- err
|
|
}
|
|
}(g)
|
|
}
|
|
|
|
// Collect results
|
|
var errors []error
|
|
for range numGoroutines * numIterations * 2 {
|
|
if err := <-results; err != nil {
|
|
errors = append(errors, err)
|
|
}
|
|
}
|
|
|
|
// All validations should succeed (no race conditions)
|
|
if len(errors) > 0 {
|
|
t.Errorf("Expected no errors in concurrent validation, got %d errors: %v", len(errors), errors)
|
|
}
|
|
|
|
// Verify all JTIs are in cache (use shardedReplayCache which is the actual cache used)
|
|
for i, jti := range jtis {
|
|
if !shardedReplayCache.Exists(jti) {
|
|
t.Errorf("JTI %d (%s) should be in replay cache", i, jti)
|
|
}
|
|
}
|
|
}
|
|
|
|
// TestJTIBlacklistBehavior tests the JTI blacklist cache management
|
|
func TestJTIBlacklistBehavior(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Properly reinitialize the global replay cache
|
|
cleanupReplayCache() // Clean up any existing cache and reset sync.Once
|
|
initReplayCache() // Initialize new cache through proper channel
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
t.Logf("TestJTIBlacklistBehavior - JTI: %s", jti)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Test JTI blacklist behavior
|
|
tests := []struct {
|
|
action func() error
|
|
name string
|
|
description string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Initial verification adds JTI to blacklist",
|
|
action: func() error {
|
|
return ts.tOidc.VerifyToken(token)
|
|
},
|
|
expectError: false,
|
|
description: "First verification should succeed and add JTI to blacklist",
|
|
},
|
|
{
|
|
name: "JTI exists in blacklist after verification",
|
|
action: func() error {
|
|
// Use shardedReplayCache which is the actual cache used
|
|
if !shardedReplayCache.Exists(jti) {
|
|
return fmt.Errorf("JTI not found in blacklist cache")
|
|
}
|
|
return nil
|
|
},
|
|
expectError: false,
|
|
description: "JTI should be present in blacklist cache",
|
|
},
|
|
{
|
|
name: "Subsequent verification uses cache (no replay check)",
|
|
action: func() error {
|
|
return ts.tOidc.VerifyToken(token)
|
|
},
|
|
expectError: false,
|
|
description: "Subsequent verification should succeed using token cache",
|
|
},
|
|
{
|
|
name: "Direct JWT verification detects replay",
|
|
action: func() error {
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
|
},
|
|
expectError: true,
|
|
description: "Direct JWT verification should detect replay",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
err := tc.action()
|
|
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error for %s, but got nil", tc.description)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Expected no error for %s, but got: %v", tc.description, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestSessionBasedTokenRevalidation tests token revalidation in session-based scenarios
|
|
func TestSessionBasedTokenRevalidation(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping session-based token revalidation test in short mode")
|
|
}
|
|
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Clear the global replay cache
|
|
cleanupReplayCache()
|
|
initReplayCache()
|
|
|
|
// Create a test JWT with unique JTI
|
|
jti := generateRandomString(16)
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": jti,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// Simulate session-based token revalidation scenario
|
|
// This tests the specific case that was causing false positives
|
|
|
|
// Step 1: Initial authentication (callback processing)
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Fatalf("Initial authentication should succeed: %v", err)
|
|
}
|
|
|
|
// Step 2: Multiple session-based requests (normal request processing)
|
|
// These should not trigger replay detection false positives
|
|
for i := range 5 {
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Errorf("Session request %d should succeed: %v", i+1, err)
|
|
}
|
|
}
|
|
|
|
// Step 3: Verify token is in both caches appropriately
|
|
// Check token cache
|
|
if _, exists := ts.tOidc.tokenCache.Get(token); !exists {
|
|
t.Error("Token should be in token cache")
|
|
}
|
|
|
|
// Check replay cache
|
|
// Use shardedReplayCache which is the actual cache used
|
|
inReplayCache := shardedReplayCache.Exists(jti)
|
|
if !inReplayCache {
|
|
t.Error("JTI should be in replay cache")
|
|
}
|
|
|
|
// Step 4: Verify that clearing token cache still allows validation
|
|
ts.tOidc.tokenCache = NewTokenCache() // Clear token cache
|
|
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Errorf("Token validation should succeed even after cache clear: %v", err)
|
|
}
|
|
}
|
|
|
|
// TestEdgeCasesWithDifferentTokenTypes tests replay detection with different token types
|
|
func TestEdgeCasesWithDifferentTokenTypes(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
// Properly reinitialize the global replay cache
|
|
cleanupReplayCache() // Clean up any existing cache and reset sync.Once
|
|
initReplayCache() // Initialize new cache through proper channel
|
|
|
|
now := time.Now()
|
|
exp := now.Add(1 * time.Hour).Unix()
|
|
iat := now.Unix()
|
|
nbf := now.Unix()
|
|
|
|
tests := []struct {
|
|
claims map[string]interface{}
|
|
name string
|
|
tokenType string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "ID Token with JTI",
|
|
tokenType: "id_token",
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
"jti": generateRandomString(16),
|
|
"token_type": "id_token",
|
|
},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Access Token with JTI",
|
|
tokenType: "access_token",
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"scope": "openid profile email",
|
|
"jti": generateRandomString(16),
|
|
"token_type": "access_token",
|
|
},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Token without JTI",
|
|
tokenType: "no_jti",
|
|
claims: map[string]interface{}{
|
|
"iss": "https://test-issuer.com",
|
|
"aud": "test-client-id",
|
|
"exp": exp,
|
|
"iat": iat,
|
|
"nbf": nbf,
|
|
"sub": "test-subject",
|
|
"email": "user@example.com",
|
|
"nonce": "test-nonce",
|
|
// No JTI claim
|
|
},
|
|
expectError: false, // Should still work, just no replay protection
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create token with specific claims
|
|
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
|
if err != nil {
|
|
t.Fatalf("Failed to create test JWT: %v", err)
|
|
}
|
|
|
|
// First verification should succeed
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if tc.expectError {
|
|
if err == nil {
|
|
t.Errorf("Expected error for token type %s, but got nil", tc.tokenType)
|
|
}
|
|
} else {
|
|
if err != nil {
|
|
t.Errorf("Expected no error for token type %s, but got: %v", tc.tokenType, err)
|
|
}
|
|
}
|
|
|
|
// Second verification should also succeed (uses cache)
|
|
if !tc.expectError {
|
|
err = ts.tOidc.VerifyToken(token)
|
|
if err != nil {
|
|
t.Errorf("Second verification should succeed for token type %s: %v", tc.tokenType, err)
|
|
}
|
|
}
|
|
|
|
// Test direct JWT verification for replay detection
|
|
if !tc.expectError && tc.claims["jti"] != nil {
|
|
jwt, err := parseJWT(token)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse JWT: %v", err)
|
|
}
|
|
|
|
// This should detect replay for tokens with JTI
|
|
err = jwt.Verify("https://test-issuer.com", "test-client-id", false)
|
|
if err == nil {
|
|
t.Errorf("Expected replay detection for token type %s with JTI", tc.tokenType)
|
|
} else if !strings.Contains(err.Error(), "token replay detected") {
|
|
t.Errorf("Expected 'token replay detected' error for token type %s, got: %v", tc.tokenType, err)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestScopeMerging tests the scope append functionality
|
|
func TestScopeMerging(t *testing.T) {
|
|
// Helper function to compare string slices
|
|
equalSlices := func(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i, v := range a {
|
|
if v != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
defaultScopes []string
|
|
userScopes []string
|
|
expectedScopes []string
|
|
}{
|
|
{
|
|
name: "Empty user scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{},
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
},
|
|
{
|
|
name: "Nil user scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: nil,
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
},
|
|
{
|
|
name: "New scopes are appended",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"custom_scope", "another_scope"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
|
|
},
|
|
{
|
|
name: "Deduplication - user scope already in defaults",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"openid", "custom_scope"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
|
|
},
|
|
{
|
|
name: "Duplicate user scopes are removed",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"custom_scope", "custom_scope", "another_scope"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
|
|
},
|
|
{
|
|
name: "Multiple overlapping scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"profile", "custom_scope", "email", "another_scope", "profile"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
|
|
},
|
|
{
|
|
name: "Only custom scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"read:users", "write:users", "admin"},
|
|
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
|
|
},
|
|
{
|
|
name: "Empty defaults",
|
|
defaultScopes: []string{},
|
|
userScopes: []string{"custom1", "custom2"},
|
|
expectedScopes: []string{"custom1", "custom2"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Test the mergeScopes function directly
|
|
result := mergeScopes(tc.defaultScopes, tc.userScopes)
|
|
if !equalSlices(result, tc.expectedScopes) {
|
|
t.Errorf("Expected %v, got %v", tc.expectedScopes, result)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestScopeMergingEdgeCases tests additional edge cases for scope deduplication
|
|
func TestScopeMergingEdgeCases(t *testing.T) {
|
|
// Helper function to compare string slices
|
|
equalSlices := func(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i, v := range a {
|
|
if v != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
description string
|
|
defaultScopes []string
|
|
userScopes []string
|
|
expectedScopes []string
|
|
}{
|
|
{
|
|
name: "Case sensitivity preserved",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"OpenID", "PROFILE", "custom"},
|
|
expectedScopes: []string{"openid", "profile", "email", "OpenID", "PROFILE", "custom"},
|
|
description: "OAuth scopes are case-sensitive, so different cases should be preserved",
|
|
},
|
|
{
|
|
name: "Empty strings in user scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"", "custom", "", "another"},
|
|
expectedScopes: []string{"openid", "profile", "email", "", "custom", "another"},
|
|
description: "Empty strings should be preserved (though invalid in OAuth)",
|
|
},
|
|
{
|
|
name: "Whitespace scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{" ", "custom", " ", "another"},
|
|
expectedScopes: []string{"openid", "profile", "email", " ", "custom", " ", "another"},
|
|
description: "Whitespace-only scopes should be preserved as distinct",
|
|
},
|
|
{
|
|
name: "Large number of scopes",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: generateLargeUserScopes(),
|
|
expectedScopes: func() []string {
|
|
// Manually calculate expected result with proper deduplication
|
|
defaults := []string{"openid", "profile", "email"}
|
|
userScopes := generateLargeUserScopes()
|
|
return mergeScopes(defaults, userScopes)
|
|
}(),
|
|
description: "Performance test with larger scope lists",
|
|
},
|
|
{
|
|
name: "Complex OAuth scopes with special characters",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
|
|
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin:*", "scope/with/slashes", "scope-with-dashes"},
|
|
description: "Real-world OAuth scopes with colons, slashes, and special characters",
|
|
},
|
|
{
|
|
name: "Duplicate defaults in user scopes multiple times",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"openid", "profile", "openid", "custom", "email", "profile", "custom"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom"},
|
|
description: "Multiple duplicates of default scopes should be completely deduplicated",
|
|
},
|
|
{
|
|
name: "All user scopes are duplicates of defaults",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"email", "openid", "profile", "openid"},
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
description: "When all user scopes duplicate defaults, result should be just defaults",
|
|
},
|
|
{
|
|
name: "Single scope scenarios",
|
|
defaultScopes: []string{"openid"},
|
|
userScopes: []string{"custom"},
|
|
expectedScopes: []string{"openid", "custom"},
|
|
description: "Minimal case with single scopes",
|
|
},
|
|
{
|
|
name: "Identical scopes in same order",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"openid", "profile", "email"},
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
description: "When user scopes exactly match defaults, no duplication",
|
|
},
|
|
{
|
|
name: "Identical scopes in different order",
|
|
defaultScopes: []string{"openid", "profile", "email"},
|
|
userScopes: []string{"email", "profile", "openid"},
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
description: "Order of defaults is preserved when user scopes are reordered duplicates",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Test the mergeScopes function directly
|
|
result := mergeScopes(tc.defaultScopes, tc.userScopes)
|
|
if !equalSlices(result, tc.expectedScopes) {
|
|
t.Errorf("Expected %v, got %v\nDescription: %s", tc.expectedScopes, result, tc.description)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// generateLargeUserScopes creates a large list of user scopes for performance testing
|
|
func generateLargeUserScopes() []string {
|
|
scopes := make([]string, 100)
|
|
for i := range 100 {
|
|
scopes[i] = fmt.Sprintf("scope_%d", i)
|
|
}
|
|
// Add some duplicates to test deduplication performance
|
|
scopes = append(scopes, "scope_1", "scope_5", "scope_10", "openid") // Include a default duplicate
|
|
return scopes
|
|
}
|
|
|
|
// TestScopeMergingPerformance tests performance with large scope lists
|
|
func TestScopeMergingPerformance(t *testing.T) {
|
|
// Create large scope lists
|
|
defaultScopes := []string{"openid", "profile", "email"}
|
|
|
|
// Create 1000 user scopes with some duplicates
|
|
userScopes := make([]string, 1000)
|
|
for i := range 1000 {
|
|
if i%10 == 0 {
|
|
// Add some duplicates of defaults
|
|
userScopes[i] = defaultScopes[i%len(defaultScopes)]
|
|
} else if i%7 == 0 {
|
|
// Add some internal duplicates
|
|
userScopes[i] = fmt.Sprintf("scope_%d", i%50)
|
|
} else {
|
|
userScopes[i] = fmt.Sprintf("scope_%d", i)
|
|
}
|
|
}
|
|
|
|
// Measure performance
|
|
start := time.Now()
|
|
result := mergeScopes(defaultScopes, userScopes)
|
|
duration := time.Since(start)
|
|
|
|
// Verify result correctness
|
|
if len(result) < len(defaultScopes) {
|
|
t.Errorf("Result should contain at least the default scopes")
|
|
}
|
|
|
|
// Verify no duplicates exist
|
|
seen := make(map[string]bool)
|
|
for _, scope := range result {
|
|
if seen[scope] {
|
|
t.Errorf("Duplicate scope found in result: %s", scope)
|
|
}
|
|
seen[scope] = true
|
|
}
|
|
|
|
// Performance assertion (should be very fast)
|
|
if duration > time.Millisecond*10 {
|
|
t.Logf("Performance note: mergeScopes took %v for 1000+ scopes (still acceptable)", duration)
|
|
}
|
|
|
|
t.Logf("Performance: processed %d user scopes in %v, result has %d unique scopes",
|
|
len(userScopes), duration, len(result))
|
|
}
|
|
|
|
// TestScopeMergingMemoryEfficiency tests memory efficiency of the mergeScopes function
|
|
func TestScopeMergingMemoryEfficiency(t *testing.T) {
|
|
defaultScopes := []string{"openid", "profile", "email"}
|
|
userScopes := []string{"custom1", "custom2"}
|
|
|
|
// Test that the function doesn't modify input slices
|
|
originalDefaults := make([]string, len(defaultScopes))
|
|
copy(originalDefaults, defaultScopes)
|
|
originalUser := make([]string, len(userScopes))
|
|
copy(originalUser, userScopes)
|
|
|
|
result := mergeScopes(defaultScopes, userScopes)
|
|
|
|
// Verify input slices are unchanged
|
|
for i, scope := range defaultScopes {
|
|
if scope != originalDefaults[i] {
|
|
t.Errorf("Default scopes were modified: expected %s, got %s", originalDefaults[i], scope)
|
|
}
|
|
}
|
|
for i, scope := range userScopes {
|
|
if scope != originalUser[i] {
|
|
t.Errorf("User scopes were modified: expected %s, got %s", originalUser[i], scope)
|
|
}
|
|
}
|
|
|
|
// Verify result is independent
|
|
result[0] = "modified"
|
|
if defaultScopes[0] == "modified" {
|
|
t.Error("Modifying result affected input defaults")
|
|
}
|
|
|
|
expectedLength := len(defaultScopes) + len(userScopes)
|
|
if len(result) != expectedLength {
|
|
t.Errorf("Expected result length %d, got %d", expectedLength, len(result))
|
|
}
|
|
}
|
|
|
|
// TestNewWithScopeAppending tests that the New function properly merges scopes
|
|
func TestNewWithScopeAppending(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping test in short mode")
|
|
}
|
|
|
|
// Create mock provider metadata server
|
|
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
return
|
|
}
|
|
metadata := ProviderMetadata{
|
|
Issuer: "https://test-issuer.com",
|
|
AuthURL: "https://test-issuer.com/auth",
|
|
TokenURL: "https://test-issuer.com/token",
|
|
JWKSURL: "https://test-issuer.com/jwks",
|
|
RevokeURL: "https://test-issuer.com/revoke",
|
|
EndSessionURL: "https://test-issuer.com/end-session",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
}))
|
|
defer mockServer.Close()
|
|
|
|
tests := []struct {
|
|
name string
|
|
configScopes []string
|
|
expectedScopes []string
|
|
}{
|
|
{
|
|
name: "Default scopes only",
|
|
configScopes: []string{},
|
|
expectedScopes: []string{"openid", "profile", "email"},
|
|
},
|
|
{
|
|
name: "Custom scopes appended",
|
|
configScopes: []string{"custom_scope", "another_scope"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
|
|
},
|
|
{
|
|
name: "Overlapping scopes deduplicated",
|
|
configScopes: []string{"openid", "custom_scope"},
|
|
expectedScopes: []string{"openid", "profile", "email", "custom_scope"},
|
|
},
|
|
{
|
|
name: "OAuth scopes",
|
|
configScopes: []string{"read:users", "write:users", "admin"},
|
|
expectedScopes: []string{"openid", "profile", "email", "read:users", "write:users", "admin"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Create config with test scopes
|
|
config := &Config{
|
|
ProviderURL: mockServer.URL,
|
|
ClientID: "test-client",
|
|
ClientSecret: "test-secret",
|
|
CallbackURL: "/callback",
|
|
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
|
|
Scopes: tc.configScopes,
|
|
RateLimit: 100,
|
|
}
|
|
|
|
// Create middleware instance
|
|
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}), config, "test")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create middleware: %v", err)
|
|
}
|
|
|
|
// Wait for initialization
|
|
if m, ok := middleware.(*TraefikOidc); ok {
|
|
// Ensure middleware is properly closed to prevent goroutine leaks
|
|
defer func() {
|
|
if err := m.Close(); err != nil {
|
|
t.Errorf("Failed to close middleware: %v", err)
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-m.initComplete:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Middleware failed to initialize")
|
|
}
|
|
|
|
// Check that scopes were properly merged
|
|
if !equalSlices(m.scopes, tc.expectedScopes) {
|
|
t.Errorf("Expected scopes %v, got %v", tc.expectedScopes, m.scopes)
|
|
}
|
|
} else {
|
|
t.Fatalf("Middleware is not of type *TraefikOidc")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBuildAuthURLWithMergedScopes tests that the auth URL includes the properly merged scopes
|
|
func TestBuildAuthURLWithMergedScopes(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup()
|
|
|
|
tests := []struct {
|
|
name string
|
|
expectedScopes string
|
|
scopes []string
|
|
}{
|
|
{
|
|
name: "Default scopes only",
|
|
scopes: []string{"openid", "profile", "email"},
|
|
expectedScopes: "openid profile email offline_access",
|
|
},
|
|
{
|
|
name: "Custom scopes appended",
|
|
scopes: []string{"openid", "profile", "email", "custom_scope", "another_scope"},
|
|
expectedScopes: "openid profile email custom_scope another_scope offline_access",
|
|
},
|
|
{
|
|
name: "OAuth scopes",
|
|
scopes: []string{"openid", "profile", "email", "read:users", "write:users"},
|
|
expectedScopes: "openid profile email read:users write:users offline_access",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// Configure the test instance with specific scopes
|
|
tOidc := ts.tOidc
|
|
tOidc.scopes = tc.scopes // These scopes are already deduplicated by New()
|
|
tOidc.authURL = "https://auth.example.com/oauth/authorize"
|
|
tOidc.issuerURL = "https://auth.example.com"
|
|
// Reset overrideScopes for each test case, as it's part of tOidc state
|
|
// Default to false, specific tests will set it.
|
|
tOidc.overrideScopes = false
|
|
|
|
// Build auth URL
|
|
result := tOidc.buildAuthURL("https://app.example.com/callback", "test-state", "test-nonce", "")
|
|
|
|
// Parse the resulting URL to verify scopes
|
|
parsedURL, err := url.Parse(result)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse resulting URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
actualScopes := query.Get("scope")
|
|
if actualScopes != tc.expectedScopes {
|
|
t.Errorf("Expected scopes %q, got %q", tc.expectedScopes, actualScopes)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBuildAuthURL_OverrideScopes_And_OfflineAccess tests the offline_access logic in buildAuthURL
|
|
// considering the overrideScopes flag.
|
|
func TestBuildAuthURL_OverrideScopes_And_OfflineAccess(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup() // Sets up ts.tOidc
|
|
|
|
tests := []struct {
|
|
expectedParams map[string]string
|
|
name string
|
|
expectedScope string
|
|
initialScopes []string
|
|
overrideScopes bool
|
|
isGoogle bool
|
|
isAzure bool
|
|
}{
|
|
{
|
|
name: "Override false, no user scopes, non-Google/Azure",
|
|
initialScopes: []string{"openid", "profile", "email"}, // Defaults from New() when config.Scopes is empty
|
|
overrideScopes: false,
|
|
expectedScope: "openid profile email offline_access",
|
|
},
|
|
{
|
|
name: "Override false, user scopes without offline_access, non-Google/Azure",
|
|
initialScopes: []string{"openid", "profile", "email", "custom1"}, // Merged and deduplicated by New()
|
|
overrideScopes: false,
|
|
expectedScope: "openid profile email custom1 offline_access",
|
|
},
|
|
{
|
|
name: "Override false, user scopes with offline_access, non-Google/Azure",
|
|
initialScopes: []string{"openid", "profile", "email", "offline_access", "custom1"},
|
|
overrideScopes: false,
|
|
expectedScope: "openid profile email offline_access custom1", // Order might vary based on merge, but offline_access present
|
|
},
|
|
{
|
|
name: "Override true, user scopes without offline_access, non-Google/Azure",
|
|
initialScopes: []string{"custom1", "custom2"}, // Directly from config.Scopes, deduplicated
|
|
overrideScopes: true,
|
|
expectedScope: "custom1 custom2", // offline_access NOT added
|
|
},
|
|
{
|
|
name: "Override true, user scopes with offline_access, non-Google/Azure",
|
|
initialScopes: []string{"custom1", "offline_access", "custom2"},
|
|
overrideScopes: true,
|
|
expectedScope: "custom1 offline_access custom2", // User explicitly included it
|
|
},
|
|
{
|
|
name: "Override true, no user scopes (edge case), non-Google/Azure",
|
|
initialScopes: []string{}, // config.Scopes was empty
|
|
overrideScopes: true,
|
|
// In this edge case, buildAuthURL's logic `(t.overrideScopes && len(t.scopes) == 0)`
|
|
// will lead to offline_access being added, as it behaves like defaults.
|
|
expectedScope: "offline_access",
|
|
},
|
|
// Google Provider Tests (access_type=offline, prompt=consent)
|
|
{
|
|
name: "Google, Override false, no user scopes",
|
|
initialScopes: []string{"openid", "profile", "email"},
|
|
overrideScopes: false,
|
|
isGoogle: true,
|
|
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
|
expectedScope: "openid profile email", // No offline_access scope for Google
|
|
},
|
|
{
|
|
name: "Google, Override true, user scopes",
|
|
initialScopes: []string{"custom1", "custom2"},
|
|
overrideScopes: true,
|
|
isGoogle: true,
|
|
expectedParams: map[string]string{"access_type": "offline", "prompt": "consent"},
|
|
expectedScope: "custom1 custom2", // No offline_access scope for Google
|
|
},
|
|
// Azure Provider Tests (response_mode=query, offline_access scope added if not present by user)
|
|
{
|
|
name: "Azure, Override false, no user scopes",
|
|
initialScopes: []string{"openid", "profile", "email"},
|
|
overrideScopes: false,
|
|
isAzure: true,
|
|
expectedParams: map[string]string{"response_mode": "query"},
|
|
expectedScope: "openid profile email offline_access",
|
|
},
|
|
{
|
|
name: "Azure, Override true, user scopes without offline_access",
|
|
initialScopes: []string{"custom1", "custom2"},
|
|
overrideScopes: true,
|
|
isAzure: true,
|
|
expectedParams: map[string]string{"response_mode": "query"},
|
|
expectedScope: "custom1 custom2", // offline_access NOT added by default when override is true
|
|
},
|
|
{
|
|
name: "Azure, Override true, user scopes with offline_access",
|
|
initialScopes: []string{"custom1", "offline_access"},
|
|
overrideScopes: true,
|
|
isAzure: true,
|
|
expectedParams: map[string]string{"response_mode": "query"},
|
|
expectedScope: "custom1 offline_access",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
tOidc := ts.tOidc
|
|
tOidc.scopes = tc.initialScopes // Set the scopes as if they came from New()
|
|
tOidc.overrideScopes = tc.overrideScopes
|
|
|
|
// Adjust issuerURL for provider-specific tests
|
|
originalIssuerURL := tOidc.issuerURL
|
|
if tc.isGoogle {
|
|
tOidc.issuerURL = "https://accounts.google.com"
|
|
} else if tc.isAzure {
|
|
tOidc.issuerURL = "https://login.microsoftonline.com/common"
|
|
} else {
|
|
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
|
|
}
|
|
|
|
authURLString := tOidc.buildAuthURL("http://localhost/callback", "state123", "nonce123", "challenge123")
|
|
parsedAuthURL, err := url.Parse(authURLString)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
query := parsedAuthURL.Query()
|
|
|
|
actualScope := query.Get("scope")
|
|
if actualScope != tc.expectedScope {
|
|
t.Errorf("Expected scope string %q, got %q", tc.expectedScope, actualScope)
|
|
}
|
|
|
|
if tc.expectedParams != nil {
|
|
for k, v := range tc.expectedParams {
|
|
if query.Get(k) != v {
|
|
t.Errorf("Expected param %s=%s, got %s", k, v, query.Get(k))
|
|
}
|
|
}
|
|
}
|
|
|
|
// Restore original issuerURL for next test
|
|
tOidc.issuerURL = originalIssuerURL
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestBuildAuthURL_SpecificUserCase tests the buildAuthURL function with the specific user-reported scenario.
|
|
func TestBuildAuthURL_SpecificUserCase(t *testing.T) {
|
|
ts := NewTestSuite(t)
|
|
ts.Setup() // Basic setup for tOidc
|
|
|
|
// Configure the TraefikOidc instance for the specific scenario
|
|
tOidc := ts.tOidc
|
|
tOidc.scopes = []string{"email", "test3"} // This is what t.scopes should be after New()
|
|
tOidc.overrideScopes = true
|
|
tOidc.issuerURL = "https://generic-provider.com" // Non-Google/Azure
|
|
tOidc.authURL = "https://generic-provider.com/auth" // Dummy auth URL
|
|
tOidc.clientID = "test-client-id"
|
|
|
|
// Expected scope string in the URL
|
|
expectedScopeString := "email test3"
|
|
|
|
// Call buildAuthURL
|
|
authURLString := tOidc.buildAuthURL("http://localhost/callback", "test-state", "test-nonce", "")
|
|
|
|
// Parse the resulting URL
|
|
parsedAuthURL, err := url.Parse(authURLString)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse generated auth URL %q: %v", authURLString, err)
|
|
}
|
|
|
|
// Get the 'scope' query parameter
|
|
actualScopeString := parsedAuthURL.Query().Get("scope")
|
|
|
|
// Assert that the scope string is as expected
|
|
if actualScopeString != expectedScopeString {
|
|
t.Errorf("Expected scope parameter to be %q, but got %q. Full URL: %s",
|
|
expectedScopeString, actualScopeString, authURLString)
|
|
}
|
|
|
|
// Additionally, ensure 'offline_access' was not added
|
|
if strings.Contains(actualScopeString, "offline_access") {
|
|
t.Errorf("Scope parameter %q should not contain 'offline_access' when overrideScopes is true and it's not in tOidc.scopes", actualScopeString)
|
|
}
|
|
}
|