mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Move session management into session manager. Split the cookies to avoid the 4k limit ( resolves issue: #15 )
This commit is contained in:
+53
-87
@@ -13,17 +13,16 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,33 +99,21 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear the existing session
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
|
||||
// Set new values
|
||||
session.Values["csrf"] = uuid.New().String()
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Values["nonce"], _ = generateNonce()
|
||||
session.Options = newSessionOptions(t.determineScheme(req) == "https")
|
||||
|
||||
// Save the session before initiating authentication
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
|
||||
// Initialize new authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
@@ -143,26 +130,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the state parameter matches the session's CSRF token
|
||||
// Validate state parameter matches the session's CSRF token
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
csrfToken, ok := session.Values["csrf"].(string)
|
||||
if !ok || csrfToken == "" {
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to exchange the code for tokens
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
@@ -177,49 +166,42 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Extract id_token
|
||||
idToken := tokenResponse.IDToken
|
||||
if idToken == "" {
|
||||
t.logger.Error("No id_token in token response")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the id_token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
// Verify and process tokens
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims from id_token
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the nonce claim matches the one stored in session
|
||||
// Verify nonce
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sessionNonce, ok := session.Values["nonce"].(string)
|
||||
if !ok || sessionNonce == "" {
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the email from claims
|
||||
// Process email
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -227,31 +209,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Store tokens and authentication status in session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["email"] = email
|
||||
session.Values["id_token"] = idToken
|
||||
session.Values["refresh_token"] = tokenResponse.RefreshToken
|
||||
session.Options = newSessionOptions(t.determineScheme(req) == "https")
|
||||
|
||||
// Remove CSRF and nonce from session
|
||||
delete(session.Values, "csrf")
|
||||
delete(session.Values, "nonce")
|
||||
// Update session with new values
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Authentication successful. User email: %s", email)
|
||||
|
||||
// Redirect to the original requested path or default to root
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
|
||||
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
|
||||
redirectPath = path
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -376,21 +352,19 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
|
||||
// handleLogout handles the logout request
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the id_token before clearing the session
|
||||
idToken, _ := session.Values["id_token"].(string)
|
||||
// Get the access token before clearing session
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
// Clear and expire the session
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
session.Options.MaxAge = -1
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Error saving session: %v", err)
|
||||
// Clear all session data
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -401,34 +375,26 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
var postLogoutRedirectURI string
|
||||
if t.postLogoutRedirectURI != "" {
|
||||
// Use explicitly configured postLogoutRedirectURI
|
||||
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
|
||||
postLogoutRedirectURI = t.postLogoutRedirectURI
|
||||
} else {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
|
||||
}
|
||||
} else {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
|
||||
|
||||
// If we have an end session endpoint and an ID token, use OIDC end session
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
// If we have an end session endpoint and an access token, use OIDC end session
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If no end session endpoint or no ID token, just redirect to the post logout URI
|
||||
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
|
||||
// Otherwise, redirect to post logout URI
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user