mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-07 22:53:58 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ab36f10a70 | |||
| 4972b21373 | |||
| 0be45f06a5 | |||
| 2cc89f0f31 | |||
| b9928f2f0c | |||
| 2e1a3a9320 | |||
| 9dabd0e5cf | |||
| dedbdf63c3 | |||
| af032c6cd3 | |||
| 9938cff053 | |||
| 7a404ef76f | |||
| 63922f362f | |||
| 2de9297ab6 | |||
| 971c84f762 | |||
| d2a0d2167e | |||
| c46d958397 | |||
| 95cf0034d6 | |||
| 380ef96571 | |||
| 1886396dc1 | |||
| 24ecf00053 | |||
| e338992f84 | |||
| a9a596031b | |||
| 23afcad2ba | |||
| d06f9fcf90 |
@@ -13,7 +13,6 @@ testData:
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
|
||||
@@ -38,7 +38,6 @@ spec:
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
// CacheItem represents an item in the cache
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
ExpiresAt int64 // Changed to int64 for faster comparisons
|
||||
}
|
||||
|
||||
// Cache is a simple in-memory cache
|
||||
@@ -27,43 +27,47 @@ func NewCache() *Cache {
|
||||
// Set adds an item to the cache
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
// Removed defer for slightly better performance
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
ExpiresAt: time.Now().Add(expiration).UnixNano(), // Store as UnixNano for faster comparisons
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
item, found := c.items[key]
|
||||
if !found {
|
||||
c.mutex.RUnlock()
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
if time.Now().UnixNano() > item.ExpiresAt {
|
||||
c.mutex.RUnlock()
|
||||
// Use a separate goroutine to delete expired items to avoid blocking
|
||||
go c.Delete(key)
|
||||
return nil, false
|
||||
}
|
||||
c.mutex.RUnlock()
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
// Cleanup removes expired items from the cache
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
now := time.Now()
|
||||
now := time.Now().UnixNano()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
if now > item.ExpiresAt {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ toolchain go1.23.1
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
github.com/gorilla/sessions v1.4.0
|
||||
golang.org/x/time v0.7.0
|
||||
)
|
||||
|
||||
|
||||
@@ -6,5 +6,9 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
|
||||
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
||||
+89
-163
@@ -20,13 +20,20 @@ import (
|
||||
// generateNonce generates a random nonce
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
if err != nil {
|
||||
if _, err := rand.Read(nonceBytes); err != nil {
|
||||
return "", fmt.Errorf("could not generate nonce: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host, and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
data := url.Values{
|
||||
@@ -35,14 +42,15 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
|
||||
if grantType == "authorization_code" {
|
||||
switch grantType {
|
||||
case "authorization_code":
|
||||
data.Set("code", codeOrToken)
|
||||
data.Set("redirect_uri", redirectURL)
|
||||
} else if grantType == "refresh_token" {
|
||||
case "refresh_token":
|
||||
data.Set("refresh_token", codeOrToken)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.tokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
@@ -89,10 +97,50 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// handleLogout handles the user logout
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
t.logger.Debugf("Logging out user")
|
||||
if err != nil {
|
||||
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Revoke tokens if available
|
||||
for _, tokenType := range []string{"refresh_token", "access_token"} {
|
||||
if token, ok := session.Values[tokenType].(string); ok && token != "" {
|
||||
if err := t.RevokeTokenWithProvider(token, tokenType); err != nil {
|
||||
t.logger.Errorf("Failed to revoke %s: %v", tokenType, err)
|
||||
}
|
||||
t.RevokeToken(token)
|
||||
}
|
||||
delete(session.Values, tokenType)
|
||||
}
|
||||
|
||||
// Remove other session values
|
||||
delete(session.Values, "id_token")
|
||||
delete(session.Values, "authenticated")
|
||||
|
||||
// Set session options to delete the session
|
||||
session.Options = &sessions.Options{MaxAge: -1, Path: "/", HttpOnly: true, Secure: true}
|
||||
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte("Logged out successfully"))
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
|
||||
if session == nil {
|
||||
t.logger.Error("Session is nil in handleExpiredToken")
|
||||
http.Error(rw, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Clear the existing session
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
@@ -101,21 +149,19 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
|
||||
session.Values["csrf"] = uuid.New().String()
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Values["nonce"], _ = generateNonce()
|
||||
session.Options = defaultSessionOptions
|
||||
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
|
||||
|
||||
// Save the session before initiating authentication
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
|
||||
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
@@ -125,34 +171,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the query parameters
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
if errParam := req.URL.Query().Get("error"); errParam != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
t.logger.Errorf("Authentication error: %s - %s", errParam, errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the state parameter matches the session's CSRF token
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
csrfToken, ok := session.Values["csrf"].(string)
|
||||
if !ok || csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
if !ok || state == "" || csrfToken == "" || state != csrfToken {
|
||||
t.logger.Error("Invalid state parameter or CSRF token")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Proceed to exchange the code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
@@ -160,14 +193,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract id_token
|
||||
idToken := tokenResponse.IDToken
|
||||
if idToken == "" {
|
||||
t.logger.Error("No id_token in token response")
|
||||
@@ -175,14 +207,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the id_token
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract claims from id_token
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
@@ -190,26 +220,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Verify the nonce claim matches the one stored in session
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
sessionNonce, ok := session.Values["nonce"].(string)
|
||||
if !ok || sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
sessionNonce, ok2 := session.Values["nonce"].(string)
|
||||
if !ok || !ok2 || nonceClaim == "" || sessionNonce == "" || nonceClaim != sessionNonce {
|
||||
t.logger.Error("Invalid nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the email from claims
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -217,14 +235,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Store tokens and authentication status in session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["email"] = email
|
||||
session.Values["id_token"] = idToken
|
||||
session.Values["refresh_token"] = tokenResponse.RefreshToken
|
||||
session.Options = defaultSessionOptions
|
||||
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
|
||||
|
||||
// Remove CSRF and nonce from session
|
||||
delete(session.Values, "csrf")
|
||||
delete(session.Values, "nonce")
|
||||
|
||||
@@ -236,7 +252,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
t.logger.Debugf("Authentication successful. User email: %s", email)
|
||||
|
||||
// Redirect to the original requested path or default to root
|
||||
redirectPath := "/"
|
||||
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
|
||||
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
|
||||
@@ -267,42 +282,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
|
||||
// TokenBlacklist maintains a blacklist of tokens
|
||||
type TokenBlacklist struct {
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
blacklist sync.Map
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
return &TokenBlacklist{}
|
||||
}
|
||||
func (tb *TokenBlacklist) Add(token string, expiration time.Time) {
|
||||
tb.blacklist.Store(token, expiration)
|
||||
}
|
||||
|
||||
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
|
||||
if exp, ok := tb.blacklist.Load(token); ok {
|
||||
return time.Now().Before(exp.(time.Time))
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is blacklisted
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
expiration, exists := tb.blacklist[tokenID]
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for tokenID, expiration := range tb.blacklist {
|
||||
if now.After(expiration) {
|
||||
delete(tb.blacklist, tokenID)
|
||||
tb.blacklist.Range(func(key, value interface{}) bool {
|
||||
if now.After(value.(time.Time)) {
|
||||
tb.blacklist.Delete(key)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// TokenCache caches tokens
|
||||
@@ -319,14 +324,12 @@ func NewTokenCache() *TokenCache {
|
||||
|
||||
// Set sets a token in the cache
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
tc.cache.Set("t-"+token, claims, expiration)
|
||||
}
|
||||
|
||||
// Get retrieves a token from the cache
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
value, found := tc.cache.Get("t-" + token)
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
@@ -336,8 +339,7 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
|
||||
// Delete removes a token from the cache
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
tc.cache.Delete("t-" + token)
|
||||
}
|
||||
|
||||
// Cleanup cleans up expired tokens from the cache
|
||||
@@ -346,9 +348,9 @@ func (tc *TokenCache) Cleanup() {
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
@@ -357,85 +359,9 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To
|
||||
|
||||
// createStringMap creates a map from a slice of strings
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
result := make(map[string]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
result[key] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout handles the logout request
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the id_token before clearing the session
|
||||
idToken, _ := session.Values["id_token"].(string)
|
||||
|
||||
// Clear and expire the session
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
session.Options.MaxAge = -1
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Error saving session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the base URL for redirects
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
var postLogoutRedirectURI string
|
||||
if t.postLogoutRedirectURI != "" {
|
||||
// Use explicitly configured postLogoutRedirectURI
|
||||
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
|
||||
postLogoutRedirectURI = t.postLogoutRedirectURI
|
||||
} else {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
|
||||
}
|
||||
} else {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
|
||||
}
|
||||
|
||||
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
|
||||
|
||||
// If we have an end session endpoint and an ID token, use OIDC end session
|
||||
if t.endSessionURL != "" && idToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// If no end session endpoint or no ID token, just redirect to the post logout URI
|
||||
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
// Ensure postLogoutRedirectURI is properly URL encoded
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
}
|
||||
|
||||
@@ -4,13 +4,12 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -58,6 +57,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
// Double-check locking pattern
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
return c.jwks, nil
|
||||
}
|
||||
@@ -120,25 +120,12 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
|
||||
}
|
||||
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := new(big.Int).SetBytes(eBytes)
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: n,
|
||||
E: int(e.Int64()),
|
||||
N: new(big.Int).SetBytes(nBytes),
|
||||
E: int(new(big.Int).SetBytes(eBytes).Int64()),
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
})
|
||||
|
||||
return pubKeyPEM, nil
|
||||
return marshalPublicKey(pubKey)
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC JWK to PEM
|
||||
@@ -152,16 +139,9 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
|
||||
}
|
||||
|
||||
var curve elliptic.Curve
|
||||
switch jwk.Crv {
|
||||
case "P-256":
|
||||
curve = elliptic.P256()
|
||||
case "P-384":
|
||||
curve = elliptic.P384()
|
||||
case "P-521":
|
||||
curve = elliptic.P521()
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
|
||||
curve, err := getCurve(jwk.Crv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKey := &ecdsa.PublicKey{
|
||||
@@ -170,15 +150,32 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
Y: new(big.Int).SetBytes(yBytes),
|
||||
}
|
||||
|
||||
return marshalPublicKey(pubKey)
|
||||
}
|
||||
|
||||
// getCurve returns the elliptic curve based on the JWK curve parameter
|
||||
func getCurve(crv string) (elliptic.Curve, error) {
|
||||
switch crv {
|
||||
case "P-256":
|
||||
return elliptic.P256(), nil
|
||||
case "P-384":
|
||||
return elliptic.P384(), nil
|
||||
case "P-521":
|
||||
return elliptic.P521(), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported elliptic curve: %s", crv)
|
||||
}
|
||||
}
|
||||
|
||||
// marshalPublicKey marshals a public key to PEM format
|
||||
func marshalPublicKey(pubKey interface{}) ([]byte, error) {
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
return pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
})
|
||||
|
||||
return pubKeyPEM, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
@@ -4,18 +4,29 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidJWTFormat = errors.New("invalid JWT format")
|
||||
ErrInvalidAudience = errors.New("invalid audience")
|
||||
ErrInvalidIssuer = errors.New("invalid issuer")
|
||||
ErrTokenExpired = errors.New("token has expired")
|
||||
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
|
||||
ErrMissingClaim = errors.New("missing claim")
|
||||
ErrInvalidClaimType = errors.New("invalid claim type")
|
||||
ErrUnsupportedAlgorithm = errors.New("unsupported algorithm")
|
||||
ErrInvalidSignature = errors.New("invalid signature")
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
@@ -28,212 +39,187 @@ type JWT struct {
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
return nil, fmt.Errorf("%w: expected 3 parts, got %d", ErrInvalidJWTFormat, len(parts))
|
||||
}
|
||||
|
||||
jwt := &JWT{
|
||||
Token: tokenString,
|
||||
jwt := &JWT{Token: tokenString}
|
||||
|
||||
if err := decodeJSONPart(parts[0], &jwt.Header); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode header: %w", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err := decodeJSONPart(parts[1], &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode claims: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
jwt.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
|
||||
return nil, fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
|
||||
}
|
||||
|
||||
// Decode and unmarshal the claims
|
||||
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
|
||||
}
|
||||
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
}
|
||||
jwt.Signature = signatureBytes
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
func decodeJSONPart(part string, target interface{}) error {
|
||||
bytes, err := base64.RawURLEncoding.DecodeString(part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(bytes, target)
|
||||
}
|
||||
|
||||
// Verify verifies the standard claims in the JWT
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
claims := j.Claims
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'iss' claim")
|
||||
}
|
||||
if err := verifyIssuer(iss, issuerURL); err != nil {
|
||||
if err := verifyIssuer(j.Claims["iss"], issuerURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
aud, ok := claims["aud"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing 'aud' claim")
|
||||
}
|
||||
if err := verifyAudience(aud, clientID); err != nil {
|
||||
if err := verifyAudience(j.Claims["aud"], clientID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exp, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing or invalid 'exp' claim")
|
||||
}
|
||||
if err := verifyExpiration(exp); err != nil {
|
||||
if err := verifyExpiration(j.Claims["exp"]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing or invalid 'iat' claim")
|
||||
}
|
||||
if err := verifyIssuedAt(iat); err != nil {
|
||||
if err := verifyIssuedAt(j.Claims["iat"]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sub, ok := claims["sub"].(string)
|
||||
if !ok || sub == "" {
|
||||
return fmt.Errorf("missing or empty 'sub' claim")
|
||||
if sub, ok := j.Claims["sub"].(string); !ok || sub == "" {
|
||||
return fmt.Errorf("%w: sub", ErrMissingClaim)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience verifies the audience claim
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
return ErrInvalidAudience
|
||||
}
|
||||
case []interface{}:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
found = true
|
||||
break
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
return ErrInvalidAudience
|
||||
default:
|
||||
return fmt.Errorf("invalid 'aud' claim type")
|
||||
return fmt.Errorf("%w: aud", ErrInvalidClaimType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer verifies the issuer claim
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
func verifyIssuer(tokenIssuer interface{}, expectedIssuer string) error {
|
||||
iss, ok := tokenIssuer.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: iss", ErrMissingClaim)
|
||||
}
|
||||
if iss != expectedIssuer {
|
||||
return ErrInvalidIssuer
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token has expired
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
return fmt.Errorf("token has expired")
|
||||
func verifyExpiration(expiration interface{}) error {
|
||||
exp, ok := expiration.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: exp", ErrInvalidClaimType)
|
||||
}
|
||||
if time.Now().After(time.Unix(int64(exp), 0)) {
|
||||
return ErrTokenExpired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks if the token was issued in the future
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
func verifyIssuedAt(issuedAt interface{}) error {
|
||||
iat, ok := issuedAt.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("%w: iat", ErrInvalidClaimType)
|
||||
}
|
||||
if time.Now().Before(time.Unix(int64(iat), 0)) {
|
||||
return ErrTokenUsedBeforeIssued
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifySignature verifies the token signature using the provided public key and algorithm
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
return ErrInvalidJWTFormat
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
pubKey, err := parsePublicKey(publicKeyPEM)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
hashFunc = crypto.SHA256
|
||||
case "RS384", "PS384", "ES384":
|
||||
hashFunc = crypto.SHA384
|
||||
case "RS512", "PS512", "ES512":
|
||||
hashFunc = crypto.SHA512
|
||||
default:
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
hashFunc, err := getHashFunc(alg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
hashed := hashFunc.New().Sum([]byte(signedContent))
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
return verifyRSASignature(pubKey, hashFunc, hashed, signature, alg)
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
return fmt.Errorf("invalid ECDSA signature length")
|
||||
}
|
||||
r.SetBytes(signature[:sigLen/2])
|
||||
s.SetBytes(signature[sigLen/2:])
|
||||
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
return verifyECDSASignature(pubKey, hashed, signature)
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
||||
}
|
||||
}
|
||||
|
||||
func parsePublicKey(publicKeyPEM []byte) (interface{}, error) {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return nil, errors.New("failed to parse PEM block containing the public key")
|
||||
}
|
||||
return x509.ParsePKIXPublicKey(block.Bytes)
|
||||
}
|
||||
|
||||
func getHashFunc(alg string) (crypto.Hash, error) {
|
||||
switch alg {
|
||||
case "RS256", "PS256", "ES256":
|
||||
return crypto.SHA256, nil
|
||||
case "RS384", "PS384", "ES384":
|
||||
return crypto.SHA384, nil
|
||||
case "RS512", "PS512", "ES512":
|
||||
return crypto.SHA512, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyRSASignature(pubKey *rsa.PublicKey, hashFunc crypto.Hash, hashed, signature []byte, alg string) error {
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
}
|
||||
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
|
||||
}
|
||||
|
||||
func verifyECDSASignature(pubKey *ecdsa.PublicKey, hashed, signature []byte) error {
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
return errors.New("invalid ECDSA signature length")
|
||||
}
|
||||
r, s := new(big.Int), new(big.Int)
|
||||
r.SetBytes(signature[:sigLen/2])
|
||||
s.SetBytes(signature[sigLen/2:])
|
||||
if ecdsa.Verify(pubKey, hashed, r, s) {
|
||||
return nil
|
||||
}
|
||||
return ErrInvalidSignature
|
||||
}
|
||||
|
||||
@@ -53,28 +53,26 @@ type TraefikOidc struct {
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
excludedURLs map[string]struct{}
|
||||
allowedUserDomains map[string]struct{}
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initOnce sync.Once
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
type ProviderMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
// defaultExcludedURLs are the paths that are excluded from authentication
|
||||
@@ -84,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{
|
||||
|
||||
var newTicker = time.NewTicker
|
||||
|
||||
var (
|
||||
globalMetadataCache struct {
|
||||
sync.Once
|
||||
metadata *ProviderMetadata
|
||||
err error
|
||||
}
|
||||
)
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
@@ -227,12 +217,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
return config.LogoutURL
|
||||
}(),
|
||||
postLogoutRedirectURI: func() string {
|
||||
if config.PostLogoutRedirectURI == "" {
|
||||
return "/"
|
||||
}
|
||||
return config.PostLogoutRedirectURI
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
clientID: config.ClientID,
|
||||
@@ -270,26 +254,20 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
globalMetadataCache.Once.Do(func() {
|
||||
t.logger.Debug("Starting global provider metadata discovery")
|
||||
t.initOnce.Do(func() {
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
globalMetadataCache.metadata = metadata
|
||||
globalMetadataCache.err = err
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v", err)
|
||||
} else {
|
||||
t.logger.Debug("Provider metadata discovered successfully")
|
||||
t.jwksURL = metadata.JWKSURL
|
||||
t.authURL = metadata.AuthURL
|
||||
t.tokenURL = metadata.TokenURL
|
||||
t.issuerURL = metadata.Issuer
|
||||
t.revocationURL = metadata.RevokeURL
|
||||
}
|
||||
close(t.initComplete)
|
||||
})
|
||||
|
||||
if globalMetadataCache.err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
|
||||
} else if globalMetadataCache.metadata != nil {
|
||||
t.logger.Debug("Using cached provider metadata")
|
||||
t.jwksURL = globalMetadataCache.metadata.JWKSURL
|
||||
t.authURL = globalMetadataCache.metadata.AuthURL
|
||||
t.tokenURL = globalMetadataCache.metadata.TokenURL
|
||||
t.issuerURL = globalMetadataCache.metadata.Issuer
|
||||
t.revocationURL = globalMetadataCache.metadata.RevokeURL
|
||||
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
|
||||
}
|
||||
|
||||
close(t.initComplete)
|
||||
}
|
||||
|
||||
// discoverProviderMetadata fetches the OIDC provider metadata
|
||||
@@ -381,12 +359,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
defaultSessionOptions.Secure = t.scheme == "https"
|
||||
host := t.determineHost(req)
|
||||
|
||||
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
|
||||
// Build the redirect URL if not already set
|
||||
if redirectURL == "" {
|
||||
redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Redirect URL updated to: %s", redirectURL)
|
||||
if t.redirectURL == "" {
|
||||
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
|
||||
}
|
||||
|
||||
// Get the session
|
||||
@@ -407,7 +383,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
// Handle callback URL
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
t.handleCallback(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -415,19 +391,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
if expired {
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
return
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -436,21 +412,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Errorf("No id_token found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := extractClaims(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Debugf("No email found in token claims")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -460,7 +436,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(idToken)
|
||||
groups, roles := t.extractGroupsAndRoles(claims)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
@@ -507,16 +483,16 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if t.forceHTTPS {
|
||||
switch {
|
||||
case t.forceHTTPS:
|
||||
return "https"
|
||||
}
|
||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||
return scheme
|
||||
}
|
||||
if req.TLS != nil {
|
||||
case req.Header.Get(headerXForwardedProto) != "":
|
||||
return req.Header.Get(headerXForwardedProto)
|
||||
case req.TLS != nil:
|
||||
return "https"
|
||||
default:
|
||||
return "http"
|
||||
}
|
||||
return "http"
|
||||
}
|
||||
|
||||
// determineHost determines the host of the request
|
||||
@@ -644,9 +620,14 @@ func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// Remove from cache
|
||||
t.tokenCache.Delete(token)
|
||||
|
||||
// Add to blacklist with default expiration
|
||||
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
|
||||
t.tokenBlacklist.Add(token, expiry)
|
||||
// Add to blacklist
|
||||
claims, err := extractClaims(token)
|
||||
if err == nil {
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
t.tokenBlacklist.Add(token, expTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeTokenWithProvider revokes the token with the provider
|
||||
@@ -722,71 +703,33 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
// isAllowedDomain checks if the user's email domain is allowed
|
||||
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
if len(t.allowedUserDomains) == 0 {
|
||||
return true // If no domains are specified, all are allowed
|
||||
return true
|
||||
}
|
||||
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 {
|
||||
return false // Invalid email format
|
||||
atIndex := strings.LastIndex(email, "@")
|
||||
if atIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
domain := parts[1]
|
||||
domain := email[atIndex+1:]
|
||||
_, ok := t.allowedUserDomains[domain]
|
||||
return ok
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts groups and roles from the id_token
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
|
||||
}
|
||||
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups with type checking
|
||||
if groupsClaim, exists := claims["groups"]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
}
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles with type checking
|
||||
if rolesClaim, exists := claims["roles"]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
}
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return groups, roles, nil
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(claims map[string]interface{}) ([]string, []string) {
|
||||
groups := extractStringSlice(claims, "groups")
|
||||
roles := extractStringSlice(claims, "roles")
|
||||
return groups, roles
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
// If the path is already a full URL, return it as-is
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
func extractStringSlice(claims map[string]interface{}, key string) []string {
|
||||
if slice, ok := claims[key].([]interface{}); ok {
|
||||
result := make([]string, 0, len(slice))
|
||||
for _, item := range slice {
|
||||
if str, ok := item.(string); ok {
|
||||
result = append(result, str)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Ensure the path starts with a forward slash
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
return nil
|
||||
}
|
||||
|
||||
+27
-518
@@ -104,7 +104,7 @@ func (ts *TestSuite) Setup() {
|
||||
}
|
||||
|
||||
// Helper functions used by TraefikOidc
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
|
||||
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -167,6 +167,18 @@ func TestVerifyToken(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
ts.mockJWKCache.JWKS = &JWKSet{
|
||||
Keys: []JWK{
|
||||
{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Alg: "RS256",
|
||||
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
|
||||
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
@@ -452,12 +464,10 @@ func TestHandleCallback(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
redirectURL := "http://example.com/"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
expectedStatus int
|
||||
@@ -465,7 +475,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Success",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -495,7 +505,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Exchange Code Error",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return nil, fmt.Errorf("exchange code error")
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
@@ -507,7 +517,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Missing ID Token",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{}, nil
|
||||
},
|
||||
sessionSetupFunc: func(session *sessions.Session) {
|
||||
@@ -519,7 +529,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Disallowed Email",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -540,7 +550,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Invalid State Parameter",
|
||||
queryParams: "?code=test-code&state=invalid-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -561,7 +571,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Nonce Mismatch",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -582,7 +592,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
{
|
||||
name: "Missing Nonce in Claims",
|
||||
queryParams: "?code=test-code&state=test-csrf-token",
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
RefreshToken: "test-refresh-token",
|
||||
@@ -635,7 +645,7 @@ func TestHandleCallback(t *testing.T) {
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Call handleCallback
|
||||
tOidc.handleCallback(rr, req, redirectURL)
|
||||
tOidc.handleCallback(rr, req)
|
||||
|
||||
// Check response
|
||||
if rr.Code != tc.expectedStatus {
|
||||
@@ -690,7 +700,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams string
|
||||
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
|
||||
exchangeCodeForToken func(code string) (*TokenResponse, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
sessionSetupFunc func(session *sessions.Session)
|
||||
expectedStatus int
|
||||
@@ -706,7 +716,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -730,7 +740,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -753,7 +763,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -777,7 +787,7 @@ func TestOIDCHandler(t *testing.T) {
|
||||
session.Values["csrf"] = "test-csrf-token"
|
||||
session.Values["nonce"] = "test-nonce"
|
||||
},
|
||||
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
|
||||
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
|
||||
// Simulate token exchange
|
||||
return &TokenResponse{
|
||||
IDToken: ts.token,
|
||||
@@ -822,504 +832,3 @@ func TestOIDCHandler(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleLogout tests the logout functionality
|
||||
func TestHandleLogout(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
// Create mock revocation endpoint server
|
||||
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form: %v", err)
|
||||
}
|
||||
// Verify the required parameters are present
|
||||
if r.Form.Get("token") == "" {
|
||||
t.Error("Missing token parameter")
|
||||
}
|
||||
if r.Form.Get("token_type_hint") == "" {
|
||||
t.Error("Missing token_type_hint parameter")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer mockRevocationServer.Close()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*sessions.Session)
|
||||
endSessionURL string
|
||||
expectedStatus int
|
||||
expectedURL string
|
||||
host string
|
||||
}{
|
||||
{
|
||||
name: "Successful logout with end session endpoint",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
session.Values["refresh_token"] = "test-refresh-token"
|
||||
session.Values["access_token"] = "test-access-token"
|
||||
},
|
||||
endSessionURL: "https://provider/end-session",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Successful logout without end session endpoint",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
session.Values["refresh_token"] = "test-refresh-token"
|
||||
session.Values["access_token"] = "test-access-token"
|
||||
},
|
||||
endSessionURL: "",
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with empty session",
|
||||
setupSession: func(session *sessions.Session) {},
|
||||
expectedStatus: http.StatusFound,
|
||||
expectedURL: "http://example.com/",
|
||||
host: "test-host",
|
||||
},
|
||||
{
|
||||
name: "Logout with invalid end session URL",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "test.id.token"
|
||||
},
|
||||
endSessionURL: ":\\invalid-url",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
host: "test-host",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a new TraefikOidc instance for each test
|
||||
tOidc := &TraefikOidc{
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
revocationURL: mockRevocationServer.URL,
|
||||
endSessionURL: tc.endSessionURL,
|
||||
scheme: "http",
|
||||
logger: NewLogger("info"),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
httpClient: &http.Client{},
|
||||
clientID: "test-client-id",
|
||||
clientSecret: "test-client-secret",
|
||||
tokenCache: NewTokenCache(),
|
||||
forceHTTPS: false,
|
||||
}
|
||||
|
||||
// Create request with proper headers
|
||||
req := httptest.NewRequest("GET", "/logout", nil)
|
||||
req.Header.Set("Host", tc.host)
|
||||
|
||||
// Create a response recorder
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get a session
|
||||
session, err := tOidc.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Setup session
|
||||
tc.setupSession(session)
|
||||
session.Save(req, rr)
|
||||
|
||||
// Copy session cookie to request
|
||||
for _, cookie := range rr.Result().Cookies() {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Reset response recorder
|
||||
rr = httptest.NewRecorder()
|
||||
|
||||
// Handle logout
|
||||
tOidc.handleLogout(rr, req)
|
||||
|
||||
// Check response
|
||||
if rr.Code != tc.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
|
||||
}
|
||||
|
||||
// Check redirect URL if expected
|
||||
if tc.expectedURL != "" {
|
||||
location := rr.Header().Get("Location")
|
||||
if location != tc.expectedURL {
|
||||
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify session is cleared
|
||||
newSession, _ := tOidc.store.Get(req, cookieName)
|
||||
if len(newSession.Values) > 0 {
|
||||
t.Error("Session was not cleared")
|
||||
}
|
||||
if newSession.Options.MaxAge != -1 {
|
||||
t.Error("Session MaxAge was not set to -1")
|
||||
}
|
||||
|
||||
// Check token blacklist
|
||||
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
|
||||
t.Error("Refresh token was not blacklisted")
|
||||
}
|
||||
}
|
||||
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
|
||||
t.Error("Access token was not blacklisted")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRevokeTokenWithProvider tests the token revocation with provider
|
||||
func TestRevokeTokenWithProvider(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
tokenType string
|
||||
statusCode int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Successful token revocation",
|
||||
token: "valid-token",
|
||||
tokenType: "refresh_token",
|
||||
statusCode: http.StatusOK,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Failed token revocation",
|
||||
token: "invalid-token",
|
||||
tokenType: "refresh_token",
|
||||
statusCode: http.StatusBadRequest,
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create test server
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request method and content type
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
|
||||
}
|
||||
|
||||
// Verify form values
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatalf("Failed to parse form: %v", err)
|
||||
}
|
||||
if got := r.Form.Get("token"); got != tc.token {
|
||||
t.Errorf("Expected token %s, got %s", tc.token, got)
|
||||
}
|
||||
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
|
||||
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
|
||||
}
|
||||
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
|
||||
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
|
||||
}
|
||||
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
|
||||
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
|
||||
}
|
||||
|
||||
w.WriteHeader(tc.statusCode)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Set revocation URL to test server
|
||||
ts.tOidc.revocationURL = server.URL
|
||||
|
||||
// Test token revocation
|
||||
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
|
||||
if tc.expectError && err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
if !tc.expectError && err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRevokeToken tests the token revocation functionality
|
||||
func TestRevokeToken(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
token := "test.token.with.claims"
|
||||
claims := map[string]interface{}{
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
|
||||
// Test token revocation
|
||||
t.Run("Token revocation", func(t *testing.T) {
|
||||
// Create a new instance for this specific test
|
||||
tOidc := &TraefikOidc{
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
tokenCache: NewTokenCache(),
|
||||
}
|
||||
|
||||
// Cache the token
|
||||
tOidc.tokenCache.Set(token, claims, time.Hour)
|
||||
|
||||
// Revoke the token
|
||||
tOidc.RevokeToken(token)
|
||||
|
||||
// Verify token was removed from cache
|
||||
if _, exists := tOidc.tokenCache.Get(token); exists {
|
||||
t.Error("Token was not removed from cache")
|
||||
}
|
||||
|
||||
// Verify token was added to blacklist
|
||||
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
|
||||
t.Error("Token was not added to blacklist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Add this new test function
|
||||
func TestBuildLogoutURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
endSessionURL string
|
||||
idToken string
|
||||
postLogoutRedirect string
|
||||
expectedURL string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid URL",
|
||||
endSessionURL: "https://provider/end-session",
|
||||
idToken: "test.id.token",
|
||||
postLogoutRedirect: "http://example.com/",
|
||||
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL",
|
||||
endSessionURL: "://invalid-url",
|
||||
idToken: "test.id.token",
|
||||
postLogoutRedirect: "http://example.com/",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "URL with existing query parameters",
|
||||
endSessionURL: "https://provider/end-session?existing=param",
|
||||
idToken: "test.id.token",
|
||||
postLogoutRedirect: "http://example.com/",
|
||||
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
if url != tc.expectedURL {
|
||||
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add this new test function
|
||||
func TestHandleExpiredToken(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func(*sessions.Session)
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "Basic expired token",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "expired.token"
|
||||
session.Values["email"] = "test@example.com"
|
||||
},
|
||||
expectedPath: "/original/path",
|
||||
},
|
||||
{
|
||||
name: "Session with additional values",
|
||||
setupSession: func(session *sessions.Session) {
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = "expired.token"
|
||||
session.Values["custom_value"] = "should-be-cleared"
|
||||
},
|
||||
expectedPath: "/another/path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a new TraefikOidc instance for each test
|
||||
tOidc := &TraefikOidc{
|
||||
store: sessions.NewCookieStore([]byte("test-secret-key")),
|
||||
logger: NewLogger("info"),
|
||||
tokenVerifier: ts.tOidc.tokenVerifier,
|
||||
jwtVerifier: ts.tOidc.jwtVerifier,
|
||||
initComplete: make(chan struct{}),
|
||||
// Add this initialization of initiateAuthenticationFunc
|
||||
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
// Mock implementation for test
|
||||
http.Redirect(rw, req, "/login", http.StatusFound)
|
||||
},
|
||||
}
|
||||
close(tOidc.initComplete)
|
||||
|
||||
// Create request
|
||||
req := httptest.NewRequest("GET", tc.expectedPath, nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Get session
|
||||
session, _ := tOidc.store.New(req, cookieName)
|
||||
tc.setupSession(session)
|
||||
|
||||
// Handle expired token
|
||||
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
|
||||
|
||||
// Verify session is cleaned
|
||||
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
|
||||
t.Errorf("Expected 3 session values, got %d", len(session.Values))
|
||||
}
|
||||
|
||||
// Verify required values are set
|
||||
if _, ok := session.Values["csrf"].(string); !ok {
|
||||
t.Error("CSRF token not set")
|
||||
}
|
||||
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
|
||||
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
|
||||
}
|
||||
if _, ok := session.Values["nonce"].(string); !ok {
|
||||
t.Error("Nonce not set")
|
||||
}
|
||||
|
||||
// Verify session options
|
||||
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
|
||||
t.Error("Session MaxAge not set correctly")
|
||||
}
|
||||
|
||||
// Verify redirect status
|
||||
if rr.Code != http.StatusFound {
|
||||
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add this new test function
|
||||
func TestExtractGroupsAndRoles(t *testing.T) {
|
||||
ts := &TestSuite{t: t}
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims map[string]interface{}
|
||||
expectGroups []string
|
||||
expectRoles []string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid groups and roles",
|
||||
claims: map[string]interface{}{
|
||||
"groups": []interface{}{"group1", "group2"},
|
||||
"roles": []interface{}{"role1", "role2"},
|
||||
},
|
||||
expectGroups: []string{"group1", "group2"},
|
||||
expectRoles: []string{"role1", "role2"},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty groups and roles",
|
||||
claims: map[string]interface{}{
|
||||
"groups": []interface{}{},
|
||||
"roles": []interface{}{},
|
||||
},
|
||||
expectGroups: []string{},
|
||||
expectRoles: []string{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid groups format",
|
||||
claims: map[string]interface{}{
|
||||
"groups": "not-an-array",
|
||||
"roles": []interface{}{"role1"},
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Create a test token with the claims
|
||||
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test token: %v", err)
|
||||
}
|
||||
|
||||
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
|
||||
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got nil")
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Compare groups
|
||||
if !stringSliceEqual(groups, tc.expectGroups) {
|
||||
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
|
||||
}
|
||||
|
||||
// Compare roles
|
||||
if !stringSliceEqual(roles, tc.expectRoles) {
|
||||
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare string slices
|
||||
func stringSliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
+8
-2
@@ -14,6 +14,14 @@ const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
)
|
||||
|
||||
const (
|
||||
headerXForwardedProto = "X-Forwarded-Proto"
|
||||
headerXForwardedHost = "X-Forwarded-Host"
|
||||
headerXForwardedUser = "X-Forwarded-User"
|
||||
headerXUserGroups = "X-User-Groups"
|
||||
headerXUserRoles = "X-User-Roles"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
@@ -30,8 +38,6 @@ type Config struct {
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
|
||||
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
|
||||
+5
-1
@@ -1,4 +1,7 @@
|
||||
# sessions
|
||||
# Gorilla Sessions
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
|
||||
|
||||

|
||||
[](https://codecov.io/github/gorilla/sessions)
|
||||
@@ -74,6 +77,7 @@ Other implementations of the `sessions.Store` interface:
|
||||
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
|
||||
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
|
||||
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
|
||||
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
|
||||
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
|
||||
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
|
||||
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
|
||||
|
||||
+12
-9
@@ -1,5 +1,6 @@
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sessions
|
||||
|
||||
@@ -8,13 +9,15 @@ import "net/http"
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
Partitioned: options.Partitioned,
|
||||
SameSite: options.SameSite,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
-21
@@ -1,21 +0,0 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// newCookieFromOptions returns an http.Cookie with the options set.
|
||||
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: options.Path,
|
||||
Domain: options.Domain,
|
||||
MaxAge: options.MaxAge,
|
||||
Secure: options.Secure,
|
||||
HttpOnly: options.HttpOnly,
|
||||
SameSite: options.SameSite,
|
||||
}
|
||||
|
||||
}
|
||||
+10
-5
@@ -1,8 +1,11 @@
|
||||
//go:build !go1.11
|
||||
// +build !go1.11
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
@@ -13,7 +16,9 @@ type Options struct {
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
Partitioned bool
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
-23
@@ -1,23 +0,0 @@
|
||||
//go:build go1.11
|
||||
// +build go1.11
|
||||
|
||||
package sessions
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Options stores configuration for a session or session store.
|
||||
//
|
||||
// Fields are a subset of http.Cookie fields.
|
||||
type Options struct {
|
||||
Path string
|
||||
Domain string
|
||||
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
|
||||
// deleted after the browser session ends.
|
||||
// MaxAge<0 means delete cookie immediately.
|
||||
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
// Defaults to http.SameSiteDefaultMode
|
||||
SameSite http.SameSite
|
||||
}
|
||||
Vendored
+2
-2
@@ -4,8 +4,8 @@ github.com/google/uuid
|
||||
# github.com/gorilla/securecookie v1.1.2
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
# github.com/gorilla/sessions v1.4.0
|
||||
## explicit; go 1.23
|
||||
github.com/gorilla/sessions
|
||||
# golang.org/x/time v0.7.0
|
||||
## explicit; go 1.18
|
||||
|
||||
Reference in New Issue
Block a user