Compare commits

..

24 Commits

Author SHA1 Message Date
lukaszraczylo ab36f10a70 Draft. 2024-10-13 18:21:13 +01:00
lukaszraczylo 4972b21373 Improve speed of the cache module. 2024-10-13 18:00:43 +01:00
lukaszraczylo 0be45f06a5 Merge pull request #10 from lukaszraczylo/code-improvements
Code improvements
2024-10-09 09:34:45 +01:00
lukaszraczylo 2cc89f0f31 fixup! Update dependencies. 2024-10-09 09:33:20 +01:00
lukaszraczylo b9928f2f0c Update dependencies. 2024-10-09 09:32:25 +01:00
lukaszraczylo 2e1a3a9320 Add ability to verify default ECDSA keys provided by logto as well. 2024-10-09 09:30:42 +01:00
lukaszraczylo 9dabd0e5cf Revert "Update go mod dependencies."
This reverts commit dedbdf63c3.
2024-10-09 09:11:07 +01:00
lukaszraczylo dedbdf63c3 Update go mod dependencies. 2024-10-09 09:07:25 +01:00
lukaszraczylo af032c6cd3 Add simple benchmark to track the allocations and speed for future improvements. 2024-10-08 14:41:43 +01:00
lukaszraczylo 9938cff053 fixup! Cleanup and optimise the code. 2024-10-08 14:26:26 +01:00
lukaszraczylo 7a404ef76f Cleanup and optimise the code. 2024-10-08 14:14:47 +01:00
lukaszraczylo 63922f362f fixup! Add support for more algorithms. 2024-10-07 16:07:07 +01:00
lukaszraczylo 2de9297ab6 Add support for more algorithms. 2024-10-07 16:01:07 +01:00
lukaszraczylo 971c84f762 Abstract filling up maps. 2024-10-07 15:56:24 +01:00
lukaszraczylo d2a0d2167e Fix the bug with user not being redirected to originally requested URL post authentication. 2024-10-05 09:33:56 +01:00
lukaszraczylo c46d958397 Update documentation - setting secrets in kubernetes. 2024-10-04 17:15:43 +01:00
lukaszraczylo 95cf0034d6 Fix the tests hanging on the open channel. 2024-10-04 14:39:37 +01:00
lukaszraczylo 380ef96571 Improvement - startup time.
Previous implementations blocked the traefik startup until OIDC plugin was loaded.
This caused chicken-or-egg issue when called OIDC endpoint was hosted by the same traefik as well,
generating rather ridiculous situation when traefik couldn't come up because plugin tried to call the
discovery endpoint which was hosted by the same traefik.

