diff --git a/middleware.go b/middleware.go index ec37a3a..3ef3530 100644 --- a/middleware.go +++ b/middleware.go @@ -333,7 +333,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } t.logger.Debugf("Callback URL did not match (request_path=%q != configured=%q), continuing auth flow", req.URL.Path, t.redirURLPath) - authenticated, needsRefresh, expired := t.isUserAuthenticated(session) + // Token validation reads session via the captured snapshot — saves ~21 + // sd.sessionMutex.RLock acquisitions (Yaegi-dispatched, ~1-5ms each) + // across the validation path. + authenticated, needsRefresh, expired := t.isUserAuthenticatedRS(rs) if expired { t.logger.Debug("Session token is definitively expired or invalid, initiating re-auth") diff --git a/token_validation_rs.go b/token_validation_rs.go new file mode 100644 index 0000000..ce291cf --- /dev/null +++ b/token_validation_rs.go @@ -0,0 +1,279 @@ +// Package traefikoidc provides OIDC authentication middleware for Traefik. +// This file contains requestState-aware variants of the token validation +// functions. They read session field values from the captured snapshot in +// *requestState instead of calling session.GetX(), eliminating ~21 RLock +// acquisitions on sd.sessionMutex per request through the validation path +// (validateStandardTokens reads 17, validateAzureTokens reads 10, +// validateTokenExpiry reads 4 — and many are the SAME field). Under Yaegi +// each RLock costs ~1-5ms of interpreter dispatch. +// +// The non-RS variants are retained for paths that don't have a captured +// snapshot (tests that drive the validators directly, the Azure/Google path +// when reached without rs threading, etc). +package traefikoidc + +import ( + "encoding/base64" + "encoding/json" + "strings" + "time" +) + +// isUserAuthenticatedRS is the requestState-aware variant of +// isUserAuthenticated. Dispatches to the right per-provider validator based +// on the configured provider, all of which read from rs instead of session. +func (t *TraefikOidc) isUserAuthenticatedRS(rs *requestState) (bool, bool, bool) { + if t.isAzureProvider() { + return t.validateAzureTokensRS(rs) + } else if t.isGoogleProvider() { + return t.validateStandardTokensRS(rs) + } + return t.validateStandardTokensRS(rs) +} + +// validateTokenExpiryRS is the requestState-aware variant of validateTokenExpiry. +// Reads rs.refreshToken instead of session.GetRefreshToken() (4 RLocks avoided). +func (t *TraefikOidc) validateTokenExpiryRS(rs *requestState, token string) (bool, bool, bool) { + cachedClaims, found := t.tokenCache.Get(token) + if !found { + t.logger.Debug("Claims not found in cache after successful token verification") + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + + expClaim, ok := cachedClaims["exp"].(float64) + if !ok { + t.logger.Error("Failed to get expiration time ('exp' claim) from verified token") + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + + expTimeObj := time.Unix(int64(expClaim), 0) + nowObj := time.Now() + + if expTimeObj.Before(nowObj) { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + + refreshThreshold := nowObj.Add(t.refreshGracePeriod) + if expTimeObj.Before(refreshThreshold) { + if rs.refreshToken != "" { + return true, true, false + } + return true, false, false + } + + return true, false, false +} + +// validateStandardTokensRS is the requestState-aware variant of +// validateStandardTokens. Replaces all session.GetX() calls (17 of them in +// the non-RS variant, dominated by GetRefreshToken called 11 times) with +// rs field reads. Same control flow. +// +//nolint:gocognit,gocyclo // Mirrors validateStandardTokens complexity by design. +func (t *TraefikOidc) validateStandardTokensRS(rs *requestState) (bool, bool, bool) { + if !rs.authenticated { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, false + } + + if rs.accessToken == "" { + if rs.refreshToken != "" { + // ID-token grace-period check (only when accessToken is absent). + if rs.idToken != "" { + parts := strings.Split(rs.idToken, ".") + if len(parts) == 3 { + if claimsData, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil { + var claims map[string]interface{} + if err := json.Unmarshal(claimsData, &claims); err == nil { + if expClaim, ok := claims["exp"].(float64); ok { + expTime := time.Unix(int64(expClaim), 0) + if time.Now().After(expTime) { + expiredDuration := time.Since(expTime) + if expiredDuration > t.refreshGracePeriod { + return false, false, true + } + } + } + } + } + } + } + return false, true, false + } + return false, false, true + } + + dotCount := strings.Count(rs.accessToken, ".") + isOpaqueToken := dotCount != 2 + + if isOpaqueToken { + if t.allowOpaqueTokens { + if err := t.validateOpaqueToken(rs.accessToken); err != nil { + errMsg := err.Error() + isTokenInvalid := strings.Contains(errMsg, "token is not active") || + strings.Contains(errMsg, "revoked") || + strings.Contains(errMsg, "token has expired") + if isTokenInvalid { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + if t.requireTokenIntrospection { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + // Transient introspection error: fall through to ID-token validation. + } else { + // Introspection succeeded. + if rs.idToken != "" { + return t.validateTokenExpiryRS(rs, rs.idToken) + } + return true, false, false + } + } + + // Fall back to ID-token validation when opaque + no successful introspection. + if rs.idToken == "" { + if rs.refreshToken != "" { + return false, true, false + } + return true, false, false + } + if err := t.verifyToken(rs.idToken); err != nil { + if strings.Contains(err.Error(), "token has expired") { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + return t.validateTokenExpiryRS(rs, rs.idToken) + } + + // JWT access token present. + accessTokenValid := false + if err := t.verifyToken(rs.accessToken); err != nil { + errMsg := err.Error() + if strings.Contains(errMsg, "invalid audience") || strings.Contains(errMsg, "audience") { + if t.strictAudienceValidation { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + // Fall through to ID-token validation. + } + } else { + accessTokenValid = true + } + + if rs.idToken == "" { + if accessTokenValid { + return t.validateTokenExpiryRS(rs, rs.accessToken) + } + if rs.refreshToken != "" { + return true, true, false + } + return true, false, false + } + + if err := t.verifyToken(rs.idToken); err != nil { + if strings.Contains(err.Error(), "token has expired") { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + + if accessTokenValid { + return t.validateTokenExpiryRS(rs, rs.accessToken) + } + return t.validateTokenExpiryRS(rs, rs.idToken) +} + +// validateAzureTokensRS is the requestState-aware variant of validateAzureTokens. +// Eliminates 10 session.GetX() RLocks per Azure-path request. +func (t *TraefikOidc) validateAzureTokensRS(rs *requestState) (bool, bool, bool) { + if !rs.authenticated { + if rs.refreshToken != "" { + return false, true, false + } + return false, true, false + } + + if rs.accessToken != "" { + if strings.Count(rs.accessToken, ".") == 2 { + if t.isUnverifiableAzureAccessToken(rs.accessToken) { + if rs.idToken != "" { + if err := t.verifyToken(rs.idToken); err != nil { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + return t.validateTokenExpiryRS(rs, rs.idToken) + } + return true, false, false + } + if err := t.verifyToken(rs.accessToken); err != nil { + if rs.idToken != "" { + if err := t.verifyToken(rs.idToken); err != nil { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + return t.validateTokenExpiryRS(rs, rs.idToken) + } + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + return t.validateTokenExpiryRS(rs, rs.accessToken) + } + // Opaque access token. + if rs.idToken != "" { + return t.validateTokenExpiryRS(rs, rs.idToken) + } + return true, false, false + } + + if rs.idToken != "" { + if err := t.verifyToken(rs.idToken); err != nil { + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true + } + return t.validateTokenExpiryRS(rs, rs.idToken) + } + + if rs.refreshToken != "" { + return false, true, false + } + return false, false, true +}