mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Optimize the code, find edge cases, polish the bugs out.
This commit is contained in:
@@ -29,8 +29,16 @@ func cleanupReplayCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// ClockSkewTolerance is configurable to adjust time-based validations.
|
||||
var ClockSkewTolerance = 2 * time.Minute
|
||||
// ClockSkewToleranceFuture defines the tolerance for future-based claims like 'exp'.
|
||||
// Allows for more leniency with expiration checks.
|
||||
var ClockSkewToleranceFuture = 2 * time.Minute
|
||||
|
||||
// ClockSkewTolerancePast defines the tolerance for past-based claims like 'iat' and 'nbf'.
|
||||
// A smaller tolerance is typically used here to prevent accepting tokens issued too far in the future.
|
||||
var (
|
||||
ClockSkewTolerancePast = 10 * time.Second
|
||||
ClockSkewTolerance = 2 * time.Minute
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token as defined in RFC 7519.
|
||||
type JWT struct {
|
||||
@@ -202,13 +210,16 @@ func verifyTimeConstraint(unixTime float64, claimName string, future bool) error
|
||||
claimTime := time.Unix(int64(unixTime), 0)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
|
||||
// For expiration (future=true), we add skew to now (making now later)
|
||||
// For iat/nbf (future=false), we subtract skew from now (making now earlier)
|
||||
skewDirection := 1
|
||||
if !future {
|
||||
skewDirection = -1
|
||||
var skewedNow time.Time
|
||||
if future {
|
||||
// For expiration (future=true), add skew to now (making now later)
|
||||
// Use the larger tolerance for future checks (exp)
|
||||
skewedNow = now.Add(ClockSkewToleranceFuture)
|
||||
} else {
|
||||
// For iat/nbf (future=false), subtract skew from now (making now earlier)
|
||||
// Use the smaller, specific tolerance for past checks (iat, nbf)
|
||||
skewedNow = now.Add(-ClockSkewTolerancePast) // Subtract the past tolerance
|
||||
}
|
||||
skewedNow := now.Add(time.Duration(skewDirection) * ClockSkewTolerance)
|
||||
|
||||
if claimTime.Equal(now) {
|
||||
return nil
|
||||
|
||||
@@ -110,6 +110,7 @@ type TraefikOidc struct {
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
tokenExchanger TokenExchanger // Added field for mocking
|
||||
refreshGracePeriod time.Duration // Configurable grace period for proactive refresh
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
@@ -356,6 +357,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
logger: logger,
|
||||
refreshGracePeriod: func() time.Duration { // Set refresh grace period from config or default
|
||||
if config.RefreshGracePeriodSeconds > 0 {
|
||||
return time.Duration(config.RefreshGracePeriodSeconds) * time.Second
|
||||
}
|
||||
return 60 * time.Second // Default to 60 seconds
|
||||
}(),
|
||||
}
|
||||
|
||||
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
@@ -595,9 +602,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
// Original logic: Always handle failed refresh as an expired token
|
||||
t.logger.Debug("Token refresh failed, handling as expired token")
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
t.logger.Infof("Token refresh failed") // Changed from Warn to Infof
|
||||
// Check if the client prefers JSON (likely an API call)
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
t.logger.Debug("Client accepts JSON, sending 401 Unauthorized on refresh failure")
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(rw).Encode(map[string]string{"error": "unauthorized", "message": "Token refresh failed"})
|
||||
} else {
|
||||
// Client likely a browser, initiate full re-authentication
|
||||
t.logger.Debug("Client does not prefer JSON, handling refresh failure as expired token (initiating re-auth)")
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
}
|
||||
return // Stop processing
|
||||
}
|
||||
}
|
||||
@@ -612,7 +629,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
errorMsg := fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -639,7 +657,8 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
errorMsg := fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath)
|
||||
t.sendErrorResponse(rw, req, errorMsg, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -714,8 +733,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
// Check for errors in the callback
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
if errorDescription == "" {
|
||||
errorDescription = req.URL.Query().Get("error") // Use error code if description is empty
|
||||
}
|
||||
t.logger.Errorf("Authentication error from provider: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
t.sendErrorResponse(rw, req, fmt.Sprintf("Authentication error from provider: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -723,20 +745,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
t.sendErrorResponse(rw, req, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
t.sendErrorResponse(rw, req, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
t.sendErrorResponse(rw, req, "Invalid state parameter (CSRF mismatch)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -744,7 +766,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
http.Error(rw, "No code in callback", http.StatusBadRequest)
|
||||
t.sendErrorResponse(rw, req, "No authorization code received in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -754,7 +776,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
tokenResponse, err := t.tokenExchanger.ExchangeCodeForToken(req.Context(), "authorization_code", code, redirectURL, codeVerifier)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not exchange code for token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -762,14 +784,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
// Use the exported VerifyToken method now that handleCallback is in main.go
|
||||
if err := t.VerifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not verify ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Could not extract claims from token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -777,20 +799,20 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce missing in session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Nonce mismatch", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -799,7 +821,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
http.Error(rw, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
t.sendErrorResponse(rw, req, "Authentication failed: Invalid or disallowed email", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -905,17 +927,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
expTime := int64(expClaim)
|
||||
|
||||
// Expiration check is now handled within VerifyJWTSignatureAndClaims logic above
|
||||
// We only get here if the token is valid and not expired
|
||||
|
||||
// Check if token is nearing expiration (needs refresh proactively)
|
||||
// Define a grace period, e.g., 5 minutes before actual expiry
|
||||
refreshGracePeriod := int64(5 * 60)
|
||||
if expTime-now < refreshGracePeriod {
|
||||
t.logger.Debugf("Token nearing expiration (within %d seconds), scheduling refresh", refreshGracePeriod)
|
||||
// Check if token is nearing expiration using the configured grace period
|
||||
if time.Unix(expTime, 0).Before(time.Now().Add(t.refreshGracePeriod)) {
|
||||
// Recalculate remaining seconds for logging clarity if needed, using the configured duration
|
||||
remainingSeconds := int64(time.Until(time.Unix(expTime, 0)).Seconds())
|
||||
t.logger.Debugf("Token nearing expiration (expires in %d seconds, grace period %s), scheduling refresh", remainingSeconds, t.refreshGracePeriod)
|
||||
return true, true, false // Needs proactive refresh
|
||||
}
|
||||
|
||||
@@ -1095,18 +1117,26 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshToken refreshes the user's token
|
||||
// refreshToken refreshes the user's token, protected by a mutex within the session.
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
t.logger.Debug("Refreshing token")
|
||||
refreshToken := session.GetRefreshToken()
|
||||
// Lock the mutex specific to this session instance before attempting refresh
|
||||
session.refreshMutex.Lock()
|
||||
defer session.refreshMutex.Unlock()
|
||||
|
||||
t.logger.Debug("Attempting to refresh token (mutex acquired)")
|
||||
refreshToken := session.GetRefreshToken() // Get token *after* acquiring lock
|
||||
if refreshToken == "" {
|
||||
t.logger.Debug("No refresh token found in session")
|
||||
t.logger.Debug("No refresh token found in session (inside lock)")
|
||||
return false
|
||||
}
|
||||
|
||||
newToken, err := t.tokenExchanger.GetNewTokenWithRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to refresh token: %v", err)
|
||||
// Log the error, potentially clear the invalid refresh token?
|
||||
t.logger.Errorf("Failed to refresh token using refresh token: %v", err)
|
||||
// Consider clearing the refresh token from the session here if the error indicates it's invalid
|
||||
// session.SetRefreshToken("") // Example: Clear potentially invalid token
|
||||
// session.Save(req, rw) // Need to handle potential save error
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1216,3 +1246,65 @@ func (t *TraefikOidc) GetNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
// Note: The original getNewTokenWithRefreshToken helper is defined in helpers.go and is already a method on *TraefikOidc
|
||||
return t.getNewTokenWithRefreshToken(refreshToken)
|
||||
}
|
||||
|
||||
// sendErrorResponse sends an error response, adapting to the client's Accept header.
|
||||
func (t *TraefikOidc) sendErrorResponse(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
||||
acceptHeader := req.Header.Get("Accept")
|
||||
|
||||
// Check if the client prefers JSON
|
||||
if strings.Contains(acceptHeader, "application/json") {
|
||||
t.logger.Debugf("Sending JSON error response (code %d): %s", code, message)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(code)
|
||||
// Use a simple error structure
|
||||
json.NewEncoder(rw).Encode(map[string]interface{}{
|
||||
"error": http.StatusText(code),
|
||||
"error_description": message,
|
||||
"status_code": code,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Default to HTML response for browsers
|
||||
t.logger.Debugf("Sending HTML error response (code %d): %s", code, message)
|
||||
|
||||
// Determine the return URL (mostly relevant for HTML)
|
||||
returnURL := "/" // Default to root
|
||||
session, err := t.sessionManager.GetSession(req) // Attempt to get session for return URL
|
||||
if err == nil {
|
||||
incomingPath := session.GetIncomingPath()
|
||||
// Use incoming path if it's valid and not one of the special OIDC paths
|
||||
if incomingPath != "" && incomingPath != t.redirURLPath && incomingPath != t.logoutURLPath {
|
||||
returnURL = incomingPath
|
||||
}
|
||||
} else {
|
||||
t.logger.Infof("Could not get session to determine return URL in sendErrorResponse: %v", err)
|
||||
}
|
||||
|
||||
// Basic HTML structure for the error page
|
||||
htmlBody := fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Error</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; padding: 20px; background-color: #f8f9fa; color: #343a40; }
|
||||
h1 { color: #dc3545; }
|
||||
a { color: #007bff; text-decoration: none; }
|
||||
a:hover { text-decoration: underline; }
|
||||
.container { max-width: 600px; margin: auto; background: #fff; padding: 20px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>Authentication Error</h1>
|
||||
<p>%s</p>
|
||||
<p><a href="%s">Return to application</a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`, message, returnURL)
|
||||
|
||||
rw.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
rw.WriteHeader(code)
|
||||
_, _ = rw.Write([]byte(htmlBody)) // Ignore write error as header is already sent
|
||||
}
|
||||
|
||||
@@ -128,10 +128,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*
|
||||
|
||||
// Initialize session pool.
|
||||
sm.sessionPool.New = func() interface{} {
|
||||
// Initialize SessionData with necessary fields and the mutex.
|
||||
return &SessionData{
|
||||
manager: sm,
|
||||
accessTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshTokenChunks: make(map[int]*sessions.Session),
|
||||
refreshMutex: sync.Mutex{}, // Initialize the mutex
|
||||
}
|
||||
}
|
||||
|
||||
@@ -251,6 +253,9 @@ type SessionData struct {
|
||||
// refreshTokenChunks stores additional chunks of the refresh token
|
||||
// when it exceeds the maximum cookie size.
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
|
||||
// refreshMutex protects refresh token operations within this session instance.
|
||||
refreshMutex sync.Mutex
|
||||
}
|
||||
|
||||
// Save persists all session data to cookies in the HTTP response.
|
||||
|
||||
+16
-5
@@ -84,6 +84,11 @@ type Config struct {
|
||||
|
||||
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
|
||||
HTTPClient *http.Client
|
||||
|
||||
// RefreshGracePeriodSeconds defines how many seconds before a token expires
|
||||
// the plugin should attempt to refresh it proactively (optional)
|
||||
// Default: 60
|
||||
RefreshGracePeriodSeconds int `json:"refreshGracePeriodSeconds"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -111,11 +116,12 @@ const (
|
||||
// - EnablePKCE: false (PKCE is opt-in)
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
LogLevel: DefaultLogLevel,
|
||||
RateLimit: DefaultRateLimit,
|
||||
ForceHTTPS: true, // Secure by default
|
||||
EnablePKCE: false, // PKCE is opt-in
|
||||
RefreshGracePeriodSeconds: 60, // Default grace period of 60 seconds
|
||||
}
|
||||
|
||||
return c
|
||||
@@ -197,6 +203,11 @@ func (c *Config) Validate() error {
|
||||
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
|
||||
}
|
||||
|
||||
// Validate refresh grace period
|
||||
if c.RefreshGracePeriodSeconds < 0 {
|
||||
return fmt.Errorf("refreshGracePeriodSeconds cannot be negative")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user