This version resolves the issue allowing for quickstart and lazy loading of the provider metadata.
Disadvantage is - until discovery is done, the plugin will not provide any access to the client.
2024-10-04 14:22:16 +01:00
lukaszraczylo 1886396dc1 First step in improvement of caching mechanism. 2024-10-04 14:05:12 +01:00
lukaszraczylo 24ecf00053 Add support for roles and groups. 2024-10-03 19:48:43 +01:00
lukaszraczylo e338992f84 Update tests and additional fixups. 2024-10-03 14:00:47 +01:00
lukaszraczylo a9a596031b Update the tests to handle nonce 2024-10-03 14:00:44 +01:00
lukaszraczylo 23afcad2ba Support additional verification of the token to ensure OIDC compliance 2024-10-03 14:00:44 +01:00
lukaszraczylo d06f9fcf90 Update dependencies. 2024-10-03 14:00:44 +01:00
18 changed files with 387 additions and 1046 deletions
-1
View File
@@ -13,7 +13,6 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
-1
View File
@@ -38,7 +38,6 @@ spec:
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
+14 -10
View File
@@ -8,7 +8,7 @@ import (
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
ExpiresAt int64 // Changed to int64 for faster comparisons
}
// Cache is a simple in-memory cache
@@ -27,43 +27,47 @@ func NewCache() *Cache {
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Removed defer for slightly better performance
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: time.Now().Add(expiration).UnixNano(), // Store as UnixNano for faster comparisons
}
c.mutex.Unlock()
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
if time.Now().UnixNano() > item.ExpiresAt {
c.mutex.RUnlock()
// Use a separate goroutine to delete expired items to avoid blocking
go c.Delete(key)
return nil, false
}
c.mutex.RUnlock()
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.mutex.Unlock()
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
now := time.Now().UnixNano()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
if now > item.ExpiresAt {
delete(c.items, key)
}
}
c.mutex.Unlock()
}
+1 -1
View File
@@ -6,7 +6,7 @@ toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
github.com/gorilla/sessions v1.4.0
golang.org/x/time v0.7.0
)
+4
View File
@@ -6,5 +6,9 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+89 -163
View File
@@ -20,13 +20,20 @@ import (
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
if err != nil {
if _, err := rand.Read(nonceBytes); err != nil {
return "", fmt.Errorf("could not generate nonce: %w", err)
}
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
@@ -35,14 +42,15 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
"client_secret": {t.clientSecret},
}
if grantType == "authorization_code" {
switch grantType {
case "authorization_code":
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
} else if grantType == "refresh_token" {
case "refresh_token":
data.Set("refresh_token", codeOrToken)
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -89,10 +97,50 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
// Revoke tokens if available
for _, tokenType := range []string{"refresh_token", "access_token"} {
if token, ok := session.Values[tokenType].(string); ok && token != "" {
if err := t.RevokeTokenWithProvider(token, tokenType); err != nil {
t.logger.Errorf("Failed to revoke %s: %v", tokenType, err)
}
t.RevokeToken(token)
}
delete(session.Values, tokenType)
}
// Remove other session values
delete(session.Values, "id_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = &sessions.Options{MaxAge: -1, Path: "/", HttpOnly: true, Secure: true}
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
if session == nil {
t.logger.Error("Session is nil in handleExpiredToken")
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
@@ -101,21 +149,19 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, redirectURL)
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Session error: %v", err)
@@ -125,34 +171,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
if errParam := req.URL.Query().Get("error"); errParam != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.logger.Errorf("Authentication error: %s - %s", errParam, errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
if !ok || state == "" || csrfToken == "" || state != csrfToken {
t.logger.Error("Invalid state parameter or CSRF token")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -160,14 +193,13 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
@@ -175,14 +207,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
@@ -190,26 +220,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
sessionNonce, ok2 := session.Values["nonce"].(string)
if !ok || !ok2 || nonceClaim == "" || sessionNonce == "" || nonceClaim != sessionNonce {
t.logger.Error("Invalid nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -217,14 +235,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
@@ -236,7 +252,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
@@ -267,42 +282,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
blacklist sync.Map
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
return &TokenBlacklist{}
}
func (tb *TokenBlacklist) Add(token string, expiration time.Time) {
tb.blacklist.Store(token, expiration)
}
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
if exp, ok := tb.blacklist.Load(token); ok {
return time.Now().Before(exp.(time.Time))
}
return false
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
tb.blacklist.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
tb.blacklist.Delete(key)
}
}
return true
})
}
// TokenCache caches tokens
@@ -319,14 +324,12 @@ func NewTokenCache() *TokenCache {
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
tc.cache.Set("t-"+token, claims, expiration)
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
value, found := tc.cache.Get("t-" + token)
if !found {
return nil, false
}
@@ -336,8 +339,7 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
tc.cache.Delete("t-" + token)
}
// Cleanup cleans up expired tokens from the cache
@@ -346,9 +348,9 @@ func (tc *TokenCache) Cleanup() {
}
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
@@ -357,85 +359,9 @@ func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*To
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
result := make(map[string]struct{}, len(keys))
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the id_token before clearing the session
idToken, _ := session.Values["id_token"].(string)
// Clear and expire the session
session.Values = make(map[interface{}]interface{})
session.Options.MaxAge = -1
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Error saving session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the base URL for redirects
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Determine post logout redirect URI
var postLogoutRedirectURI string
if t.postLogoutRedirectURI != "" {
// Use explicitly configured postLogoutRedirectURI
if strings.HasPrefix(t.postLogoutRedirectURI, "http://") || strings.HasPrefix(t.postLogoutRedirectURI, "https://") {
postLogoutRedirectURI = t.postLogoutRedirectURI
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, t.postLogoutRedirectURI)
}
} else {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, "/")
}
t.logger.Debugf("Using post logout redirect URI: %s", postLogoutRedirectURI)
// If we have an end session endpoint and an ID token, use OIDC end session
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, postLogoutRedirectURI)
if err != nil {
handleError(rw, fmt.Sprintf("Failed to build logout URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
t.logger.Debugf("Redirecting to end session URL: %s", logoutURL)
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// If no end session endpoint or no ID token, just redirect to the post logout URI
t.logger.Debugf("Redirecting to post logout URI: %s", postLogoutRedirectURI)
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
// Ensure postLogoutRedirectURI is properly URL encoded
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
+30 -33
View File
@@ -4,13 +4,12 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
@@ -58,6 +57,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check locking pattern
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
@@ -120,25 +120,12 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
N: new(big.Int).SetBytes(nBytes),
E: int(new(big.Int).SetBytes(eBytes).Int64()),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return pubKeyPEM, nil
return marshalPublicKey(pubKey)
}
// ecJWKToPEM converts an EC JWK to PEM
@@ -152,16 +139,9 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
curve, err := getCurve(jwk.Crv)
if err != nil {
return nil, err
}
pubKey := &ecdsa.PublicKey{
@@ -170,15 +150,32 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
Y: new(big.Int).SetBytes(yBytes),
}
return marshalPublicKey(pubKey)
}
// getCurve returns the elliptic curve based on the JWK curve parameter
func getCurve(crv string) (elliptic.Curve, error) {
switch crv {
case "P-256":
return elliptic.P256(), nil
case "P-384":
return elliptic.P384(), nil
case "P-521":
return elliptic.P521(), nil
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", crv)
}
}
// marshalPublicKey marshals a public key to PEM format
func marshalPublicKey(pubKey interface{}) ([]byte, error) {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return pubKeyPEM, nil
}), nil
}
+119 -133
View File
@@ -4,18 +4,29 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"time"
)
var (
ErrInvalidJWTFormat = errors.New("invalid JWT format")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrTokenExpired = errors.New("token has expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrMissingClaim = errors.New("missing claim")
ErrInvalidClaimType = errors.New("invalid claim type")
ErrUnsupportedAlgorithm = errors.New("unsupported algorithm")
ErrInvalidSignature = errors.New("invalid signature")
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
@@ -28,212 +39,187 @@ type JWT struct {
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
return nil, fmt.Errorf("%w: expected 3 parts, got %d", ErrInvalidJWTFormat, len(parts))
}
jwt := &JWT{
Token: tokenString,
jwt := &JWT{Token: tokenString}
if err := decodeJSONPart(parts[0], &jwt.Header); err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err := decodeJSONPart(parts[1], &jwt.Claims); err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var err error
jwt.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
}
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
return jwt, nil
}
func decodeJSONPart(part string, target interface{}) error {
bytes, err := base64.RawURLEncoding.DecodeString(part)
if err != nil {
return err
}
return json.Unmarshal(bytes, target)
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
}
if err := verifyIssuer(iss, issuerURL); err != nil {
if err := verifyIssuer(j.Claims["iss"], issuerURL); err != nil {
return err
}
aud, ok := claims["aud"]
if !ok {
return fmt.Errorf("missing 'aud' claim")
}
if err := verifyAudience(aud, clientID); err != nil {
if err := verifyAudience(j.Claims["aud"], clientID); err != nil {
return err
}
exp, ok := claims["exp"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'exp' claim")
}
if err := verifyExpiration(exp); err != nil {
if err := verifyExpiration(j.Claims["exp"]); err != nil {
return err
}
iat, ok := claims["iat"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'iat' claim")
}
if err := verifyIssuedAt(iat); err != nil {
if err := verifyIssuedAt(j.Claims["iat"]); err != nil {
return err
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
if sub, ok := j.Claims["sub"].(string); !ok || sub == "" {
return fmt.Errorf("%w: sub", ErrMissingClaim)
}
return nil
}
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
return ErrInvalidAudience
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
return nil
}
}
if !found {
return fmt.Errorf("invalid audience")
}
return ErrInvalidAudience
default:
return fmt.Errorf("invalid 'aud' claim type")
return fmt.Errorf("%w: aud", ErrInvalidClaimType)
}
return nil
}
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
func verifyIssuer(tokenIssuer interface{}, expectedIssuer string) error {
iss, ok := tokenIssuer.(string)
if !ok {
return fmt.Errorf("%w: iss", ErrMissingClaim)
}
if iss != expectedIssuer {
return ErrInvalidIssuer
}
return nil
}
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
func verifyExpiration(expiration interface{}) error {
exp, ok := expiration.(float64)
if !ok {
return fmt.Errorf("%w: exp", ErrInvalidClaimType)
}
if time.Now().After(time.Unix(int64(exp), 0)) {
return ErrTokenExpired
}
return nil
}
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
func verifyIssuedAt(issuedAt interface{}) error {
iat, ok := issuedAt.(float64)
if !ok {
return fmt.Errorf("%w: iat", ErrInvalidClaimType)
}
if time.Now().Before(time.Unix(int64(iat), 0)) {
return ErrTokenUsedBeforeIssued
}
return nil
}
// verifySignature verifies the token signature using the provided public key and algorithm
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
return ErrInvalidJWTFormat
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
pubKey, err := parsePublicKey(publicKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
return err
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
hashFunc, err := getHashFunc(alg)
if err != nil {
return err
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
hashed := hashFunc.New().Sum([]byte(signedContent))
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
return verifyRSASignature(pubKey, hashFunc, hashed, signature, alg)
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
return verifyECDSASignature(pubKey, hashed, signature)
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
}
func parsePublicKey(publicKeyPEM []byte) (interface{}, error) {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse PEM block containing the public key")
}
return x509.ParsePKIXPublicKey(block.Bytes)
}
func getHashFunc(alg string) (crypto.Hash, error) {
switch alg {
case "RS256", "PS256", "ES256":
return crypto.SHA256, nil
case "RS384", "PS384", "ES384":
return crypto.SHA384, nil
case "RS512", "PS512", "ES512":
return crypto.SHA512, nil
default:
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
}
func verifyRSASignature(pubKey *rsa.PublicKey, hashFunc crypto.Hash, hashed, signature []byte, alg string) error {
if strings.HasPrefix(alg, "RS") {
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
}
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
func verifyECDSASignature(pubKey *ecdsa.PublicKey, hashed, signature []byte) error {
sigLen := len(signature)
if sigLen%2 != 0 {
return errors.New("invalid ECDSA signature length")
}
r, s := new(big.Int), new(big.Int)
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, r, s) {
return nil
}
return ErrInvalidSignature
}
+65 -122
View File
@@ -53,28 +53,26 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
@@ -84,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -227,12 +217,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
}
return config.PostLogoutRedirectURI
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
@@ -270,26 +254,20 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
t.initOnce.Do(func() {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
}
close(t.initComplete)
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -381,12 +359,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
// Build the redirect URL if not already set
if redirectURL == "" {
redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", redirectURL)
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
@@ -407,7 +383,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
t.handleCallback(rw, req)
return
}
@@ -415,19 +391,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
}
@@ -436,21 +412,21 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
@@ -460,7 +436,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
groups, roles := t.extractGroupsAndRoles(claims)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
@@ -507,16 +483,16 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
switch {
case t.forceHTTPS:
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
case req.Header.Get(headerXForwardedProto) != "":
return req.Header.Get(headerXForwardedProto)
case req.TLS != nil:
return "https"
default:
return "http"
}
return "http"
}
// determineHost determines the host of the request
@@ -644,9 +620,14 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -722,71 +703,33 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
return true
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false // Invalid email format
atIndex := strings.LastIndex(email, "@")
if atIndex == -1 {
return false
}
domain := parts[1]
domain := email[atIndex+1:]
_, ok := t.allowedUserDomains[domain]
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
var groups []string
var roles []string
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
return groups, roles, nil
func (t *TraefikOidc) extractGroupsAndRoles(claims map[string]interface{}) ([]string, []string) {
groups := extractStringSlice(claims, "groups")
roles := extractStringSlice(claims, "roles")
return groups, roles
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
func extractStringSlice(claims map[string]interface{}, key string) []string {
if slice, ok := claims[key].([]interface{}); ok {
result := make([]string, 0, len(slice))
for _, item := range slice {
if str, ok := item.(string); ok {
result = append(result, str)
}
}
return result
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
return nil
}
+27 -518
View File
@@ -104,7 +104,7 @@ func (ts *TestSuite) Setup() {
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -167,6 +167,18 @@ func TestVerifyToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.mockJWKCache.JWKS = &JWKSet{
Keys: []JWK{
{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
},
},
}
tests := []struct {
name string
token string
@@ -452,12 +464,10 @@ func TestHandleCallback(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
redirectURL := "http://example.com/"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
@@ -465,7 +475,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Success",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -495,7 +505,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *sessions.Session) {
@@ -507,7 +517,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
@@ -519,7 +529,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -540,7 +550,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -561,7 +571,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -582,7 +592,7 @@ func TestHandleCallback(t *testing.T) {
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -635,7 +645,7 @@ func TestHandleCallback(t *testing.T) {
rr = httptest.NewRecorder()
// Call handleCallback
tOidc.handleCallback(rr, req, redirectURL)
tOidc.handleCallback(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
@@ -690,7 +700,7 @@ func TestOIDCHandler(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
@@ -706,7 +716,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -730,7 +740,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -753,7 +763,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -777,7 +787,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -822,504 +832,3 @@ func TestOIDCHandler(t *testing.T) {
})
}
}
// TestHandleLogout tests the logout functionality
func TestHandleLogout(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
// Create mock revocation endpoint server
mockRevocationServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
// Verify the required parameters are present
if r.Form.Get("token") == "" {
t.Error("Missing token parameter")
}
if r.Form.Get("token_type_hint") == "" {
t.Error("Missing token_type_hint parameter")
}
w.WriteHeader(http.StatusOK)
}))
defer mockRevocationServer.Close()
tests := []struct {
name string
setupSession func(*sessions.Session)
endSessionURL string
expectedStatus int
expectedURL string
host string
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
},
endSessionURL: "",
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with empty session",
setupSession: func(session *sessions.Session) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
host: "test-host",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: NewLogger("info"),
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
}
// Create request with proper headers
req := httptest.NewRequest("GET", "/logout", nil)
req.Header.Set("Host", tc.host)
// Create a response recorder
rr := httptest.NewRecorder()
// Get a session
session, err := tOidc.store.Get(req, cookieName)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Handle logout
tOidc.handleLogout(rr, req)
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
t.Errorf("Expected redirect to %q, got %q", tc.expectedURL, location)
}
}
// Verify session is cleared
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
}
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
}
// Check token blacklist
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
}
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
}
}
})
}
}
// TestRevokeTokenWithProvider tests the token revocation with provider
func TestRevokeTokenWithProvider(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
token string
tokenType string
statusCode int
expectError bool
}{
{
name: "Successful token revocation",
token: "valid-token",
tokenType: "refresh_token",
statusCode: http.StatusOK,
expectError: false,
},
{
name: "Failed token revocation",
token: "invalid-token",
tokenType: "refresh_token",
statusCode: http.StatusBadRequest,
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create test server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request method and content type
if r.Method != "POST" {
t.Errorf("Expected POST request, got %s", r.Method)
}
if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
t.Errorf("Expected Content-Type application/x-www-form-urlencoded, got %s", ct)
}
// Verify form values
if err := r.ParseForm(); err != nil {
t.Fatalf("Failed to parse form: %v", err)
}
if got := r.Form.Get("token"); got != tc.token {
t.Errorf("Expected token %s, got %s", tc.token, got)
}
if got := r.Form.Get("token_type_hint"); got != tc.tokenType {
t.Errorf("Expected token_type_hint %s, got %s", tc.tokenType, got)
}
if got := r.Form.Get("client_id"); got != ts.tOidc.clientID {
t.Errorf("Expected client_id %s, got %s", ts.tOidc.clientID, got)
}
if got := r.Form.Get("client_secret"); got != ts.tOidc.clientSecret {
t.Errorf("Expected client_secret %s, got %s", ts.tOidc.clientSecret, got)
}
w.WriteHeader(tc.statusCode)
}))
defer server.Close()
// Set revocation URL to test server
ts.tOidc.revocationURL = server.URL
// Test token revocation
err := ts.tOidc.RevokeTokenWithProvider(tc.token, tc.tokenType)
if tc.expectError && err == nil {
t.Error("Expected error but got nil")
}
if !tc.expectError && err != nil {
t.Errorf("Unexpected error: %v", err)
}
})
}
}
// TestRevokeToken tests the token revocation functionality
func TestRevokeToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
token := "test.token.with.claims"
claims := map[string]interface{}{
"exp": float64(time.Now().Add(time.Hour).Unix()),
}
// Test token revocation
t.Run("Token revocation", func(t *testing.T) {
// Create a new instance for this specific test
tOidc := &TraefikOidc{
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
}
// Cache the token
tOidc.tokenCache.Set(token, claims, time.Hour)
// Revoke the token
tOidc.RevokeToken(token)
// Verify token was removed from cache
if _, exists := tOidc.tokenCache.Get(token); exists {
t.Error("Token was not removed from cache")
}
// Verify token was added to blacklist
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Token was not added to blacklist")
}
})
}
// Add this new test function
func TestBuildLogoutURL(t *testing.T) {
tests := []struct {
name string
endSessionURL string
idToken string
postLogoutRedirect string
expectedURL string
expectError bool
}{
{
name: "Valid URL",
endSessionURL: "https://provider/end-session",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
{
name: "Invalid URL",
endSessionURL: "://invalid-url",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectError: true,
},
{
name: "URL with existing query parameters",
endSessionURL: "https://provider/end-session?existing=param",
idToken: "test.id.token",
postLogoutRedirect: "http://example.com/",
expectedURL: "https://provider/end-session?existing=param&id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
expectError: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
url, err := BuildLogoutURL(tc.endSessionURL, tc.idToken, tc.postLogoutRedirect)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if url != tc.expectedURL {
t.Errorf("Expected URL %q, got %q", tc.expectedURL, url)
}
}
})
}
}
// Add this new test function
func TestHandleExpiredToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
setupSession func(*sessions.Session)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
close(tOidc.initComplete)
// Create request
req := httptest.NewRequest("GET", tc.expectedPath, nil)
rr := httptest.NewRecorder()
// Get session
session, _ := tOidc.store.New(req, cookieName)
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
}
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
t.Error("CSRF token not set")
}
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if _, ok := session.Values["nonce"].(string); !ok {
t.Error("Nonce not set")
}
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
}
// Verify redirect status
if rr.Code != http.StatusFound {
t.Errorf("Expected status %d, got %d", http.StatusFound, rr.Code)
}
})
}
}
// Add this new test function
func TestExtractGroupsAndRoles(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
claims map[string]interface{}
expectGroups []string
expectRoles []string
expectError bool
}{
{
name: "Valid groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{"group1", "group2"},
"roles": []interface{}{"role1", "role2"},
},
expectGroups: []string{"group1", "group2"},
expectRoles: []string{"role1", "role2"},
expectError: false,
},
{
name: "Empty groups and roles",
claims: map[string]interface{}{
"groups": []interface{}{},
"roles": []interface{}{},
},
expectGroups: []string{},
expectRoles: []string{},
expectError: false,
},
{
name: "Invalid groups format",
claims: map[string]interface{}{
"groups": "not-an-array",
"roles": []interface{}{"role1"},
},
expectError: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a test token with the claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
groups, roles, err := ts.tOidc.extractGroupsAndRoles(token)
if tc.expectError {
if err == nil {
t.Error("Expected error but got nil")
}
} else {
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
// Compare groups
if !stringSliceEqual(groups, tc.expectGroups) {
t.Errorf("Expected groups %v, got %v", tc.expectGroups, groups)
}
// Compare roles
if !stringSliceEqual(roles, tc.expectRoles) {
t.Errorf("Expected roles %v, got %v", tc.expectRoles, roles)
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
+8 -2
View File
@@ -14,6 +14,14 @@ const (
cookieName = "_raczylo_oidc"
)
const (
headerXForwardedProto = "X-Forwarded-Proto"
headerXForwardedHost = "X-Forwarded-Host"
headerXForwardedUser = "X-Forwarded-User"
headerXUserGroups = "X-User-Groups"
headerXUserRoles = "X-User-Roles"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
@@ -30,8 +38,6 @@ type Config struct {
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
HTTPClient *http.Client
}
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+5 -1
View File
@@ -1,4 +1,7 @@
# sessions
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -74,6 +77,7 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+12 -9
View File
@@ -1,5 +1,6 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
@@ -8,13 +9,15 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
}
}
-21
View File
@@ -1,21 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+10 -5
View File
@@ -1,8 +1,11 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -13,7 +16,9 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
}
-23
View File
@@ -1,23 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -4,8 +4,8 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
## explicit; go 1.18