mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
6 Commits
v0.3.7-rc2
...
v0.4.2
| Author | SHA1 | Date | |
|---|---|---|---|
| bef4212c57 | |||
| 1fee2f9e9a | |||
| 11bc6f3e31 | |||
| 2b7af88ff9 | |||
| 01ee7c4dc8 | |||
| a6fa4d8789 |
@@ -13,6 +13,7 @@ testData:
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
|
||||
@@ -38,6 +38,7 @@ spec:
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
|
||||
@@ -6,7 +6,7 @@ toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
golang.org/x/time v0.7.0
|
||||
)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
||||
+111
-137
@@ -13,10 +13,19 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce generates a random nonce
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
@@ -27,14 +36,6 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host, and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
@@ -97,104 +98,22 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleLogout handles the logout process
|
||||
func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := t.store.Get(r, cookieName)
|
||||
if err != nil {
|
||||
handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Get tokens from session
|
||||
idToken, _ := session.Values["id_token"].(string)
|
||||
refreshToken, _ := session.Values["refresh_token"].(string)
|
||||
accessToken, _ := session.Values["access_token"].(string)
|
||||
|
||||
// Revoke tokens if they exist
|
||||
if refreshToken != "" {
|
||||
t.RevokeTokenWithProvider(refreshToken, "refresh_token")
|
||||
t.RevokeToken(refreshToken)
|
||||
}
|
||||
if accessToken != "" {
|
||||
t.RevokeTokenWithProvider(accessToken, "access_token")
|
||||
t.RevokeToken(accessToken)
|
||||
}
|
||||
|
||||
// Clear session
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
if err := session.Save(r, w); err != nil {
|
||||
handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine redirect URL
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
scheme := "http"
|
||||
if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS {
|
||||
scheme = "https"
|
||||
}
|
||||
baseURL := fmt.Sprintf("%s://%s/", scheme, host)
|
||||
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL)
|
||||
if err != nil {
|
||||
handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, baseURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the logout URL with proper encoding
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid end session URL: %v", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Clear the existing session
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
|
||||
// Set new values
|
||||
session.Values["csrf"] = uuid.New().String()
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Values["nonce"], _ = generateNonce()
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Save the session before initiating authentication
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
|
||||
// Initialize new authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
@@ -211,26 +130,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the state parameter matches the session's CSRF token
|
||||
// Validate state parameter matches the session's CSRF token
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
csrfToken, ok := session.Values["csrf"].(string)
|
||||
if !ok || csrfToken == "" {
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to exchange the code for tokens
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
@@ -238,56 +159,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract id_token
|
||||
idToken := tokenResponse.IDToken
|
||||
if idToken == "" {
|
||||
t.logger.Error("No id_token in token response")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the id_token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
// Verify and process tokens
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims from id_token
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the nonce claim matches the one stored in session
|
||||
// Verify nonce
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sessionNonce, ok := session.Values["nonce"].(string)
|
||||
if !ok || sessionNonce == "" {
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the email from claims
|
||||
// Process email
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -295,31 +209,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
// Store tokens and authentication status in session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["email"] = email
|
||||
session.Values["id_token"] = idToken
|
||||
session.Values["refresh_token"] = tokenResponse.RefreshToken
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Remove CSRF and nonce from session
|
||||
delete(session.Values, "csrf")
|
||||
delete(session.Values, "nonce")
|
||||
// Update session with new values
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Authentication successful. User email: %s", email)
|
||||
|
||||
// Redirect to the original requested path or default to root
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
|
||||
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
|
||||
redirectPath = path
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
}
|
||||
|
||||
@@ -424,9 +332,9 @@ func (tc *TokenCache) Cleanup() {
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
@@ -441,3 +349,69 @@ func createStringMap(keys []string) map[string]struct{} {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout handles the logout request
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the access token before clearing session
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
// Clear all session data
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the base URL for redirects
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
// If we have an end session endpoint and an access token, use OIDC end session
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, redirect to post logout URI
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
// Ensure postLogoutRedirectURI is properly URL encoded
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
@@ -34,7 +33,6 @@ type JWTVerifier interface {
|
||||
type TraefikOidc struct {
|
||||
next http.Handler
|
||||
name string
|
||||
store sessions.Store
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
issuerURL string
|
||||
@@ -53,18 +51,19 @@ type TraefikOidc struct {
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
excludedURLs map[string]struct{}
|
||||
allowedUserDomains map[string]struct{}
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
@@ -185,9 +184,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
|
||||
// New creates a new instance of the OIDC middleware
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
|
||||
store.Options = defaultSessionOptions
|
||||
|
||||
// Setup HTTP client
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
@@ -200,7 +196,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
@@ -219,7 +215,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
t := &TraefikOidc{
|
||||
next: next,
|
||||
name: name,
|
||||
store: store,
|
||||
redirURLPath: config.CallbackURL,
|
||||
logoutURLPath: func() string {
|
||||
if config.LogoutURL == "" {
|
||||
@@ -227,6 +222,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
return config.LogoutURL
|
||||
}(),
|
||||
postLogoutRedirectURI: func() string {
|
||||
if config.PostLogoutRedirectURI == "" {
|
||||
return "/"
|
||||
}
|
||||
return config.PostLogoutRedirectURI
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
clientID: config.ClientID,
|
||||
@@ -243,9 +244,10 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
initComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
|
||||
@@ -357,92 +359,68 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Process the request as normal
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled")
|
||||
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the URL is excluded from authentication
|
||||
// Check if URL is excluded
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the scheme (http/https) and host
|
||||
t.scheme = t.determineScheme(req)
|
||||
defaultSessionOptions.Secure = t.scheme == "https"
|
||||
host := t.determineHost(req)
|
||||
|
||||
// Build the redirect URL if not already set
|
||||
if t.redirectURL == "" {
|
||||
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
|
||||
}
|
||||
|
||||
// Get the session
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
// Get session
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
t.logger.Debugf("Session contents at start: %+v", session.Values)
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
|
||||
// Handle logout URL
|
||||
// Handle special URLs
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle callback URL
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.handleCallback(rw, req)
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the user is authenticated
|
||||
// Check authentication status
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
if expired {
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// At this point, the user is authenticated
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Errorf("No id_token found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := extractClaims(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
// Process authenticated request
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Debugf("No email found in token claims")
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
t.logger.Debug("No email found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -452,11 +430,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(idToken)
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
// Set headers for groups and roles
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
@@ -465,6 +442,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
@@ -475,13 +453,15 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
// Process the request
|
||||
t.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
@@ -520,37 +500,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
}
|
||||
|
||||
// isUserAuthenticated checks if the user is authenticated
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
|
||||
authenticated, _ := session.Values["authenticated"].(bool)
|
||||
t.logger.Debugf("Session authenticated value: %v", authenticated)
|
||||
|
||||
if !authenticated {
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Debug("No id_token found in session")
|
||||
accessToken := session.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
t.logger.Debug("No access token found in session")
|
||||
return false, false, true // Session is invalid, consider it expired
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
t.logger.Errorf("Token verification failed: %v", err)
|
||||
return false, false, true // Token is invalid, consider it expired
|
||||
}
|
||||
|
||||
claims, err := extractClaims(idToken)
|
||||
claims, err := extractClaims(accessToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
return false, false, true // Can't read claims, consider it expired
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
expClaim, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Errorf("Failed to get expiration time from claims")
|
||||
return false, false, true // No expiration, consider it expired
|
||||
t.logger.Error("Failed to get expiration time from claims")
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -558,7 +535,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
|
||||
|
||||
if now > expTime {
|
||||
t.logger.Debug("Token has expired")
|
||||
return false, false, true // Token has expired
|
||||
return false, false, true
|
||||
}
|
||||
|
||||
gracePeriod := time.Minute * 5
|
||||
@@ -567,26 +544,23 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
|
||||
return true, true, false // Token will expire soon, needs refresh
|
||||
}
|
||||
|
||||
return true, false, false // Token is valid and not expiring soon
|
||||
return true, false, false
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
// Generate CSRF token
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
csrfToken := uuid.New().String()
|
||||
session.Values["csrf"] = csrfToken
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Options = defaultSessionOptions
|
||||
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
|
||||
|
||||
// Generate nonce
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.Values["nonce"] = nonce
|
||||
t.logger.Debugf("Setting nonce: %s", nonce)
|
||||
|
||||
// Set session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath(req.URL.Path)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
@@ -595,7 +569,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
||||
return
|
||||
}
|
||||
|
||||
// Build the authentication URL
|
||||
// Build and redirect to auth URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
@@ -679,10 +653,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
}
|
||||
|
||||
// refreshToken refreshes the user's token
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
t.logger.Debug("Refreshing token")
|
||||
refreshToken, ok := session.Values["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
refreshToken := session.GetRefreshToken()
|
||||
if refreshToken == "" {
|
||||
t.logger.Debug("No refresh token found in session")
|
||||
return false
|
||||
}
|
||||
@@ -693,16 +667,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the new id_token
|
||||
// Verify the new access token
|
||||
if err := t.verifyToken(newToken.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify new id_token: %v", err)
|
||||
t.logger.Errorf("Failed to verify new access token: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update session with new tokens
|
||||
session.Values["id_token"] = newToken.IDToken
|
||||
session.Values["refresh_token"] = newToken.RefreshToken
|
||||
session.Options = defaultSessionOptions
|
||||
session.SetAccessToken(newToken.IDToken)
|
||||
session.SetRefreshToken(newToken.RefreshToken)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save refreshed session: %v", err)
|
||||
return false
|
||||
@@ -767,3 +742,18 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
|
||||
return groups, roles, nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
// If the path is already a full URL, return it as-is
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
|
||||
// Ensure the path starts with a forward slash
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
+374
-158
@@ -22,13 +22,14 @@ import (
|
||||
|
||||
// TestSuite holds common test data and setup
|
||||
type TestSuite struct {
|
||||
t *testing.T
|
||||
rsaPrivateKey *rsa.PrivateKey
|
||||
rsaPublicKey *rsa.PublicKey
|
||||
ecPrivateKey *ecdsa.PrivateKey
|
||||
tOidc *TraefikOidc
|
||||
mockJWKCache *MockJWKCache
|
||||
token string
|
||||
t *testing.T
|
||||
rsaPrivateKey *rsa.PrivateKey
|
||||
rsaPublicKey *rsa.PublicKey
|
||||
ecPrivateKey *ecdsa.PrivateKey
|
||||
tOidc *TraefikOidc
|
||||
mockJWKCache *MockJWKCache
|
||||
token string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
|
||||
// Setup initializes the test suite
|
||||
@@ -78,6 +79,9 @@ func (ts *TestSuite) Setup() {
|
||||
ts.t.Fatalf("Failed to create test JWT: %v", err)
|
||||
}
|
||||
|
||||
logger := NewLogger("info")
|
||||
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
|
||||
|
||||
// Common TraefikOidc instance
|
||||
ts.tOidc = &TraefikOidc{
|
||||
issuerURL: "https://test-issuer.com",
|
||||
@@ -89,13 +93,13 @@ func (ts *TestSuite) Setup() {
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenCache: NewTokenCache(),
|
||||
logger: NewLogger("info"),
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
logger: logger,
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
excludedURLs: map[string]struct{}{"/favicon": {}},
|
||||
httpClient: &http.Client{},
|
||||
extractClaimsFunc: extractClaims,
|
||||
initComplete: make(chan struct{}),
|
||||
sessionManager: ts.sessionManager,
|
||||
}
|
||||
close(ts.tOidc.initComplete)
|
||||
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
|
||||
@@ -104,7 +108,7 @@ func (ts *TestSuite) Setup() {
|
||||
}
|
||||
|
||||
// Helper functions used by TraefikOidc
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -257,6 +261,7 @@ func TestServeHTTP(t *testing.T) {
|
||||
sessionValues map[interface{}]interface{}
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
setupSession func(*SessionData)
|
||||
}{
|
||||
{
|
||||
name: "Excluded URL",
|
||||
@@ -272,10 +277,10 @@ func TestServeHTTP(t *testing.T) {
|
||||
{
|
||||
name: "Authenticated request to protected URL",
|
||||
requestPath: "/protected",
|
||||
sessionValues: map[interface{}]interface{}{
|
||||
"authenticated": true,
|
||||
"email": "user@example.com",
|
||||
"id_token": ts.token,
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "OK",
|
||||
@@ -283,52 +288,52 @@ func TestServeHTTP(t *testing.T) {
|
||||
{
|
||||
name: "Logout URL",
|
||||
requestPath: "/logout",
|
||||
sessionValues: map[interface{}]interface{}{
|
||||
"authenticated": true,
|
||||
"email": "user@example.com",
|
||||
"id_token": ts.token,
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
session.SetAccessToken(ts.token)
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "Logged out\n",
|
||||
expectedBody: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a request
|
||||
req := httptest.NewRequest("GET", tc.requestPath, nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
req.Header.Set("X-Forwarded-Host", "localhost")
|
||||
|
||||
// Create a temporary response recorder to save the session
|
||||
rrSession := httptest.NewRecorder()
|
||||
|
||||
// Create a session
|
||||
session, _ := ts.tOidc.store.New(req, cookieName)
|
||||
if tc.sessionValues != nil {
|
||||
for k, v := range tc.sessionValues {
|
||||
session.Values[k] = v
|
||||
}
|
||||
session.Save(req, rrSession)
|
||||
}
|
||||
|
||||
// Copy session cookie from rrSession to request
|
||||
for _, cookie := range rrSession.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Create a response recorder for ServeHTTP
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Setup session if needed
|
||||
session, err := ts.tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
if tc.setupSession != nil {
|
||||
tc.setupSession(session)
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
rr = httptest.NewRecorder()
|
||||
}
|
||||
|
||||
// Call ServeHTTP
|
||||
ts.tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check the response
|
||||
// Check response
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Test %s: expected status %d, got %d", tc.name, tc.expectedStatus, rr.Code)
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
if tc.expectedBody != "" && strings.TrimSpace(rr.Body.String()) != strings.TrimSpace(rr.Body.String()) {
|
||||
t.Errorf("Test %s: expected body '%s', got '%s'", tc.name, tc.expectedBody, rr.Body.String())
|
||||
if tc.expectedBody != "" {
|
||||
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
|
||||
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -452,18 +457,20 @@ func TestHandleCallback(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
redirectURL := "http://example.com/"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
sessionSetupFunc func(*SessionData)
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -475,49 +482,49 @@ func TestHandleCallback(t *testing.T) {
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
name: "Missing Code",
|
||||
queryParams: "",
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Exchange Code Error",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("exchange code error")
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Missing ID Token",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Disallowed Email",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -529,16 +536,16 @@ func TestHandleCallback(t *testing.T) {
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "Invalid State Parameter",
|
||||
queryParams: "?code=test-code&state=invalid-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -550,16 +557,16 @@ func TestHandleCallback(t *testing.T) {
|
||||
"nonce": "test-nonce",
|
||||
}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "Nonce Mismatch",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -571,16 +578,16 @@ func TestHandleCallback(t *testing.T) {
|
||||
"nonce": "invalid-nonce",
|
||||
}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "Missing Nonce in Claims",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -592,9 +599,9 @@ func TestHandleCallback(t *testing.T) {
|
||||
// Missing nonce
|
||||
}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
sessionSetupFunc: func(session *SessionData) {
|
||||
session.SetCSRF("test-csrf-token")
|
||||
session.SetNonce("test-nonce")
|
||||
},
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
@@ -602,15 +609,18 @@ func TestHandleCallback(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
|
||||
// Create a new instance for each test to avoid state carryover
|
||||
tOidc := &TraefikOidc{
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
allowedUserDomains: map[string]struct{}{"example.com": {}},
|
||||
logger: NewLogger("info"),
|
||||
logger: logger,
|
||||
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
|
||||
extractClaimsFunc: tc.extractClaimsFunc,
|
||||
tokenVerifier: ts.tOidc.tokenVerifier,
|
||||
jwtVerifier: ts.tOidc.jwtVerifier,
|
||||
sessionManager: sessionManager,
|
||||
}
|
||||
|
||||
// Create request and response recorder
|
||||
@@ -618,22 +628,27 @@ func TestHandleCallback(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Create session
|
||||
session, _ := tOidc.store.New(req, cookieName)
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
if tc.sessionSetupFunc != nil {
|
||||
tc.sessionSetupFunc(session)
|
||||
}
|
||||
session.Save(req, rr)
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy session cookie to request
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset rr for the actual test
|
||||
// Reset response recorder for the actual test
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Call handleCallback
|
||||
tOidc.handleCallback(rr, req)
|
||||
tOidc.handleCallback(rr, req, redirectURL)
|
||||
|
||||
// Check response
|
||||
if rr.Code != tc.expectedStatus {
|
||||
@@ -688,7 +703,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
expectedStatus int
|
||||
@@ -704,7 +719,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -728,7 +743,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -751,7 +766,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -775,7 +790,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -847,7 +862,7 @@ func TestHandleLogout(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*sessions.Session)
|
||||
setupSession func(*SessionData)
|
||||
endSessionURL string
|
||||
expectedStatus int
|
||||
expectedURL string
|
||||
@@ -855,25 +870,22 @@ func TestHandleLogout(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "Successful logout with end session endpoint",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
session.Values["refresh_token"] = "test-refresh-token"
|
||||
session.Values["access_token"] = "test-access-token"
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("test.id.token")
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
},
|
||||
endSessionURL: "https://provider/end-session",
|
||||
expectedStatus: http.StatusFound,
|
||||
// Fix: The entire URL should be URL-encoded
|
||||
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
||||
host: "test-host",
|
||||
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Successful logout without end session endpoint",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
session.Values["refresh_token"] = "test-refresh-token"
|
||||
session.Values["access_token"] = "test-access-token"
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("test.id.token")
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
},
|
||||
endSessionURL: "",
|
||||
expectedStatus: http.StatusFound,
|
||||
@@ -882,16 +894,17 @@ func TestHandleLogout(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Logout with empty session",
|
||||
setupSession: func(session *sessions.Session) {},
|
||||
setupSession: func(session *SessionData) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with invalid end session URL",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("test.id.token")
|
||||
session.SetRefreshToken("test-refresh-token")
|
||||
},
|
||||
endSessionURL: ":\\invalid-url",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
@@ -901,19 +914,20 @@ func TestHandleLogout(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a new TraefikOidc instance for each test
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
tOidc := &TraefikOidc{
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
revocationURL: mockRevocationServer.URL,
|
||||
endSessionURL: tc.endSessionURL,
|
||||
scheme: "http",
|
||||
logger: NewLogger("info"),
|
||||
logger: logger,
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
httpClient: &http.Client{},
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
tokenCache: NewTokenCache(),
|
||||
forceHTTPS: false,
|
||||
sessionManager: sessionManager,
|
||||
}
|
||||
|
||||
// Create request with proper headers
|
||||
@@ -924,16 +938,18 @@ func TestHandleLogout(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get a session
|
||||
session, err := tOidc.store.Get(req, cookieName)
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
if tc.setupSession != nil {
|
||||
tc.setupSession(session)
|
||||
}
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Setup session
|
||||
tc.setupSession(session)
|
||||
session.Save(req, rr)
|
||||
|
||||
// Copy session cookie to request
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
@@ -949,7 +965,6 @@ func TestHandleLogout(t *testing.T) {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check redirect URL if expected
|
||||
if tc.expectedURL != "" {
|
||||
location := rr.Header().Get("Location")
|
||||
if location != tc.expectedURL {
|
||||
@@ -958,23 +973,31 @@ func TestHandleLogout(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify session is cleared
|
||||
newSession, _ := tOidc.store.Get(req, cookieName)
|
||||
if len(newSession.Values) > 0 {
|
||||
t.Error("Session was not cleared")
|
||||
updatedSession, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated session: %v", err)
|
||||
}
|
||||
if newSession.Options.MaxAge != -1 {
|
||||
t.Error("Session MaxAge was not set to -1")
|
||||
|
||||
// Verify tokens are cleared
|
||||
if token := updatedSession.GetAccessToken(); token != "" {
|
||||
t.Error("Access token not cleared")
|
||||
}
|
||||
if token := updatedSession.GetRefreshToken(); token != "" {
|
||||
t.Error("Refresh token not cleared")
|
||||
}
|
||||
if updatedSession.GetAuthenticated() {
|
||||
t.Error("Session still marked as authenticated")
|
||||
}
|
||||
|
||||
// Check token blacklist
|
||||
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
|
||||
t.Error("Refresh token was not blacklisted")
|
||||
if token := session.GetAccessToken(); token != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
t.Error("Access token was not blacklisted")
|
||||
}
|
||||
}
|
||||
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
|
||||
t.Error("Access token was not blacklisted")
|
||||
if token := session.GetRefreshToken(); token != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
t.Error("Refresh token was not blacklisted")
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -1155,24 +1178,24 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*sessions.Session)
|
||||
setupSession func(*SessionData)
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "Basic expired token",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "expired.token"
|
||||
session.Values["email"] = "test@example.com"
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("expired.token")
|
||||
session.SetEmail("test@example.com")
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
{
|
||||
name: "Session with additional values",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "expired.token"
|
||||
session.Values["custom_value"] = "should-be-cleared"
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("expired.token")
|
||||
session.mainSession.Values["custom_value"] = "should-be-cleared"
|
||||
},
|
||||
expectedPath: "/another/path",
|
||||
},
|
||||
@@ -1180,17 +1203,16 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a new TraefikOidc instance for each test
|
||||
logger := NewLogger("info")
|
||||
sessionManager := NewSessionManager("test-secret-key", false, logger)
|
||||
|
||||
tOidc := &TraefikOidc{
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
logger: NewLogger("info"),
|
||||
redirectURL: "http://example.com/callback",
|
||||
tokenVerifier: ts.tOidc.tokenVerifier,
|
||||
jwtVerifier: ts.tOidc.jwtVerifier,
|
||||
initComplete: make(chan struct{}),
|
||||
// Add this initialization of initiateAuthenticationFunc
|
||||
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
// Mock implementation for test
|
||||
sessionManager: sessionManager,
|
||||
logger: logger,
|
||||
tokenVerifier: ts.tOidc.tokenVerifier,
|
||||
jwtVerifier: ts.tOidc.jwtVerifier,
|
||||
initComplete: make(chan struct{}),
|
||||
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
http.Redirect(rw, req, "/login", http.StatusFound)
|
||||
},
|
||||
}
|
||||
@@ -1201,31 +1223,40 @@ func TestHandleExpiredToken(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, _ := tOidc.store.New(req, cookieName)
|
||||
session, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Setup session data
|
||||
tc.setupSession(session)
|
||||
|
||||
// Handle expired token
|
||||
tOidc.handleExpiredToken(rr, req, session)
|
||||
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
|
||||
|
||||
// Verify session is cleaned
|
||||
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
|
||||
t.Errorf("Expected 3 session values, got %d", len(session.Values))
|
||||
// Get the updated session to verify changes
|
||||
updatedSession, err := sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated session: %v", err)
|
||||
}
|
||||
|
||||
// Verify required values are set
|
||||
if _, ok := session.Values["csrf"].(string); !ok {
|
||||
// Verify main session values
|
||||
if updatedSession.GetCSRF() == "" {
|
||||
t.Error("CSRF token not set")
|
||||
}
|
||||
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
|
||||
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
|
||||
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
|
||||
}
|
||||
if _, ok := session.Values["nonce"].(string); !ok {
|
||||
if updatedSession.GetNonce() == "" {
|
||||
t.Error("Nonce not set")
|
||||
}
|
||||
|
||||
// Verify session options
|
||||
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
|
||||
t.Error("Session MaxAge not set correctly")
|
||||
// Verify tokens are cleared
|
||||
if token := updatedSession.GetAccessToken(); token != "" {
|
||||
t.Error("Access token not cleared")
|
||||
}
|
||||
if token := updatedSession.GetRefreshToken(); token != "" {
|
||||
t.Error("Refresh token not cleared")
|
||||
}
|
||||
|
||||
// Verify redirect status
|
||||
@@ -1311,6 +1342,191 @@ func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestServeHTTPRolesAndGroups(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
claims map[string]interface{}
|
||||
setupSession func(*SessionData)
|
||||
expectedStatus int
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
name: "User with allowed role",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"admin", "user"},
|
||||
"groups": []interface{}{"group1"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "admin,user",
|
||||
"X-User-Groups": "group1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User with allowed group",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"allowed-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "allowed-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "User without allowed roles or groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{
|
||||
"admin": {},
|
||||
"allowed-group": {},
|
||||
},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "No role/group restrictions",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
"roles": []interface{}{"user"},
|
||||
"groups": []interface{}{"regular-group"},
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-User-Roles": "user",
|
||||
"X-User-Groups": "regular-group",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Claims without roles and groups",
|
||||
allowedRolesAndGroups: map[string]struct{}{},
|
||||
claims: map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
"sub": "test-subject",
|
||||
},
|
||||
setupSession: func(session *SessionData) {
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("user@example.com")
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedHeaders: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create token with claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
// Create test handler
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Configure OIDC middleware
|
||||
tOidc := ts.tOidc
|
||||
tOidc.next = nextHandler
|
||||
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", "/protected", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Set up session
|
||||
session, err := tOidc.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
tc.setupSession(session)
|
||||
session.SetAccessToken(token)
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Copy cookies to the new request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset response recorder
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Serve request
|
||||
tOidc.ServeHTTP(rr, req)
|
||||
|
||||
// Check status code
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check headers if status is OK
|
||||
if tc.expectedStatus == http.StatusOK {
|
||||
for header, expectedValue := range tc.expectedHeaders {
|
||||
if value := req.Header.Get(header); value != expectedValue {
|
||||
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
|
||||
+355
@@ -0,0 +1,355 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
mainCookieName = "_raczylo_oidc" // Main session cookie
|
||||
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
|
||||
maxCookieSize = 2000 // Max size for each chunk to stay within 4096-byte cookie limit
|
||||
|
||||
// REASON:
|
||||
// Let x be the maximum size of the chunk (maxCookieSize).
|
||||
// Encrypted size = x + 28 bytes
|
||||
// Base64-encoded size = ((x + 28) * 4) / 3 bytes
|
||||
// ((x + 28) * 4) / 3 <= 4096
|
||||
// Multiply both sides by 3:
|
||||
// 4 * (x + 28) <= 4096 * 3
|
||||
// 4 * (x + 28) <= 12288
|
||||
// Divide both sides by 4:
|
||||
// x + 28 <= 3072
|
||||
// Subtract 28 from both sides:
|
||||
// x <= 3044
|
||||
)
|
||||
|
||||
// SessionManager handles multiple session cookies
|
||||
type SessionManager struct {
|
||||
store sessions.Store
|
||||
forceHTTPS bool
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
return &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
forceHTTPS: forceHTTPS,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// getSessionOptions returns session options based on scheme
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure || sm.forceHTTPS,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves all session data
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
mainSession, err := sm.store.Get(r, mainCookieName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||
}
|
||||
|
||||
accessSession, err := sm.store.Get(r, accessTokenCookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token session: %w", err)
|
||||
}
|
||||
|
||||
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||
}
|
||||
|
||||
sessionData := &SessionData{
|
||||
manager: sm,
|
||||
request: r,
|
||||
mainSession: mainSession,
|
||||
accessSession: accessSession,
|
||||
refreshSession: refreshSession,
|
||||
}
|
||||
|
||||
// Retrieve chunked access token sessions
|
||||
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
|
||||
// Retrieve chunked refresh token sessions
|
||||
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// getTokenChunkSessions retrieves sessions for token chunks
|
||||
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
|
||||
chunks := make(map[int]*sessions.Session)
|
||||
for i := 0; ; i++ {
|
||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||
session, err := sm.store.Get(r, sessionName)
|
||||
if err != nil || session.IsNew {
|
||||
// No more sessions
|
||||
break
|
||||
}
|
||||
chunks[i] = session
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// SessionData holds all session information
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
request *http.Request
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
accessTokenChunks map[int]*sessions.Session
|
||||
refreshTokenChunks map[int]*sessions.Session
|
||||
}
|
||||
|
||||
// Save saves all session data
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
options := sd.manager.getSessionOptions(isSecure)
|
||||
sd.mainSession.Options = options
|
||||
sd.accessSession.Options = options
|
||||
sd.refreshSession.Options = options
|
||||
|
||||
// Save main session
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token session
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
|
||||
// Save refresh token session
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
// Save access token chunks
|
||||
for _, session := range sd.accessTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Save refresh token chunks
|
||||
for _, session := range sd.refreshTokenChunks {
|
||||
session.Options = options
|
||||
if err := session.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all session data
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
}
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
}
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
// Clear chunk sessions
|
||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||
|
||||
return sd.Save(r, w)
|
||||
}
|
||||
|
||||
// clearTokenChunks clears chunked token sessions
|
||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||
for _, session := range chunks {
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAuthenticated returns authentication status
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
return auth
|
||||
}
|
||||
|
||||
// SetAuthenticated sets authentication status
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
}
|
||||
|
||||
// GetAccessToken returns the access token
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.accessTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.accessTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetAccessToken sets the access token
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
|
||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if len(token) <= maxCookieSize {
|
||||
sd.accessSession.Values["token"] = token
|
||||
} else {
|
||||
// Split token into chunks
|
||||
sd.accessSession.Values["token"] = ""
|
||||
chunks := splitIntoChunks(token, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.accessTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRefreshToken returns the refresh token
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// Reassemble token from chunks
|
||||
if len(sd.refreshTokenChunks) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var chunks []string
|
||||
for i := 0; ; i++ {
|
||||
session, ok := sd.refreshTokenChunks[i]
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
chunk, _ := session.Values["token_chunk"].(string)
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return strings.Join(chunks, "")
|
||||
}
|
||||
|
||||
// SetRefreshToken sets the refresh token
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
// Clear existing chunks
|
||||
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
|
||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||
|
||||
if len(token) <= maxCookieSize {
|
||||
sd.refreshSession.Values["token"] = token
|
||||
} else {
|
||||
// Split token into chunks
|
||||
sd.refreshSession.Values["token"] = ""
|
||||
chunks := splitIntoChunks(token, maxCookieSize)
|
||||
for i, chunk := range chunks {
|
||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||
session, _ := sd.manager.store.Get(sd.request, sessionName)
|
||||
session.Values["token_chunk"] = chunk
|
||||
sd.refreshTokenChunks[i] = session
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// splitIntoChunks splits a string into chunks of specified size
|
||||
func splitIntoChunks(s string, chunkSize int) []string {
|
||||
var chunks []string
|
||||
for len(s) > 0 {
|
||||
if len(s) > chunkSize {
|
||||
chunks = append(chunks, s[:chunkSize])
|
||||
s = s[chunkSize:]
|
||||
} else {
|
||||
chunks = append(chunks, s)
|
||||
break
|
||||
}
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
// GetCSRF returns the CSRF token
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRF token
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce returns the nonce
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce sets the nonce
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail returns the user's email
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail sets the user's email
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath returns the original incoming path
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath sets the original incoming path
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
+129
@@ -0,0 +1,129 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSessionManager tests the SessionManager functionality
|
||||
func TestSessionManager(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticated bool
|
||||
email string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
expectedCookieCount int
|
||||
}{
|
||||
{
|
||||
name: "Short tokens",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: "shortaccesstoken",
|
||||
refreshToken: "shortrefreshtoken",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
},
|
||||
{
|
||||
name: "Long tokens exceeding 4096 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 5000),
|
||||
refreshToken: strings.Repeat("y", 6000),
|
||||
// Recalculate expected cookies based on new maxCookieSize
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
|
||||
},
|
||||
{
|
||||
name: "REALLY long tokens, exceeding 25000 bytes",
|
||||
authenticated: true,
|
||||
email: "test@example.com",
|
||||
accessToken: strings.Repeat("x", 25000),
|
||||
refreshToken: strings.Repeat("y", 25000),
|
||||
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
|
||||
},
|
||||
{
|
||||
name: "Unauthenticated session",
|
||||
authenticated: false,
|
||||
email: "",
|
||||
accessToken: "",
|
||||
refreshToken: "",
|
||||
expectedCookieCount: 3, // main, access, refresh
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc // Capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := ts.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Set session values
|
||||
session.SetAuthenticated(tc.authenticated)
|
||||
session.SetEmail(tc.email)
|
||||
session.SetAccessToken(tc.accessToken)
|
||||
session.SetRefreshToken(tc.refreshToken)
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != tc.expectedCookieCount {
|
||||
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
|
||||
}
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := ts.sessionManager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
if newSession.GetAuthenticated() != tc.authenticated {
|
||||
t.Errorf("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != tc.email {
|
||||
t.Errorf("Expected email %s, got %s", tc.email, email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != tc.accessToken {
|
||||
t.Errorf("Access token not preserved")
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
|
||||
t.Errorf("Refresh token not preserved")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
|
||||
count := 3 // main, access, refresh
|
||||
|
||||
// Calculate number of chunks for access token
|
||||
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
|
||||
if accessChunks > 1 {
|
||||
count += accessChunks
|
||||
}
|
||||
|
||||
// Calculate number of chunks for refresh token
|
||||
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
|
||||
if refreshChunks > 1 {
|
||||
count += refreshChunks
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
+1
-10
@@ -6,8 +6,6 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -31,17 +29,10 @@ type Config struct {
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
var defaultSessionOptions = &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
// CreateConfig creates a new Config with default values
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
|
||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
||||
+1
-5
@@ -1,7 +1,4 @@
|
||||
# Gorilla Sessions
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
|
||||
# sessions
|
||||
|
||||

|
||||
[](https://codecov.io/github/gorilla/sessions)
|
||||
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
|
||||
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
|
||||
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
|
||||
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
|
||||
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
|
||||
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
|
||||
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
|
||||
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
|
||||
|
||||
+9
-12
@@ -1,6 +1,5 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
@@ -9,15 +8,13 @@ import "net/http"
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
Partitioned: options.Partitioned,
|
||||
SameSite: options.SameSite,
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
SameSite: options.SameSite,
|
||||
}
|
||||
|
||||
}
|
||||
+5
-10
@@ -1,11 +1,8 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
@@ -16,9 +13,7 @@ type Options struct {
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
Partitioned bool
|
||||
SameSite http.SameSite
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
}
|
||||
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
type Options struct {
|
||||
Path string
|
||||
Domain string
|
||||
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
// Defaults to http.SameSiteDefaultMode
|
||||
SameSite http.SameSite
|
||||
}
|
||||
Vendored
+2
-2
@@ -4,8 +4,8 @@ github.com/google/uuid
|
||||
# github.com/gorilla/securecookie v1.1.2
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.4.0
|
||||
## explicit; go 1.23
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# golang.org/x/time v0.7.0
|
||||
## explicit; go 1.18
|
||||
|
||||
Reference in New Issue
Block a user