Compare commits

..

26 Commits

Author SHA1 Message Date
lukaszraczylo bef4212c57 Add support for the large tokens, which exceed the standard 4096 limit for cookie. 2024-12-11 12:55:16 +00:00
lukaszraczylo 1fee2f9e9a fixup! Re-introduce user roles separation with additional tests. 2024-12-11 09:11:34 +00:00
lukaszraczylo 11bc6f3e31 Re-introduce user roles separation with additional tests. 2024-12-11 09:08:50 +00:00
lukaszraczylo 2b7af88ff9 Move session management into session manager. Split the cookies to avoid the 4k limit ( resolves issue: #15 ) 2024-12-10 10:19:35 +00:00
lukaszraczylo 01ee7c4dc8 Improve cookie setting. 2024-12-10 10:19:35 +00:00
lukaszraczylo a6fa4d8789 Downgrade gorilla sessions preventing the publishing by traefik hub temporarily. 2024-12-10 10:19:34 +00:00
lukaszraczylo 8101fb2bf6 Clean up dependencies. 2024-11-06 11:51:20 +00:00
lukaszraczylo 8ca669105b Fix OIDC logout issue, improve test coverage, load provider once. 2024-11-06 11:33:29 +00:00
lukaszraczylo 555164160d Update dependencies. 2024-11-06 11:33:06 +00:00
lukaszraczylo 3fe537d38f Add ability to verify default ECDSA keys provided by logto as well. 2024-11-06 11:33:06 +00:00
lukaszraczylo 31de2c63b2 Revert "Update go mod dependencies."
This reverts commit dedbdf63c3.
2024-11-06 11:33:04 +00:00
lukaszraczylo 7dd9205277 Update go mod dependencies. 2024-11-06 11:33:04 +00:00
lukaszraczylo f3598e4ab8 Add simple benchmark to track the allocations and speed for future improvements. 2024-11-06 11:33:03 +00:00
lukaszraczylo 218165d365 Cleanup and optimise the code. 2024-11-06 11:33:03 +00:00
lukaszraczylo dc4c4824cd Add support for more algorithms. 2024-11-06 11:33:03 +00:00
lukaszraczylo 345c0c4a11 Abstract filling up maps. 2024-11-06 11:32:37 +00:00
lukaszraczylo da4f97de04 Fix the bug with user not being redirected to originally requested URL post authentication. 2024-11-06 11:32:36 +00:00
lukaszraczylo ce916f3ca3 Update documentation - setting secrets in kubernetes. 2024-11-06 11:32:36 +00:00
lukaszraczylo 6f2cf65d49 Fix the tests hanging on the open channel. 2024-11-06 11:32:36 +00:00
lukaszraczylo 78b9d611f0 Improvement - startup time.
Previous implementations blocked the traefik startup until OIDC plugin was loaded.
This caused chicken-or-egg issue when called OIDC endpoint was hosted by the same traefik as well,
generating rather ridiculous situation when traefik couldn't come up because plugin tried to call the
discovery endpoint which was hosted by the same traefik.

This version resolves the issue allowing for quickstart and lazy loading of the provider metadata.
Disadvantage is - until discovery is done, the plugin will not provide any access to the client.
2024-11-06 11:32:36 +00:00
lukaszraczylo 2bb1debeb3 First step in improvement of caching mechanism. 2024-11-06 11:32:36 +00:00
lukaszraczylo 93b49b6d17 Add support for roles and groups. 2024-11-06 11:32:35 +00:00
lukaszraczylo 7a53da6080 Update tests and additional fixups. 2024-11-06 11:32:35 +00:00
lukaszraczylo 66e08755c1 Update the tests to handle nonce 2024-11-06 11:32:35 +00:00
lukaszraczylo d6fd3467c3 Support additional verification of the token to ensure OIDC compliance 2024-11-06 11:32:35 +00:00
lukaszraczylo 6196a72a8e Update dependencies. 2024-11-06 11:32:34 +00:00
17 changed files with 1605 additions and 374 deletions
+1
View File
@@ -13,6 +13,7 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
+1
View File
@@ -38,6 +38,7 @@ spec:
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
+1 -1
View File
@@ -6,7 +6,7 @@ toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.4.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
)
-4
View File
@@ -6,9 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+111 -110
View File
@@ -13,10 +13,19 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
@@ -27,14 +36,6 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
@@ -97,77 +98,22 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
}
t.RevokeToken(refreshToken)
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
t.RevokeToken(accessToken)
}
// Remove tokens from session
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
// Initialize new authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -184,26 +130,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Validate the state parameter matches the session's CSRF token
// Validate state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -211,56 +159,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
// Verify and process tokens
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the nonce claim matches the one stored in session
// Verify nonce
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
// Process email
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -268,31 +209,25 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
// Update session with new values
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
// Save session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
// Redirect to original path or root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
@@ -397,9 +332,9 @@ func (tc *TokenCache) Cleanup() {
}
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
@@ -414,3 +349,69 @@ func createStringMap(keys []string) map[string]struct{} {
}
return result
}
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the access token before clearing session
accessToken := session.GetAccessToken()
// Clear all session data
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the base URL for redirects
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Determine post logout redirect URI
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
// If we have an end session endpoint and an access token, use OIDC end session
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
}
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// Otherwise, redirect to post logout URI
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
// Ensure postLogoutRedirectURI is properly URL encoded
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
+136 -131
View File
@@ -14,7 +14,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
@@ -34,7 +33,6 @@ type JWTVerifier interface {
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -53,26 +51,29 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
@@ -82,6 +83,14 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -175,9 +184,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -190,7 +196,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ExpectContinueTimeout: 0,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
@@ -209,7 +215,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -217,6 +222,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
}
return config.PostLogoutRedirectURI
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
@@ -233,9 +244,10 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -254,20 +266,26 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
}
close(t.initComplete)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -341,92 +359,68 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
return
}
// Process the request as normal
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
}
// Check if the URL is excluded from authentication
// Check if URL is excluded
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
// Get session
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Handle logout URL
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
t.handleCallback(rw, req, redirectURL)
return
}
// Check if the user is authenticated
// Check authentication status
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
}
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
// Process authenticated request
email := session.GetEmail()
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.logger.Debug("No email found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -436,11 +430,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -449,6 +442,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -459,13 +453,15 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -504,37 +500,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(idToken); err != nil {
if err := t.verifyToken(accessToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(idToken)
claims, err := extractClaims(accessToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true // Can't read claims, consider it expired
return false, false, true
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
}
now := time.Now().Unix()
@@ -542,7 +535,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true // Token has expired
return false, false, true
}
gracePeriod := time.Minute * 5
@@ -551,26 +544,23 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false // Token is valid and not expiring soon
return true, false, false
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -579,7 +569,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build the authentication URL
// Build and redirect to auth URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -620,14 +610,9 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -668,10 +653,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -682,16 +667,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
// Verify the new access token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
t.logger.Errorf("Failed to verify new access token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
@@ -726,29 +712,48 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
var groups []string
var roles []string
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
+807 -88
View File
File diff suppressed because it is too large Load Diff
+355
View File
@@ -0,0 +1,355 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
)
const (
mainCookieName = "_raczylo_oidc" // Main session cookie
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
maxCookieSize = 2000 // Max size for each chunk to stay within 4096-byte cookie limit
// REASON:
// Let x be the maximum size of the chunk (maxCookieSize).
// Encrypted size = x + 28 bytes
// Base64-encoded size = ((x + 28) * 4) / 3 bytes
// ((x + 28) * 4) / 3 <= 4096
// Multiply both sides by 3:
// 4 * (x + 28) <= 4096 * 3
// 4 * (x + 28) <= 12288
// Divide both sides by 4:
// x + 28 <= 3072
// Subtract 28 from both sides:
// x <= 3044
)
// SessionManager handles multiple session cookies
type SessionManager struct {
store sessions.Store
forceHTTPS bool
logger *Logger
}
// NewSessionManager creates a new session manager
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
}
// getSessionOptions returns session options based on scheme
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// GetSession retrieves all session data
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
if err != nil {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
request: r,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
}
// Retrieve chunked access token sessions
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
// Retrieve chunked refresh token sessions
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
return sessionData, nil
}
// getTokenChunkSessions retrieves sessions for token chunks
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
chunks := make(map[int]*sessions.Session)
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
return chunks
}
// SessionData holds all session information
type SessionData struct {
manager *SessionManager
request *http.Request
mainSession *sessions.Session
accessSession *sessions.Session
refreshSession *sessions.Session
accessTokenChunks map[int]*sessions.Session
refreshTokenChunks map[int]*sessions.Session
}
// Save saves all session data
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token chunk session: %w", err)
}
}
// Save refresh token chunks
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
}
}
return nil
}
// Clear clears all session data
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
return sd.Save(r, w)
}
// clearTokenChunks clears chunked token sessions
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
}
}
// GetAuthenticated returns authentication status
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
}
// SetAuthenticated sets authentication status
func (sd *SessionData) SetAuthenticated(value bool) {
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken returns the access token
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.accessTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.accessTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetAccessToken sets the access token
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
sd.accessTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.accessSession.Values["token"] = token
} else {
// Split token into chunks
sd.accessSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.accessTokenChunks[i] = session
}
}
}
// GetRefreshToken returns the refresh token
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.refreshTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.refreshTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetRefreshToken sets the refresh token
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
sd.refreshTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.refreshSession.Values["token"] = token
} else {
// Split token into chunks
sd.refreshSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.refreshTokenChunks[i] = session
}
}
}
// splitIntoChunks splits a string into chunks of specified size
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
if len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// GetCSRF returns the CSRF token
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF sets the CSRF token
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce returns the nonce
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce sets the nonce
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail returns the user's email
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail sets the user's email
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath returns the original incoming path
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath sets the original incoming path
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+129
View File
@@ -0,0 +1,129 @@
package traefikoidc
import (
"net/http/httptest"
"strings"
"testing"
)
// TestSessionManager tests the SessionManager functionality
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
// Recalculate expected cookies based on new maxCookieSize
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved")
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Calculate number of chunks for access token
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
if accessChunks > 1 {
count += accessChunks
}
// Calculate number of chunks for refresh token
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
if refreshChunks > 1 {
count += refreshChunks
}
return count
}
+2 -10
View File
@@ -6,8 +6,6 @@ import (
"log"
"net/http"
"os"
"github.com/gorilla/sessions"
)
const (
@@ -30,17 +28,11 @@ type Config struct {
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{}
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+1 -5
View File
@@ -1,7 +1,4 @@
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
# sessions
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+9 -12
View File
@@ -1,6 +1,5 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
@@ -9,15 +8,13 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
+21
View File
@@ -0,0 +1,21 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+5 -10
View File
@@ -1,11 +1,8 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -16,9 +13,7 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
MaxAge int
Secure bool
HttpOnly bool
}
+23
View File
@@ -0,0 +1,23 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -4,8 +4,8 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
## explicit; go 1.18