Move session management into session manager. Split the cookies to avoid the 4k limit ( resolves issue: #15 )

This commit is contained in:
2024-12-10 10:01:06 +00:00
parent 01ee7c4dc8
commit 2b7af88ff9
5 changed files with 525 additions and 335 deletions
+53 -87
View File
@@ -13,17 +13,16 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
@@ -100,33 +99,21 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = newSessionOptions(t.determineScheme(req) == "https")
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
// Initialize new authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.store.Get(req, cookieName)
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -143,26 +130,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Validate the state parameter matches the session's CSRF token
// Validate state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -177,49 +166,42 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
// Verify and process tokens
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the nonce claim matches the one stored in session
// Verify nonce
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
// Process email
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -227,31 +209,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = newSessionOptions(t.determineScheme(req) == "https")
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
// Update session with new values
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
// Save session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
// Redirect to original path or root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -376,21 +352,19 @@ func createStringMap(keys []string) map[string]struct{} {
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the id_token before clearing the session
idToken, _ := session.Values["id_token"].(string)
// Get the access token before clearing session
accessToken := session.GetAccessToken()
// Clear and expire the session
session.Values = make(map[interface{}]interface{})
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Error saving session: %v", err)
// Clear all session data
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
@@ -401,34 +375,26 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Determine post logout redirect URI
var postLogoutRedirectURI string
if t.postLogoutRedirectURI != "" {
// Use explicitly configured postLogoutRedirectURI
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
postLogoutRedirectURI = t.postLogoutRedirectURI
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
}
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
// If we have an end session endpoint and an ID token, use OIDC end session
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
// If we have an end session endpoint and an access token, use OIDC end session
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// If no end session endpoint or no ID token, just redirect to the post logout URI
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
// Otherwise, redirect to post logout URI
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}