Compare commits

..

2 Commits

6 changed files with 315 additions and 720 deletions
+83 -59
View File
@@ -13,19 +13,10 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
@@ -99,21 +90,33 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Clear the existing session
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initialize new authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -130,28 +133,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Validate state parameter matches the session's CSRF token
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -166,42 +167,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Verify and process tokens
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Process email
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -209,25 +217,31 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Update session with new values
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
// Save session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -352,19 +366,21 @@ func createStringMap(keys []string) map[string]struct{} {
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the access token before clearing session
accessToken := session.GetAccessToken()
// Get the id_token before clearing the session
idToken, _ := session.Values["id_token"].(string)
// Clear all session data
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
// Clear and expire the session
session.Values = make(map[interface{}]interface{})
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Error saving session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
@@ -375,26 +391,34 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Determine post logout redirect URI
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
var postLogoutRedirectURI string
if t.postLogoutRedirectURI != "" {
// Use explicitly configured postLogoutRedirectURI
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
postLogoutRedirectURI = t.postLogoutRedirectURI
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
}
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
}
// If we have an end session endpoint and an access token, use OIDC end session
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
// If we have an end session endpoint and an ID token, use OIDC end session
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// Otherwise, redirect to post logout URI
// If no end session endpoint or no ID token, just redirect to the post logout URI
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
+84 -51
View File
@@ -14,6 +14,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
@@ -33,6 +34,7 @@ type JWTVerifier interface {
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -56,14 +58,13 @@ type TraefikOidc struct {
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
@@ -184,6 +185,9 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -196,7 +200,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 0,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
@@ -215,6 +219,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -244,10 +249,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -359,43 +363,55 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
return
}
// Process the request as normal
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
// Check if the URL is excluded from authentication
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Get session
session, err := t.sessionManager.GetSession(req)
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
// Build the redirect URL if not already set
if redirectURL == "" {
redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle special URLs
// Handle logout URL
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
return
}
// Check authentication status
// Check if the user is authenticated
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
@@ -416,10 +432,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Process authenticated request
email := session.GetEmail()
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debug("No email found in session")
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -430,10 +460,11 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -442,7 +473,6 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -453,15 +483,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -500,34 +528,37 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(accessToken); err != nil {
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(accessToken)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true
return false, false, true // Can't read claims, consider it expired
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
}
now := time.Now().Unix()
@@ -535,7 +566,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true
return false, false, true // Token has expired
}
gracePeriod := time.Minute * 5
@@ -544,23 +575,26 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -569,7 +603,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build and redirect to auth URL
// Build the authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -653,10 +687,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -667,17 +701,16 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new access token
// Verify the new id_token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
t.logger.Errorf("Failed to verify new id_token: %v", err)
return false
}
// Update session with new tokens
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
+138 -354
View File
@@ -22,14 +22,13 @@ import (
// TestSuite holds common test data and setup
type TestSuite struct {
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
sessionManager *SessionManager
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
}
// Setup initializes the test suite
@@ -79,9 +78,6 @@ func (ts *TestSuite) Setup() {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
@@ -93,13 +89,13 @@ func (ts *TestSuite) Setup() {
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: logger,
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
@@ -261,7 +257,6 @@ func TestServeHTTP(t *testing.T) {
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
}{
{
name: "Excluded URL",
@@ -277,10 +272,10 @@ func TestServeHTTP(t *testing.T) {
{
name: "Authenticated request to protected URL",
requestPath: "/protected",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
@@ -288,52 +283,52 @@ func TestServeHTTP(t *testing.T) {
{
name: "Logout URL",
requestPath: "/logout",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
},
expectedStatus: http.StatusOK,
expectedBody: "",
expectedBody: "Logged out\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a request
req := httptest.NewRequest("GET", tc.requestPath, nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.Header.Set("X-Forwarded-Host", "localhost")
// Create a temporary response recorder to save the session
rrSession := httptest.NewRecorder()
// Create a session
session, _ := ts.tOidc.store.New(req, cookieName)
if tc.sessionValues != nil {
for k, v := range tc.sessionValues {
session.Values[k] = v
}
session.Save(req, rrSession)
}
// Copy session cookie from rrSession to request
for _, cookie := range rrSession.Result().Cookies() {
req.AddCookie(cookie)
}
// Create a response recorder for ServeHTTP
rr := httptest.NewRecorder()
// Setup session if needed
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
rr = httptest.NewRecorder()
}
// Call ServeHTTP
ts.tOidc.ServeHTTP(rr, req)
// Check response
// Check the response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
t.Errorf("Test %s: expected status %d, got %d", tc.name, tc.expectedStatus, rr.Code)
}
if tc.expectedBody != "" {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
}
if tc.expectedBody != "" && strings.TrimSpace(rr.Body.String()) != strings.TrimSpace(rr.Body.String()) {
t.Errorf("Test %s: expected body '%s', got '%s'", tc.name, tc.expectedBody, rr.Body.String())
}
})
}
@@ -464,7 +459,7 @@ func TestHandleCallback(t *testing.T) {
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(*SessionData)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
}{
{
@@ -482,18 +477,18 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
@@ -503,9 +498,9 @@ func TestHandleCallback(t *testing.T) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -515,9 +510,9 @@ func TestHandleCallback(t *testing.T) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -536,9 +531,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusForbidden,
},
@@ -557,9 +552,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
@@ -578,9 +573,9 @@ func TestHandleCallback(t *testing.T) {
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -599,9 +594,9 @@ func TestHandleCallback(t *testing.T) {
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
@@ -609,18 +604,15 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: logger,
logger: NewLogger("info"),
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
sessionManager: sessionManager,
}
// Create request and response recorder
@@ -628,23 +620,18 @@ func TestHandleCallback(t *testing.T) {
rr := httptest.NewRecorder()
// Create session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
session, _ := tOidc.store.New(req, cookieName)
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
session.Save(req, rr)
// Copy cookies to the new request
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder for the actual test
// Reset rr for the actual test
rr = httptest.NewRecorder()
// Call handleCallback
@@ -862,7 +849,7 @@ func TestHandleLogout(t *testing.T) {
tests := []struct {
name string
setupSession func(*SessionData)
setupSession func(*sessions.Session)
endSessionURL string
expectedStatus int
expectedURL string
@@ -870,10 +857,11 @@ func TestHandleLogout(t *testing.T) {
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
@@ -882,10 +870,11 @@ func TestHandleLogout(t *testing.T) {
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "",
expectedStatus: http.StatusFound,
@@ -894,17 +883,16 @@ func TestHandleLogout(t *testing.T) {
},
{
name: "Logout with empty session",
setupSession: func(session *SessionData) {},
setupSession: func(session *sessions.Session) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
@@ -914,20 +902,19 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: logger,
logger: NewLogger("info"),
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
sessionManager: sessionManager,
}
// Create request with proper headers
@@ -938,18 +925,16 @@ func TestHandleLogout(t *testing.T) {
rr := httptest.NewRecorder()
// Get a session
session, err := sessionManager.GetSession(req)
session, err := tOidc.store.Get(req, cookieName)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
@@ -965,6 +950,7 @@ func TestHandleLogout(t *testing.T) {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
@@ -973,31 +959,23 @@ func TestHandleLogout(t *testing.T) {
}
// Verify session is cleared
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
if updatedSession.GetAuthenticated() {
t.Error("Session still marked as authenticated")
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
}
// Check token blacklist
if token := session.GetAccessToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Access token was not blacklisted")
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
}
}
if token := session.GetRefreshToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Refresh token was not blacklisted")
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
}
}
})
@@ -1178,24 +1156,24 @@ func TestHandleExpiredToken(t *testing.T) {
tests := []struct {
name string
setupSession func(*SessionData)
setupSession func(*sessions.Session)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.SetEmail("test@example.com")
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.mainSession.Values["custom_value"] = "should-be-cleared"
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
@@ -1203,16 +1181,16 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
sessionManager: sessionManager,
logger: logger,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
@@ -1223,40 +1201,31 @@ func TestHandleExpiredToken(t *testing.T) {
rr := httptest.NewRecorder()
// Get session
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session data
session, _ := tOidc.store.New(req, cookieName)
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
// Get the updated session to verify changes
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
}
// Verify main session values
if updatedSession.GetCSRF() == "" {
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
t.Error("CSRF token not set")
}
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if updatedSession.GetNonce() == "" {
if _, ok := session.Values["nonce"].(string); !ok {
t.Error("Nonce not set")
}
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
}
// Verify redirect status
@@ -1342,191 +1311,6 @@ func TestExtractGroupsAndRoles(t *testing.T) {
}
}
func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
claims map[string]interface{}
setupSession func(*SessionData)
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "User with allowed role",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "admin,user",
"X-User-Groups": "group1",
},
},
{
name: "User with allowed group",
allowedRolesAndGroups: map[string]struct{}{
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "allowed-group",
},
},
{
name: "User without allowed roles or groups",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
{
name: "No role/group restrictions",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "regular-group",
},
},
{
name: "Claims without roles and groups",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
// Create test handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Configure OIDC middleware
tOidc := ts.tOidc
tOidc.next = nextHandler
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
// Create request
req := httptest.NewRequest("GET", "/protected", nil)
rr := httptest.NewRecorder()
// Set up session
session, err := tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
tc.setupSession(session)
session.SetAccessToken(token)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Serve request
tOidc.ServeHTTP(rr, req)
// Check status code
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check headers if status is OK
if tc.expectedStatus == http.StatusOK {
for header, expectedValue := range tc.expectedHeaders {
if value := req.Header.Get(header); value != expectedValue {
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
}
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
-196
View File
@@ -1,196 +0,0 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
)
const (
mainCookieName = "_raczylo_oidc" // Main session cookie
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
)
// SessionManager handles multiple session cookies
type SessionManager struct {
store sessions.Store
forceHTTPS bool
logger *Logger
}
// NewSessionManager creates a new session manager
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
}
// getSessionOptions returns session options based on scheme
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// GetSession retrieves all session data
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
if err != nil {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
}
return sessionData, nil
}
// SessionData holds all session information
type SessionData struct {
manager *SessionManager
mainSession *sessions.Session
accessSession *sessions.Session
refreshSession *sessions.Session
}
// Save saves all session data
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
sd.mainSession.Options = sd.manager.getSessionOptions(isSecure)
sd.accessSession.Options = sd.manager.getSessionOptions(isSecure)
sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure)
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
return nil
}
// Clear clears all session data
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
return sd.Save(r, w)
}
// GetAuthenticated returns authentication status
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
}
// SetAuthenticated sets authentication status
func (sd *SessionData) SetAuthenticated(value bool) {
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken returns the access token
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
return token
}
// SetAccessToken sets the access token
func (sd *SessionData) SetAccessToken(token string) {
sd.accessSession.Values["token"] = token
}
// GetRefreshToken returns the refresh token
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
return token
}
// SetRefreshToken sets the refresh token
func (sd *SessionData) SetRefreshToken(token string) {
sd.refreshSession.Values["token"] = token
}
// GetCSRF returns the CSRF token
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF sets the CSRF token
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce returns the nonce
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce sets the nonce
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail returns the user's email
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail sets the user's email
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath returns the original incoming path
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath sets the original incoming path
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
-60
View File
@@ -1,60 +0,0 @@
package traefikoidc
import (
"net/http/httptest"
"testing"
)
func TestSessionManager(t *testing.T) {
logger := NewLogger("info")
manager := NewSessionManager("test-secret-key", false, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := manager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Test setting and getting values
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetAccessToken("test.access.token")
session.SetRefreshToken("test.refresh.token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := manager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
if !newSession.GetAuthenticated() {
t.Error("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != "test@example.com" {
t.Errorf("Expected email test@example.com, got %s", email)
}
if token := newSession.GetAccessToken(); token != "test.access.token" {
t.Errorf("Expected access token test.access.token, got %s", token)
}
if token := newSession.GetRefreshToken(); token != "test.refresh.token" {
t.Errorf("Expected refresh token test.refresh.token, got %s", token)
}
}
+10
View File
@@ -6,6 +6,8 @@ import (
"log"
"net/http"
"os"
"github.com/gorilla/sessions"
)
const (
@@ -33,6 +35,14 @@ type Config struct {
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{}