From 46c2f98a15066cd24078287dcf23e335f539b028 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sat, 5 Apr 2025 11:15:15 +0100 Subject: [PATCH] Optimize the code, find edge cases, polish the bugs out. --- jwt.go | 27 +++++++--- main.go | 148 ++++++++++++++++++++++++++++++++++++++++++---------- session.go | 5 ++ settings.go | 21 ++++++-- 4 files changed, 160 insertions(+), 41 deletions(-) diff --git a/jwt.go b/jwt.go index 43e81c4..febab3d 100644 --- a/jwt.go +++ b/jwt.go @@ -29,8 +29,16 @@ func cleanupReplayCache() { } } -// ClockSkewTolerance is configurable to adjust time-based validations. -var ClockSkewTolerance = 2 * time.Minute +// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'. +// Allows for more leniency with expiration checks. +var ClockSkewToleranceFuture = 2 * time.Minute + +// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'. +// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future. +var ( + ClockSkewTolerancePast = 10 * time.Second + ClockSkewTolerance = 2 * time.Minute +) // JWT represents a JSON Web Token as defined in RFC 7519. type JWT struct { @@ -202,13 +210,16 @@ func verifyTimeConstraint(unixTime float64, claimName string, future bool) error claimTime := time.Unix(int64(unixTime), 0) now := time.Now().Truncate(time.Second) - // For expiration (future=true), we add skew to now (making now later) - // For iat/nbf (future=false), we subtract skew from now (making now earlier) - skewDirection := 1 - if !future { - skewDirection = -1 + var skewedNow time.Time + if future { + // For expiration (future=true), add skew to now (making now later) + // Use the larger tolerance for future checks (exp) + skewedNow = now.Add(ClockSkewToleranceFuture) + } else { + // For iat/nbf (future=false), subtract skew from now (making now earlier) + // Use the smaller, specific tolerance for past checks (iat, nbf) + skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance } - skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance) if claimTime.Equal(now) { return nil diff --git a/main.go b/main.go index d97f366..92c9536 100644 --- a/main.go +++ b/main.go @@ -110,6 +110,7 @@ type TraefikOidc struct { postLogoutRedirectURI string sessionManager *SessionManager tokenExchanger TokenExchanger // Added field for mocking + refreshGracePeriod time.Duration // Configurable grace period for proactive refresh } // ProviderMetadata holds OIDC provider metadata @@ -356,6 +357,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups), initComplete: make(chan struct{}), logger: logger, + refreshGracePeriod: func() time.Duration { // Set refresh grace period from config or default + if config.RefreshGracePeriodSeconds > 0 { + return time.Duration(config.RefreshGracePeriodSeconds) * time.Second + } + return 60 * time.Second // Default to 60 seconds + }(), } t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger) @@ -595,9 +602,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if needsRefresh { refreshed := t.refreshToken(rw, req, session) if !refreshed { - // Original logic: Always handle failed refresh as an expired token - t.logger.Debug("Token refresh failed, handling as expired token") - t.handleExpiredToken(rw, req, session, redirectURL) + t.logger.Infof("Token refresh failed") // Changed from Warn to Infof + // Check if the client prefers JSON (likely an API call) + acceptHeader := req.Header.Get("Accept") + if strings.Contains(acceptHeader, "application/json") { + t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure") + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"}) + } else { + // Client likely a browser, initiate full re-authentication + t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)") + t.handleExpiredToken(rw, req, session, redirectURL) + } return // Stop processing } } @@ -612,7 +629,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if !t.isAllowedDomain(email) { t.logger.Infof("User with email %s is not from an allowed domain", email) - http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) + errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath) + t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } @@ -639,7 +657,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if !allowed { t.logger.Infof("User with email %s does not have any allowed roles or groups", email) - http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden) + errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath) + t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden) return } } @@ -714,8 +733,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, // Check for errors in the callback if req.URL.Query().Get("error") != "" { errorDescription := req.URL.Query().Get("error_description") - t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription) - http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest) + if errorDescription == "" { + errorDescription = req.URL.Query().Get("error") // Use error code if description is empty + } + t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription) + t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest) return } @@ -723,20 +745,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, state := req.URL.Query().Get("state") if state == "" { t.logger.Error("No state in callback") - http.Error(rw, "State parameter missing in callback", http.StatusBadRequest) + t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest) return } csrfToken := session.GetCSRF() if csrfToken == "" { t.logger.Error("CSRF token missing in session") - http.Error(rw, "CSRF token missing", http.StatusBadRequest) + t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest) return } if state != csrfToken { t.logger.Error("State parameter does not match CSRF token in session") - http.Error(rw, "Invalid state parameter", http.StatusBadRequest) + t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest) return } @@ -744,7 +766,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, code := req.URL.Query().Get("code") if code == "" { t.logger.Error("No code in callback") - http.Error(rw, "No code in callback", http.StatusBadRequest) + t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest) return } @@ -754,7 +776,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier) if err != nil { t.logger.Errorf("Failed to exchange code for token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError) return } @@ -762,14 +784,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, // Use the exported VerifyToken method now that handleCallback is in main.go if err := t.VerifyToken(tokenResponse.IDToken); err != nil { t.logger.Errorf("Failed to verify id_token: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError) return } claims, err := t.extractClaimsFunc(tokenResponse.IDToken) if err != nil { t.logger.Errorf("Failed to extract claims: %v", err) - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError) return } @@ -777,20 +799,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, nonceClaim, ok := claims["nonce"].(string) if !ok || nonceClaim == "" { t.logger.Error("Nonce claim missing in id_token") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError) return } sessionNonce := session.GetNonce() if sessionNonce == "" { t.logger.Error("Nonce not found in session") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError) return } if nonceClaim != sessionNonce { t.logger.Error("Nonce claim does not match session nonce") - http.Error(rw, "Authentication failed", http.StatusInternalServerError) + t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError) return } @@ -799,7 +821,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, email, _ := claims["email"].(string) if email == "" || !t.isAllowedDomain(email) { t.logger.Errorf("Invalid or disallowed email: %s", email) - http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) + t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden) return } @@ -905,17 +927,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo return false, false, true } - now := time.Now().Unix() expTime := int64(expClaim) // Expiration check is now handled within VerifyJWTSignatureAndClaims logic above // We only get here if the token is valid and not expired // Check if token is nearing expiration (needs refresh proactively) - // Define a grace period, e.g., 5 minutes before actual expiry - refreshGracePeriod := int64(5 * 60) - if expTime-now < refreshGracePeriod { - t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod) + // Check if token is nearing expiration using the configured grace period + if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) { + // Recalculate remaining seconds for logging clarity if needed, using the configured duration + remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds()) + t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod) return true, true, false // Needs proactive refresh } @@ -1095,18 +1117,26 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error { return nil } -// refreshToken refreshes the user's token +// refreshToken refreshes the user's token, protected by a mutex within the session. func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool { - t.logger.Debug("Refreshing token") - refreshToken := session.GetRefreshToken() + // Lock the mutex specific to this session instance before attempting refresh + session.refreshMutex.Lock() + defer session.refreshMutex.Unlock() + + t.logger.Debug("Attempting to refresh token (mutex acquired)") + refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock if refreshToken == "" { - t.logger.Debug("No refresh token found in session") + t.logger.Debug("No refresh token found in session (inside lock)") return false } newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken) if err != nil { - t.logger.Errorf("Failed to refresh token: %v", err) + // Log the error, potentially clear the invalid refresh token? + t.logger.Errorf("Failed to refresh token using refresh token: %v", err) + // Consider clearing the refresh token from the session here if the error indicates it's invalid + // session.SetRefreshToken("") // Example: Clear potentially invalid token + // session.Save(req, rw) // Need to handle potential save error return false } @@ -1216,3 +1246,65 @@ func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenRe // Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc return t.getNewTokenWithRefreshToken(refreshToken) } + +// sendErrorResponse sends an error response, adapting to the client's Accept header. +func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) { + acceptHeader := req.Header.Get("Accept") + + // Check if the client prefers JSON + if strings.Contains(acceptHeader, "application/json") { + t.logger.Debugf("Sending JSON error response (code %d): %s", code, message) + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(code) + // Use a simple error structure + json.NewEncoder(rw).Encode(map[string]interface{}{ + "error": http.StatusText(code), + "error_description": message, + "status_code": code, + }) + return + } + + // Default to HTML response for browsers + t.logger.Debugf("Sending HTML error response (code %d): %s", code, message) + + // Determine the return URL (mostly relevant for HTML) + returnURL := "/" // Default to root + session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL + if err == nil { + incomingPath := session.GetIncomingPath() + // Use incoming path if it's valid and not one of the special OIDC paths + if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath { + returnURL = incomingPath + } + } else { + t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err) + } + + // Basic HTML structure for the error page + htmlBody := fmt.Sprintf(` + + + + Authentication Error + + + +
+

