Compare commits

...

10 Commits

Author SHA1 Message Date
lukaszraczylo 7e3dc46b6e Improve initial fetch of the provider metadata until successful. 2025-01-06 12:19:11 +00:00
lukaszraczylo 147aa0b169 Fix the issue #16
Removed global metadata cache and sync.Once
Each middleware instance now handles its own metadata initialization
Added tests to verify multiple instances work correctly
The changes ensure that:

Each route gets its own properly initialized middleware instance
Metadata is fetched and set correctly for each instance
No shared state between instances that could cause conflicts
Each instance can handle requests independently
The added test verifies this by creating multiple middleware instances with different routes and confirming they all initialize and function correctly. The test specifically checks that:

Each instance initializes successfully
Each instance gets its own metadata configuration
Each instance can handle requests independently
Callback URLs are correctly set per route
2025-01-06 11:23:12 +00:00
lukaszraczylo eecb7dfc92 Improve test coverage 2025-01-06 11:01:20 +00:00
lukaszraczylo a8d65688c4 Improve documentation. 2025-01-06 10:44:49 +00:00
lukaszraczylo bef4212c57 Add support for the large tokens, which exceed the standard 4096 limit for cookie. 2024-12-11 12:55:16 +00:00
lukaszraczylo 1fee2f9e9a fixup! Re-introduce user roles separation with additional tests. 2024-12-11 09:11:34 +00:00
lukaszraczylo 11bc6f3e31 Re-introduce user roles separation with additional tests. 2024-12-11 09:08:50 +00:00
lukaszraczylo 2b7af88ff9 Move session management into session manager. Split the cookies to avoid the 4k limit ( resolves issue: #15 ) 2024-12-10 10:19:35 +00:00
lukaszraczylo 01ee7c4dc8 Improve cookie setting. 2024-12-10 10:19:35 +00:00
lukaszraczylo a6fa4d8789 Downgrade gorilla sessions preventing the publishing by traefik hub temporarily. 2024-12-10 10:19:34 +00:00
22 changed files with 2393 additions and 551 deletions
+1
View File
@@ -13,6 +13,7 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
+1
View File
@@ -38,6 +38,7 @@ spec:
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
+34 -8
View File
@@ -5,26 +5,41 @@ import (
"time"
)
// CacheItem represents an item in the cache
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
Value interface{}
// Value is the cached data of any type
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
ExpiresAt time.Time
}
// Cache is a simple in-memory cache
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
}
// NewCache creates a new Cache
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -34,7 +49,13 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
}
}
// Get retrieves an item from the cache
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
@@ -49,14 +70,19 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return item.Value, true
}
// Delete removes an item from the cache
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes expired items from the cache
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
+306
View File
@@ -0,0 +1,306 @@
package traefikoidc
import (
"reflect"
"testing"
"time"
)
func TestCache(t *testing.T) {
t.Run("Basic Set and Get", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Test Set
cache.Set(key, value, expiration)
// Test Get
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value {
t.Errorf("Expected value %v, got %v", value, got)
}
})
t.Run("Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 10 * time.Millisecond
// Set with short expiration
cache.Set(key, value, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be expired")
}
})
t.Run("Delete", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Set and then delete
cache.Set(key, value, expiration)
cache.Delete(key)
// Should not find deleted key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be deleted")
}
})
t.Run("Cleanup", func(t *testing.T) {
cache := NewCache()
// Add multiple items with different expirations
cache.Set("expired1", "value1", 10*time.Millisecond)
cache.Set("expired2", "value2", 10*time.Millisecond)
cache.Set("valid", "value3", 1*time.Second)
// Wait for some items to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Check expired items are removed
_, found1 := cache.Get("expired1")
_, found2 := cache.Get("expired2")
_, found3 := cache.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid item to remain in cache")
}
})
t.Run("Concurrent Access", func(t *testing.T) {
cache := NewCache()
done := make(chan bool)
// Start multiple goroutines to access cache concurrently
for i := 0; i < 10; i++ {
go func(id int) {
key := "key"
value := "value"
expiration := 1 * time.Second
// Perform multiple operations
cache.Set(key, value, expiration)
cache.Get(key)
cache.Delete(key)
cache.Cleanup()
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
})
t.Run("Zero Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with zero expiration
cache.Set(key, value, 0)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with zero expiration to be immediately expired")
}
})
t.Run("Negative Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with negative expiration
cache.Set(key, value, -1*time.Second)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with negative expiration to be immediately expired")
}
})
t.Run("Update Existing Key", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value1 := "value1"
value2 := "value2"
expiration := 1 * time.Second
// Set initial value
cache.Set(key, value1, expiration)
// Update value
cache.Set(key, value2, expiration)
// Check updated value
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value2 {
t.Errorf("Expected updated value %v, got %v", value2, got)
}
})
t.Run("Different Value Types", func(t *testing.T) {
cache := NewCache()
expiration := 1 * time.Second
// Test with different value types
testCases := []struct {
key string
value interface{}
}{
{"string", "test"},
{"int", 42},
{"float", 3.14},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"a": 1, "b": 2}},
{"struct", struct{ Name string }{"test"}},
}
for _, tc := range testCases {
t.Run(tc.key, func(t *testing.T) {
cache.Set(tc.key, tc.value, expiration)
got, found := cache.Get(tc.key)
if !found {
t.Error("Expected to find key in cache")
}
// Use reflect.DeepEqual for comparing complex types like slices and maps
if !reflect.DeepEqual(got, tc.value) {
t.Errorf("Expected value %v, got %v", tc.value, got)
}
})
}
})
}
func TestTokenCache(t *testing.T) {
t.Run("Basic Operations", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}
expiration := 1 * time.Second
// Test Set and Get
tc.Set(token, claims, expiration)
gotClaims, found := tc.Get(token)
if !found {
t.Error("Expected to find token in cache")
}
if len(gotClaims) != len(claims) {
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
}
for k, v := range claims {
if gotClaims[k] != v {
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
}
}
// Test Delete
tc.Delete(token)
_, found = tc.Get(token)
if found {
t.Error("Expected token to be deleted")
}
})
t.Run("Expiration", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 10 * time.Millisecond
// Set with short expiration
tc.Set(token, claims, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired token
_, found := tc.Get(token)
if found {
t.Error("Expected token to be expired")
}
})
t.Run("Cleanup", func(t *testing.T) {
tc := NewTokenCache()
// Add multiple tokens with different expirations
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
// Wait for some tokens to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
tc.Cleanup()
// Check expired tokens are removed
_, found1 := tc.Get("expired1")
_, found2 := tc.Get("expired2")
_, found3 := tc.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid token to remain in cache")
}
})
t.Run("Token Prefix", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 1 * time.Second
// Set token
tc.Set(token, claims, expiration)
// Verify internal storage uses prefix
_, found := tc.cache.Get("t-" + token)
if !found {
t.Error("Expected to find prefixed token in underlying cache")
}
})
}
+1 -1
View File
@@ -6,7 +6,7 @@ toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.4.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
)
+2 -2
View File
@@ -4,7 +4,7 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+178 -168
View File
@@ -13,11 +13,30 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
// generateNonce generates a random nonce
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
// the authentication request.
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -27,15 +46,33 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
type TokenResponse struct {
// IDToken is the OIDC ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth 2.0 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the lifetime in seconds of the access token
ExpiresIn int `json:"expires_in"`
// TokenType is the type of token, typically "Bearer"
TokenType string `json:"token_type"`
}
// exchangeTokens exchanges a code or refresh token for tokens
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
// It supports both authorization code and refresh token grant types.
// Parameters:
// - ctx: Context for the HTTP request
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
@@ -75,16 +112,8 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
return &tokenResponse, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
// This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
@@ -93,108 +122,25 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
}
t.logger.Debugf("Token response: %+v", tokenResponse)
return tokenResponse, nil
}
// handleLogout handles the logout process
func (t *TraefikOidc) handleLogout(w http.ResponseWriter, r *http.Request) {
session, err := t.store.Get(r, cookieName)
if err != nil {
handleError(w, fmt.Sprintf("Error getting session: %v", err), http.StatusInternalServerError, t.logger)
return
}
// Get tokens from session
idToken, _ := session.Values["id_token"].(string)
refreshToken, _ := session.Values["refresh_token"].(string)
accessToken, _ := session.Values["access_token"].(string)
// Revoke tokens if they exist
if refreshToken != "" {
t.RevokeTokenWithProvider(refreshToken, "refresh_token")
t.RevokeToken(refreshToken)
}
if accessToken != "" {
t.RevokeTokenWithProvider(accessToken, "access_token")
t.RevokeToken(accessToken)
}
// Clear session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
if err := session.Save(r, w); err != nil {
handleError(w, fmt.Sprintf("Error saving session: %v", err), http.StatusInternalServerError, t.logger)
return
}
// Determine redirect URL
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
scheme := "http"
if r.Header.Get("X-Forwarded-Proto") == "https" || t.forceHTTPS {
scheme = "https"
}
baseURL := fmt.Sprintf("%s://%s/", scheme, host)
if t.endSessionURL != "" && idToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, idToken, baseURL)
if err != nil {
handleError(w, fmt.Sprintf("Invalid end session URL: %v", err), http.StatusInternalServerError, t.logger)
return
}
http.Redirect(w, r, logoutURL, http.StatusFound)
return
}
http.Redirect(w, r, baseURL, http.StatusFound)
}
// BuildLogoutURL constructs the logout URL with proper encoding
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("invalid end session URL: %v", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
u.RawQuery = q.Encode()
return u.String(), nil
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -203,7 +149,7 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
@@ -211,26 +157,28 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Validate the state parameter matches the session's CSRF token
// Validate CSRF state
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -238,56 +186,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
// Verify tokens and claims
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the nonce claim matches the one stored in session
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
// Validate user's email domain
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -295,16 +236,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
@@ -312,18 +248,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
// Redirect to original path or root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims extracts claims from a JWT token
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -343,27 +278,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
mutex sync.RWMutex
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -371,7 +311,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -383,25 +323,29 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache caches tokens
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
type TokenCache struct {
// cache is the underlying cache implementation
cache *Cache
}
// NewTokenCache creates a new TokenCache
// NewTokenCache creates a new TokenCache instance.
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set sets a token in the cache
// Set stores a token's claims in the cache with an expiration time.
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token from the cache
// Get retrieves a token's claims from the cache.
// Returns the claims and a boolean indicating if the token was found.
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
@@ -412,28 +356,29 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache
// Delete removes a token from the cache.
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup cleans up expired tokens from the cache
// Cleanup removes expired tokens from the cache.
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
// exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
// createStringMap creates a map from a slice of strings.
// Used for efficient lookups in allowed domains and roles.
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
@@ -441,3 +386,68 @@ func createStringMap(keys []string) map[string]struct{} {
}
return result
}
// handleLogout manages the OIDC logout process.
// It clears the session and redirects either to the OIDC provider's
// end session endpoint (if available) or to the configured post-logout URL.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
accessToken := session.GetAccessToken()
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
}
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
// Parameters:
// - endSessionURL: The OIDC provider's end session endpoint
// - idToken: The ID token to be invalidated
// - postLogoutRedirectURI: Where to redirect after logout completes
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
+66 -15
View File
@@ -16,37 +16,74 @@ import (
"time"
)
// JWK represents a JSON Web Key
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
type JWKCache struct {
jwks *JWKSet
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
expiresAt time.Time
mutex sync.RWMutex
// mutex protects concurrent access to the cache
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -73,7 +110,14 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -93,7 +137,9 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JWK to PEM format
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -109,7 +155,9 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -141,7 +189,10 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC JWK to PEM
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
+63 -11
View File
@@ -16,15 +16,31 @@ import (
"time"
)
// JWT represents a JSON Web Token
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Signature []byte
Token string
// Token is the original JWT string
Token string
}
// parseJWT parses a JWT token string into a JWT struct
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -63,7 +79,14 @@ func parseJWT(tokenString string) (*JWT, error) {
return jwt, nil
}
// Verify verifies the standard claims in the JWT
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future
// - issued at time (iat) is in the past
// - subject (sub) is present and not empty
// Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
@@ -107,7 +130,13 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience verifies the audience claim
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
@@ -131,7 +160,12 @@ func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
return nil
}
// verifyIssuer verifies the issuer claim
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
@@ -139,7 +173,11 @@ func verifyIssuer(tokenIssuer, expectedIssuer string) error {
return nil
}
// verifyExpiration checks if the token has expired
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time.
// Parameters:
// - expiration: The expiration timestamp from the token
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
@@ -148,7 +186,12 @@ func verifyExpiration(expiration float64) error {
return nil
}
// verifyIssuedAt checks if the token was issued in the future
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, which could
// indicate clock skew or a malicious token.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
@@ -157,7 +200,16 @@ func verifyIssuedAt(issuedAt float64) error {
return nil
}
// verifySignature verifies the token signature using the provided public key and algorithm
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
+120 -118
View File
@@ -10,11 +10,9 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
@@ -34,7 +32,6 @@ type JWTVerifier interface {
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -53,18 +50,19 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
@@ -84,14 +82,6 @@ var defaultExcludedURLs = map[string]struct{}{
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
@@ -185,9 +175,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -200,7 +187,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ExpectContinueTimeout: 0,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
@@ -219,7 +206,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -227,6 +213,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
}
return config.PostLogoutRedirectURI
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
@@ -243,9 +235,10 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -264,26 +257,43 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -353,96 +363,76 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Debug("OIDC middleware not yet initialized")
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
return
}
// Process the request as normal
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
case <-time.After(30 * time.Second):
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
return
}
// Check if the URL is excluded from authentication
// Check if URL is excluded
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
// Get session
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Handle logout URL
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
t.handleCallback(rw, req, redirectURL)
return
}
// Check if the user is authenticated
// Check authentication status
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
}
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
// Process authenticated request
email := session.GetEmail()
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.logger.Debug("No email found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -452,11 +442,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -465,6 +454,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -475,13 +465,15 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -520,37 +512,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(idToken); err != nil {
if err := t.verifyToken(accessToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(idToken)
claims, err := extractClaims(accessToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true // Can't read claims, consider it expired
return false, false, true
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
}
now := time.Now().Unix()
@@ -558,7 +547,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true // Token has expired
return false, false, true
}
gracePeriod := time.Minute * 5
@@ -567,26 +556,23 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false // Token is valid and not expiring soon
return true, false, false
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -595,7 +581,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build the authentication URL
// Build and redirect to auth URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -679,10 +665,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -693,16 +679,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
// Verify the new access token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
t.logger.Errorf("Failed to verify new access token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
@@ -767,3 +754,18 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
+468 -158
View File
@@ -1,6 +1,7 @@
package traefikoidc
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
@@ -22,13 +23,14 @@ import (
// TestSuite holds common test data and setup
type TestSuite struct {
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
t *testing.T
rsaPrivateKey *rsa.PrivateKey
rsaPublicKey *rsa.PublicKey
ecPrivateKey *ecdsa.PrivateKey
tOidc *TraefikOidc
mockJWKCache *MockJWKCache
token string
sessionManager *SessionManager
}
// Setup initializes the test suite
@@ -78,6 +80,9 @@ func (ts *TestSuite) Setup() {
ts.t.Fatalf("Failed to create test JWT: %v", err)
}
logger := NewLogger("info")
ts.sessionManager = NewSessionManager("test-secret-key", false, logger)
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
@@ -89,13 +94,13 @@ func (ts *TestSuite) Setup() {
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: logger,
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
sessionManager: ts.sessionManager,
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
@@ -104,7 +109,7 @@ func (ts *TestSuite) Setup() {
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
func (ts *TestSuite) exchangeCodeForTokenFunc(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -257,6 +262,7 @@ func TestServeHTTP(t *testing.T) {
sessionValues map[interface{}]interface{}
expectedStatus int
expectedBody string
setupSession func(*SessionData)
}{
{
name: "Excluded URL",
@@ -272,10 +278,10 @@ func TestServeHTTP(t *testing.T) {
{
name: "Authenticated request to protected URL",
requestPath: "/protected",
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
},
expectedStatus: http.StatusOK,
expectedBody: "OK",
@@ -283,52 +289,52 @@ func TestServeHTTP(t *testing.T) {
{
name: "Logout URL",
requestPath: "/logout",
sessionValues: map[interface{}]interface{}{
"authenticated": true,
"email": "user@example.com",
"id_token": ts.token,
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
session.SetAccessToken(ts.token)
},
expectedStatus: http.StatusOK,
expectedBody: "Logged out\n",
expectedBody: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a request
req := httptest.NewRequest("GET", tc.requestPath, nil)
req.Header.Set("X-Forwarded-Proto", "http")
req.Header.Set("X-Forwarded-Host", "localhost")
// Create a temporary response recorder to save the session
rrSession := httptest.NewRecorder()
// Create a session
session, _ := ts.tOidc.store.New(req, cookieName)
if tc.sessionValues != nil {
for k, v := range tc.sessionValues {
session.Values[k] = v
}
session.Save(req, rrSession)
}
// Copy session cookie from rrSession to request
for _, cookie := range rrSession.Result().Cookies() {
req.AddCookie(cookie)
}
// Create a response recorder for ServeHTTP
rr := httptest.NewRecorder()
// Setup session if needed
session, err := ts.tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
rr = httptest.NewRecorder()
}
// Call ServeHTTP
ts.tOidc.ServeHTTP(rr, req)
// Check the response
// Check response
if rr.Code != tc.expectedStatus {
t.Errorf("Test %s: expected status %d, got %d", tc.name, tc.expectedStatus, rr.Code)
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
if tc.expectedBody != "" && strings.TrimSpace(rr.Body.String()) != strings.TrimSpace(rr.Body.String()) {
t.Errorf("Test %s: expected body '%s', got '%s'", tc.name, tc.expectedBody, rr.Body.String())
if tc.expectedBody != "" {
if body := strings.TrimSpace(rr.Body.String()); body != tc.expectedBody {
t.Errorf("Expected body %q, got %q", tc.expectedBody, body)
}
}
})
}
@@ -452,18 +458,20 @@ func TestHandleCallback(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
redirectURL := "http://example.com/"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
sessionSetupFunc func(*SessionData)
expectedStatus int
}{
{
name: "Success",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -475,49 +483,49 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Exchange Code Error",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing ID Token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Disallowed Email",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -529,16 +537,16 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusForbidden,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -550,16 +558,16 @@ func TestHandleCallback(t *testing.T) {
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -571,16 +579,16 @@ func TestHandleCallback(t *testing.T) {
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
@@ -592,9 +600,9 @@ func TestHandleCallback(t *testing.T) {
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
sessionSetupFunc: func(session *SessionData) {
session.SetCSRF("test-csrf-token")
session.SetNonce("test-nonce")
},
expectedStatus: http.StatusInternalServerError,
},
@@ -602,15 +610,18 @@ func TestHandleCallback(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
// Create a new instance for each test to avoid state carryover
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
logger: NewLogger("info"),
logger: logger,
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
sessionManager: sessionManager,
}
// Create request and response recorder
@@ -618,22 +629,27 @@ func TestHandleCallback(t *testing.T) {
rr := httptest.NewRecorder()
// Create session
session, _ := tOidc.store.New(req, cookieName)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
session.Save(req, rr)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy session cookie to request
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset rr for the actual test
// Reset response recorder for the actual test
rr = httptest.NewRecorder()
// Call handleCallback
tOidc.handleCallback(rr, req)
tOidc.handleCallback(rr, req, redirectURL)
// Check response
if rr.Code != tc.expectedStatus {
@@ -688,7 +704,7 @@ func TestOIDCHandler(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (*TokenResponse, error)
exchangeCodeForToken func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
@@ -704,7 +720,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -728,7 +744,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -751,7 +767,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -775,7 +791,7 @@ func TestOIDCHandler(t *testing.T) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
exchangeCodeForToken: func(code string, redirectURL string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
@@ -847,7 +863,7 @@ func TestHandleLogout(t *testing.T) {
tests := []struct {
name string
setupSession func(*sessions.Session)
setupSession func(*SessionData)
endSessionURL string
expectedStatus int
expectedURL string
@@ -855,25 +871,22 @@ func TestHandleLogout(t *testing.T) {
}{
{
name: "Successful logout with end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
},
endSessionURL: "https://provider/end-session",
expectedStatus: http.StatusFound,
// Fix: The entire URL should be URL-encoded
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
expectedURL: "https://provider/end-session?id_token_hint=test.id.token&post_logout_redirect_uri=http%3A%2F%2Fexample.com%2F",
host: "test-host",
},
{
name: "Successful logout without end session endpoint",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
session.Values["refresh_token"] = "test-refresh-token"
session.Values["access_token"] = "test-access-token"
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
},
endSessionURL: "",
expectedStatus: http.StatusFound,
@@ -882,16 +895,17 @@ func TestHandleLogout(t *testing.T) {
},
{
name: "Logout with empty session",
setupSession: func(session *sessions.Session) {},
setupSession: func(session *SessionData) {},
expectedStatus: http.StatusFound,
expectedURL: "http://example.com/",
host: "test-host",
},
{
name: "Logout with invalid end session URL",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "test.id.token"
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("test.id.token")
session.SetRefreshToken("test-refresh-token")
},
endSessionURL: ":\\invalid-url",
expectedStatus: http.StatusInternalServerError,
@@ -901,19 +915,20 @@ func TestHandleLogout(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
revocationURL: mockRevocationServer.URL,
endSessionURL: tc.endSessionURL,
scheme: "http",
logger: NewLogger("info"),
logger: logger,
tokenBlacklist: NewTokenBlacklist(),
httpClient: &http.Client{},
clientID: "test-client-id",
clientSecret: "test-client-secret",
tokenCache: NewTokenCache(),
forceHTTPS: false,
sessionManager: sessionManager,
}
// Create request with proper headers
@@ -924,16 +939,18 @@ func TestHandleLogout(t *testing.T) {
rr := httptest.NewRecorder()
// Get a session
session, err := tOidc.store.Get(req, cookieName)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
if tc.setupSession != nil {
tc.setupSession(session)
}
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Setup session
tc.setupSession(session)
session.Save(req, rr)
// Copy session cookie to request
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
@@ -949,7 +966,6 @@ func TestHandleLogout(t *testing.T) {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check redirect URL if expected
if tc.expectedURL != "" {
location := rr.Header().Get("Location")
if location != tc.expectedURL {
@@ -958,23 +974,31 @@ func TestHandleLogout(t *testing.T) {
}
// Verify session is cleared
newSession, _ := tOidc.store.Get(req, cookieName)
if len(newSession.Values) > 0 {
t.Error("Session was not cleared")
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
}
if newSession.Options.MaxAge != -1 {
t.Error("Session MaxAge was not set to -1")
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
if updatedSession.GetAuthenticated() {
t.Error("Session still marked as authenticated")
}
// Check token blacklist
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(refreshToken) {
t.Error("Refresh token was not blacklisted")
if token := session.GetAccessToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Access token was not blacklisted")
}
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(accessToken) {
t.Error("Access token was not blacklisted")
if token := session.GetRefreshToken(); token != "" {
if !tOidc.tokenBlacklist.IsBlacklisted(token) {
t.Error("Refresh token was not blacklisted")
}
}
})
@@ -1155,24 +1179,24 @@ func TestHandleExpiredToken(t *testing.T) {
tests := []struct {
name string
setupSession func(*sessions.Session)
setupSession func(*SessionData)
expectedPath string
}{
{
name: "Basic expired token",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["email"] = "test@example.com"
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.SetEmail("test@example.com")
},
expectedPath: "/original/path",
},
{
name: "Session with additional values",
setupSession: func(session *sessions.Session) {
session.Values["authenticated"] = true
session.Values["id_token"] = "expired.token"
session.Values["custom_value"] = "should-be-cleared"
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetAccessToken("expired.token")
session.mainSession.Values["custom_value"] = "should-be-cleared"
},
expectedPath: "/another/path",
},
@@ -1180,17 +1204,16 @@ func TestHandleExpiredToken(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create a new TraefikOidc instance for each test
logger := NewLogger("info")
sessionManager := NewSessionManager("test-secret-key", false, logger)
tOidc := &TraefikOidc{
store: sessions.NewCookieStore([]byte("test-secret-key")),
logger: NewLogger("info"),
redirectURL: "http://example.com/callback",
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
// Add this initialization of initiateAuthenticationFunc
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Mock implementation for test
sessionManager: sessionManager,
logger: logger,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
initComplete: make(chan struct{}),
initiateAuthenticationFunc: func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
http.Redirect(rw, req, "/login", http.StatusFound)
},
}
@@ -1201,31 +1224,40 @@ func TestHandleExpiredToken(t *testing.T) {
rr := httptest.NewRecorder()
// Get session
session, _ := tOidc.store.New(req, cookieName)
session, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Setup session data
tc.setupSession(session)
// Handle expired token
tOidc.handleExpiredToken(rr, req, session)
tOidc.handleExpiredToken(rr, req, session, tc.expectedPath)
// Verify session is cleaned
if len(session.Values) != 3 { // Should only have csrf, incoming_path, and nonce
t.Errorf("Expected 3 session values, got %d", len(session.Values))
// Get the updated session to verify changes
updatedSession, err := sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get updated session: %v", err)
}
// Verify required values are set
if _, ok := session.Values["csrf"].(string); !ok {
// Verify main session values
if updatedSession.GetCSRF() == "" {
t.Error("CSRF token not set")
}
if path, ok := session.Values["incoming_path"].(string); !ok || path != tc.expectedPath {
if path := updatedSession.GetIncomingPath(); path != tc.expectedPath {
t.Errorf("Expected path %s, got %s", tc.expectedPath, path)
}
if _, ok := session.Values["nonce"].(string); !ok {
if updatedSession.GetNonce() == "" {
t.Error("Nonce not set")
}
// Verify session options
if session.Options.MaxAge != defaultSessionOptions.MaxAge {
t.Error("Session MaxAge not set correctly")
// Verify tokens are cleared
if token := updatedSession.GetAccessToken(); token != "" {
t.Error("Access token not cleared")
}
if token := updatedSession.GetRefreshToken(); token != "" {
t.Error("Refresh token not cleared")
}
// Verify redirect status
@@ -1311,6 +1343,284 @@ func TestExtractGroupsAndRoles(t *testing.T) {
}
}
// TestMultipleMiddlewareInstances verifies that multiple middleware instances
// can be created and initialized properly for different routes
func TestMultipleMiddlewareInstances(t *testing.T) {
// Create mock provider metadata server
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
metadata := ProviderMetadata{
Issuer: "https://test-issuer.com",
AuthURL: "https://test-issuer.com/auth",
TokenURL: "https://test-issuer.com/token",
JWKSURL: "https://test-issuer.com/jwks",
RevokeURL: "https://test-issuer.com/revoke",
EndSessionURL: "https://test-issuer.com/end-session",
}
json.NewEncoder(w).Encode(metadata)
}))
defer mockServer.Close()
// Create base config
config := &Config{
ProviderURL: mockServer.URL,
ClientID: "test-client",
ClientSecret: "test-secret",
CallbackURL: "/callback",
SessionEncryptionKey: "test-encryption-key-thats-long-enough",
}
// Create multiple middleware instances
routes := []string{"/api/v1", "/api/v2", "/api/v3"}
var middlewares []*TraefikOidc
for _, route := range routes {
config.CallbackURL = route + "/callback"
middleware, err := New(context.Background(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}), config, "test")
if err != nil {
t.Fatalf("Failed to create middleware for route %s: %v", route, err)
}
// Type assert to access internal fields
if m, ok := middleware.(*TraefikOidc); ok {
middlewares = append(middlewares, m)
} else {
t.Fatalf("Middleware is not of type *TraefikOidc")
}
}
// Wait for all instances to initialize
for i, m := range middlewares {
select {
case <-m.initComplete:
case <-time.After(5 * time.Second):
t.Fatalf("Middleware instance %d failed to initialize", i)
}
// Verify each instance has its own unique configuration
if m.issuerURL != "https://test-issuer.com" {
t.Errorf("Instance %d: Expected issuer URL %s, got %s", i, "https://test-issuer.com", m.issuerURL)
}
if m.authURL != "https://test-issuer.com/auth" {
t.Errorf("Instance %d: Expected auth URL %s, got %s", i, "https://test-issuer.com/auth", m.authURL)
}
if m.tokenURL != "https://test-issuer.com/token" {
t.Errorf("Instance %d: Expected token URL %s, got %s", i, "https://test-issuer.com/token", m.tokenURL)
}
if m.jwksURL != "https://test-issuer.com/jwks" {
t.Errorf("Instance %d: Expected JWKS URL %s, got %s", i, "https://test-issuer.com/jwks", m.jwksURL)
}
if m.redirURLPath != routes[i]+"/callback" {
t.Errorf("Instance %d: Expected callback URL %s, got %s", i, routes[i]+"/callback", m.redirURLPath)
}
}
// Test that each instance can handle requests independently
for i, m := range middlewares {
req := httptest.NewRequest("GET", routes[i]+"/protected", nil)
rr := httptest.NewRecorder()
m.ServeHTTP(rr, req)
// Should redirect to auth URL since not authenticated
if rr.Code != http.StatusFound {
t.Errorf("Instance %d: Expected redirect status %d, got %d", i, http.StatusFound, rr.Code)
}
location := rr.Header().Get("Location")
if !strings.Contains(location, "https://test-issuer.com/auth") {
t.Errorf("Instance %d: Expected redirect to auth URL, got %s", i, location)
}
}
}
func TestServeHTTPRolesAndGroups(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
allowedRolesAndGroups map[string]struct{}
claims map[string]interface{}
setupSession func(*SessionData)
expectedStatus int
expectedHeaders map[string]string
}{
{
name: "User with allowed role",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"admin", "user"},
"groups": []interface{}{"group1"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "admin,user",
"X-User-Groups": "group1",
},
},
{
name: "User with allowed group",
allowedRolesAndGroups: map[string]struct{}{
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"allowed-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "allowed-group",
},
},
{
name: "User without allowed roles or groups",
allowedRolesAndGroups: map[string]struct{}{
"admin": {},
"allowed-group": {},
},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusForbidden,
},
{
name: "No role/group restrictions",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
"roles": []interface{}{"user"},
"groups": []interface{}{"regular-group"},
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{
"X-User-Roles": "user",
"X-User-Groups": "regular-group",
},
},
{
name: "Claims without roles and groups",
allowedRolesAndGroups: map[string]struct{}{},
claims: map[string]interface{}{
"iss": "https://test-issuer.com",
"aud": "test-client-id",
"exp": time.Now().Add(1 * time.Hour).Unix(),
"iat": time.Now().Unix(),
"sub": "test-subject",
},
setupSession: func(session *SessionData) {
session.SetAuthenticated(true)
session.SetEmail("user@example.com")
},
expectedStatus: http.StatusOK,
expectedHeaders: map[string]string{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Create token with claims
token, err := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", tc.claims)
if err != nil {
t.Fatalf("Failed to create test token: %v", err)
}
// Create test handler
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Configure OIDC middleware
tOidc := ts.tOidc
tOidc.next = nextHandler
tOidc.allowedRolesAndGroups = tc.allowedRolesAndGroups
// Create request
req := httptest.NewRequest("GET", "/protected", nil)
rr := httptest.NewRecorder()
// Set up session
session, err := tOidc.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
tc.setupSession(session)
session.SetAccessToken(token)
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Copy cookies to the new request
for _, cookie := range rr.Result().Cookies() {
req.AddCookie(cookie)
}
// Reset response recorder
rr = httptest.NewRecorder()
// Serve request
tOidc.ServeHTTP(rr, req)
// Check status code
if rr.Code != tc.expectedStatus {
t.Errorf("Expected status %d, got %d", tc.expectedStatus, rr.Code)
}
// Check headers if status is OK
if tc.expectedStatus == http.StatusOK {
for header, expectedValue := range tc.expectedHeaders {
if value := req.Header.Get(header); value != expectedValue {
t.Errorf("Expected header %s to be %s, got %s", header, expectedValue, value)
}
}
}
})
}
}
// Helper function to compare string slices
func stringSliceEqual(a, b []string) bool {
if len(a) != len(b) {
+456
View File
@@ -0,0 +1,456 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
)
// Cookie names and configuration constants used for session management
const (
// mainCookieName is the name of the main session cookie that stores authentication state
// and basic user information like email and CSRF tokens
mainCookieName = "_raczylo_oidc"
// accessTokenCookie is the name of the cookie that stores the OIDC access token
// This may be split into multiple cookies if the token is large
accessTokenCookie = "_raczylo_oidc_access"
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
// This may be split into multiple cookies if the token is large
refreshTokenCookie = "_raczylo_oidc_refresh"
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
// 3. Calculation:
// - Let x be the chunk size
// - After encryption: x + 28 bytes
// - After base64: ((x + 28) * 4/3) bytes
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
)
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
type SessionManager struct {
// store is the underlying session store for cookie management
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
forceHTTPS bool
// logger provides structured logging capabilities
logger *Logger
}
// NewSessionManager creates a new session manager with the specified configuration.
// Parameters:
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
// - Using SameSite=Lax for CSRF protection
// - Set with appropriate timeout and path settings
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// GetSession retrieves all session data for the current request.
// It loads the main session and token sessions, including any chunked token data,
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
if err != nil {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
request: r,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
}
// Retrieve chunked access token sessions
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
// Retrieve chunked refresh token sessions
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
return sessionData, nil
}
// getTokenChunkSessions retrieves all session chunks for a given token type.
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// Returns a map of chunk index to session, used for handling large tokens
// that exceed single cookie size limits.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
chunks := make(map[int]*sessions.Session)
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
return chunks
}
// SessionData holds all session information for an authenticated user.
// It manages multiple session cookies to handle the main session state
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
manager *SessionManager
// request is the current HTTP request associated with this session
request *http.Request
// mainSession stores authentication state and basic user info
mainSession *sessions.Session
// accessSession stores the primary access token cookie
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
refreshTokenChunks map[int]*sessions.Session
}
// Save persists all session data to cookies in the HTTP response.
// It saves the main session, token sessions, and any token chunks,
// applying appropriate security options to each cookie. All cookies
// are saved with consistent security settings based on the request scheme.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token chunk session: %w", err)
}
}
// Save refresh token chunks
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
}
}
return nil
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
return sd.Save(r, w)
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
}
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication,
// false otherwise or if the authentication status cannot be determined.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
}
// SetAuthenticated updates the session's authentication status.
// This should be called after successful OIDC authentication or during logout.
func (sd *SessionData) SetAuthenticated(value bool) {
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.accessTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.accessTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
sd.accessTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.accessSession.Values["token"] = token
} else {
// Split token into chunks
sd.accessSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.accessTokenChunks[i] = session
}
}
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.refreshTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.refreshTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
sd.refreshTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.refreshSession.Values["token"] = token
} else {
// Split token into chunks
sd.refreshSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.refreshTokenChunks[i] = session
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
if len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+129
View File
@@ -0,0 +1,129 @@
package traefikoidc
import (
"net/http/httptest"
"strings"
"testing"
)
// TestSessionManager tests the SessionManager functionality
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
// Recalculate expected cookies based on new maxCookieSize
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved")
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Calculate number of chunks for access token
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
if accessChunks > 1 {
count += accessChunks
}
// Calculate number of chunks for refresh token
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
if refreshChunks > 1 {
count += refreshChunks
}
return count
}
+144 -40
View File
@@ -5,44 +5,93 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"github.com/gorilla/sessions"
"strings"
)
const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHTTPS"`
RateLimit int `json:"rateLimit"`
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
// ProviderURL is the base URL of the OIDC provider (required)
// Example: https://accounts.google.com
ProviderURL string `json:"providerURL"`
// RevocationURL is the endpoint for revoking tokens (optional)
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
// LogoutURL is the path for handling logout requests (optional)
// If not provided, it will be set to CallbackURL + "/logout"
LogoutURL string `json:"logoutURL"`
// ClientID is the OAuth 2.0 client identifier (required)
ClientID string `json:"clientID"`
// ClientSecret is the OAuth 2.0 client secret (required)
ClientSecret string `json:"clientSecret"`
// Scopes defines the OAuth 2.0 scopes to request (optional)
// Defaults to ["openid", "profile", "email"] if not provided
Scopes []string `json:"scopes"`
// LogLevel sets the logging verbosity (optional)
// Valid values: "debug", "info", "error"
// Default: "info"
LogLevel string `json:"logLevel"`
// SessionEncryptionKey is used to encrypt session data (required)
// Must be a secure random string
SessionEncryptionKey string `json:"sessionEncryptionKey"`
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
// Default: false
ForceHTTPS bool `json:"forceHTTPS"`
// RateLimit sets the maximum number of requests per second (optional)
// Default: 100
RateLimit int `json:"rateLimit"`
// ExcludedURLs lists paths that bypass authentication (optional)
// Example: ["/health", "/metrics"]
ExcludedURLs []string `json:"excludedURLs"`
// AllowedUserDomains restricts access to specific email domains (optional)
// Example: ["company.com", "subsidiary.com"]
AllowedUserDomains []string `json:"allowedUserDomains"`
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
// Example: ["admin", "developer"]
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
HTTPClient *http.Client
// OIDCEndSessionURL is the provider's end session endpoint (optional)
// If not provided, it will be discovered from provider metadata
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
// Default: "/"
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
// CreateConfig creates a new Config with default values
// CreateConfig creates a new Config with sensible default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
func CreateConfig() *Config {
c := &Config{}
@@ -65,14 +114,22 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
// Validate performs validation checks on the Config.
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid URL")
}
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
@@ -82,25 +139,58 @@ func (c *Config) Validate() error {
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < 32 {
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
}
if c.RateLimit < 0 {
return fmt.Errorf("rateLimit must be non-negative")
}
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
return nil
}
// Logger is a simple logger with different levels
// isValidURL checks if the provided string is a valid URL
func isValidURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
func isValidLogLevel(level string) bool {
return level == "debug" || level == "info" || level == "error"
}
// Logger provides structured logging capabilities with different severity levels.
// It supports error, info, and debug levels with appropriate output streams
// and formatting for each level.
type Logger struct {
// logError handles error-level messages, writing to stderr
logError *log.Logger
logInfo *log.Logger
// logInfo handles informational messages, writing to stdout
logInfo *log.Logger
// logDebug handles debug-level messages, writing to stdout when debug is enabled
logDebug *log.Logger
}
// NewLogger creates a new Logger
// NewLogger creates a new Logger with the specified log level.
// The log level determines which messages are output:
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
logInfo.SetOutput(os.Stdout)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
@@ -112,37 +202,51 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
// Info logs an informational message.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
// Debug logs a debug message.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
// Error logs an error message.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
// Infof logs an informational message using Printf formatting.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
// Debugf logs a debug message using Printf formatting.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
// Errorf logs an error message using Printf formatting.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
// handleError writes an error message to both the HTTP response and the error log.
// It ensures consistent error handling across the middleware by logging the error
// and sending an appropriate HTTP response to the client.
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
+362
View File
@@ -0,0 +1,362 @@
package traefikoidc
import (
"bytes"
"log"
"net/http"
"testing"
)
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
// Check default scopes
expectedScopes := []string{"openid", "profile", "email"}
if len(config.Scopes) != len(expectedScopes) {
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
}
for i, scope := range expectedScopes {
if config.Scopes[i] != scope {
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
}
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
})
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
config *Config
expectedError string
}{
{
name: "Empty Config",
config: &Config{},
expectedError: "providerURL is required",
},
{
name: "Missing CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
},
expectedError: "callbackURL is required",
},
{
name: "Missing ClientID",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
},
expectedError: "clientID is required",
},
{
name: "Missing ClientSecret",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
},
expectedError: "clientSecret is required",
},
{
name: "Missing SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
},
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
},
{
name: "Invalid CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "callback", // Missing leading slash
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "callbackURL must start with /",
},
{
name: "Short SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "short",
},
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
},
expectedError: "rateLimit must be non-negative",
},
{
name: "Invalid LogLevel",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "invalid",
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Valid Config",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
},
expectedError: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
}
}
})
}
}
func TestLogger(t *testing.T) {
// Capture log output
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
name: "Debug Level",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut == "" {
t.Error("Expected debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Info Level",
logLevel: "info",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Error Level",
logLevel: "error",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut != "" {
t.Error("Did not expect info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Printf Methods",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debugf("debug %s", "formatted")
l.Infof("info %s", "formatted")
l.Errorf("error %s", "formatted")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
t.Error("Expected formatted debug message")
}
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
t.Error("Expected formatted info message")
}
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
t.Error("Expected formatted error message")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset buffers
debugBuf.Reset()
infoBuf.Reset()
errorBuf.Reset()
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
if tc.logLevel == "debug" {
logger.logDebug.SetOutput(&debugBuf)
}
// Run test
tc.testFunc(logger)
// Check results
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
})
}
}
func TestHandleError(t *testing.T) {
// Create a test logger with captured output
var errorBuf bytes.Buffer
logger := &Logger{
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
}
logger.logError.SetOutput(&errorBuf)
// Create a test response recorder
rr := &testResponseRecorder{
headers: make(map[string][]string),
}
// Test error handling
message := "test error message"
code := 400
handleError(rr, message, code, logger)
// Check response code
if rr.statusCode != code {
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
}
// Check response body
expectedBody := message + "\n"
if rr.body != expectedBody {
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
}
// Check error was logged
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
t.Error("Error message was not logged")
}
}
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
}
func (r *testResponseRecorder) Header() http.Header {
return r.headers
}
func (r *testResponseRecorder) Write(b []byte) (int, error) {
r.body = string(b)
return len(b), nil
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+1 -5
View File
@@ -1,7 +1,4 @@
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
# sessions
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+9 -12
View File
@@ -1,6 +1,5 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
@@ -9,15 +8,13 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
+21
View File
@@ -0,0 +1,21 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+5 -10
View File
@@ -1,11 +1,8 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -16,9 +13,7 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
MaxAge int
Secure bool
HttpOnly bool
}
+23
View File
@@ -0,0 +1,23 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -4,8 +4,8 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
## explicit; go 1.18