mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bc02901972 | |||
| f75b2f20e0 | |||
| cf6ed1da55 | |||
| f821b8829b | |||
| 5f9c574f95 | |||
| 7c6f09fb20 | |||
| 68e1c4319c | |||
| 17e3f8ef62 | |||
| 827926bc3a | |||
| abbfdb02a7 | |||
| 72e2b682bb | |||
| ae4ccaa89d | |||
| 984fd1c08f | |||
| 99bdd23986 |
@@ -111,6 +111,7 @@ Full reference in [docs/CONFIGURATION.md](docs/CONFIGURATION.md).
|
||||
| `logoutURL` | `callbackURL + "/logout"` | RP-initiated logout path. |
|
||||
| `postLogoutRedirectURI` | `/` | Where to send users after logout. |
|
||||
| `scopes` | appended to `openid profile email` | Extra OAuth scopes. Set `overrideScopes: true` to replace defaults. |
|
||||
| `extraAuthParams` | none | Map of extra query parameters appended to the authorization request (e.g. `screen_hint: signup`, `login_hint`, `ui_locales`, `prompt`). Plugin-managed params (`client_id`, `state`, `nonce`, `redirect_uri`, `code_challenge`, `scope`, `response_type`, …) cannot be overridden. |
|
||||
| `excludedURLs` | none | Prefix-matched paths that bypass auth. |
|
||||
| `allowedUserDomains` | none | Restrict to email domains. |
|
||||
| `allowedUsers` | none | Restrict to specific addresses (or claim values when `userIdentifierClaim != email`). |
|
||||
@@ -411,6 +412,19 @@ namespaced claims, Cognito regions, GitLab self-hosted) live in
|
||||
|
||||
Set `logLevel: debug` to surface detail.
|
||||
|
||||
## Telemetry
|
||||
|
||||
On first plugin instantiation this middleware sends a single anonymous
|
||||
adoption ping — project name, version, timestamp; no identifiers, no
|
||||
request data, no token contents. Fire-and-forget with a 2-second timeout;
|
||||
cannot block plugin load or panic.
|
||||
|
||||
Local source: [`telemetry.go`](./telemetry.go). Disclosure mirrors
|
||||
**[oss-telemetry — Disabling telemetry](https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry)**.
|
||||
|
||||
Quick opt-out: set any of `DO_NOT_TRACK=1`, `OSS_TELEMETRY_DISABLED=1`,
|
||||
or `TRAEFIKOIDC_DISABLE_TELEMETRY=1`.
|
||||
|
||||
## License
|
||||
|
||||
See [LICENSE](LICENSE).
|
||||
|
||||
+4
-2
@@ -484,7 +484,8 @@ func TestAuth0Scenario3OpaqueAccessToken(t *testing.T) {
|
||||
session.SetAccessToken(opaqueAccessToken)
|
||||
session.SetIDToken(idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
|
||||
if !authenticated || needsRefresh || expired {
|
||||
t.Errorf("Session with opaque access token and valid ID token should be authenticated. Got: auth=%v, refresh=%v, expired=%v",
|
||||
authenticated, needsRefresh, expired)
|
||||
@@ -623,7 +624,8 @@ func TestAuth0Scenario2StrictMode(t *testing.T) {
|
||||
session.SetRefreshToken("test-refresh-token") // Add refresh token so it can attempt refresh
|
||||
|
||||
// In strict mode, this should FAIL (no fallback to ID token)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := ts.tOidc.validateStandardTokensRS(rs)
|
||||
if authenticated {
|
||||
t.Errorf("Strict mode: Session with wrong access token audience should be rejected, but got authenticated=true")
|
||||
}
|
||||
|
||||
@@ -305,28 +305,6 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// isUserAuthenticated determines the authentication status and refresh requirements.
|
||||
// It delegates to provider-specific validation methods that handle different token types
|
||||
// and expiration behaviors.
|
||||
// Parameters:
|
||||
// - session: The session data containing authentication tokens.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated (bool): True if the user has valid tokens.
|
||||
// - needsRefresh (bool): True if tokens are valid but nearing expiration.
|
||||
// - expired (bool): True if the session is unauthenticated, the token is missing,
|
||||
// or the token verification failed for reasons other than nearing/actual expiration.
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if t.isAzureProvider() {
|
||||
return t.validateAzureTokens(session)
|
||||
} else if t.isGoogleProvider() {
|
||||
return t.validateGoogleTokens(session)
|
||||
}
|
||||
// Auth0 and other providers can now use standard validation
|
||||
// which handles opaque tokens generically
|
||||
return t.validateStandardTokens(session)
|
||||
}
|
||||
|
||||
// isAjaxRequest determines if this is an AJAX request that should receive 401 instead of redirect
|
||||
func (t *TraefikOidc) isAjaxRequest(req *http.Request) bool {
|
||||
xhr := req.Header.Get("X-Requested-With")
|
||||
|
||||
+8
-4
@@ -262,7 +262,8 @@ func TestAzureOIDCRegression(t *testing.T) {
|
||||
defer func() { tOidc.tokenVerifier = originalTokenVerifier }()
|
||||
|
||||
// Test that CSRF is preserved during Azure validation failures
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := tOidc.validateAzureTokensRS(rs)
|
||||
|
||||
// Should not be authenticated due to validation failure
|
||||
if authenticated {
|
||||
@@ -453,7 +454,8 @@ func TestValidateGoogleTokens(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateGoogleTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.validateGoogleTokensRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
@@ -637,7 +639,8 @@ func TestIsUserAuthenticated(t *testing.T) {
|
||||
defer func() { ts.tOidc.issuerURL = originalIssuer }()
|
||||
|
||||
session := tt.setupSession()
|
||||
auth, refresh, expired := ts.tOidc.isUserAuthenticated(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.isUserAuthenticatedRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
@@ -762,7 +765,8 @@ func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
auth, refresh, expired := ts.tOidc.validateAzureTokensRS(rs)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
|
||||
+2
-2
@@ -71,8 +71,8 @@ func makeBearerOIDC(t *testing.T, next http.Handler) *TraefikOidc {
|
||||
logger: NewLogger("error"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://issuer.example.com",
|
||||
audience: "https://api.example.com",
|
||||
clientID: "https://api.example.com",
|
||||
|
||||
@@ -5,6 +5,7 @@ go 1.24.0
|
||||
require (
|
||||
github.com/alicebob/miniredis/v2 v2.35.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3
|
||||
github.com/redis/go-redis/v9 v9.17.2
|
||||
github.com/stretchr/testify v1.10.0
|
||||
golang.org/x/time v0.14.0
|
||||
|
||||
@@ -16,6 +16,8 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3 h1:xoDtBqeZGmXj7IteiE1M5WMuzeoqag58qEleI0Cf2Ms=
|
||||
github.com/lukaszraczylo/oss-telemetry v0.2.3/go.mod h1:+Cn78qZo8rc3T9eZt0v3oICYRdd75wORtSidc8lNjDQ=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
|
||||
@@ -234,7 +234,8 @@ func TestIssue134_Followup_ValidateAzureTokensSkipsGraphAccessToken(t *testing.T
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
output := errBuf.String()
|
||||
assert.NotContains(t, output, "crypto/rsa: verification error",
|
||||
@@ -344,7 +345,8 @@ func TestIssue134_Followup_StandardAzureAccessTokenStillVerifies(t *testing.T) {
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, accessToken, idToken)
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
assert.True(t, authenticated, "standard Azure access token must verify and authenticate")
|
||||
assert.False(t, needsRefresh)
|
||||
@@ -381,7 +383,8 @@ func TestIssue134_Followup_GraphAccessTokenWithoutIDToken(t *testing.T) {
|
||||
oidc, errBuf := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, graphAccessToken, "")
|
||||
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, needsRefresh, expired := oidc.validateAzureTokensRS(rs)
|
||||
|
||||
assert.True(t, authenticated, "Graph token without ID token must remain authenticated (matches existing opaque-token semantics)")
|
||||
assert.False(t, needsRefresh)
|
||||
@@ -443,7 +446,8 @@ func TestIssue134_Followup_ConfusedDeputyAttackDoesNotBypassVerification(t *test
|
||||
oidc, _ := newAzureFollowupOIDC(t, jwks)
|
||||
session := authedSessionWithTokens(t, forgedAccessToken, forgedIDToken)
|
||||
|
||||
authenticated, _, _ := oidc.validateAzureTokens(session)
|
||||
rs := (&requestState{}).captureSession(session)
|
||||
authenticated, _, _ := oidc.validateAzureTokensRS(rs)
|
||||
assert.False(t, authenticated,
|
||||
"attacker's forged tokens must not authenticate even when the access token has a nonce header — ID token verification rejects the wrong-key signature")
|
||||
}
|
||||
|
||||
@@ -478,11 +478,10 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
|
||||
// Test 3: Rate limiting
|
||||
t.Run("RateLimiting", func(t *testing.T) {
|
||||
// Reset circuit breaker to closed state for this test
|
||||
coordinator.circuitBreaker.mutex.Lock()
|
||||
// Reset circuit breaker to closed state for this test. All fields are
|
||||
// atomic so we don't need any mutex.
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.state, 0) // closed
|
||||
atomic.StoreInt32(&coordinator.circuitBreaker.failures, 0)
|
||||
coordinator.circuitBreaker.mutex.Unlock()
|
||||
|
||||
// Temporarily increase circuit breaker threshold to not interfere
|
||||
oldMaxFailures := coordinator.circuitBreaker.config.MaxFailures
|
||||
@@ -525,9 +524,11 @@ func TestRefreshCoordinatorIntegration(t *testing.T) {
|
||||
time.Sleep(config.CleanupInterval * 3)
|
||||
|
||||
// Old sessions should be cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
count := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
count := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have fewer sessions after cleanup
|
||||
if count > 10 {
|
||||
|
||||
@@ -53,10 +53,26 @@ type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache
|
||||
// JWKCache provides thread-safe caching of JWKS using UniversalCache.
|
||||
//
|
||||
// inflightFetches deduplicates concurrent fetches for the same JWKS URL.
|
||||
// It replaces a global sync.RWMutex that was previously held for the entire
|
||||
// HTTP round-trip in GetJWKS: on a cold cache (cold pod, JWK rotation, brief
|
||||
// network blip) every concurrent request piled up on that single Lock(), and
|
||||
// under Yaegi each Lock acquisition costs 10-50ms of interpreter-dispatch
|
||||
// overhead. The singleflight pattern keeps the cold-cache cost O(1) HTTP
|
||||
// fetch regardless of how many requests are waiting.
|
||||
type JWKCache struct {
|
||||
cache *UniversalCache
|
||||
mutex sync.RWMutex
|
||||
cache *UniversalCache
|
||||
inflightFetches sync.Map // map[jwksURL string]*jwksFetch
|
||||
}
|
||||
|
||||
// jwksFetch represents an in-flight JWKS fetch. Done is closed when the fetch
|
||||
// completes; jwks and err carry the result (one of them is set, never both).
|
||||
type jwksFetch struct {
|
||||
done chan struct{}
|
||||
jwks *JWKSet
|
||||
err error
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the contract for JWK caching implementations.
|
||||
@@ -83,36 +99,58 @@ func NewJWKCache() *JWKCache {
|
||||
// request refetches from the upstream. JWK rotation is rare and a per-replica
|
||||
// HTTP fetch on cold cache is cheap, so cross-replica coherence buys nothing.
|
||||
func (c *JWKCache) GetJWKS(ctx context.Context, jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
// Check cache first
|
||||
// Fast path: cache hit.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Singleflight: dedupe concurrent fetches per URL key. The first arrival
|
||||
// performs the HTTP fetch; any later arrival for the same URL waits on
|
||||
// its done channel and shares the result. No global lock is held during
|
||||
// the fetch.
|
||||
candidate := &jwksFetch{done: make(chan struct{})}
|
||||
if existing, loaded := c.inflightFetches.LoadOrStore(jwksURL, candidate); loaded {
|
||||
f, _ := existing.(*jwksFetch)
|
||||
select {
|
||||
case <-f.done:
|
||||
return f.jwks, f.err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Double-check after acquiring lock
|
||||
// We're the leader. Make absolutely sure the result fields and the
|
||||
// in-flight map entry are cleaned up before any waiter unblocks.
|
||||
defer func() {
|
||||
c.inflightFetches.Delete(jwksURL)
|
||||
close(candidate.done)
|
||||
}()
|
||||
|
||||
// Re-check the cache in case a concurrent fetch completed between our
|
||||
// initial miss and our LoadOrStore win.
|
||||
if cachedValue, found := c.cache.GetLocal(jwksURL); found {
|
||||
if jwks, ok := cachedValue.(*JWKSet); ok {
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch from URL
|
||||
jwks, err := fetchJWKS(ctx, jwksURL, httpClient)
|
||||
if err != nil {
|
||||
candidate.err = err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(jwks.Keys) == 0 {
|
||||
return nil, fmt.Errorf("JWKS response contains no keys")
|
||||
candidate.err = fmt.Errorf("JWKS response contains no keys")
|
||||
return nil, candidate.err
|
||||
}
|
||||
|
||||
// Cache for 1 hour
|
||||
// Cache for 1 hour.
|
||||
_ = c.cache.SetLocal(jwksURL, jwks, 1*time.Hour) // Safe to ignore: cache failures are non-critical
|
||||
|
||||
candidate.jwks = jwks
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
|
||||
+4
-4
@@ -415,8 +415,8 @@ func TestMiddlewareBackchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -457,8 +457,8 @@ func TestMiddlewareFrontchannelLogoutRouting(t *testing.T) {
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://provider.example.com",
|
||||
initComplete: make(chan struct{}),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
telemetry "github.com/lukaszraczylo/oss-telemetry"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -23,6 +24,11 @@ const (
|
||||
ConstSessionTimeout = 86400
|
||||
)
|
||||
|
||||
// telemetryStartupOnce keeps the anonymous "plugin loaded" ping to one per
|
||||
// process. Traefik calls New once per route that uses the plugin; oss-telemetry
|
||||
// does not deduplicate client-side (the server does), so the gate stays here.
|
||||
var telemetryStartupOnce sync.Once
|
||||
|
||||
// isTestMode detects if the code is running in a test environment.
|
||||
func isTestMode() bool {
|
||||
if os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS") == "1" {
|
||||
@@ -89,6 +95,13 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
// - The configured TraefikOidc handler ready to process requests.
|
||||
// - An error if essential configuration is missing or invalid (e.g., short encryption key).
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
telemetryStartupOnce.Do(func() {
|
||||
// Only stamped release builds phone home; dev/local/test builds keep the
|
||||
// devPluginVersion sentinel (see version.go) and stay silent.
|
||||
if traefikoidcPluginVersion != devPluginVersion {
|
||||
telemetry.Send("traefikoidc", traefikoidcPluginVersion)
|
||||
}
|
||||
})
|
||||
return NewWithContext(ctx, config, next, name)
|
||||
}
|
||||
|
||||
@@ -201,6 +214,7 @@ func NewWithContext(ctx context.Context, config *Config, next http.Handler, name
|
||||
}(),
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
enablePKCE: config.EnablePKCE,
|
||||
extraAuthParams: config.ExtraAuthParams,
|
||||
overrideScopes: config.OverrideScopes,
|
||||
strictAudienceValidation: config.StrictAudienceValidation,
|
||||
allowOpaqueTokens: config.AllowOpaqueTokens,
|
||||
@@ -516,6 +530,19 @@ func (t *TraefikOidc) updateMetadataEndpoints(metadata *ProviderMetadata) {
|
||||
introspectionURL := t.introspectionURL
|
||||
registrationURL := t.registrationURL
|
||||
|
||||
// Publish the read-mostly URL bundle atomically. Hot-path readers Load
|
||||
// this directly instead of acquiring metadataMu.RLock per request.
|
||||
t.metadataSnapshot.Store(&MetadataSnapshot{
|
||||
IssuerURL: metadata.Issuer,
|
||||
JWKSURL: metadata.JWKSURL,
|
||||
TokenURL: metadata.TokenURL,
|
||||
AuthURL: metadata.AuthURL,
|
||||
RevocationURL: metadata.RevokeURL,
|
||||
EndSessionURL: metadata.EndSessionURL,
|
||||
IntrospectionURL: metadata.IntrospectionURL,
|
||||
RegistrationURL: metadata.RegistrationURL,
|
||||
})
|
||||
|
||||
t.metadataMu.Unlock()
|
||||
|
||||
// Log introspection endpoint availability for opaque token support
|
||||
|
||||
+14
-30
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -484,9 +485,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -508,19 +508,13 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate first request processing
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
// Simulate first request processing — single-firing via CAS.
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
// This would normally be called asynchronously
|
||||
go func() {
|
||||
oidc.initializeMetadata(server.URL)
|
||||
// initComplete is closed internally by initializeMetadata
|
||||
}()
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Wait for initialization
|
||||
@@ -556,9 +550,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
oidc := &TraefikOidc{
|
||||
providerURL: server.URL,
|
||||
firstRequestReceived: false,
|
||||
firstRequestMutex: sync.Mutex{},
|
||||
providerURL: server.URL,
|
||||
firstRequestStarted: 0,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
@@ -580,31 +573,22 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// Simulate multiple concurrent "first" requests
|
||||
// Simulate multiple concurrent "first" requests — only one CAS winner
|
||||
// fires the bootstrap path.
|
||||
const numRequests = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numRequests)
|
||||
|
||||
initStarted := 0
|
||||
var initMu sync.Mutex
|
||||
var initStarted int32
|
||||
|
||||
for i := 0; i < numRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
oidc.firstRequestMutex.Lock()
|
||||
if !oidc.firstRequestReceived {
|
||||
oidc.firstRequestReceived = true
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
|
||||
initMu.Lock()
|
||||
initStarted++
|
||||
initMu.Unlock()
|
||||
|
||||
if atomic.CompareAndSwapInt32(&oidc.firstRequestStarted, 0, 1) {
|
||||
atomic.AddInt32(&initStarted, 1)
|
||||
// Only one should actually start initialization
|
||||
oidc.initializeMetadata(server.URL)
|
||||
} else {
|
||||
oidc.firstRequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -612,8 +596,8 @@ func TestFirstRequestHandling(t *testing.T) {
|
||||
wg.Wait()
|
||||
|
||||
// Verify only one initialization was started
|
||||
if initStarted != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", initStarted)
|
||||
if atomic.LoadInt32(&initStarted) != 1 {
|
||||
t.Errorf("expected exactly 1 initialization, got %d", atomic.LoadInt32(&initStarted))
|
||||
}
|
||||
|
||||
// The metadata endpoint might be called once or not at all depending on timing
|
||||
|
||||
+28
-28
@@ -61,8 +61,8 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com", // Required for initialization check
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -92,8 +92,8 @@ func TestServeHTTP_EventStream(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -175,8 +175,8 @@ func TestServeHTTP_WebSocketUpgrade(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
}
|
||||
close(oidc.initComplete)
|
||||
@@ -272,8 +272,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close this to simulate timeout
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
@@ -307,8 +307,8 @@ func TestServeHTTP_InitializationTimeout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -337,8 +337,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -367,8 +367,8 @@ func TestServeHTTP_CallbackAndLogout(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -740,8 +740,8 @@ func TestMinimalHeaders(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: tt.minimalHeaders,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -817,8 +817,8 @@ func TestMinimalHeaders_TokenHeaderNotSet(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
minimalHeaders: true, // Enable minimal headers
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -903,8 +903,8 @@ func TestStripAuthCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: tt.stripAuthCookies,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -987,8 +987,8 @@ func TestStripAuthCookies_NoCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1034,8 +1034,8 @@ func TestStripAuthCookies_OnlyOIDCCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1085,8 +1085,8 @@ func TestStripAuthCookies_OnlyAppCookies(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
@@ -1148,8 +1148,8 @@ func TestStripAuthCookies_CustomPrefix(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sm,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
stripAuthCookies: true,
|
||||
extractClaimsFunc: func(token string) (map[string]interface{}, error) {
|
||||
|
||||
+4
-4
@@ -16,6 +16,7 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -2685,10 +2686,9 @@ func TestMetadataRecoveryOnProviderFailure(t *testing.T) {
|
||||
providerAvailable = true
|
||||
mu.Unlock()
|
||||
|
||||
// Reset the retry timer to allow immediate retry
|
||||
m.metadataRetryMutex.Lock()
|
||||
m.lastMetadataRetryTime = time.Time{} // Reset to zero time
|
||||
m.metadataRetryMutex.Unlock()
|
||||
// Reset the retry timer to allow immediate retry. The field is atomic
|
||||
// now, so no lock is needed.
|
||||
atomic.StoreInt64(&m.lastMetadataRetryNano, 0)
|
||||
|
||||
// Second request should trigger recovery attempt
|
||||
req2 := httptest.NewRequest("GET", "/protected", nil)
|
||||
|
||||
+140
-22
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/traefikoidc/internal/utils"
|
||||
@@ -145,19 +146,20 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(req.URL.Path, "/health") {
|
||||
t.firstRequestMutex.Lock()
|
||||
if !t.firstRequestReceived {
|
||||
t.firstRequestReceived = true
|
||||
// Lock-free one-shot bootstrap. The previous firstRequestMutex.Lock()
|
||||
// fired on EVERY non-health request forever (even after the boolean
|
||||
// flipped true), which under Yaegi added a per-request serialization
|
||||
// point. CAS gives single-firing semantics with zero steady-state cost.
|
||||
if atomic.CompareAndSwapInt32(&t.firstRequestStarted, 0, 1) {
|
||||
t.logger.Debug("Starting background tasks on first request")
|
||||
t.startTokenCleanup()
|
||||
|
||||
if !t.metadataRefreshStarted && t.providerURL != "" {
|
||||
t.metadataRefreshStarted = true
|
||||
if t.providerURL != "" &&
|
||||
atomic.CompareAndSwapInt32(&t.metadataRefreshStartedAtomic, 0, 1) {
|
||||
// Metadata refresh is handled by singleton resource manager
|
||||
t.startMetadataRefresh(t.providerURL)
|
||||
}
|
||||
}
|
||||
t.firstRequestMutex.Unlock()
|
||||
}
|
||||
|
||||
// Evaluate auth-bypass once, before waiting for initialization. Excluded
|
||||
@@ -207,20 +209,31 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
// Read issuerURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
issuerURL := t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
// Read issuerURL via atomic snapshot when available — replaces the
|
||||
// metadataMu.RLock that previously fired on every non-bypass request.
|
||||
// Under Yaegi each RLock acquisition costs 1-5ms of interpreter
|
||||
// dispatch; the snapshot is a single atomic.Value.Load. Falls back
|
||||
// to the legacy field+RLock for paths that haven't published a
|
||||
// snapshot yet (notably some test setups that initialize the struct
|
||||
// fields directly).
|
||||
var issuerURL string
|
||||
if snap := t.metadataSnap(); snap != nil {
|
||||
issuerURL = snap.IssuerURL
|
||||
} else {
|
||||
t.metadataMu.RLock()
|
||||
issuerURL = t.issuerURL
|
||||
t.metadataMu.RUnlock()
|
||||
}
|
||||
|
||||
if issuerURL == "" {
|
||||
// Provider metadata initialization failed - try to recover
|
||||
// Retry every 30 seconds to allow automatic recovery when provider comes back online
|
||||
t.metadataRetryMutex.Lock()
|
||||
shouldRetry := time.Since(t.lastMetadataRetryTime) >= 30*time.Second
|
||||
if shouldRetry {
|
||||
t.lastMetadataRetryTime = time.Now()
|
||||
}
|
||||
t.metadataRetryMutex.Unlock()
|
||||
// Provider metadata initialization failed - try to recover.
|
||||
// Retry every 30 seconds to allow automatic recovery. Lock-free
|
||||
// throttle via CAS on lastMetadataRetryNano: one goroutine wins
|
||||
// the window, others see shouldRetry=false.
|
||||
nowNano := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(&t.lastMetadataRetryNano)
|
||||
shouldRetry := time.Duration(nowNano-last) >= 30*time.Second &&
|
||||
atomic.CompareAndSwapInt64(&t.lastMetadataRetryNano, last, nowNano)
|
||||
|
||||
if shouldRetry && t.providerURL != "" {
|
||||
t.logger.Info("Attempting to recover OIDC provider metadata...")
|
||||
@@ -298,6 +311,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
host := utils.DetermineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Capture per-request state: one RLock on sd.sessionMutex covers all the
|
||||
// getter values the handler chain needs (instead of 5-7 separate
|
||||
// session.GetX() calls each acquiring their own RLock under Yaegi).
|
||||
// metadataSnap is also stored once so downstream handlers don't repeat
|
||||
// the atomic.Value.Load.
|
||||
rs := (&requestState{
|
||||
scheme: scheme,
|
||||
host: host,
|
||||
redirectURL: redirectURL,
|
||||
next: t.next,
|
||||
metadata: t.metadataSnap(),
|
||||
}).captureSession(session)
|
||||
|
||||
// Check if the current request is the OIDC callback
|
||||
t.logger.Debugf("Checking callback URL match: request_path=%q, configured_callback=%q", req.URL.Path, t.redirURLPath)
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
@@ -307,7 +333,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath)
|
||||
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
// Token validation reads session via the captured snapshot — saves ~21
|
||||
// sd.sessionMutex.RLock acquisitions (Yaegi-dispatched, ~1-5ms each)
|
||||
// across the validation path.
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticatedRS(rs)
|
||||
|
||||
if expired {
|
||||
t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth")
|
||||
@@ -315,7 +344,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
userIdentifier := rs.userIdentifier
|
||||
// User authorization check
|
||||
if authenticated && userIdentifier != "" {
|
||||
if !t.isAllowedUser(userIdentifier) {
|
||||
@@ -332,11 +361,11 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// methods (validateAzureTokens/validateStandardTokens) before reaching this point.
|
||||
// Redundant validation here was causing issues with Azure AD tokens that have
|
||||
// JWT format but unverifiable signatures. See issue #89.
|
||||
t.processAuthorizedRequest(rw, req, session, redirectURL)
|
||||
t.processAuthorizedRequestRS(rw, req, rs)
|
||||
return
|
||||
}
|
||||
|
||||
refreshTokenPresent := session.GetRefreshToken() != ""
|
||||
refreshTokenPresent := rs.refreshToken != ""
|
||||
|
||||
// Decide whether to answer with 401 instead of a redirect. AJAX requests
|
||||
// cannot follow a 302 into an IdP, and sub-resource loads (script/image/
|
||||
@@ -443,6 +472,95 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// - req: The HTTP request to process.
|
||||
// - session: The user's session data containing tokens and claims.
|
||||
// - redirectURL: The callback URL for re-authentication if needed.
|
||||
// processAuthorizedRequestRS is the requestState-aware variant of
|
||||
// processAuthorizedRequest. It reads SessionData fields from the captured
|
||||
// snapshot in rs instead of calling session.GetX() (each of which acquires
|
||||
// sd.sessionMutex.RLock — under Yaegi every RLock pays ~1-5ms of interpreter
|
||||
// dispatch). Only session-mutating operations (Save, ResetRedirectCount,
|
||||
// Clear, IsDirty) still go through the session pointer because those write
|
||||
// state and have no snapshot.
|
||||
func (t *TraefikOidc) processAuthorizedRequestRS(rw http.ResponseWriter, req *http.Request, rs *requestState) {
|
||||
session := rs.session
|
||||
redirectURL := rs.redirectURL
|
||||
userIdentifier := rs.userIdentifier
|
||||
if userIdentifier == "" {
|
||||
t.logger.Info("No user identifier found in session during final processing, initiating re-auth")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if session has been invalidated via backchannel or front-channel logout
|
||||
idToken := rs.idToken
|
||||
if t.enableBackchannelLogout || t.enableFrontchannelLogout {
|
||||
if idToken != "" {
|
||||
sid, sub, createdAt := t.extractSessionInfo(idToken)
|
||||
if t.isSessionInvalidated(sid, sub, createdAt) {
|
||||
t.logger.Infof("Session for user %s has been invalidated via IdP-initiated logout", userIdentifier)
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing invalidated session: %v", err)
|
||||
}
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve ID-token claims at most once per request. SessionData caches
|
||||
// the parsed claims keyed on the raw ID token.
|
||||
var (
|
||||
idClaims map[string]interface{}
|
||||
idClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc)
|
||||
}
|
||||
|
||||
var (
|
||||
groupClaims map[string]interface{}
|
||||
groupClaimsErr error
|
||||
)
|
||||
if idToken != "" {
|
||||
groupClaims, groupClaimsErr = idClaims, idClaimsErr
|
||||
} else if rs.accessToken != "" {
|
||||
groupClaims, groupClaimsErr = t.extractClaimsFunc(rs.accessToken)
|
||||
} else if len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Error("No token available but roles/groups checks are required")
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if groupClaimsErr != nil && len(t.allowedRolesAndGroups) > 0 {
|
||||
t.logger.Errorf("Failed to extract claims for roles/groups check: %v", groupClaimsErr)
|
||||
session.ResetRedirectCount()
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Persist any dirty session state BEFORE forwardAuthorized writes the
|
||||
// response.
|
||||
if session.IsDirty() {
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session after processing headers: %v", err)
|
||||
}
|
||||
} else {
|
||||
t.logger.Debug("Session not dirty, skipping save in processAuthorizedRequest")
|
||||
}
|
||||
|
||||
p := &principal{
|
||||
Source: sourceSession,
|
||||
Identifier: userIdentifier,
|
||||
AccessToken: rs.accessToken,
|
||||
IDToken: idToken,
|
||||
RefreshToken: rs.refreshToken,
|
||||
Claims: groupClaims,
|
||||
}
|
||||
|
||||
t.forwardAuthorized(rw, req, p)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
userIdentifier := session.GetUserIdentifier()
|
||||
if userIdentifier == "" {
|
||||
|
||||
@@ -13,8 +13,8 @@ func TestMiddlewareContextCancellation(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate waiting
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
}
|
||||
|
||||
// Create request with canceled context
|
||||
@@ -39,8 +39,8 @@ func TestMiddlewareSessionErrorRecovery(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -73,8 +73,8 @@ func TestMiddlewareAJAXRequestHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -102,8 +102,8 @@ func TestLogoutWorksWithoutOIDCInitialization(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}), // Never close to simulate provider unavailable
|
||||
sessionManager: createTestSessionManager(t),
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
logoutURLPath: "/logout",
|
||||
postLogoutRedirectURI: "/",
|
||||
forceHTTPS: false,
|
||||
@@ -142,8 +142,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -187,8 +187,8 @@ func TestMiddlewareDomainRestrictions(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
@@ -236,8 +236,8 @@ func TestMiddlewareOpaqueTokenHandling(t *testing.T) {
|
||||
logger: NewLogger("debug"),
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: sessionManager,
|
||||
firstRequestReceived: true,
|
||||
metadataRefreshStarted: true,
|
||||
firstRequestStarted: 1,
|
||||
metadataRefreshStartedAtomic: 1,
|
||||
issuerURL: "https://provider.example.com",
|
||||
redirURLPath: "/callback",
|
||||
logoutURLPath: "/logout",
|
||||
|
||||
+338
-198
@@ -15,18 +15,28 @@ import (
|
||||
// It implements request coalescing, rate limiting, and circuit breaking
|
||||
// specifically for token refresh operations.
|
||||
type RefreshCoordinator struct {
|
||||
inFlightRefreshes map[string]*refreshOperation
|
||||
cleanupTimers map[string]*time.Timer
|
||||
sessionRefreshAttempts map[string]*refreshAttemptTracker
|
||||
// inFlightRefreshes maps tokenHash -> *refreshOperation. sync.Map is used
|
||||
// instead of a plain map + RWMutex so concurrent refreshes do not
|
||||
// serialize on a single global lock. Under Yaegi the previous
|
||||
// refreshMutex.Lock() was held for tens of milliseconds per request due
|
||||
// to interpreter overhead on the work inside the critical section,
|
||||
// causing dozens of goroutines to stack up on it and pin one CPU core.
|
||||
inFlightRefreshes sync.Map
|
||||
// sessionRefreshAttempts maps sessionID -> *refreshAttemptTracker.
|
||||
// sync.Map + atomic tracker fields means isInCooldown/recordRefreshAttempt/
|
||||
// recordRefreshSuccess/recordRefreshFailure are lock-free. Previously
|
||||
// these used attemptsMutex sync.RWMutex; under Yaegi every Lock() acquisition
|
||||
// adds 10-50ms of dispatch overhead, and they were called twice per leader
|
||||
// request (once for recordRefreshAttempt, once for isInCooldown). That
|
||||
// serializing pattern caused the v1.0.15 death spiral after v1.0.14
|
||||
// removed the refreshMutex (same architectural shape, different mutex).
|
||||
sessionRefreshAttempts sync.Map
|
||||
circuitBreaker *RefreshCircuitBreaker
|
||||
metrics *RefreshMetrics
|
||||
logger *Logger
|
||||
stopChan chan struct{}
|
||||
config RefreshCoordinatorConfig
|
||||
wg sync.WaitGroup
|
||||
attemptsMutex sync.RWMutex
|
||||
refreshMutex sync.RWMutex
|
||||
cleanupTimerMu sync.Mutex
|
||||
}
|
||||
|
||||
// RefreshCoordinatorConfig configures the refresh coordinator behavior
|
||||
@@ -84,14 +94,46 @@ type refreshResult struct {
|
||||
fromCache bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session
|
||||
type refreshAttemptTracker struct {
|
||||
lastAttemptTime time.Time
|
||||
windowStartTime time.Time
|
||||
cooldownEndTime time.Time
|
||||
// attemptState is the immutable snapshot of a session's refresh-attempt
|
||||
// state. Lives behind refreshAttemptTracker.state (atomic.Value). Every
|
||||
// transition (record, success, failure, window-reset, cooldown-enter,
|
||||
// cooldown-exit) constructs a fresh attemptState and publishes it via
|
||||
// CompareAndSwap so the entire field set is updated together.
|
||||
//
|
||||
// Per-field atomic.Load/Store (the previous v1.0.15 design) had a benign
|
||||
// but observable hazard: the cooldown-exit reset wrote cooldownEndNano = 0
|
||||
// first, then separately stored attempts = 1 and windowStartNano = now.
|
||||
// A concurrent isInCooldown call could see cooldownEndNano = 0 (reset
|
||||
// just completed) with attempts still at MaxRefreshAttempts, triggering
|
||||
// a fresh cooldown immediately. The snapshot approach eliminates the
|
||||
// intermediate state entirely.
|
||||
type attemptState struct {
|
||||
lastAttemptNano int64 // UnixNano of last attempt
|
||||
windowStartNano int64 // UnixNano of attempt-window start
|
||||
cooldownEndNano int64 // UnixNano; 0 = not in cooldown
|
||||
attempts int32
|
||||
consecutiveFailures int32
|
||||
inCooldown bool
|
||||
}
|
||||
|
||||
// refreshAttemptTracker tracks refresh attempts for a session via a single
|
||||
// atomic.Value holding a *attemptState pointer. Readers do exactly one Load.
|
||||
// Writers do Load → construct new → CompareAndSwap (retry on conflict).
|
||||
// Under Yaegi this collapses 3-4 per-field atomic dispatches into one Load,
|
||||
// and eliminates the cross-field race in the window-reset path.
|
||||
type refreshAttemptTracker struct {
|
||||
state atomic.Value // *attemptState
|
||||
}
|
||||
|
||||
// stateOf returns the current attemptState, or a zero-value snapshot if none
|
||||
// has been published yet. The empty snapshot represents "no attempts recorded".
|
||||
func (t *refreshAttemptTracker) stateOf() *attemptState {
|
||||
if v := t.state.Load(); v != nil {
|
||||
s, _ := v.(*attemptState)
|
||||
if s != nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return &attemptState{}
|
||||
}
|
||||
|
||||
// RefreshMetrics tracks coordinator performance metrics
|
||||
@@ -106,14 +148,18 @@ type RefreshMetrics struct {
|
||||
currentInFlightRefreshes int32
|
||||
}
|
||||
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh operations
|
||||
// RefreshCircuitBreaker implements a circuit breaker specifically for refresh
|
||||
// operations. All mutable fields are atomic so AllowRequest/RecordSuccess/
|
||||
// RecordFailure run without any mutex. The previous sync.RWMutex.RLock() was
|
||||
// taken on every CoordinateRefresh — under Yaegi this added 10-50ms of
|
||||
// interpreter dispatch per call, which compounded with attemptsMutex to keep
|
||||
// the pod's single CPU core saturated.
|
||||
type RefreshCircuitBreaker struct {
|
||||
lastFailureTime time.Time
|
||||
lastSuccessTime time.Time
|
||||
lastFailureNano int64 // atomic, UnixNano of most recent failure
|
||||
lastSuccessNano int64 // atomic, UnixNano of most recent success
|
||||
config RefreshCircuitBreakerConfig
|
||||
mutex sync.RWMutex
|
||||
state int32
|
||||
failures int32
|
||||
state int32 // atomic: 0=closed, 1=open, 2=half-open
|
||||
failures int32 // atomic
|
||||
}
|
||||
|
||||
// RefreshCircuitBreakerConfig configures the refresh circuit breaker
|
||||
@@ -130,13 +176,12 @@ func NewRefreshCoordinator(config RefreshCoordinatorConfig, logger *Logger) *Ref
|
||||
}
|
||||
|
||||
rc := &RefreshCoordinator{
|
||||
inFlightRefreshes: make(map[string]*refreshOperation),
|
||||
sessionRefreshAttempts: make(map[string]*refreshAttemptTracker),
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
cleanupTimers: make(map[string]*time.Timer),
|
||||
// inFlightRefreshes and sessionRefreshAttempts are both sync.Map;
|
||||
// their zero values are ready to use.
|
||||
config: config,
|
||||
metrics: &RefreshMetrics{},
|
||||
logger: logger,
|
||||
stopChan: make(chan struct{}),
|
||||
circuitBreaker: &RefreshCircuitBreaker{
|
||||
config: RefreshCircuitBreakerConfig{
|
||||
MaxFailures: 3,
|
||||
@@ -227,13 +272,28 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
tokenHash string,
|
||||
refreshToken string,
|
||||
) (*refreshOperation, bool, error) {
|
||||
rc.refreshMutex.Lock()
|
||||
defer rc.refreshMutex.Unlock()
|
||||
// Speculatively construct the operation we WOULD register if we win the
|
||||
// race. Allocating here keeps the LoadOrStore call below atomic and
|
||||
// avoids any global lock — under Yaegi the previous map+RWMutex design
|
||||
// held the write lock long enough (tens of ms per call) that concurrent
|
||||
// refreshes on the same coordinator serialized into a queue that grew
|
||||
// without bound. See struct comment on inFlightRefreshes.
|
||||
candidate := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
|
||||
// Check for existing operation while holding the lock
|
||||
if existingOp, exists := rc.inFlightRefreshes[tokenHash]; exists {
|
||||
if existing, loaded := rc.inFlightRefreshes.LoadOrStore(tokenHash, candidate); loaded {
|
||||
existingOp, ok := existing.(*refreshOperation)
|
||||
if !ok {
|
||||
// Defensive: anything stored here is always *refreshOperation, but
|
||||
// keep the typed assert so a programming error elsewhere doesn't
|
||||
// surface as a confusing panic in an interpreter frame.
|
||||
return nil, false, fmt.Errorf("inFlightRefreshes corrupt: unexpected type %T", existing)
|
||||
}
|
||||
if existingOp.refreshToken == refreshToken {
|
||||
// Join existing operation
|
||||
atomic.AddInt32(&existingOp.waiterCount, 1)
|
||||
return existingOp, false, nil
|
||||
}
|
||||
@@ -241,41 +301,71 @@ func (rc *RefreshCoordinator) getOrCreateOperation(
|
||||
return nil, false, fmt.Errorf("refresh token mismatch")
|
||||
}
|
||||
|
||||
// No existing operation - check if we can create a new one
|
||||
// All checks happen while holding the lock to prevent races
|
||||
// We won the race and registered `candidate`. Apply gates now. If any
|
||||
// gate fails we must remove our entry from the map and signal failure
|
||||
// to any joiners that snuck in between LoadOrStore and now.
|
||||
if err := rc.applyLeaderGates(sessionID); err != nil {
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// Check and record refresh attempt for rate limiting
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
// Reserve concurrent slot via ticket-and-return: increment optimistically,
|
||||
// decrement if we overshot the limit. The previous CAS-loop allowed a
|
||||
// transient overshoot of up to N-1 leaders when several goroutines all
|
||||
// observed `current < max` in the same scheduling slice before any one
|
||||
// of them succeeded their CAS — visible to readers as
|
||||
// currentInFlightRefreshes > MaxConcurrentRefreshes for a brief window.
|
||||
// The ticket pattern is strictly bounded: the counter momentarily reads
|
||||
// max+k for k concurrent attempts past the limit, but only the k that
|
||||
// produced max+1..max+k decrement back, and only k=1 ever observes max+1
|
||||
// as committed.
|
||||
newCount := atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
if int(newCount) > rc.config.MaxConcurrentRefreshes {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
err := fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
rc.failCandidate(tokenHash, candidate, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return candidate, true, nil
|
||||
}
|
||||
|
||||
// applyLeaderGates runs the rate-limit, cooldown, and memory-pressure checks
|
||||
// that previously ran under the global refreshMutex. Only the leader (the
|
||||
// goroutine that just registered the operation) runs them; joiners share the
|
||||
// leader's outcome via operation.done.
|
||||
func (rc *RefreshCoordinator) applyLeaderGates(sessionID string) error {
|
||||
// Cooldown check FIRST, BEFORE incrementing the attempt counter.
|
||||
// Previously this function recorded the attempt and then read the
|
||||
// cooldown state. Under burst load (many concurrent leaders with
|
||||
// different token hashes but same session) every goroutine could
|
||||
// increment past MaxRefreshAttempts before any one of them observed
|
||||
// the threshold, so the cooldown gate fired too late — the same
|
||||
// thundering-herd shape that drove v1.0.14 into the ground.
|
||||
if rc.isInCooldown(sessionID) {
|
||||
atomic.AddInt64(&rc.metrics.cooldownsTriggered, 1)
|
||||
return nil, false, fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
return fmt.Errorf("refresh attempts exceeded for session, in cooldown period")
|
||||
}
|
||||
|
||||
// Check memory pressure
|
||||
if rc.config.EnableMemoryPressureDetection && rc.isUnderMemoryPressure() {
|
||||
atomic.AddInt64(&rc.metrics.memoryPressureEvents, 1)
|
||||
return nil, false, fmt.Errorf("system under memory pressure, refresh denied")
|
||||
return fmt.Errorf("system under memory pressure, refresh denied")
|
||||
}
|
||||
// Only count attempts that actually progress past the gates.
|
||||
rc.recordRefreshAttempt(sessionID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check and reserve concurrent refresh slot atomically
|
||||
current := atomic.LoadInt32(&rc.metrics.currentInFlightRefreshes)
|
||||
if int(current) >= rc.config.MaxConcurrentRefreshes {
|
||||
return nil, false, fmt.Errorf("maximum concurrent refresh operations reached")
|
||||
}
|
||||
|
||||
// Reserve the slot - we're still holding the lock so this is safe
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, 1)
|
||||
|
||||
// Create and register new operation
|
||||
operation := &refreshOperation{
|
||||
refreshToken: refreshToken,
|
||||
done: make(chan struct{}),
|
||||
startTime: time.Now(),
|
||||
waiterCount: 1,
|
||||
}
|
||||
rc.inFlightRefreshes[tokenHash] = operation
|
||||
|
||||
return operation, true, nil
|
||||
// failCandidate removes the leader's just-registered operation from the
|
||||
// in-flight map and signals the error to any joiners by recording the result
|
||||
// and closing the done channel. This keeps the (nil, false, err) return path
|
||||
// equivalent to the pre-sync.Map version: callers see the error directly,
|
||||
// joiners see it via operation.done.
|
||||
func (rc *RefreshCoordinator) failCandidate(tokenHash string, op *refreshOperation, err error) {
|
||||
rc.inFlightRefreshes.Delete(tokenHash)
|
||||
op.mutex.Lock()
|
||||
op.result = &refreshResult{err: err}
|
||||
op.mutex.Unlock()
|
||||
close(op.done)
|
||||
}
|
||||
|
||||
// executeRefreshAsync performs the actual refresh operation asynchronously
|
||||
@@ -338,130 +428,196 @@ func (rc *RefreshCoordinator) executeRefreshAsync(
|
||||
}
|
||||
}
|
||||
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning a goroutine
|
||||
// This prevents goroutine explosion under high load (500+ req/sec)
|
||||
// scheduleDelayedCleanup schedules a cleanup using a timer instead of spawning
|
||||
// a goroutine — time.AfterFunc uses the runtime's timer heap and never spawns
|
||||
// a per-timer goroutine until the callback actually fires.
|
||||
//
|
||||
// The previous implementation tracked every pending timer in a map guarded by
|
||||
// cleanupTimerMu so a duplicate scheduling could cancel the prior timer. That
|
||||
// "shouldn't happen" path was the only consumer of the map, but the mutex
|
||||
// fired on every successful refresh completion — yet another per-request
|
||||
// Yaegi-dispatched lock acquisition. performCleanup is already idempotent
|
||||
// (LoadAndDelete on the sync.Map), so a duplicate scheduling at worst fires
|
||||
// performCleanup twice; the second call is a no-op. Dropping the map removes
|
||||
// the whole class of contention on this code path.
|
||||
func (rc *RefreshCoordinator) scheduleDelayedCleanup(tokenHash string) {
|
||||
delay := rc.config.DeduplicationCleanupDelay
|
||||
if delay <= 0 {
|
||||
// Immediate cleanup
|
||||
rc.performCleanup(tokenHash)
|
||||
return
|
||||
}
|
||||
|
||||
// Use time.AfterFunc which is more efficient than spawning a goroutine with Sleep
|
||||
// time.AfterFunc uses the runtime's timer heap which is much more efficient
|
||||
rc.cleanupTimerMu.Lock()
|
||||
// Cancel any existing timer for this hash (shouldn't happen, but just in case)
|
||||
if existingTimer, exists := rc.cleanupTimers[tokenHash]; exists {
|
||||
existingTimer.Stop()
|
||||
}
|
||||
rc.cleanupTimers[tokenHash] = time.AfterFunc(delay, func() {
|
||||
rc.performCleanup(tokenHash)
|
||||
// Remove timer from map
|
||||
rc.cleanupTimerMu.Lock()
|
||||
delete(rc.cleanupTimers, tokenHash)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
})
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
time.AfterFunc(delay, func() { rc.performCleanup(tokenHash) })
|
||||
}
|
||||
|
||||
// performCleanup removes the operation from the in-flight map.
|
||||
// Idempotent: only decrements the in-flight counter if an entry was actually
|
||||
// removed. This guards against any future path accidentally calling cleanup
|
||||
// twice for the same tokenHash (which would corrupt the refresh budget).
|
||||
// removed. LoadAndDelete is atomic so any concurrent failCandidate or repeat
|
||||
// cleanup call will see exactly one removal — the budget cannot be corrupted
|
||||
// by double-decrement.
|
||||
func (rc *RefreshCoordinator) performCleanup(tokenHash string) {
|
||||
rc.refreshMutex.Lock()
|
||||
_, existed := rc.inFlightRefreshes[tokenHash]
|
||||
if existed {
|
||||
delete(rc.inFlightRefreshes, tokenHash)
|
||||
}
|
||||
rc.refreshMutex.Unlock()
|
||||
if existed {
|
||||
if _, existed := rc.inFlightRefreshes.LoadAndDelete(tokenHash); existed {
|
||||
atomic.AddInt32(&rc.metrics.currentInFlightRefreshes, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown after recording an attempt
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
// getOrCreateTracker fetches the tracker for sessionID or atomically creates a
|
||||
// fresh one. The sync.Map.LoadOrStore semantics make this lock-free even under
|
||||
// concurrent first-touch races: at most one tracker per sessionID survives.
|
||||
//
|
||||
// trackerFromMapValue centralizes the type assertion so the lint-mandated
|
||||
// two-value form lives in one place; the stored type is always
|
||||
// *refreshAttemptTracker by construction.
|
||||
func trackerFromMapValue(v interface{}) *refreshAttemptTracker {
|
||||
t, _ := v.(*refreshAttemptTracker)
|
||||
return t
|
||||
}
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
func (rc *RefreshCoordinator) getOrCreateTracker(sessionID string) *refreshAttemptTracker {
|
||||
if v, ok := rc.sessionRefreshAttempts.Load(sessionID); ok {
|
||||
return trackerFromMapValue(v)
|
||||
}
|
||||
fresh := &refreshAttemptTracker{}
|
||||
fresh.state.Store(&attemptState{windowStartNano: time.Now().UnixNano()})
|
||||
actual, _ := rc.sessionRefreshAttempts.LoadOrStore(sessionID, fresh)
|
||||
return trackerFromMapValue(actual)
|
||||
}
|
||||
|
||||
// mutateState performs a CompareAndSwap loop that applies mutate to the
|
||||
// current snapshot. mutate must be PURE: it receives an immutable view of
|
||||
// the current state and returns a fresh *attemptState. If mutate returns nil
|
||||
// the update is skipped (used by isInCooldown for "no change needed" paths).
|
||||
//
|
||||
// Retries on CAS conflict are bounded by the number of concurrent writers —
|
||||
// in practice 1-3. Under Yaegi each retry pays the dispatch cost of one Load
|
||||
// + one CompareAndSwap; still cheaper than the previous per-field atomic
|
||||
// sequence and immune to the cross-field race the v1.0.15 design had.
|
||||
func (t *refreshAttemptTracker) mutateState(mutate func(cur *attemptState) *attemptState) *attemptState {
|
||||
for {
|
||||
cur := t.stateOf()
|
||||
next := mutate(cur)
|
||||
if next == nil {
|
||||
return cur
|
||||
}
|
||||
if t.state.CompareAndSwap(cur, next) {
|
||||
return next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isInCooldown checks if a session is in cooldown. Snapshot-based: every
|
||||
// transition publishes a fresh *attemptState atomically so readers never see
|
||||
// a partially-updated state. The previous per-field atomic design had a
|
||||
// benign race in the cooldown-exit path (cooldownEndNano reset before
|
||||
// attempts reset) that could double-trigger cooldown.
|
||||
func (rc *RefreshCoordinator) isInCooldown(sessionID string) bool {
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return false // No tracker means first attempt, not in cooldown
|
||||
}
|
||||
|
||||
tracker := trackerFromMapValue(v)
|
||||
now := time.Now()
|
||||
nowNano := now.UnixNano()
|
||||
maxAttempts := rc.config.MaxRefreshAttempts
|
||||
window := rc.config.RefreshAttemptWindow
|
||||
cooldownPeriod := rc.config.RefreshCooldownPeriod
|
||||
|
||||
// Check if already in cooldown
|
||||
if tracker.inCooldown {
|
||||
if now.After(tracker.cooldownEndTime) {
|
||||
// Cooldown expired, reset tracker
|
||||
tracker.inCooldown = false
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.consecutiveFailures = 0
|
||||
tracker.windowStartTime = now
|
||||
return false
|
||||
cur := tracker.stateOf()
|
||||
|
||||
// Already in cooldown?
|
||||
if cur.cooldownEndNano != 0 {
|
||||
if nowNano <= cur.cooldownEndNano {
|
||||
return true // still in cooldown
|
||||
}
|
||||
return true // Still in cooldown
|
||||
}
|
||||
|
||||
// Check if window expired
|
||||
if now.Sub(tracker.windowStartTime) > rc.config.RefreshAttemptWindow {
|
||||
// Reset window
|
||||
tracker.attempts = 1 // Already recorded one attempt
|
||||
tracker.windowStartTime = now
|
||||
// Cooldown expired: atomically publish a fresh state with the window
|
||||
// restarted from one attempt. Whichever goroutine wins the CAS sets
|
||||
// the new snapshot; losers see it via the next stateOf load.
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano == 0 || nowNano <= s.cooldownEndNano {
|
||||
return nil // someone else already reset, or back in cooldown
|
||||
}
|
||||
return &attemptState{
|
||||
windowStartNano: nowNano,
|
||||
attempts: 1,
|
||||
}
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if just exceeded attempt limit
|
||||
if int(tracker.attempts) >= rc.config.MaxRefreshAttempts {
|
||||
// Enter cooldown now
|
||||
tracker.inCooldown = true
|
||||
tracker.cooldownEndTime = now.Add(rc.config.RefreshCooldownPeriod)
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, tracker.attempts)
|
||||
// Window expired?
|
||||
if time.Duration(nowNano-cur.windowStartNano) > window {
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if time.Duration(nowNano-s.windowStartNano) <= window {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.windowStartNano = nowNano
|
||||
next.attempts = 1
|
||||
return &next
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
// Just exceeded attempt limit?
|
||||
if int(cur.attempts) >= maxAttempts {
|
||||
end := now.Add(cooldownPeriod).UnixNano()
|
||||
published := tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
if s.cooldownEndNano != 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.cooldownEndNano = end
|
||||
return &next
|
||||
})
|
||||
if published.cooldownEndNano == end {
|
||||
rc.logger.Infof("Session %s entering refresh cooldown after %d attempts",
|
||||
sessionID, published.attempts)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting
|
||||
// recordRefreshAttempt records a refresh attempt for rate limiting. Lock-free
|
||||
// snapshot mutation; attempts and lastAttemptNano are advanced atomically.
|
||||
func (rc *RefreshCoordinator) recordRefreshAttempt(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
tracker, exists := rc.sessionRefreshAttempts[sessionID]
|
||||
if !exists {
|
||||
tracker = &refreshAttemptTracker{
|
||||
windowStartTime: time.Now(),
|
||||
}
|
||||
rc.sessionRefreshAttempts[sessionID] = tracker
|
||||
}
|
||||
|
||||
atomic.AddInt32(&tracker.attempts, 1)
|
||||
tracker.lastAttemptTime = time.Now()
|
||||
tracker := rc.getOrCreateTracker(sessionID)
|
||||
nowNano := time.Now().UnixNano()
|
||||
tracker.mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.attempts++
|
||||
next.lastAttemptNano = nowNano
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshSuccess records a successful refresh
|
||||
// recordRefreshSuccess records a successful refresh: zero consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshSuccess(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
tracker.consecutiveFailures = 0
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
if s.consecutiveFailures == 0 {
|
||||
return nil
|
||||
}
|
||||
next := *s
|
||||
next.consecutiveFailures = 0
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// recordRefreshFailure records a failed refresh
|
||||
// recordRefreshFailure records a failed refresh: increments consecutiveFailures.
|
||||
func (rc *RefreshCoordinator) recordRefreshFailure(sessionID string) {
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
if tracker, exists := rc.sessionRefreshAttempts[sessionID]; exists {
|
||||
atomic.AddInt32(&tracker.consecutiveFailures, 1)
|
||||
v, ok := rc.sessionRefreshAttempts.Load(sessionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
trackerFromMapValue(v).mutateState(func(s *attemptState) *attemptState {
|
||||
next := *s
|
||||
next.consecutiveFailures++
|
||||
return &next
|
||||
})
|
||||
}
|
||||
|
||||
// hashRefreshToken creates a hash of the refresh token for deduplication
|
||||
@@ -512,20 +668,22 @@ func (rc *RefreshCoordinator) cleanupRoutine() {
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupStaleEntries removes outdated tracking entries
|
||||
// cleanupStaleEntries removes outdated tracking entries. Lock-free iteration
|
||||
// via sync.Map.Range; safe to race with concurrent reads/writes.
|
||||
func (rc *RefreshCoordinator) cleanupStaleEntries() {
|
||||
now := time.Now()
|
||||
|
||||
rc.attemptsMutex.Lock()
|
||||
defer rc.attemptsMutex.Unlock()
|
||||
|
||||
// Clean up old session trackers
|
||||
for sessionID, tracker := range rc.sessionRefreshAttempts {
|
||||
// Remove trackers that haven't been used recently
|
||||
if now.Sub(tracker.lastAttemptTime) > 2*rc.config.RefreshAttemptWindow {
|
||||
delete(rc.sessionRefreshAttempts, sessionID)
|
||||
cutoff := time.Now().Add(-2 * rc.config.RefreshAttemptWindow).UnixNano()
|
||||
rc.sessionRefreshAttempts.Range(func(key, value interface{}) bool {
|
||||
tracker := trackerFromMapValue(value)
|
||||
if tracker == nil {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if tracker.stateOf().lastAttemptNano < cutoff {
|
||||
// Compare-and-delete to avoid evicting a tracker that was just
|
||||
// re-used by a concurrent caller. We compare by pointer identity.
|
||||
rc.sessionRefreshAttempts.CompareAndDelete(key, value)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// GetMetrics returns current coordinator metrics
|
||||
@@ -543,78 +701,60 @@ func (rc *RefreshCoordinator) GetMetrics() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown gracefully shuts down the coordinator
|
||||
// Shutdown gracefully shuts down the coordinator. Pending delayed-cleanup
|
||||
// timers are NOT canceled explicitly: time.AfterFunc callbacks are tiny
|
||||
// (one map LoadAndDelete) and harmless after Shutdown — sync.Map operations
|
||||
// remain safe on an unused coordinator until GC.
|
||||
func (rc *RefreshCoordinator) Shutdown() {
|
||||
close(rc.stopChan)
|
||||
|
||||
// Cancel all pending cleanup timers
|
||||
rc.cleanupTimerMu.Lock()
|
||||
for _, timer := range rc.cleanupTimers {
|
||||
timer.Stop()
|
||||
}
|
||||
rc.cleanupTimers = make(map[string]*time.Timer)
|
||||
rc.cleanupTimerMu.Unlock()
|
||||
|
||||
rc.wg.Wait()
|
||||
}
|
||||
|
||||
// AllowRequest checks if the circuit breaker allows a request
|
||||
// AllowRequest reports whether the circuit breaker allows a request. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) AllowRequest() bool {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
switch state {
|
||||
case 0: // Closed
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0: // closed
|
||||
return true
|
||||
case 1: // Open
|
||||
if time.Since(cb.lastFailureTime) > cb.config.OpenDuration {
|
||||
// Try to transition to half-open
|
||||
case 1: // open
|
||||
lastFail := atomic.LoadInt64(&cb.lastFailureNano)
|
||||
if time.Duration(time.Now().UnixNano()-lastFail) > cb.config.OpenDuration {
|
||||
// Transition to half-open; first CAS winner gets the probe.
|
||||
if atomic.CompareAndSwapInt32(&cb.state, 1, 2) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
case 2: // Half-open
|
||||
case 2: // half-open
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSuccess records a successful operation
|
||||
// RecordSuccess records a successful operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordSuccess() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
if state == 2 { // Half-open
|
||||
// Close the circuit
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 2: // half-open -> close
|
||||
atomic.StoreInt32(&cb.state, 0)
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
} else if state == 0 { // Closed
|
||||
// Reset failure count on success
|
||||
case 0: // closed
|
||||
atomic.StoreInt32(&cb.failures, 0)
|
||||
}
|
||||
cb.lastSuccessTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastSuccessNano, time.Now().UnixNano())
|
||||
}
|
||||
|
||||
// RecordFailure records a failed operation
|
||||
// RecordFailure records a failed operation. Lock-free.
|
||||
func (cb *RefreshCircuitBreaker) RecordFailure() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
failures := atomic.AddInt32(&cb.failures, 1)
|
||||
cb.lastFailureTime = time.Now()
|
||||
atomic.StoreInt64(&cb.lastFailureNano, time.Now().UnixNano())
|
||||
|
||||
state := atomic.LoadInt32(&cb.state)
|
||||
|
||||
if state == 0 && int(failures) >= cb.config.MaxFailures {
|
||||
// Open the circuit
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
} else if state == 2 {
|
||||
// Half-open failed, return to open
|
||||
switch atomic.LoadInt32(&cb.state) {
|
||||
case 0:
|
||||
if int(failures) >= cb.config.MaxFailures {
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
case 2:
|
||||
// Half-open probe failed -> back to open.
|
||||
atomic.StoreInt32(&cb.state, 1)
|
||||
}
|
||||
}
|
||||
|
||||
+28
-34
@@ -165,9 +165,14 @@ func TestRefreshRateLimiting(t *testing.T) {
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Verify that cooldown was triggered after max attempts
|
||||
// With the new logic, the Nth attempt triggers cooldown, so we get N-1 successful attempts
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts - 1
|
||||
// Verify that cooldown was triggered after max attempts.
|
||||
// With applyLeaderGates checking cooldown BEFORE recording the attempt
|
||||
// (the v1.0.16 reorder fixing the thundering-herd off-by-one), N attempts
|
||||
// run to completion and the (N+1)th is denied. Previously the Nth was
|
||||
// denied as it tried to record, which under burst load let multiple
|
||||
// concurrent leaders increment past the limit before any one of them
|
||||
// observed the gate.
|
||||
expectedSuccessfulAttempts := config.MaxRefreshAttempts
|
||||
if attempts != expectedSuccessfulAttempts {
|
||||
t.Errorf("Expected %d successful attempts before cooldown, got %d", expectedSuccessfulAttempts, attempts)
|
||||
}
|
||||
@@ -365,10 +370,12 @@ func TestMemoryLeakPrevention(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Verify cleanup is working
|
||||
coordinator.attemptsMutex.RLock()
|
||||
sessionCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
// Verify cleanup is working. sync.Map has no Len(); count via Range.
|
||||
sessionCount := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
sessionCount++
|
||||
return true
|
||||
})
|
||||
|
||||
// Should have cleaned up old sessions (only recent ones remain)
|
||||
if sessionCount > numWorkers*2 {
|
||||
@@ -650,24 +657,23 @@ func TestCleanupRoutine(t *testing.T) {
|
||||
coordinator.recordRefreshAttempt(fmt.Sprintf("session_%d", i))
|
||||
}
|
||||
|
||||
// Verify sessions exist
|
||||
coordinator.attemptsMutex.RLock()
|
||||
initialCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
countSessions := func() int {
|
||||
n := 0
|
||||
coordinator.sessionRefreshAttempts.Range(func(_, _ interface{}) bool {
|
||||
n++
|
||||
return true
|
||||
})
|
||||
return n
|
||||
}
|
||||
|
||||
if initialCount != 5 {
|
||||
if initialCount := countSessions(); initialCount != 5 {
|
||||
t.Errorf("Expected 5 sessions, got %d", initialCount)
|
||||
}
|
||||
|
||||
// Wait for cleanup to run (2x window + cleanup interval)
|
||||
time.Sleep(2*config.RefreshAttemptWindow + 2*config.CleanupInterval)
|
||||
|
||||
// Verify sessions were cleaned up
|
||||
coordinator.attemptsMutex.RLock()
|
||||
finalCount := len(coordinator.sessionRefreshAttempts)
|
||||
coordinator.attemptsMutex.RUnlock()
|
||||
|
||||
if finalCount != 0 {
|
||||
if finalCount := countSessions(); finalCount != 0 {
|
||||
t.Errorf("Expected 0 sessions after cleanup, got %d", finalCount)
|
||||
}
|
||||
}
|
||||
@@ -720,11 +726,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
currentGoroutines := runtime.NumGoroutine()
|
||||
t.Logf("Goroutines after %d refresh operations: %d", numRefreshes, currentGoroutines)
|
||||
|
||||
// Check timer count
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
timerCount := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
t.Logf("Active cleanup timers: %d", timerCount)
|
||||
// (Coordinator no longer tracks pending timers; time.AfterFunc closures
|
||||
// fire performCleanup directly. This test now only checks the goroutine
|
||||
// budget, which was always the real invariant.)
|
||||
|
||||
// With timer-based cleanup, goroutine increase should be minimal
|
||||
// Timers don't create goroutines - they use the runtime timer heap
|
||||
@@ -740,19 +744,9 @@ func TestNoGoroutineExplosionWithTimers(t *testing.T) {
|
||||
initialGoroutines, currentGoroutines, goroutineIncrease)
|
||||
}
|
||||
|
||||
// Wait for timers to fire and cleanup
|
||||
// Wait for timers to fire and cleanup.
|
||||
time.Sleep(config.DeduplicationCleanupDelay + 50*time.Millisecond)
|
||||
|
||||
// Verify timers were cleaned up
|
||||
coordinator.cleanupTimerMu.Lock()
|
||||
remainingTimers := len(coordinator.cleanupTimers)
|
||||
coordinator.cleanupTimerMu.Unlock()
|
||||
|
||||
// Most timers should have fired and been removed
|
||||
if remainingTimers > 10 {
|
||||
t.Errorf("Too many cleanup timers remaining: %d", remainingTimers)
|
||||
}
|
||||
|
||||
// Verify goroutines returned to near initial
|
||||
runtime.GC()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// requestState bundles read-mostly fields for a single ServeHTTP call.
|
||||
package traefikoidc
|
||||
|
||||
import "net/http"
|
||||
|
||||
// requestState is a per-request context object allocated at the top of
|
||||
// ServeHTTP and threaded through to downstream handlers. It caches values
|
||||
// that would otherwise require a Yaegi-dispatched lock acquisition each time
|
||||
// they're read:
|
||||
//
|
||||
// - The metadata snapshot (atomic.Value.Load once, not per-handler).
|
||||
// - SessionData getter results (one RLock on sd.sessionMutex covers all
|
||||
// fields, instead of 5-7 separate RLock/RUnlock pairs scattered through
|
||||
// the handler chain).
|
||||
//
|
||||
// The struct is alloc'd at request entry, populated under at most one RLock
|
||||
// of sd.sessionMutex, and discarded at request exit. It is NOT shared across
|
||||
// requests and never written from another goroutine, so no synchronization
|
||||
// on its fields is required.
|
||||
//
|
||||
// Cross-request global caches (tokenCache, JWKCache, sessionEntries,
|
||||
// sessionInvalidationCache) remain — they're orthogonal. requestState's job
|
||||
// is to eliminate redundant per-handler reads of values that don't change
|
||||
// within a single request.
|
||||
type requestState struct {
|
||||
// Globals snapshotted once.
|
||||
metadata *MetadataSnapshot
|
||||
|
||||
// SessionData fields snapshotted under one RLock. The pointer to the
|
||||
// SessionData is retained so handlers that genuinely need to mutate
|
||||
// (Save, Clear, etc.) still have access.
|
||||
session *SessionData
|
||||
|
||||
authenticated bool
|
||||
accessToken string
|
||||
idToken string
|
||||
refreshToken string
|
||||
userIdentifier string
|
||||
createdAtUnixSec int64
|
||||
|
||||
// Output: scheme/host/redirect path determined at top of ServeHTTP.
|
||||
scheme string
|
||||
host string
|
||||
redirectURL string
|
||||
|
||||
// Carry the next handler so forwardAuthorized doesn't need to close over t.
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// captureSession populates requestState's SessionData-derived fields under a
|
||||
// single RLock of sd.sessionMutex. Returns the populated rs for chaining.
|
||||
//
|
||||
// Replaces a sequence of SessionData.GetX() calls each of which acquires
|
||||
// sd.sessionMutex.RLock(). Under Yaegi each RLock costs ~1-5ms of
|
||||
// interpreter dispatch; batching saves the rest.
|
||||
func (rs *requestState) captureSession(sd *SessionData) *requestState {
|
||||
if sd == nil {
|
||||
return rs
|
||||
}
|
||||
rs.session = sd
|
||||
sd.sessionMutex.RLock()
|
||||
rs.authenticated = sd.getAuthenticatedUnsafe()
|
||||
rs.accessToken = sd.getAccessTokenUnsafe()
|
||||
rs.idToken = sd.getIDTokenUnsafe()
|
||||
rs.refreshToken = sd.getRefreshTokenUnsafe()
|
||||
rs.userIdentifier = sd.getUserIdentifierUnsafe()
|
||||
rs.createdAtUnixSec = sd.getCreatedAtUnsafe()
|
||||
sd.sessionMutex.RUnlock()
|
||||
return rs
|
||||
}
|
||||
@@ -54,6 +54,7 @@ type Config struct {
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedUsers []string `json:"allowedUsers"`
|
||||
Headers []TemplatedHeader `json:"headers"`
|
||||
ExtraAuthParams map[string]string `json:"extraAuthParams,omitempty"`
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
// MaxRefreshTokenAgeSeconds is a heuristic upper bound on the lifetime of
|
||||
// a stored refresh token. Once the token has been in the session longer
|
||||
|
||||
@@ -5,8 +5,6 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -860,437 +858,6 @@ func (t *TraefikOidc) isAzureProvider() bool {
|
||||
strings.Contains(issuerURL, "login.windows.net")
|
||||
}
|
||||
|
||||
// validateAzureTokens validates tokens with Azure AD-specific logic.
|
||||
// Azure tokens may be opaque access tokens that cannot be verified as JWTs,
|
||||
// so this method handles both JWT and opaque token scenarios.
|
||||
// Parameters:
|
||||
// - session: The session data containing tokens to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated: Whether the user has valid authentication.
|
||||
// - needsRefresh: Whether tokens need to be refreshed.
|
||||
// - expired: Whether tokens have expired and cannot be refreshed.
|
||||
//
|
||||
//nolint:gocognit // Azure-specific validation requires multiple token type checks
|
||||
func (t *TraefikOidc) validateAzureTokens(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("Azure user is not authenticated according to session flag")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Azure session not authenticated, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, true, false
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
idToken := session.GetIDToken()
|
||||
|
||||
if accessToken != "" {
|
||||
if strings.Count(accessToken, ".") == 2 {
|
||||
// Microsoft documents that client apps cannot validate access
|
||||
// tokens issued for Microsoft-owned APIs (Graph, Azure Mgmt) due
|
||||
// to their proprietary signing format (nonce in JWT header is
|
||||
// the marker — signed bytes hash the nonce, wire bytes ship the
|
||||
// raw value, so rsa verification always fails). Treat such
|
||||
// tokens as opaque, matching Microsoft's guidance and avoiding
|
||||
// per-request signature-error log spam (issue #134 followup).
|
||||
//
|
||||
// https://learn.microsoft.com/en-us/entra/identity-platform/access-tokens
|
||||
// "you can't validate tokens for Microsoft Graph according to
|
||||
// these rules due to their proprietary format"
|
||||
if t.isUnverifiableAzureAccessToken(accessToken) {
|
||||
t.logger.Debug("Azure access token is Microsoft-proprietary (Graph/Mgmt) — treating as opaque per Microsoft guidance")
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Debugf("Azure: ID token validation failed while access token was opaque: %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Debugf("Azure: Both access and ID token validation failed: %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
}
|
||||
t.logger.Debug("Azure access token appears opaque, treating as valid")
|
||||
if idToken != "" {
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
if idToken != "" {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// validateGoogleTokens handles Google-specific token validation logic.
|
||||
// Currently delegates to standard token validation but provides a hook
|
||||
// for Google-specific validation requirements in the future.
|
||||
// Parameters:
|
||||
// - session: The session data containing tokens to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated: Whether the user has valid authentication.
|
||||
// - needsRefresh: Whether tokens need to be refreshed.
|
||||
// - expired: Whether tokens have expired and cannot be refreshed.
|
||||
func (t *TraefikOidc) validateGoogleTokens(session *SessionData) (bool, bool, bool) {
|
||||
return t.validateStandardTokens(session)
|
||||
}
|
||||
|
||||
// validateStandardTokens handles standard OIDC token validation logic.
|
||||
// This is the default validation method for generic OIDC providers.
|
||||
// It verifies ID tokens and handles access tokens appropriately.
|
||||
// Parameters:
|
||||
// - session: The session data containing tokens to validate.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated: Whether the user has valid authentication.
|
||||
// - needsRefresh: Whether tokens need to be refreshed.
|
||||
// - expired: Whether tokens have expired and cannot be refreshed.
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Complex validation logic handles multiple token scenarios and edge cases
|
||||
func (t *TraefikOidc) validateStandardTokens(session *SessionData) (bool, bool, bool) {
|
||||
authenticated := session.GetAuthenticated()
|
||||
// Removed debug output
|
||||
if !authenticated {
|
||||
t.logger.Debug("User is not authenticated according to session flag")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Session not authenticated, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
// Removed debug output
|
||||
if accessToken == "" {
|
||||
t.logger.Debug("Authenticated flag set, but no access token found in session")
|
||||
if session.GetRefreshToken() != "" {
|
||||
// Check if we have an ID token to determine if we're beyond grace period
|
||||
// When access token is missing, check ID token expiry to determine if refresh is viable
|
||||
idToken := session.GetIDToken()
|
||||
t.logger.Debugf("Checking ID token for grace period: ID token present: %v", idToken != "")
|
||||
if idToken != "" {
|
||||
// Try to parse the ID token to check its expiry
|
||||
parts := strings.Split(idToken, ".")
|
||||
if len(parts) == 3 {
|
||||
// Decode the claims part
|
||||
claimsData, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err == nil {
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(claimsData, &claims); err == nil {
|
||||
if expClaim, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(expClaim), 0)
|
||||
if time.Now().After(expTime) {
|
||||
expiredDuration := time.Since(expTime)
|
||||
if expiredDuration > t.refreshGracePeriod {
|
||||
t.logger.Debugf("ID token expired beyond grace period (%v > %v), must re-authenticate",
|
||||
expiredDuration, t.refreshGracePeriod)
|
||||
return false, false, true // expired, cannot refresh
|
||||
}
|
||||
t.logger.Debugf("ID token expired %v ago, within grace period %v, allowing refresh",
|
||||
expiredDuration, t.refreshGracePeriod)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
t.logger.Debug("Access token missing, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// Check if access token is opaque (doesn't have JWT structure)
|
||||
dotCount := strings.Count(accessToken, ".")
|
||||
isOpaqueToken := dotCount != 2
|
||||
|
||||
// For opaque access tokens, use introspection if available (RFC 7662 - Option C: Scenario 3)
|
||||
if isOpaqueToken {
|
||||
t.logger.Debugf("Access token appears to be opaque (dots: %d)", dotCount)
|
||||
|
||||
// Try introspection first if opaque tokens are allowed
|
||||
if t.allowOpaqueTokens {
|
||||
if err := t.validateOpaqueToken(accessToken); err != nil {
|
||||
errMsg := err.Error()
|
||||
t.logger.Infof("⚠️ Opaque access token validation via introspection failed: %v", err)
|
||||
|
||||
// Check if the token was explicitly marked as inactive/revoked/expired by the provider
|
||||
// In these cases, we should NOT fall back to ID token - the provider has explicitly
|
||||
// told us this token is no longer valid. We must refresh or re-authenticate.
|
||||
isTokenInvalid := strings.Contains(errMsg, "token is not active") ||
|
||||
strings.Contains(errMsg, "revoked") ||
|
||||
strings.Contains(errMsg, "token has expired")
|
||||
|
||||
if isTokenInvalid {
|
||||
t.logger.Infof("⚠️ Token explicitly marked as invalid by provider, cannot fall back to ID token")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Refresh token available, attempting refresh")
|
||||
return false, true, false
|
||||
}
|
||||
t.logger.Debug("No refresh token available, must re-authenticate")
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// If introspection required, reject the session
|
||||
if t.requireTokenIntrospection {
|
||||
t.logger.Errorf("❌ SECURITY: Opaque token rejected (introspection required but failed)")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// Only fall back to ID token validation for transient errors (network issues, etc.)
|
||||
// where the introspection endpoint couldn't be reached
|
||||
t.logger.Infof("⚠️ Falling back to ID token validation for opaque access token (transient error)")
|
||||
} else {
|
||||
// Introspection successful
|
||||
t.logger.Debugf("✓ Opaque access token validated via introspection")
|
||||
// Still need to check ID token for session expiry
|
||||
idToken := session.GetIDToken()
|
||||
if idToken != "" {
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
} else {
|
||||
// Opaque tokens not allowed - log warning and reject or fall back
|
||||
t.logger.Infof("⚠️ Opaque access token detected but allowOpaqueTokens=false")
|
||||
}
|
||||
|
||||
// Fall back to ID token validation
|
||||
idToken := session.GetIDToken()
|
||||
if idToken == "" {
|
||||
t.logger.Debug("Opaque access token present but no ID token found")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("ID token missing but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
// Accept session with opaque access token even without ID token
|
||||
// The OAuth provider validated it when issued
|
||||
t.logger.Debug("Accepting session with opaque access token")
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// Validate ID token if present
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
t.logger.Debugf("ID token expired with opaque access token, needs refresh")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
t.logger.Errorf("ID token verification failed with opaque access token: %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// Use ID token for expiry validation
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
|
||||
// JWT access token present - validate it explicitly to detect Scenario 2
|
||||
// (Option C: Scenario 2 detection and strict mode)
|
||||
accessTokenValid := false
|
||||
accessTokenError := ""
|
||||
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
// Access token validation failed
|
||||
accessTokenError = err.Error()
|
||||
|
||||
// Check if it's an audience validation failure (Scenario 2)
|
||||
if strings.Contains(accessTokenError, "invalid audience") || strings.Contains(accessTokenError, "audience") {
|
||||
// SCENARIO 2 DETECTED: Access token has wrong audience
|
||||
t.logger.Infof("⚠️ SCENARIO 2 DETECTED: Access token validation failed due to audience mismatch: %v", err)
|
||||
|
||||
if t.strictAudienceValidation {
|
||||
// Strict mode: Reject the session (don't fall back to ID token)
|
||||
t.logger.Errorf("❌ SECURITY: Session rejected due to access token audience mismatch (strictAudienceValidation=true)")
|
||||
t.logger.Errorf("❌ This prevents potential cross-API token confusion attacks (Auth0 Scenario 2)")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false // try refresh
|
||||
}
|
||||
return false, false, true // must re-authenticate
|
||||
}
|
||||
// Backward compatibility mode: Log loud warning but allow fallback to ID token
|
||||
t.logger.Infof("⚠️⚠️⚠️ SECURITY WARNING: Falling back to ID token validation despite access token audience mismatch!")
|
||||
t.logger.Infof("⚠️ This could allow tokens intended for different APIs to grant access")
|
||||
t.logger.Infof("⚠️ Set strictAudienceValidation=true to enforce proper audience validation")
|
||||
t.logger.Infof("⚠️ See: https://github.com/lukaszraczylo/traefikoidc/issues/74")
|
||||
} else if !strings.Contains(accessTokenError, "token has expired") {
|
||||
// Other validation errors (not expiration, not audience)
|
||||
t.logger.Debugf("Access token validation failed (non-expiration, non-audience): %v", err)
|
||||
}
|
||||
} else {
|
||||
// Access token is valid
|
||||
accessTokenValid = true
|
||||
}
|
||||
|
||||
idToken := session.GetIDToken()
|
||||
if idToken == "" {
|
||||
if accessTokenValid {
|
||||
// Access token is valid, no ID token needed
|
||||
t.logger.Debug("Access token valid, no ID token present")
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
}
|
||||
|
||||
t.logger.Debug("Authenticated flag set with access token, but no ID token found in session")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("ID token missing but refresh token exists. Signaling conditional refresh to obtain ID token.")
|
||||
return true, true, false
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// Validate ID token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
t.logger.Debugf("ID token signature/claims valid but token expired, needs refresh")
|
||||
if session.GetRefreshToken() != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
t.logger.Errorf("ID token verification failed (non-expiration): %v", err)
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("ID token verification failed, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
// If access token was valid, use it for expiry; otherwise use ID token
|
||||
if accessTokenValid {
|
||||
return t.validateTokenExpiry(session, accessToken)
|
||||
}
|
||||
|
||||
return t.validateTokenExpiry(session, idToken)
|
||||
}
|
||||
|
||||
// validateTokenExpiry checks if a token is nearing expiration and needs refresh.
|
||||
// It uses the configured grace period to determine when proactive refresh should occur.
|
||||
// Parameters:
|
||||
// - session: The session data for refresh token availability.
|
||||
// - token: The token to check expiry for.
|
||||
//
|
||||
// Returns:
|
||||
// - authenticated: Whether the token is currently valid.
|
||||
// - needsRefresh: Whether the token is nearing expiration and should be refreshed.
|
||||
// - expired: Whether the token is invalid or verification failed.
|
||||
func (t *TraefikOidc) validateTokenExpiry(session *SessionData, token string) (bool, bool, bool) {
|
||||
cachedClaims, found := t.tokenCache.Get(token)
|
||||
if !found {
|
||||
t.logger.Debug("Claims not found in cache after successful token verification")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Claims missing post-verification, attempting refresh to recover.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
expClaim, ok := cachedClaims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debug("Token missing 'exp' claim, but refresh token exists. Signaling need for refresh.")
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
expTime := int64(expClaim)
|
||||
expTimeObj := time.Unix(expTime, 0)
|
||||
nowObj := time.Now()
|
||||
|
||||
// Check if token has already expired
|
||||
if expTimeObj.Before(nowObj) {
|
||||
// Token has expired
|
||||
expiredDuration := nowObj.Sub(expTimeObj)
|
||||
|
||||
t.logger.Debugf("Token expired %v ago, grace period is %v",
|
||||
expiredDuration, t.refreshGracePeriod)
|
||||
|
||||
// If we have a refresh token, always attempt to use it regardless of grace period
|
||||
// The refresh token has its own expiry and the provider will reject it if invalid
|
||||
if session.GetRefreshToken() != "" {
|
||||
t.logger.Debugf("Token expired, attempting refresh with available refresh token")
|
||||
return false, true, false // needs refresh
|
||||
}
|
||||
|
||||
// No refresh token available - must re-authenticate
|
||||
t.logger.Debugf("Token expired and no refresh token available, must re-authenticate")
|
||||
return false, false, true // expired, cannot refresh
|
||||
}
|
||||
|
||||
// Token not yet expired - check if nearing expiration
|
||||
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
|
||||
|
||||
t.logger.Debugf("Token expires at %v, now is %v, refresh threshold is %v",
|
||||
expTimeObj.Format(time.RFC3339),
|
||||
nowObj.Format(time.RFC3339),
|
||||
refreshThreshold.Format(time.RFC3339))
|
||||
|
||||
if expTimeObj.Before(refreshThreshold) {
|
||||
remainingSeconds := int64(time.Until(expTimeObj).Seconds())
|
||||
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling proactive refresh",
|
||||
remainingSeconds, t.refreshGracePeriod)
|
||||
|
||||
if session.GetRefreshToken() != "" {
|
||||
return true, true, false
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token nearing expiration but no refresh token available, cannot proactively refresh.")
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token is valid and not nearing expiration (expires in %d seconds, outside %s grace period)",
|
||||
int64(time.Until(expTimeObj).Seconds()), t.refreshGracePeriod)
|
||||
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// startTokenCleanup starts background cleanup goroutines for cache maintenance.
|
||||
// It runs periodic cleanup of token cache, JWK cache, and session chunks.
|
||||
|
||||
@@ -0,0 +1,286 @@
|
||||
// Package traefikoidc provides OIDC authentication middleware for Traefik.
|
||||
// This file contains requestState-aware variants of the token validation
|
||||
// functions. They read session field values from the captured snapshot in
|
||||
// *requestState instead of calling session.GetX(), eliminating ~21 RLock
|
||||
// acquisitions on sd.sessionMutex per request through the validation path
|
||||
// (validateStandardTokens reads 17, validateAzureTokens reads 10,
|
||||
// validateTokenExpiry reads 4 — and many are the SAME field). Under Yaegi
|
||||
// each RLock costs ~1-5ms of interpreter dispatch.
|
||||
//
|
||||
// The non-RS variants are retained for paths that don't have a captured
|
||||
// snapshot (tests that drive the validators directly, the Azure/Google path
|
||||
// when reached without rs threading, etc).
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// isUserAuthenticatedRS is the requestState-aware variant of
|
||||
// isUserAuthenticated. Dispatches to the right per-provider validator based
|
||||
// on the configured provider, all of which read from rs instead of session.
|
||||
func (t *TraefikOidc) isUserAuthenticatedRS(rs *requestState) (bool, bool, bool) {
|
||||
if t.isAzureProvider() {
|
||||
return t.validateAzureTokensRS(rs)
|
||||
} else if t.isGoogleProvider() {
|
||||
return t.validateGoogleTokensRS(rs)
|
||||
}
|
||||
return t.validateStandardTokensRS(rs)
|
||||
}
|
||||
|
||||
// validateGoogleTokensRS handles Google-specific token validation. Currently
|
||||
// delegates to standard token validation; retained as a hook for any future
|
||||
// Google-specific behavior (matches the v1.0.20 layout of the non-RS variant).
|
||||
func (t *TraefikOidc) validateGoogleTokensRS(rs *requestState) (bool, bool, bool) {
|
||||
return t.validateStandardTokensRS(rs)
|
||||
}
|
||||
|
||||
// validateTokenExpiryRS is the requestState-aware variant of validateTokenExpiry.
|
||||
// Reads rs.refreshToken instead of session.GetRefreshToken() (4 RLocks avoided).
|
||||
func (t *TraefikOidc) validateTokenExpiryRS(rs *requestState, token string) (bool, bool, bool) {
|
||||
cachedClaims, found := t.tokenCache.Get(token)
|
||||
if !found {
|
||||
t.logger.Debug("Claims not found in cache after successful token verification")
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
expClaim, ok := cachedClaims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Error("Failed to get expiration time ('exp' claim) from verified token")
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
expTimeObj := time.Unix(int64(expClaim), 0)
|
||||
nowObj := time.Now()
|
||||
|
||||
if expTimeObj.Before(nowObj) {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
refreshThreshold := nowObj.Add(t.refreshGracePeriod)
|
||||
if expTimeObj.Before(refreshThreshold) {
|
||||
if rs.refreshToken != "" {
|
||||
return true, true, false
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// validateStandardTokensRS is the requestState-aware variant of
|
||||
// validateStandardTokens. Replaces all session.GetX() calls (17 of them in
|
||||
// the non-RS variant, dominated by GetRefreshToken called 11 times) with
|
||||
// rs field reads. Same control flow.
|
||||
//
|
||||
//nolint:gocognit,gocyclo // Mirrors validateStandardTokens complexity by design.
|
||||
func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bool) {
|
||||
if !rs.authenticated {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
if rs.accessToken == "" {
|
||||
if rs.refreshToken != "" {
|
||||
// ID-token grace-period check (only when accessToken is absent).
|
||||
if rs.idToken != "" {
|
||||
parts := strings.Split(rs.idToken, ".")
|
||||
if len(parts) == 3 {
|
||||
if claimsData, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(claimsData, &claims); err == nil {
|
||||
if expClaim, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(expClaim), 0)
|
||||
if time.Now().After(expTime) {
|
||||
expiredDuration := time.Since(expTime)
|
||||
if expiredDuration > t.refreshGracePeriod {
|
||||
return false, false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
dotCount := strings.Count(rs.accessToken, ".")
|
||||
isOpaqueToken := dotCount != 2
|
||||
|
||||
if isOpaqueToken {
|
||||
if t.allowOpaqueTokens {
|
||||
if err := t.validateOpaqueToken(rs.accessToken); err != nil {
|
||||
errMsg := err.Error()
|
||||
isTokenInvalid := strings.Contains(errMsg, "token is not active") ||
|
||||
strings.Contains(errMsg, "revoked") ||
|
||||
strings.Contains(errMsg, "token has expired")
|
||||
if isTokenInvalid {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
if t.requireTokenIntrospection {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
// Transient introspection error: fall through to ID-token validation.
|
||||
} else {
|
||||
// Introspection succeeded.
|
||||
if rs.idToken != "" {
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to ID-token validation when opaque + no successful introspection.
|
||||
if rs.idToken == "" {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
|
||||
// JWT access token present.
|
||||
accessTokenValid := false
|
||||
if err := t.verifyToken(rs.accessToken); err != nil {
|
||||
errMsg := err.Error()
|
||||
if strings.Contains(errMsg, "invalid audience") || strings.Contains(errMsg, "audience") {
|
||||
if t.strictAudienceValidation {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
// Fall through to ID-token validation.
|
||||
}
|
||||
} else {
|
||||
accessTokenValid = true
|
||||
}
|
||||
|
||||
if rs.idToken == "" {
|
||||
if accessTokenValid {
|
||||
return t.validateTokenExpiryRS(rs, rs.accessToken)
|
||||
}
|
||||
if rs.refreshToken != "" {
|
||||
return true, true, false
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if strings.Contains(err.Error(), "token has expired") {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
if accessTokenValid {
|
||||
return t.validateTokenExpiryRS(rs, rs.accessToken)
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
|
||||
// validateAzureTokensRS is the requestState-aware variant of validateAzureTokens.
|
||||
// Eliminates 10 session.GetX() RLocks per Azure-path request.
|
||||
func (t *TraefikOidc) validateAzureTokensRS(rs *requestState) (bool, bool, bool) {
|
||||
if !rs.authenticated {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, true, false
|
||||
}
|
||||
|
||||
if rs.accessToken != "" {
|
||||
if strings.Count(rs.accessToken, ".") == 2 {
|
||||
if t.isUnverifiableAzureAccessToken(rs.accessToken) {
|
||||
if rs.idToken != "" {
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
if err := t.verifyToken(rs.accessToken); err != nil {
|
||||
if rs.idToken != "" {
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.accessToken)
|
||||
}
|
||||
// Opaque access token.
|
||||
if rs.idToken != "" {
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
if rs.idToken != "" {
|
||||
if err := t.verifyToken(rs.idToken); err != nil {
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
return t.validateTokenExpiryRS(rs, rs.idToken)
|
||||
}
|
||||
|
||||
if rs.refreshToken != "" {
|
||||
return false, true, false
|
||||
}
|
||||
return false, false, true
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
@@ -64,8 +65,46 @@ type ProviderMetadata struct {
|
||||
// It integrates with various OIDC providers, manages sessions, caches tokens, and handles
|
||||
// the complete authentication flow. It's designed to work seamlessly with Traefik's
|
||||
// plugin system and provides flexible configuration options.
|
||||
// MetadataSnapshot is an immutable bundle of provider-metadata URLs that the
|
||||
// plugin needs on the hot request path. Published atomically via
|
||||
// TraefikOidc.metadataSnapshot; readers do exactly one atomic.Value.Load to
|
||||
// access all fields. Replaces 3 per-request metadataMu.RLock acquisitions
|
||||
// in middleware.ServeHTTP + token_manager paths, each of which paid
|
||||
// 1-5ms of Yaegi-dispatch overhead.
|
||||
//
|
||||
// The fields are a strict subset of the metadataMu-guarded TraefikOidc
|
||||
// fields; the legacy fields are still written under metadataMu for
|
||||
// less-frequent code paths that have not been migrated.
|
||||
type MetadataSnapshot struct {
|
||||
IssuerURL string
|
||||
JWKSURL string
|
||||
TokenURL string
|
||||
AuthURL string
|
||||
RevocationURL string
|
||||
EndSessionURL string
|
||||
IntrospectionURL string
|
||||
RegistrationURL string
|
||||
}
|
||||
|
||||
type TraefikOidc struct {
|
||||
lastMetadataRetryTime time.Time
|
||||
// metadataSnapshot atomically publishes the read-mostly URL bundle.
|
||||
// Hot-path readers (middleware.ServeHTTP, token verification) load it
|
||||
// directly; less-frequent paths still acquire metadataMu.RLock and
|
||||
// read the individual fields below.
|
||||
metadataSnapshot atomic.Value
|
||||
// lastMetadataRetryNano is the UnixNano timestamp of the last metadata
|
||||
// recovery attempt. Stored atomically so the hot ServeHTTP path can
|
||||
// throttle retries without acquiring metadataRetryMutex on every request.
|
||||
lastMetadataRetryNano int64
|
||||
// firstRequestStarted is 0 until the very first non-health request fires
|
||||
// the background-task bootstrap; then it flips to 1 via CAS. Replaces the
|
||||
// firstRequestMutex + firstRequestReceived combo which previously took
|
||||
// a write lock on every non-health request forever.
|
||||
firstRequestStarted int32
|
||||
// metadataRefreshStartedAtomic is the CAS-only variant of the old
|
||||
// metadataRefreshStarted bool. Both flags live under the same atomic so
|
||||
// concurrent first-request goroutines race exactly once.
|
||||
metadataRefreshStartedAtomic int32
|
||||
jwkCache JWKCacheInterface
|
||||
jwtVerifier JWTVerifier
|
||||
ctx context.Context
|
||||
@@ -126,21 +165,18 @@ type TraefikOidc struct {
|
||||
frontchannelLogoutPath string
|
||||
scopesSupported []string
|
||||
scopes []string
|
||||
extraAuthParams map[string]string
|
||||
refreshGracePeriod time.Duration
|
||||
maxRefreshTokenAge time.Duration
|
||||
metadataMu sync.RWMutex
|
||||
shutdownOnce sync.Once
|
||||
metadataRetryMutex sync.Mutex
|
||||
firstRequestMutex sync.Mutex
|
||||
sessionInvalidationCache CacheInterface
|
||||
refreshResultCache CacheInterface
|
||||
minimalHeaders bool
|
||||
stripAuthCookies bool
|
||||
enableBackchannelLogout bool
|
||||
enableFrontchannelLogout bool
|
||||
firstRequestReceived bool
|
||||
requireTokenIntrospection bool
|
||||
metadataRefreshStarted bool
|
||||
allowPrivateIPAddresses bool
|
||||
disableReplayDetection bool
|
||||
allowOpaqueTokens bool
|
||||
|
||||
+44
-10
@@ -396,8 +396,16 @@ func (c *UniversalCache) getLocal(key string) (interface{}, bool) {
|
||||
return value, true
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
// Expired — fall through to the write-locked slow path below to
|
||||
// remove the entry under exclusive access.
|
||||
// Expired — return miss immediately. The periodic cleanup goroutine
|
||||
// will evict the stale entry. NEVER fall through to the write-locked
|
||||
// slow path for Token/JWK/Session caches: under Yaegi the write Lock
|
||||
// at line 403 costs 10-100ms per acquisition, and Go's RWMutex
|
||||
// writer-priority semantics block ALL new RLock callers while a Lock
|
||||
// is pending. A single expired-token event turns every concurrent
|
||||
// request from read-parallel into write-serialized — the exact
|
||||
// convoy that produced the 737-goroutine pileup at 0x400275a608.
|
||||
atomic.AddInt64(&c.misses, 1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
@@ -595,15 +603,28 @@ func (c *UniversalCache) removeItem(key string, item *CacheItem) {
|
||||
|
||||
// evictOldest evicts the oldest item from the cache (must be called with lock held)
|
||||
func (c *UniversalCache) evictOldest() {
|
||||
if elem := c.lruList.Back(); elem != nil {
|
||||
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
|
||||
if item, exists := c.items[key]; exists {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
}
|
||||
elem := c.lruList.Back()
|
||||
if elem == nil {
|
||||
return
|
||||
}
|
||||
key, _ := elem.Value.(string) // Safe to ignore: cache internal type assertion
|
||||
if item, exists := c.items[key]; exists && item.element == elem {
|
||||
c.removeItem(key, item)
|
||||
atomic.AddInt64(&c.evictions, 1)
|
||||
if c.logger.IsDebug() {
|
||||
c.logger.Debugf("UniversalCache[%s]: Evicted key=%s", c.config.Type, key)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Defensive forward-progress guard: the back node is dangling — its key is
|
||||
// absent from c.items, or c.items[key] points at a newer node (a stale
|
||||
// duplicate). Drop the node directly so an eviction loop
|
||||
// (`for ... && c.lruList.Len() > 0`) is guaranteed to terminate and can
|
||||
// never spin holding c.mu.Lock(). With the updateLocalCache replace-in-place
|
||||
// fix this branch should be unreachable, but it makes the spin impossible.
|
||||
c.lruList.Remove(elem)
|
||||
if c.currentSize > 0 {
|
||||
c.currentSize--
|
||||
}
|
||||
}
|
||||
|
||||
@@ -936,6 +957,19 @@ func (c *UniversalCache) updateLocalCache(key string, value interface{}, ttl tim
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// Replace any existing entry in place. Without this, a repeat populate of
|
||||
// the same key (the per-request Get->backend-hit path at line ~359)
|
||||
// PushFronts a second list node and overwrites c.items[key], orphaning the
|
||||
// previous node. Orphans inflate currentMemory/currentSize and — once the
|
||||
// eviction loop deletes the key — leave Back() nodes whose key is absent
|
||||
// from c.items, so evictOldest() no-ops while lruList.Len()>0 stays true:
|
||||
// an infinite loop while holding c.mu.Lock(), i.e. the 100%-CPU holder and
|
||||
// write-lock convoy. setLocal already dedups on this path; this mirrors it.
|
||||
if existing, ok := c.items[key]; ok {
|
||||
c.removeItem(key, existing)
|
||||
}
|
||||
|
||||
item := &CacheItem{
|
||||
Key: key,
|
||||
Value: value,
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// newOrphanTestCache builds a Token-type cache with background cleanup disabled
|
||||
// so the test fully controls lruList/items state.
|
||||
func newOrphanTestCache(maxMem int64) *UniversalCache {
|
||||
return NewUniversalCache(UniversalCacheConfig{
|
||||
Type: CacheTypeToken,
|
||||
DefaultTTL: time.Hour,
|
||||
MaxSize: 1_000_000, // large: keep the size-branch out of the way
|
||||
MaxMemoryBytes: maxMem,
|
||||
EnableMemoryLimit: maxMem > 0,
|
||||
SkipAutoCleanup: true,
|
||||
EnableAutoCleanup: false,
|
||||
})
|
||||
}
|
||||
|
||||
// TestUpdateLocalCache_NoOrphanElements is the direct red test: repeatedly
|
||||
// populating the SAME key via updateLocalCache (the per-request Get->backend-hit
|
||||
// path) must NOT leave dangling lruList elements. Today updateLocalCache blindly
|
||||
// PushFronts + overwrites c.items[key] without removing the prior element, so the
|
||||
// list grows one orphan per call while items stays at 1 entry.
|
||||
func TestUpdateLocalCache_NoOrphanElements(t *testing.T) {
|
||||
c := newOrphanTestCache(0) // memory limit off: isolate the orphan, no eviction
|
||||
const key = "same-key"
|
||||
|
||||
for range 5 {
|
||||
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
|
||||
t.Fatalf("updateLocalCache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
c.mu.RLock()
|
||||
listLen := c.lruList.Len()
|
||||
itemCount := len(c.items)
|
||||
c.mu.RUnlock()
|
||||
|
||||
if itemCount != 1 {
|
||||
t.Fatalf("items: got %d want 1", itemCount)
|
||||
}
|
||||
if listLen != 1 {
|
||||
t.Fatalf("ORPHAN BUG: lruList.Len()=%d but items=%d (one list element per key expected)", listLen, itemCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateLocalCache_EvictionTerminates is the convoy reproducer: once orphans
|
||||
// for a key exist and the memory-eviction loop runs, evictOldest() deletes the
|
||||
// key from items on the first eviction, after which every remaining orphan at
|
||||
// Back() has a key absent from items -> evictOldest() no-ops while lruList.Len()>0
|
||||
// stays true -> infinite loop while holding c.mu.Lock(). That is the 100%-CPU
|
||||
// holder + write-lock convoy observed in pprof.
|
||||
func TestUpdateLocalCache_EvictionTerminates(t *testing.T) {
|
||||
c := newOrphanTestCache(0) // start with memory limit OFF to accumulate orphans
|
||||
const key = "same-key"
|
||||
|
||||
// Build 3 same-key list elements (3 orphans, items={key}).
|
||||
for range 3 {
|
||||
if err := c.updateLocalCache(key, "v", time.Hour); err != nil {
|
||||
t.Fatalf("seed updateLocalCache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Arm the trap: tiny memory limit so the next call enters the eviction loop.
|
||||
c.mu.Lock()
|
||||
c.config.MaxMemoryBytes = 1
|
||||
c.mu.Unlock()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
_ = c.updateLocalCache(key, "v", time.Hour) // triggers the eviction loop
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// fix present: loop made forward progress and returned
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("INFINITE LOOP: eviction loop did not terminate within 2s (orphan whose key was deleted is never removed from lruList)")
|
||||
}
|
||||
}
|
||||
@@ -146,6 +146,21 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce, codeChallenge stri
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Final scope string being sent to OIDC provider: %s", finalScopeString)
|
||||
}
|
||||
|
||||
// Apply operator-configured extra authorization parameters (e.g.
|
||||
// screen_hint, login_hint, ui_locales, prompt). These are added last but
|
||||
// can never override parameters the plugin itself manages (client_id,
|
||||
// state, nonce, redirect_uri, code_challenge, scope, response_type, ...):
|
||||
// a key already present in params is left untouched, so this cannot
|
||||
// weaken security-critical parameters.
|
||||
for key, value := range t.extraAuthParams {
|
||||
if params.Get(key) == "" {
|
||||
params.Set(key, value)
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Added extra auth param %s", key)
|
||||
} else {
|
||||
t.logger.Debugf("TraefikOidc.buildAuthURL: Skipped extra auth param %s (already set by plugin)", key)
|
||||
}
|
||||
}
|
||||
|
||||
// Read authURL with RLock
|
||||
t.metadataMu.RLock()
|
||||
authURL := t.authURL
|
||||
|
||||
@@ -554,3 +554,54 @@ func TestForceHTTPSIntegration(t *testing.T) {
|
||||
"should use https from X-Forwarded-Proto when forceHTTPS is false")
|
||||
})
|
||||
}
|
||||
|
||||
// TestBuildAuthURLExtraAuthParams verifies operator-configured extra
|
||||
// authorization parameters are appended to the authorization URL, and that
|
||||
// they can never override parameters the plugin itself manages.
|
||||
func TestBuildAuthURLExtraAuthParams(t *testing.T) {
|
||||
t.Run("extra params are added (e.g. screen_hint=signup)", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.extraAuthParams = map[string]string{
|
||||
"screen_hint": "signup",
|
||||
"ui_locales": "en",
|
||||
}
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "screen_hint=signup")
|
||||
assert.Contains(t, authURL, "ui_locales=en")
|
||||
})
|
||||
|
||||
t.Run("nil/empty extraAuthParams is a no-op", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
// extraAuthParams left nil
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
assert.Contains(t, authURL, "client_id=test-client")
|
||||
assert.NotContains(t, authURL, "screen_hint")
|
||||
})
|
||||
|
||||
t.Run("extra params CANNOT override plugin-managed params", func(t *testing.T) {
|
||||
middleware := createMinimalMiddleware()
|
||||
middleware.extraAuthParams = map[string]string{
|
||||
"client_id": "ATTACKER",
|
||||
"state": "ATTACKER",
|
||||
"redirect_uri": "https://evil.example.com",
|
||||
"response_type": "token",
|
||||
}
|
||||
|
||||
authURL := middleware.buildAuthURL(
|
||||
"https://app.com/callback", "state123", "nonce456", "",
|
||||
)
|
||||
|
||||
// Plugin-managed values must win; injected values must be absent.
|
||||
assert.Contains(t, authURL, "client_id=test-client")
|
||||
assert.NotContains(t, authURL, "ATTACKER")
|
||||
assert.NotContains(t, authURL, "evil.example.com")
|
||||
assert.Contains(t, authURL, "response_type=code")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -14,6 +14,19 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// metadataSnap returns the most recently published *MetadataSnapshot, or nil
|
||||
// if metadata has not yet been resolved. Single atomic.Value.Load — the hot
|
||||
// ServeHTTP path uses this instead of acquiring metadataMu.RLock, which under
|
||||
// Yaegi pays 1-5ms of interpreter-dispatch overhead per acquisition.
|
||||
func (t *TraefikOidc) metadataSnap() *MetadataSnapshot {
|
||||
v := t.metadataSnapshot.Load()
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
s, _ := v.(*MetadataSnapshot)
|
||||
return s
|
||||
}
|
||||
|
||||
// safeLogDebug provides nil-safe logging for debug messages
|
||||
func (t *TraefikOidc) safeLogDebug(msg string) {
|
||||
if t.logger != nil {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
.docs
|
||||
+36
@@ -0,0 +1,36 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 2m
|
||||
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- bodyclose
|
||||
- errcheck
|
||||
- errorlint
|
||||
- gocritic
|
||||
- gocyclo
|
||||
- govet
|
||||
- ineffassign
|
||||
- misspell
|
||||
- prealloc
|
||||
- revive
|
||||
- staticcheck
|
||||
- unconvert
|
||||
- unused
|
||||
settings:
|
||||
gocyclo:
|
||||
min-complexity: 12
|
||||
revive:
|
||||
rules:
|
||||
- name: var-naming
|
||||
- name: indent-error-flow
|
||||
- name: superfluous-else
|
||||
- name: unused-parameter
|
||||
- name: redefines-builtin-id
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
- goimports
|
||||
+42
@@ -0,0 +1,42 @@
|
||||
# Configuration for lukaszraczylo/semver-generator.
|
||||
# Reference: https://github.com/lukaszraczylo/semver-generator
|
||||
#
|
||||
# Word matching is fuzzy + case-insensitive. The keywords below mirror the
|
||||
# Conventional Commits prefixes used in this repo's git history. Same pattern
|
||||
# as github.com/lukaszraczylo/go-telegram/.semver.yaml.
|
||||
|
||||
version: 1
|
||||
|
||||
# Respect existing v* tags as the version baseline. semver-generator finds
|
||||
# the highest existing tag and bumps from there. With no tags yet, the first
|
||||
# release computes from the empty base.
|
||||
force:
|
||||
existing: true
|
||||
|
||||
# Skip merge commits and machine-generated traffic that would otherwise
|
||||
# spuriously bump the version.
|
||||
blacklist:
|
||||
- "Merge branch"
|
||||
- "Merge pull request"
|
||||
- "Merge remote-tracking branch"
|
||||
- "go mod tidy"
|
||||
|
||||
wording:
|
||||
patch:
|
||||
- "fix"
|
||||
- "chore"
|
||||
- "docs"
|
||||
- "test"
|
||||
- "style"
|
||||
- "refactor"
|
||||
- "build"
|
||||
- "ci"
|
||||
- "perf"
|
||||
minor:
|
||||
- "feat"
|
||||
major:
|
||||
# Match only the canonical Conventional Commits trailer. The bare word
|
||||
# "breaking" is too greedy under semver-generator's fuzzy match — it
|
||||
# triggers on substrings inside a commit body and wrongly produces a
|
||||
# major bump.
|
||||
- "BREAKING CHANGE"
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
# oss-telemetry
|
||||
|
||||
A tiny Go client that fires one anonymous "this binary started" ping at a
|
||||
central ingest endpoint. Designed to be embedded in your own open-source
|
||||
projects so you can see approximate adoption and version spread without
|
||||
collecting anything that could identify a user.
|
||||
|
||||
This is the **client library only**. The ingest endpoint, server-side
|
||||
deduplication, rate limiting, and metrics are out of scope here.
|
||||
|
||||
## What it sends
|
||||
|
||||
A single HTTP `POST` per call to `Send`:
|
||||
|
||||
```json
|
||||
{
|
||||
"project": "my-tool",
|
||||
"version": "1.2.3",
|
||||
"ts": 1747782200
|
||||
}
|
||||
```
|
||||
|
||||
No identifiers, no IP, no machine info, no user data. The server dedupes
|
||||
incoming requests; the client just fires and forgets.
|
||||
|
||||
## Failproof by design
|
||||
|
||||
- Never blocks the caller — work runs in a goroutine.
|
||||
- Never panics — the goroutine recovers internally.
|
||||
- Never returns errors — bad input and network failures are silently dropped.
|
||||
- Never retries, never persists state, never reads back.
|
||||
- 2-second hard timeout on every request.
|
||||
- Zero third-party dependencies (Go stdlib only).
|
||||
|
||||
The endpoint is hardcoded and not overridable from consuming code, by design.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
go get github.com/lukaszraczylo/oss-telemetry
|
||||
```
|
||||
|
||||
Requires Go 1.22+.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
telemetry "github.com/lukaszraczylo/oss-telemetry"
|
||||
)
|
||||
|
||||
const version = "1.2.3"
|
||||
|
||||
func main() {
|
||||
telemetry.Send("my-tool", version)
|
||||
|
||||
// ... your program runs ...
|
||||
|
||||
// Only needed for short-lived CLIs that may exit before the goroutine
|
||||
// finishes its POST. Long-running services do not need this.
|
||||
telemetry.Wait(2 * time.Second)
|
||||
}
|
||||
```
|
||||
|
||||
Call `Send` once at boot. Calling it more often just sends more pings; the
|
||||
server deduplicates.
|
||||
|
||||
## Disabling telemetry
|
||||
|
||||
If you ship a binary that imports this library, link your users to this
|
||||
section (`https://github.com/lukaszraczylo/oss-telemetry#disabling-telemetry`)
|
||||
so they can find the opt-out paths.
|
||||
|
||||
Any one of these turns it off:
|
||||
|
||||
| Mechanism | How |
|
||||
| ---------------------------------------- | ---------------------------------------------------------------- |
|
||||
| Universal opt-out | `DO_NOT_TRACK=1` |
|
||||
| Library-wide opt-out | `OSS_TELEMETRY_DISABLED=1` |
|
||||
| Project-specific opt-out | `<UPPER_PROJECT>_DISABLE_TELEMETRY=1` |
|
||||
| Programmatic (e.g. behind a `--no-telemetry` flag) | `telemetry.Disable()` before the first `Send` |
|
||||
|
||||
Project-specific env var derivation: uppercase the project name and replace
|
||||
`-` with `_`. For `my-tool` the variable is `MY_TOOL_DISABLE_TELEMETRY`.
|
||||
|
||||
Truthy values: `1`, `true`, `yes`, `on` (case-insensitive). Anything else is
|
||||
treated as "not set".
|
||||
|
||||
## Validation rules (silently dropped if violated)
|
||||
|
||||
- `project`: matches `^[a-z0-9-]+$`, length 1–64.
|
||||
- `version`: matches `^[A-Za-z0-9.+_-]+$`, length 1–32.
|
||||
|
||||
Bad input is a no-op — the library never logs, never errors, never crashes.
|
||||
|
||||
## API
|
||||
|
||||
```go
|
||||
// Fire a single ping in the background. Returns immediately.
|
||||
func Send(project, version string)
|
||||
|
||||
// Suppress all subsequent Send calls in this process. Idempotent.
|
||||
func Disable()
|
||||
|
||||
// Block until in-flight pings complete or timeout elapses, whichever first.
|
||||
// Useful for short-lived CLI processes.
|
||||
func Wait(timeout time.Duration)
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
go test -race ./...
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Pick one before publishing. None bundled.
|
||||
+367
@@ -0,0 +1,367 @@
|
||||
// Package telemetry sends anonymous usage pings for open-source Go projects.
|
||||
//
|
||||
// Wire format (POST application/json):
|
||||
//
|
||||
// {"project":"<name>","version":"<ver>","ts":<unix-seconds>}
|
||||
//
|
||||
// Design contract (failproof):
|
||||
// - never blocks the caller (work happens in a goroutine)
|
||||
// - never panics (background goroutine recovers internally)
|
||||
// - never returns errors (silently no-ops on bad input or network failure)
|
||||
// - never retries, never deduplicates, never persists state — the client
|
||||
// fires a single ping and forgets; the server is responsible for
|
||||
// deduplication, abuse protection, and aggregation
|
||||
//
|
||||
// Typical usage at program startup:
|
||||
//
|
||||
// telemetry.Send("my-tool", "1.2.3")
|
||||
//
|
||||
// For short-lived CLI processes that may exit before the goroutine finishes:
|
||||
//
|
||||
// telemetry.Send("my-tool", "1.2.3")
|
||||
// defer telemetry.Wait(2 * time.Second)
|
||||
//
|
||||
// Disablement (any one of these suppresses pings):
|
||||
// - environment variable DO_NOT_TRACK=1
|
||||
// - environment variable OSS_TELEMETRY_DISABLED=1
|
||||
// - environment variable <UPPER_PROJECT>_DISABLE_TELEMETRY=1
|
||||
// (project name uppercased, dashes replaced with underscores)
|
||||
// - calling telemetry.Disable() at runtime
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEndpoint = "https://oss.raczylo.com/v1/ping"
|
||||
httpTimeout = 2 * time.Second
|
||||
maxProjectLen = 64
|
||||
maxVersionLen = 32
|
||||
)
|
||||
|
||||
// Yaegi note: this package is consumed by the traefikoidc Traefik plugin, which
|
||||
// Traefik interprets with Yaegi (it vendors and interprets dependency source).
|
||||
// It therefore avoids generic stdlib types (atomic.Pointer[T], atomic.Bool) and
|
||||
// range-over-int (Go 1.22), which some Traefik/Yaegi runtimes cannot interpret.
|
||||
// Endpoint mutation uses a mutex-guarded string; the disabled flag uses the
|
||||
// function-based sync/atomic int32 API (atomic.LoadInt32/StoreInt32).
|
||||
var (
|
||||
// endpointURL holds the ingest URL. Production code never mutates it; the
|
||||
// setter exists only so the test suite can retarget it at httptest servers
|
||||
// while goroutines started by Send are still in flight.
|
||||
endpointMu sync.RWMutex
|
||||
endpointURL = defaultEndpoint
|
||||
|
||||
disabled int32 // 0 = enabled, 1 = disabled; accessed via sync/atomic only
|
||||
inflight sync.WaitGroup
|
||||
|
||||
client = &http.Client{Timeout: httpTimeout}
|
||||
)
|
||||
|
||||
func currentEndpoint() string {
|
||||
endpointMu.RLock()
|
||||
defer endpointMu.RUnlock()
|
||||
return endpointURL
|
||||
}
|
||||
|
||||
func setEndpointURL(u string) {
|
||||
endpointMu.Lock()
|
||||
endpointURL = u
|
||||
endpointMu.Unlock()
|
||||
}
|
||||
|
||||
// Send fires a single anonymous telemetry ping in the background and returns
|
||||
// immediately. It never blocks, never panics, and never reports errors.
|
||||
// Invalid inputs, disabled state, and network failures are silently dropped.
|
||||
//
|
||||
// Version strings are validated against a SemVer-ish shape that mirrors the
|
||||
// receiver. An optional leading "v" or "V" is accepted and stripped before
|
||||
// transmission so that callers can pass either "v1.2.3" or "1.2.3"; the
|
||||
// wire form is always the unprefixed canonical version.
|
||||
//
|
||||
// Call once at program startup. Calling repeatedly will send repeated pings;
|
||||
// the server is responsible for deduplication.
|
||||
func Send(project, version string) {
|
||||
if atomic.LoadInt32(&disabled) != 0 {
|
||||
return
|
||||
}
|
||||
if isDisabledByEnv(project) {
|
||||
return
|
||||
}
|
||||
if !validProject(project) || !validVersion(version) {
|
||||
return
|
||||
}
|
||||
canonical := normalizeVersion(version)
|
||||
|
||||
inflight.Add(1)
|
||||
go func() {
|
||||
defer inflight.Done()
|
||||
defer func() { _ = recover() }()
|
||||
dispatch(project, canonical)
|
||||
}()
|
||||
}
|
||||
|
||||
// SendForModule is the recommended call form for Go libraries: it resolves
|
||||
// the version automatically from Go's build info for the given module path
|
||||
// so consumers do not need to maintain a hand-bumped version constant in
|
||||
// source. Behaviour and contract are otherwise identical to [Send].
|
||||
//
|
||||
// Resolution order:
|
||||
//
|
||||
// 1. debug.ReadBuildInfo Deps entry for modulePath (authoritative when the
|
||||
// library is consumed via go.mod);
|
||||
// 2. debug.ReadBuildInfo Main when the library is itself the main module
|
||||
// (e.g. running its own tests or examples);
|
||||
// 3. fallback parameter, used only when build info is unavailable or
|
||||
// unhelpful (replace directives, detached `go run`, ldflag override).
|
||||
//
|
||||
// Any leading "v" reported by build info is stripped to match the canonical
|
||||
// wire form. Empty / "(devel)" build versions are skipped in favour of the
|
||||
// next resolution source. Typical usage:
|
||||
//
|
||||
// telemetry.SendForModule("my-tool", "github.com/me/my-tool", "0.0.0-dev")
|
||||
func SendForModule(project, modulePath, fallback string) {
|
||||
Send(project, ResolveModuleVersion(modulePath, fallback))
|
||||
}
|
||||
|
||||
// ResolveModuleVersion implements the version resolution used by
|
||||
// SendForModule. Exposed for callers that need to format the resolved
|
||||
// version (e.g. logging) without firing a ping.
|
||||
func ResolveModuleVersion(modulePath, fallback string) string {
|
||||
if info, ok := debug.ReadBuildInfo(); ok {
|
||||
for _, d := range info.Deps {
|
||||
if d != nil && d.Path == modulePath && isUsableBuildVersion(d.Version) {
|
||||
return strings.TrimPrefix(d.Version, "v")
|
||||
}
|
||||
}
|
||||
if info.Main.Path == modulePath && isUsableBuildVersion(info.Main.Version) {
|
||||
return strings.TrimPrefix(info.Main.Version, "v")
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func isUsableBuildVersion(v string) bool {
|
||||
return v != "" && v != "(devel)"
|
||||
}
|
||||
|
||||
// Disable suppresses all subsequent Send calls in this process.
|
||||
// Idempotent and safe to call from any goroutine.
|
||||
func Disable() {
|
||||
atomic.StoreInt32(&disabled, 1)
|
||||
}
|
||||
|
||||
// Wait blocks until all in-flight pings have completed, or until timeout
|
||||
// elapses — whichever comes first. Useful for short-lived CLI processes
|
||||
// that may otherwise exit before the background goroutine finishes its POST.
|
||||
//
|
||||
// A non-positive timeout returns immediately.
|
||||
func Wait(timeout time.Duration) {
|
||||
if timeout <= 0 {
|
||||
return
|
||||
}
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
inflight.Wait()
|
||||
close(done)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(timeout):
|
||||
}
|
||||
}
|
||||
|
||||
func dispatch(project, version string) {
|
||||
body := buildPayload(project, version, time.Now().Unix())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), httpTimeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, currentEndpoint(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
// buildPayload writes the JSON body without encoding/json. The validators
|
||||
// restrict project and version to characters that never require JSON
|
||||
// escaping, so direct concatenation is safe.
|
||||
func buildPayload(project, version string, ts int64) []byte {
|
||||
// Wrapper text plus 20 chars for a signed int64.
|
||||
const overhead = len(`{"project":"","version":"","ts":}`) + 20
|
||||
buf := make([]byte, 0, len(project)+len(version)+overhead)
|
||||
buf = append(buf, `{"project":"`...)
|
||||
buf = append(buf, project...)
|
||||
buf = append(buf, `","version":"`...)
|
||||
buf = append(buf, version...)
|
||||
buf = append(buf, `","ts":`...)
|
||||
buf = strconv.AppendInt(buf, ts, 10)
|
||||
buf = append(buf, '}')
|
||||
return buf
|
||||
}
|
||||
|
||||
func validProject(p string) bool {
|
||||
n := len(p)
|
||||
if n == 0 || n > maxProjectLen {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
c := p[i]
|
||||
switch {
|
||||
case c >= 'a' && c <= 'z',
|
||||
c >= '0' && c <= '9',
|
||||
c == '-':
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// validVersion accepts SemVer-ish version strings with an optional leading
|
||||
// "v"/"V" prefix. Acceptable shape (after stripping the leading v):
|
||||
//
|
||||
// MAJOR[.MINOR[.PATCH]] ("-"prerelease)? ("+"build)?
|
||||
//
|
||||
// where MAJOR/MINOR/PATCH are ASCII digit sequences and the prerelease/build
|
||||
// payloads are non-empty runs of [0-9A-Za-z.-]. This intentionally mirrors
|
||||
// the receiver's version regex so junk like "dev" or "git-2026-05-22" never
|
||||
// leaves the client (where it would only be rejected with HTTP 400 anyway).
|
||||
func validVersion(v string) bool {
|
||||
n := len(v)
|
||||
if n == 0 || n > maxVersionLen {
|
||||
return false
|
||||
}
|
||||
if v[0] == 'v' || v[0] == 'V' {
|
||||
v = v[1:]
|
||||
}
|
||||
if len(v) == 0 {
|
||||
return false
|
||||
}
|
||||
return checkSemverShape(v)
|
||||
}
|
||||
|
||||
// normalizeVersion strips an optional leading "v"/"V" so the on-the-wire
|
||||
// version matches the form stored server-side by the version refresher cron
|
||||
// (which also strips the leading v from release tags). Callers may pass
|
||||
// either "v1.2.3" or "1.2.3" — only the unprefixed form is transmitted.
|
||||
func normalizeVersion(v string) string {
|
||||
if len(v) > 0 && (v[0] == 'v' || v[0] == 'V') {
|
||||
return v[1:]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func checkSemverShape(s string) bool {
|
||||
i := 0
|
||||
if !readDigitRun(s, &i) {
|
||||
return false
|
||||
}
|
||||
for groups := 0; groups < 2 && i < len(s) && s[i] == '.'; groups++ {
|
||||
i++
|
||||
if !readDigitRun(s, &i) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if i < len(s) && s[i] == '-' {
|
||||
i++
|
||||
if !readIdentRun(s, &i, '+') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if i < len(s) && s[i] == '+' {
|
||||
i++
|
||||
if !readIdentRun(s, &i, 0) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return i == len(s)
|
||||
}
|
||||
|
||||
func readDigitRun(s string, i *int) bool {
|
||||
start := *i
|
||||
for *i < len(s) && s[*i] >= '0' && s[*i] <= '9' {
|
||||
*i++
|
||||
}
|
||||
return *i > start
|
||||
}
|
||||
|
||||
// readIdentRun consumes [0-9A-Za-z.-] until end-of-string or until `stop`
|
||||
// is hit (stop=0 disables the early-stop check). Returns false if no
|
||||
// characters were consumed (i.e. empty payload).
|
||||
func readIdentRun(s string, i *int, stop byte) bool {
|
||||
start := *i
|
||||
for *i < len(s) {
|
||||
c := s[*i]
|
||||
if stop != 0 && c == stop {
|
||||
break
|
||||
}
|
||||
valid := (c >= '0' && c <= '9') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= 'a' && c <= 'z') ||
|
||||
c == '.' || c == '-'
|
||||
if !valid {
|
||||
return false
|
||||
}
|
||||
*i++
|
||||
}
|
||||
return *i > start
|
||||
}
|
||||
|
||||
func isDisabledByEnv(project string) bool {
|
||||
if truthy(os.Getenv("DO_NOT_TRACK")) {
|
||||
return true
|
||||
}
|
||||
if truthy(os.Getenv("OSS_TELEMETRY_DISABLED")) {
|
||||
return true
|
||||
}
|
||||
if project == "" {
|
||||
return false
|
||||
}
|
||||
key := projectEnvKey(project)
|
||||
return truthy(os.Getenv(key))
|
||||
}
|
||||
|
||||
// projectEnvKey returns "<UPPER_PROJECT>_DISABLE_TELEMETRY" using a single
|
||||
// allocation rather than chained strings.ToUpper(strings.ReplaceAll(...)).
|
||||
func projectEnvKey(project string) string {
|
||||
const suffix = "_DISABLE_TELEMETRY"
|
||||
buf := make([]byte, 0, len(project)+len(suffix))
|
||||
for i := 0; i < len(project); i++ {
|
||||
c := project[i]
|
||||
switch {
|
||||
case c == '-':
|
||||
c = '_'
|
||||
case c >= 'a' && c <= 'z':
|
||||
c -= 'a' - 'A'
|
||||
}
|
||||
buf = append(buf, c)
|
||||
}
|
||||
buf = append(buf, suffix...)
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
func truthy(s string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||
case "1", "true", "yes", "on":
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
Vendored
+3
@@ -24,6 +24,9 @@ github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# github.com/lukaszraczylo/oss-telemetry v0.2.3
|
||||
## explicit; go 1.22
|
||||
github.com/lukaszraczylo/oss-telemetry
|
||||
# github.com/pmezard/go-difflib v1.0.0
|
||||
## explicit
|
||||
github.com/pmezard/go-difflib/difflib
|
||||
|
||||
+17
@@ -0,0 +1,17 @@
|
||||
package traefikoidc
|
||||
|
||||
// devPluginVersion is the placeholder carried by source-tree / local / test
|
||||
// builds. Telemetry is suppressed while the plugin still reports this sentinel,
|
||||
// so only stamped release builds emit a "plugin loaded" ping.
|
||||
const devPluginVersion = "0.0.0-dev"
|
||||
|
||||
// traefikoidcPluginVersion is the released version of this plugin. It is stamped
|
||||
// at release time by ./workflow-prepare.sh (invoked by the shared go-release
|
||||
// workflow before GoReleaser builds and tags), which rewrites the string below
|
||||
// to the computed semver.
|
||||
//
|
||||
// Traefik runs this plugin under Yaegi, where the version cannot be resolved
|
||||
// from build info at runtime (debug.ReadBuildInfo sees Traefik's build graph,
|
||||
// not the interpreted plugin). This build-stamped constant is therefore the
|
||||
// single source of truth for the version reported by anonymous usage telemetry.
|
||||
const traefikoidcPluginVersion = "1.0.25"
|
||||
Executable
+67
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# workflow-prepare.sh — stamp the release version into version.go at build time.
|
||||
#
|
||||
# The shared go-release workflow (lukaszraczylo/shared-actions go-release.yaml)
|
||||
# runs this script, if present, from the repository root BEFORE GoReleaser
|
||||
# builds and tags. Traefik runs this plugin under Yaegi, where the version
|
||||
# cannot be resolved from build info at runtime, so the released semver must be
|
||||
# baked into source here.
|
||||
#
|
||||
# Version source — first non-empty wins:
|
||||
# $VERSION $VERSION_TAG $SEMVER $NEW_VERSION $RELEASE_VERSION
|
||||
# A leading "v"/"V" is stripped.
|
||||
#
|
||||
# NOTE: go-release.yaml @main does not yet pass the computed version into this
|
||||
# step's environment. Add it to the "Run workflow prepare script" step, e.g.:
|
||||
# env:
|
||||
# VERSION: ${{ needs.version.outputs.version }} # bare, no leading v
|
||||
#
|
||||
# The shared workflow runs this script in its test, version AND release jobs,
|
||||
# but only the release job has a computed version. So a missing version is a
|
||||
# no-op (leave the dev sentinel) — NOT a hard failure, otherwise the test/version
|
||||
# jobs would break. A malformed version that IS provided is a hard error. Wire
|
||||
# the env only on the release job's prepare step (see header note above).
|
||||
set -euo pipefail
|
||||
|
||||
FILE="version.go"
|
||||
CONST="traefikoidcPluginVersion"
|
||||
|
||||
VER="${VERSION:-${VERSION_TAG:-${SEMVER:-${NEW_VERSION:-${RELEASE_VERSION:-}}}}}"
|
||||
VER="${VER#v}"
|
||||
VER="${VER#V}"
|
||||
|
||||
if [ -z "$VER" ]; then
|
||||
if [ "${GITHUB_ACTIONS:-}" = "true" ]; then
|
||||
echo "workflow-prepare: WARNING no version provided; leaving ${FILE} at the dev placeholder. If this is the release build, set 'env: VERSION: \${{ needs.version.outputs.version }}' on the release job's prepare step — otherwise the release ships 0.0.0-dev and emits no telemetry." >&2
|
||||
else
|
||||
echo "workflow-prepare: no version provided; leaving dev placeholder in ${FILE} (local build)"
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Accept MAJOR[.MINOR[.PATCH]] with optional -prerelease / +build (semver-ish,
|
||||
# matching the oss-telemetry receiver's validator).
|
||||
if ! printf '%s' "$VER" | grep -Eq '^[0-9]+(\.[0-9]+){0,2}(-[0-9A-Za-z.-]+)?(\+[0-9A-Za-z.-]+)?$'; then
|
||||
echo "workflow-prepare: ERROR version '${VER}' is not semver-shaped" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$FILE" ]; then
|
||||
echo "workflow-prepare: ERROR ${FILE} not found (run from repository root)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Rewrite only the value of ${CONST}, anchored on the constant name so the
|
||||
# sibling devPluginVersion sentinel is left untouched.
|
||||
tmp="$(mktemp)"
|
||||
sed -E "s/(${CONST}[[:space:]]*=[[:space:]]*\")[^\"]*(\")/\1${VER}\2/" "$FILE" > "$tmp"
|
||||
mv "$tmp" "$FILE"
|
||||
|
||||
if ! grep -Eq "${CONST}[[:space:]]*=[[:space:]]*\"${VER}\"" "$FILE"; then
|
||||
echo "workflow-prepare: ERROR failed to stamp version into ${FILE}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
command -v gofmt >/dev/null 2>&1 && gofmt -w "$FILE"
|
||||
echo "workflow-prepare: stamped ${CONST} = \"${VER}\" in ${FILE}"
|
||||
Reference in New Issue
Block a user