Authentication Error

+

%s

+

Return to application

+
+ +`, message, returnURL) + + rw.Header().Set("Content-Type", "text/html; charset=utf-8") + rw.WriteHeader(code) + _, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent +} diff --git a/session.go b/session.go index 7e96428..14ff0e8 100644 --- a/session.go +++ b/session.go @@ -128,10 +128,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (* // Initialize session pool. sm.sessionPool.New = func() interface{} { + // Initialize SessionData with necessary fields and the mutex. return &SessionData{ manager: sm, accessTokenChunks: make(map[int]*sessions.Session), refreshTokenChunks: make(map[int]*sessions.Session), + refreshMutex: sync.Mutex{}, // Initialize the mutex } } @@ -251,6 +253,9 @@ type SessionData struct { // refreshTokenChunks stores additional chunks of the refresh token // when it exceeds the maximum cookie size. refreshTokenChunks map[int]*sessions.Session + + // refreshMutex protects refresh token operations within this session instance. + refreshMutex sync.Mutex } // Save persists all session data to cookies in the HTTP response. diff --git a/settings.go b/settings.go index d16c8ee..2965767 100644 --- a/settings.go +++ b/settings.go @@ -84,6 +84,11 @@ type Config struct { // HTTPClient allows customizing the HTTP client used for OIDC operations (optional) HTTPClient *http.Client + + // RefreshGracePeriodSeconds defines how many seconds before a token expires + // the plugin should attempt to refresh it proactively (optional) + // Default: 60 + RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"` } const ( @@ -111,11 +116,12 @@ const ( // - EnablePKCE: false (PKCE is opt-in) func CreateConfig() *Config { c := &Config{ - Scopes: []string{"openid", "profile", "email"}, - LogLevel: DefaultLogLevel, - RateLimit: DefaultRateLimit, - ForceHTTPS: true, // Secure by default - EnablePKCE: false, // PKCE is opt-in + Scopes: []string{"openid", "profile", "email"}, + LogLevel: DefaultLogLevel, + RateLimit: DefaultRateLimit, + ForceHTTPS: true, // Secure by default + EnablePKCE: false, // PKCE is opt-in + RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds } return c @@ -197,6 +203,11 @@ func (c *Config) Validate() error { return fmt.Errorf("rateLimit must be at least %d", MinRateLimit) } + // Validate refresh grace period + if c.RefreshGracePeriodSeconds < 0 { + return fmt.Errorf("refreshGracePeriodSeconds cannot be negative") + } + return nil }