Optimize the code, find edge cases, polish the bugs out.

This commit is contained in:
2025-04-05 11:15:15 +01:00
parent 9e8634bfc0
commit 46c2f98a15
4 changed files with 160 additions and 41 deletions
+19 -8
View File
@@ -29,8 +29,16 @@ func cleanupReplayCache() {
}
}
// ClockSkewTolerance is configurable to adjust time-based validations.
var ClockSkewTolerance = 2 * time.Minute
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
// Allows for more leniency with expiration checks.
var ClockSkewToleranceFuture = 2 * time.Minute
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
var (
ClockSkewTolerancePast = 10 * time.Second
ClockSkewTolerance = 2 * time.Minute
)
// JWT represents a JSON Web Token as defined in RFC 7519.
type JWT struct {
@@ -202,13 +210,16 @@ func verifyTimeConstraint(unixTime float64, claimName string, future bool) error
claimTime := time.Unix(int64(unixTime), 0)
now := time.Now().Truncate(time.Second)
// For expiration (future=true), we add skew to now (making now later)
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
skewDirection := 1
if !future {
skewDirection = -1
var skewedNow time.Time
if future {
// For expiration (future=true), add skew to now (making now later)
// Use the larger tolerance for future checks (exp)
skewedNow = now.Add(ClockSkewToleranceFuture)
} else {
// For iat/nbf (future=false), subtract skew from now (making now earlier)
// Use the smaller, specific tolerance for past checks (iat, nbf)
skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance
}
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
if claimTime.Equal(now) {
return nil
+120 -28
View File
@@ -110,6 +110,7 @@ type TraefikOidc struct {
postLogoutRedirectURI string
sessionManager *SessionManager
tokenExchanger TokenExchanger // Added field for mocking
refreshGracePeriod time.Duration // Configurable grace period for proactive refresh
}
// ProviderMetadata holds OIDC provider metadata
@@ -356,6 +357,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
logger: logger,
refreshGracePeriod: func() time.Duration { // Set refresh grace period from config or default
if config.RefreshGracePeriodSeconds > 0 {
return time.Duration(config.RefreshGracePeriodSeconds) * time.Second
}
return 60 * time.Second // Default to 60 seconds
}(),
}
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
@@ -595,9 +602,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
// Original logic: Always handle failed refresh as an expired token
t.logger.Debug("Token refresh failed, handling as expired token")
t.handleExpiredToken(rw, req, session, redirectURL)
t.logger.Infof("Token refresh failed") // Changed from Warn to Infof
// Check if the client prefers JSON (likely an API call)
acceptHeader := req.Header.Get("Accept")
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
} else {
// Client likely a browser, initiate full re-authentication
t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)")
t.handleExpiredToken(rw, req, session, redirectURL)
}
return // Stop processing
}
}
@@ -612,7 +629,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
@@ -639,7 +657,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
return
}
}
@@ -714,8 +733,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
if errorDescription == "" {
errorDescription = req.URL.Query().Get("error") // Use error code if description is empty
}
t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
return
}
@@ -723,20 +745,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
return
}
@@ -744,7 +766,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
http.Error(rw, "No code in callback", http.StatusBadRequest)
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
return
}
@@ -754,7 +776,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
return
}
@@ -762,14 +784,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
// Use the exported VerifyToken method now that handleCallback is in main.go
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
return
}
@@ -777,20 +799,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
return
}
@@ -799,7 +821,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
return
}
@@ -905,17 +927,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return false, false, true
}
now := time.Now().Unix()
expTime := int64(expClaim)
// Expiration check is now handled within VerifyJWTSignatureAndClaims logic above
// We only get here if the token is valid and not expired
// Check if token is nearing expiration (needs refresh proactively)
// Define a grace period, e.g., 5 minutes before actual expiry
refreshGracePeriod := int64(5 * 60)
if expTime-now < refreshGracePeriod {
t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod)
// Check if token is nearing expiration using the configured grace period
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod)
return true, true, false // Needs proactive refresh
}
@@ -1095,18 +1117,26 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
return nil
}
// refreshToken refreshes the user's token
// refreshToken refreshes the user's token, protected by a mutex within the session.
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
t.logger.Debug("Refreshing token")
refreshToken := session.GetRefreshToken()
// Lock the mutex specific to this session instance before attempting refresh
session.refreshMutex.Lock()
defer session.refreshMutex.Unlock()
t.logger.Debug("Attempting to refresh token (mutex acquired)")
refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
t.logger.Debug("No refresh token found in session (inside lock)")
return false
}
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
if err != nil {
t.logger.Errorf("Failed to refresh token: %v", err)
// Log the error, potentially clear the invalid refresh token?
t.logger.Errorf("Failed to refresh token using refresh token: %v", err)
// Consider clearing the refresh token from the session here if the error indicates it's invalid
// session.SetRefreshToken("") // Example: Clear potentially invalid token
// session.Save(req, rw) // Need to handle potential save error
return false
}
@@ -1216,3 +1246,65 @@ func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenRe
// Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc
return t.getNewTokenWithRefreshToken(refreshToken)
}
// sendErrorResponse sends an error response, adapting to the client's Accept header.
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
acceptHeader := req.Header.Get("Accept")
// Check if the client prefers JSON
if strings.Contains(acceptHeader, "application/json") {
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(code)
// Use a simple error structure
json.NewEncoder(rw).Encode(map[string]interface{}{
"error": http.StatusText(code),
"error_description": message,
"status_code": code,
})
return
}
// Default to HTML response for browsers
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
// Determine the return URL (mostly relevant for HTML)
returnURL := "/" // Default to root
session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL
if err == nil {
incomingPath := session.GetIncomingPath()
// Use incoming path if it's valid and not one of the special OIDC paths
if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath {
returnURL = incomingPath
}
} else {
t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err)
}
// Basic HTML structure for the error page
htmlBody := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<title>Authentication Error</title>
<style>
body { font-family: sans-serif; padding: 20px; background-color: #f8f9fa; color: #343a40; }
h1 { color: #dc3545; }
a { color: #007bff; text-decoration: none; }
a:hover { text-decoration: underline; }
.container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
</style>
</head>
<body>
<div class="container">
<h1>Authentication Error</h1>
<p>%s</p>
<p><a href="%s">Return to application</a></p>
</div>
</body>
</html>`, message, returnURL)
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
rw.WriteHeader(code)
_, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent
}
+5
View File
@@ -128,10 +128,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
// Initialize session pool.
sm.sessionPool.New = func() interface{} {
// Initialize SessionData with necessary fields and the mutex.
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
refreshMutex: sync.Mutex{}, // Initialize the mutex
}
}
@@ -251,6 +253,9 @@ type SessionData struct {
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size.
refreshTokenChunks map[int]*sessions.Session
// refreshMutex protects refresh token operations within this session instance.
refreshMutex sync.Mutex
}
// Save persists all session data to cookies in the HTTP response.
+16 -5
View File
@@ -84,6 +84,11 @@ type Config struct {
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
// RefreshGracePeriodSeconds defines how many seconds before a token expires
// the plugin should attempt to refresh it proactively (optional)
// Default: 60
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
}
const (
@@ -111,11 +116,12 @@ const (
// - EnablePKCE: false (PKCE is opt-in)
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
EnablePKCE: false, // PKCE is opt-in
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
}
return c
@@ -197,6 +203,11 @@ func (c *Config) Validate() error {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
// Validate refresh grace period
if c.RefreshGracePeriodSeconds < 0 {
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
}
return nil
}