mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
Multiple improvements.
* Add todo list. * fixup! Add todo list. * fixup! fixup! Add todo list. * fixup! fixup! fixup! Add todo list. * Improve the session handling and cache. * Fix an issue where expired session can cause infinite redirect loop * fixup! Fix an issue where expired session can cause infinite redirect loop * Add semver setup for automatic releases. * fixup! Add semver setup for automatic releases. * fixup! fixup! Add semver setup for automatic releases. * fixup! fixup! fixup! Add semver setup for automatic releases.
This commit is contained in:
@@ -0,0 +1,4 @@
|
|||||||
|
## TODO / wishlist
|
||||||
|
|
||||||
|
- [x] Improve caching mechanism
|
||||||
|
- [x] Add automatic release and semver generation
|
||||||
@@ -1,172 +1,153 @@
|
|||||||
package traefikoidc
|
package traefikoidc
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CacheItem represents an item stored in the cache with its associated metadata.
|
// CacheItem represents an item stored in the cache with its associated metadata.
|
||||||
type CacheItem struct {
|
type CacheItem struct {
|
||||||
// Value is the cached data of any type
|
// Value is the cached data of any type.
|
||||||
Value interface{}
|
Value interface{}
|
||||||
|
|
||||||
// ExpiresAt is the timestamp when this item should be considered expired
|
// ExpiresAt is the timestamp when this item should be considered expired.
|
||||||
// and removed from the cache during cleanup operations
|
|
||||||
ExpiresAt time.Time
|
ExpiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
// lruEntry represents an entry in the LRU list.
|
||||||
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
|
type lruEntry struct {
|
||||||
type Cache struct {
|
key string
|
||||||
// items stores the cached data with string keys
|
|
||||||
items map[string]CacheItem
|
|
||||||
|
|
||||||
// mutex protects concurrent access to the items map
|
|
||||||
// Use RLock/RUnlock for reads and Lock/Unlock for writes
|
|
||||||
mutex sync.RWMutex
|
|
||||||
|
|
||||||
// maxSize is the maximum number of items allowed in the cache
|
|
||||||
maxSize int
|
|
||||||
|
|
||||||
// accessList maintains the order of item access for eviction
|
|
||||||
accessList []string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultMaxSize is the default maximum number of items in the cache
|
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
|
||||||
|
// It implements an LRU (Least Recently Used) eviction policy using a doubly-linked list for efficiency.
|
||||||
|
type Cache struct {
|
||||||
|
// items stores the cached data with string keys.
|
||||||
|
items map[string]CacheItem
|
||||||
|
|
||||||
|
// order maintains the usage order; most recently used items are at the back.
|
||||||
|
order *list.List
|
||||||
|
|
||||||
|
// elems maps keys to their corresponding list elements for O(1) access.
|
||||||
|
elems map[string]*list.Element
|
||||||
|
|
||||||
|
// mutex protects concurrent access to the cache.
|
||||||
|
mutex sync.RWMutex
|
||||||
|
|
||||||
|
// maxSize is the maximum number of items allowed in the cache.
|
||||||
|
maxSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultMaxSize is the default maximum number of items in the cache.
|
||||||
const DefaultMaxSize = 1000
|
const DefaultMaxSize = 1000
|
||||||
|
|
||||||
// NewCache creates a new empty cache instance.
|
// NewCache creates a new empty cache instance that is ready for use.
|
||||||
// The cache is immediately ready for use and is thread-safe.
|
|
||||||
func NewCache() *Cache {
|
func NewCache() *Cache {
|
||||||
return &Cache{
|
return &Cache{
|
||||||
items: make(map[string]CacheItem),
|
items: make(map[string]CacheItem, DefaultMaxSize),
|
||||||
maxSize: DefaultMaxSize,
|
order: list.New(),
|
||||||
accessList: make([]string, 0, DefaultMaxSize),
|
elems: make(map[string]*list.Element, DefaultMaxSize),
|
||||||
|
maxSize: DefaultMaxSize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set adds or updates an item in the cache with the specified expiration duration.
|
// Set adds or updates an item in the cache with the specified expiration duration.
|
||||||
// Parameters:
|
// It moves the item to the most recently used position.
|
||||||
// - key: Unique identifier for the cached item
|
|
||||||
// - value: The data to cache (can be of any type)
|
|
||||||
// - expiration: How long the item should remain in the cache
|
|
||||||
//
|
|
||||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
|
||||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
// If key exists, update it
|
now := time.Now()
|
||||||
|
expTime := now.Add(expiration)
|
||||||
|
|
||||||
|
// Update existing item.
|
||||||
if _, exists := c.items[key]; exists {
|
if _, exists := c.items[key]; exists {
|
||||||
c.items[key] = CacheItem{
|
c.items[key] = CacheItem{
|
||||||
Value: value,
|
Value: value,
|
||||||
ExpiresAt: time.Now().Add(expiration),
|
ExpiresAt: expTime,
|
||||||
|
}
|
||||||
|
if elem, ok := c.elems[key]; ok {
|
||||||
|
c.order.MoveToBack(elem)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If cache is full, remove oldest item
|
// Evict oldest item if cache is full.
|
||||||
if len(c.items) >= c.maxSize {
|
if len(c.items) >= c.maxSize {
|
||||||
c.evictOldest()
|
c.evictOldest()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new item
|
// Add new item.
|
||||||
c.items[key] = CacheItem{
|
c.items[key] = CacheItem{
|
||||||
Value: value,
|
Value: value,
|
||||||
ExpiresAt: time.Now().Add(expiration),
|
ExpiresAt: expTime,
|
||||||
}
|
}
|
||||||
c.accessList = append(c.accessList, key)
|
elem := c.order.PushBack(lruEntry{key: key})
|
||||||
|
c.elems[key] = elem
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves an item from the cache if it exists and hasn't expired.
|
// Get retrieves an item from the cache if it exists and hasn't expired.
|
||||||
// Parameters:
|
// Moving the accessed item to the most recently used position.
|
||||||
// - key: The identifier of the item to retrieve
|
|
||||||
//
|
|
||||||
// Returns:
|
|
||||||
// - value: The cached data (nil if not found or expired)
|
|
||||||
// - found: true if the item was found and is valid, false otherwise
|
|
||||||
//
|
|
||||||
// Thread-safe: Uses read locking to ensure safe concurrent access.
|
|
||||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||||
c.mutex.RLock()
|
|
||||||
item, found := c.items[key]
|
|
||||||
c.mutex.RUnlock()
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if time.Now().After(item.ExpiresAt) {
|
|
||||||
c.mutex.Lock()
|
|
||||||
c.removeItem(key)
|
|
||||||
c.mutex.Unlock()
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update access order
|
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
c.updateAccessOrder(key)
|
defer c.mutex.Unlock()
|
||||||
c.mutex.Unlock()
|
|
||||||
|
item, exists := c.items[key]
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for expiration.
|
||||||
|
if time.Now().After(item.ExpiresAt) {
|
||||||
|
c.removeItem(key)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move item to the back (most recently used).
|
||||||
|
if elem, ok := c.elems[key]; ok {
|
||||||
|
c.order.MoveToBack(elem)
|
||||||
|
}
|
||||||
|
|
||||||
return item.Value, true
|
return item.Value, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes an item from the cache if it exists.
|
// Delete removes an item from the cache.
|
||||||
// If the item doesn't exist, this operation is a no-op.
|
|
||||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
|
||||||
func (c *Cache) Delete(key string) {
|
func (c *Cache) Delete(key string) {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
delete(c.items, key)
|
|
||||||
|
c.removeItem(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup removes all expired items from the cache.
|
// Cleanup removes all expired items from the cache. This should be called periodically
|
||||||
// This should be called periodically to prevent memory leaks from
|
// to prevent memory bloat from expired entries.
|
||||||
// expired items that haven't been accessed (and thus not removed during Get operations).
|
|
||||||
// Thread-safe: Uses write locking to ensure safe concurrent access.
|
|
||||||
func (c *Cache) Cleanup() {
|
func (c *Cache) Cleanup() {
|
||||||
c.mutex.Lock()
|
c.mutex.Lock()
|
||||||
defer c.mutex.Unlock()
|
defer c.mutex.Unlock()
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var newAccessList []string
|
for key, item := range c.items {
|
||||||
|
if now.After(item.ExpiresAt) {
|
||||||
for _, key := range c.accessList {
|
c.removeItem(key)
|
||||||
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
|
|
||||||
newAccessList = append(newAccessList, key)
|
|
||||||
} else {
|
|
||||||
delete(c.items, key)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.accessList = newAccessList
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// evictOldest removes the least recently used item from the cache
|
// evictOldest removes the least recently used item from the cache.
|
||||||
func (c *Cache) evictOldest() {
|
func (c *Cache) evictOldest() {
|
||||||
if len(c.accessList) > 0 {
|
elem := c.order.Front()
|
||||||
oldest := c.accessList[0]
|
if elem != nil {
|
||||||
c.removeItem(oldest)
|
entry := elem.Value.(lruEntry)
|
||||||
|
c.removeItem(entry.key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeItem removes an item from both the cache and access list
|
// removeItem removes an item from both the cache and the LRU tracking structures.
|
||||||
func (c *Cache) removeItem(key string) {
|
func (c *Cache) removeItem(key string) {
|
||||||
delete(c.items, key)
|
delete(c.items, key)
|
||||||
for i, k := range c.accessList {
|
if elem, ok := c.elems[key]; ok {
|
||||||
if k == key {
|
c.order.Remove(elem)
|
||||||
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
|
delete(c.elems, key)
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateAccessOrder moves the accessed key to the end of the access list
|
|
||||||
func (c *Cache) updateAccessOrder(key string) {
|
|
||||||
for i, k := range c.accessList {
|
|
||||||
if k == key {
|
|
||||||
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
-20
@@ -12,28 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/sessions"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// newSessionOptions creates secure session cookie options.
|
|
||||||
// Parameters:
|
|
||||||
// - isSecure: Whether to set the Secure flag on cookies
|
|
||||||
//
|
|
||||||
// Returns session options configured for security with:
|
|
||||||
// - HttpOnly flag to prevent JavaScript access
|
|
||||||
// - SameSite=Lax for CSRF protection
|
|
||||||
// - Appropriate timeout and path settings
|
|
||||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
|
||||||
return &sessions.Options{
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: isSecure,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
MaxAge: ConstSessionTimeout,
|
|
||||||
Path: "/",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateNonce creates a cryptographically secure random nonce
|
// generateNonce creates a cryptographically secure random nonce
|
||||||
// for use in the OIDC authentication flow. The nonce is used to
|
// for use in the OIDC authentication flow. The nonce is used to
|
||||||
// prevent replay attacks by ensuring the token received matches
|
// prevent replay attacks by ensuring the token received matches
|
||||||
|
|||||||
@@ -62,7 +62,6 @@ type TraefikOidc struct {
|
|||||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||||
initComplete chan struct{}
|
initComplete chan struct{}
|
||||||
endSessionURL string
|
endSessionURL string
|
||||||
baseURL string
|
|
||||||
postLogoutRedirectURI string
|
postLogoutRedirectURI string
|
||||||
sessionManager *SessionManager
|
sessionManager *SessionManager
|
||||||
}
|
}
|
||||||
@@ -82,8 +81,6 @@ var defaultExcludedURLs = map[string]struct{}{
|
|||||||
"/favicon": {},
|
"/favicon": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
var newTicker = time.NewTicker
|
|
||||||
|
|
||||||
// VerifyToken verifies the provided JWT token
|
// VerifyToken verifies the provided JWT token
|
||||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||||
t.logger.Debugf("Verifying token")
|
t.logger.Debugf("Verifying token")
|
||||||
@@ -264,7 +261,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
|||||||
// Assign the initialized logger
|
// Assign the initialized logger
|
||||||
t.logger = logger
|
t.logger = logger
|
||||||
|
|
||||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
t.sessionManager, _ = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||||
t.extractClaimsFunc = extractClaims
|
t.extractClaimsFunc = extractClaims
|
||||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||||
@@ -531,9 +528,6 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
|||||||
|
|
||||||
// determineScheme determines the scheme (http or https) of the request
|
// determineScheme determines the scheme (http or https) of the request
|
||||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||||
if t.forceHTTPS {
|
|
||||||
return "https"
|
|
||||||
}
|
|
||||||
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
|
||||||
return scheme
|
return scheme
|
||||||
}
|
}
|
||||||
@@ -602,14 +596,17 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
|||||||
// defaultInitiateAuthentication initiates the authentication process
|
// defaultInitiateAuthentication initiates the authentication process
|
||||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||||
// Generate CSRF token and nonce
|
// Generate CSRF token and nonce
|
||||||
csrfToken := uuid.New().String()
|
csrfToken := uuid.NewString()
|
||||||
nonce, err := generateNonce()
|
nonce, err := generateNonce()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set session values
|
// Clear any existing session data to avoid stale state causing redirect loops
|
||||||
|
session.Clear(req, rw)
|
||||||
|
|
||||||
|
// Set new session values
|
||||||
session.SetCSRF(csrfToken)
|
session.SetCSRF(csrfToken)
|
||||||
session.SetNonce(nonce)
|
session.SetNonce(nonce)
|
||||||
session.SetIncomingPath(req.URL.RequestURI())
|
session.SetIncomingPath(req.URL.RequestURI())
|
||||||
@@ -621,7 +618,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build and redirect to auth URL
|
// Build and redirect to authentication URL
|
||||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||||
}
|
}
|
||||||
@@ -647,7 +644,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
|||||||
|
|
||||||
// startTokenCleanup starts the token cleanup goroutine
|
// startTokenCleanup starts the token cleanup goroutine
|
||||||
func (t *TraefikOidc) startTokenCleanup() {
|
func (t *TraefikOidc) startTokenCleanup() {
|
||||||
ticker := newTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
go func() {
|
go func() {
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
t.logger.Debug("Cleaning up token cache")
|
t.logger.Debug("Cleaning up token cache")
|
||||||
|
|||||||
+4
-4
@@ -89,7 +89,7 @@ func (ts *TestSuite) Setup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger := NewLogger("info")
|
logger := NewLogger("info")
|
||||||
ts.sessionManager = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
ts.sessionManager, _ = NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||||
|
|
||||||
// Common TraefikOidc instance
|
// Common TraefikOidc instance
|
||||||
ts.tOidc = &TraefikOidc{
|
ts.tOidc = &TraefikOidc{
|
||||||
@@ -619,7 +619,7 @@ func TestHandleCallback(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
logger := NewLogger("info")
|
logger := NewLogger("info")
|
||||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||||
|
|
||||||
// Create a new instance for each test to avoid state carryover
|
// Create a new instance for each test to avoid state carryover
|
||||||
tOidc := &TraefikOidc{
|
tOidc := &TraefikOidc{
|
||||||
@@ -924,7 +924,7 @@ func TestHandleLogout(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
logger := NewLogger("info")
|
logger := NewLogger("info")
|
||||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||||
tOidc := &TraefikOidc{
|
tOidc := &TraefikOidc{
|
||||||
revocationURL: mockRevocationServer.URL,
|
revocationURL: mockRevocationServer.URL,
|
||||||
endSessionURL: tc.endSessionURL,
|
endSessionURL: tc.endSessionURL,
|
||||||
@@ -1213,7 +1213,7 @@ func TestHandleExpiredToken(t *testing.T) {
|
|||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
logger := NewLogger("info")
|
logger := NewLogger("info")
|
||||||
sessionManager := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
sessionManager, _ := NewSessionManager("test-secret-key-that-is-at-least-32-bytes", false, logger)
|
||||||
|
|
||||||
tOidc := &TraefikOidc{
|
tOidc := &TraefikOidc{
|
||||||
sessionManager: sessionManager,
|
sessionManager: sessionManager,
|
||||||
|
|||||||
+10
@@ -0,0 +1,10 @@
|
|||||||
|
version: 1
|
||||||
|
force:
|
||||||
|
existing: true
|
||||||
|
wording:
|
||||||
|
patch:
|
||||||
|
- patch-release
|
||||||
|
minor:
|
||||||
|
- minor-release
|
||||||
|
major:
|
||||||
|
- breaking
|
||||||
+99
-142
@@ -16,13 +16,14 @@ import (
|
|||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// generateSecureRandomString creates a cryptographically secure random string of specified length
|
// generateSecureRandomString creates a cryptographically secure random string of specified length.
|
||||||
func generateSecureRandomString(length int) string {
|
// It returns the generated string or an error if random generation fails.
|
||||||
|
func generateSecureRandomString(length int) (string, error) {
|
||||||
bytes := make([]byte, length)
|
bytes := make([]byte, length)
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
panic("failed to generate random string")
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
}
|
}
|
||||||
return hex.EncodeToString(bytes)
|
return hex.EncodeToString(bytes), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cookie names and configuration constants used for session management
|
// Cookie names and configuration constants used for session management
|
||||||
@@ -55,7 +56,7 @@ const (
|
|||||||
minEncryptionKeyLength = 32
|
minEncryptionKeyLength = 32
|
||||||
)
|
)
|
||||||
|
|
||||||
// compressToken compresses a token using gzip and base64 encodes it
|
// compressToken compresses a token using gzip and base64 encodes it.
|
||||||
func compressToken(token string) string {
|
func compressToken(token string) string {
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
gz := gzip.NewWriter(&b)
|
gz := gzip.NewWriter(&b)
|
||||||
@@ -68,7 +69,7 @@ func compressToken(token string) string {
|
|||||||
return base64.StdEncoding.EncodeToString(b.Bytes())
|
return base64.StdEncoding.EncodeToString(b.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
// decompressToken decompresses a base64 encoded gzipped token
|
// decompressToken decompresses a base64 encoded gzipped token.
|
||||||
func decompressToken(compressed string) string {
|
func decompressToken(compressed string) string {
|
||||||
data, err := base64.StdEncoding.DecodeString(compressed)
|
data, err := base64.StdEncoding.DecodeString(compressed)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -91,18 +92,18 @@ func decompressToken(compressed string) string {
|
|||||||
|
|
||||||
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
// SessionManager handles the management of multiple session cookies for OIDC authentication.
|
||||||
// It provides functionality for storing and retrieving authentication state, tokens,
|
// It provides functionality for storing and retrieving authentication state, tokens,
|
||||||
// and other session-related data across multiple cookies to handle large tokens.
|
// and other session-related data across multiple cookies.
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
// store is the underlying session store for cookie management
|
// store is the underlying session store for cookie management.
|
||||||
store sessions.Store
|
store sessions.Store
|
||||||
|
|
||||||
// forceHTTPS enforces secure cookie attributes regardless of request scheme
|
// forceHTTPS enforces secure cookie attributes regardless of request scheme.
|
||||||
forceHTTPS bool
|
forceHTTPS bool
|
||||||
|
|
||||||
// logger provides structured logging capabilities
|
// logger provides structured logging capabilities.
|
||||||
logger *Logger
|
logger *Logger
|
||||||
|
|
||||||
// sessionPool is a sync.Pool for reusing SessionData objects
|
// sessionPool is a sync.Pool for reusing SessionData objects.
|
||||||
sessionPool sync.Pool
|
sessionPool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,11 +113,11 @@ type SessionManager struct {
|
|||||||
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
|
||||||
// - logger: Logger instance for recording session-related events
|
// - logger: Logger instance for recording session-related events
|
||||||
//
|
//
|
||||||
// The manager handles session creation, storage, and cookie security settings.
|
// Returns an error if the encryption key does not meet minimum length requirements.
|
||||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) (*SessionManager, error) {
|
||||||
// Validate encryption key length
|
// Validate encryption key length.
|
||||||
if len(encryptionKey) < minEncryptionKeyLength {
|
if len(encryptionKey) < minEncryptionKeyLength {
|
||||||
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
|
return nil, fmt.Errorf("encryption key must be at least %d bytes long", minEncryptionKeyLength)
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := &SessionManager{
|
sm := &SessionManager{
|
||||||
@@ -125,7 +126,7 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize session pool
|
// Initialize session pool.
|
||||||
sm.sessionPool.New = func() interface{} {
|
sm.sessionPool.New = func() interface{} {
|
||||||
return &SessionData{
|
return &SessionData{
|
||||||
manager: sm,
|
manager: sm,
|
||||||
@@ -134,12 +135,12 @@ func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *S
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return sm
|
return sm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSessionOptions returns secure session options configured for the current request.
|
// getSessionOptions returns secure session options configured for the current request.
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - isSecure: Whether the current request is using HTTPS
|
// - isSecure: Whether the current request is using HTTPS.
|
||||||
//
|
//
|
||||||
// The options ensure cookies are:
|
// The options ensure cookies are:
|
||||||
// - HTTP-only (not accessible via JavaScript)
|
// - HTTP-only (not accessible via JavaScript)
|
||||||
@@ -161,7 +162,7 @@ func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
|||||||
// and combines them into a single SessionData structure for easy access.
|
// and combines them into a single SessionData structure for easy access.
|
||||||
// Returns an error if any session component cannot be loaded.
|
// Returns an error if any session component cannot be loaded.
|
||||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||||
// Get session from pool
|
// Get session from pool.
|
||||||
sessionData := sm.sessionPool.Get().(*SessionData)
|
sessionData := sm.sessionPool.Get().(*SessionData)
|
||||||
sessionData.request = r
|
sessionData.request = r
|
||||||
|
|
||||||
@@ -172,11 +173,10 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
|||||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for absolute session timeout
|
// Check for absolute session timeout.
|
||||||
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
|
||||||
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
|
||||||
// Session has expired
|
sessionData.Clear(r, nil)
|
||||||
sm.sessionPool.Put(sessionData)
|
|
||||||
return nil, fmt.Errorf("session expired")
|
return nil, fmt.Errorf("session expired")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -193,7 +193,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
|||||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear and reuse chunk maps
|
// Clear and reuse chunk maps.
|
||||||
for k := range sessionData.accessTokenChunks {
|
for k := range sessionData.accessTokenChunks {
|
||||||
delete(sessionData.accessTokenChunks, k)
|
delete(sessionData.accessTokenChunks, k)
|
||||||
}
|
}
|
||||||
@@ -201,7 +201,7 @@ func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
|||||||
delete(sessionData.refreshTokenChunks, k)
|
delete(sessionData.refreshTokenChunks, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve chunked token sessions
|
// Retrieve chunked token sessions.
|
||||||
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
|
||||||
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
|
||||||
|
|
||||||
@@ -218,7 +218,6 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
|||||||
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
sessionName := fmt.Sprintf("%s_%d", baseName, i)
|
||||||
session, err := sm.store.Get(r, sessionName)
|
session, err := sm.store.Get(r, sessionName)
|
||||||
if err != nil || session.IsNew {
|
if err != nil || session.IsNew {
|
||||||
// No more sessions
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
chunks[i] = session
|
chunks[i] = session
|
||||||
@@ -230,27 +229,27 @@ func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string
|
|||||||
// and potentially large access and refresh tokens that may need to be
|
// and potentially large access and refresh tokens that may need to be
|
||||||
// split across multiple cookies due to browser size limitations.
|
// split across multiple cookies due to browser size limitations.
|
||||||
type SessionData struct {
|
type SessionData struct {
|
||||||
// manager is the SessionManager that created this SessionData
|
// manager is the SessionManager that created this SessionData.
|
||||||
manager *SessionManager
|
manager *SessionManager
|
||||||
|
|
||||||
// request is the current HTTP request associated with this session
|
// request is the current HTTP request associated with this session.
|
||||||
request *http.Request
|
request *http.Request
|
||||||
|
|
||||||
// mainSession stores authentication state and basic user info
|
// mainSession stores authentication state and basic user info.
|
||||||
mainSession *sessions.Session
|
mainSession *sessions.Session
|
||||||
|
|
||||||
// accessSession stores the primary access token cookie
|
// accessSession stores the primary access token cookie.
|
||||||
accessSession *sessions.Session
|
accessSession *sessions.Session
|
||||||
|
|
||||||
// refreshSession stores the primary refresh token cookie
|
// refreshSession stores the primary refresh token cookie.
|
||||||
refreshSession *sessions.Session
|
refreshSession *sessions.Session
|
||||||
|
|
||||||
// accessTokenChunks stores additional chunks of the access token
|
// accessTokenChunks stores additional chunks of the access token
|
||||||
// when it exceeds the maximum cookie size
|
// when it exceeds the maximum cookie size.
|
||||||
accessTokenChunks map[int]*sessions.Session
|
accessTokenChunks map[int]*sessions.Session
|
||||||
|
|
||||||
// refreshTokenChunks stores additional chunks of the refresh token
|
// refreshTokenChunks stores additional chunks of the refresh token
|
||||||
// when it exceeds the maximum cookie size
|
// when it exceeds the maximum cookie size.
|
||||||
refreshTokenChunks map[int]*sessions.Session
|
refreshTokenChunks map[int]*sessions.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,28 +260,28 @@ type SessionData struct {
|
|||||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||||
|
|
||||||
// Set options for all sessions
|
// Set options for all sessions.
|
||||||
options := sd.manager.getSessionOptions(isSecure)
|
options := sd.manager.getSessionOptions(isSecure)
|
||||||
sd.mainSession.Options = options
|
sd.mainSession.Options = options
|
||||||
sd.accessSession.Options = options
|
sd.accessSession.Options = options
|
||||||
sd.refreshSession.Options = options
|
sd.refreshSession.Options = options
|
||||||
|
|
||||||
// Save main session
|
// Save main session.
|
||||||
if err := sd.mainSession.Save(r, w); err != nil {
|
if err := sd.mainSession.Save(r, w); err != nil {
|
||||||
return fmt.Errorf("failed to save main session: %w", err)
|
return fmt.Errorf("failed to save main session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save access token session
|
// Save access token session.
|
||||||
if err := sd.accessSession.Save(r, w); err != nil {
|
if err := sd.accessSession.Save(r, w); err != nil {
|
||||||
return fmt.Errorf("failed to save access token session: %w", err)
|
return fmt.Errorf("failed to save access token session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save refresh token session
|
// Save refresh token session.
|
||||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save access token chunks
|
// Save access token chunks.
|
||||||
for _, session := range sd.accessTokenChunks {
|
for _, session := range sd.accessTokenChunks {
|
||||||
session.Options = options
|
session.Options = options
|
||||||
if err := session.Save(r, w); err != nil {
|
if err := session.Save(r, w); err != nil {
|
||||||
@@ -290,7 +289,7 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save refresh token chunks
|
// Save refresh token chunks.
|
||||||
for _, session := range sd.refreshTokenChunks {
|
for _, session := range sd.refreshTokenChunks {
|
||||||
session.Options = options
|
session.Options = options
|
||||||
if err := session.Save(r, w); err != nil {
|
if err := session.Save(r, w); err != nil {
|
||||||
@@ -302,10 +301,8 @@ func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clear removes all session data by expiring all cookies and clearing their values.
|
// Clear removes all session data by expiring all cookies and clearing their values.
|
||||||
// This is typically used during logout to ensure all session data is properly cleaned up.
|
|
||||||
// It handles both main session data and any token chunks that may exist.
|
|
||||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||||
// Clear and expire all sessions
|
// Clear and expire all sessions.
|
||||||
sd.mainSession.Options.MaxAge = -1
|
sd.mainSession.Options.MaxAge = -1
|
||||||
sd.accessSession.Options.MaxAge = -1
|
sd.accessSession.Options.MaxAge = -1
|
||||||
sd.refreshSession.Options.MaxAge = -1
|
sd.refreshSession.Options.MaxAge = -1
|
||||||
@@ -320,7 +317,7 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
|||||||
delete(sd.refreshSession.Values, k)
|
delete(sd.refreshSession.Values, k)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear chunk sessions
|
// Clear chunk sessions.
|
||||||
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
sd.clearTokenChunks(r, sd.accessTokenChunks)
|
||||||
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
sd.clearTokenChunks(r, sd.refreshTokenChunks)
|
||||||
|
|
||||||
@@ -329,15 +326,13 @@ func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
|||||||
err = sd.Save(r, w)
|
err = sd.Save(r, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return session to pool
|
// Return session to pool.
|
||||||
sd.manager.sessionPool.Put(sd)
|
sd.manager.sessionPool.Put(sd)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// clearTokenChunks removes all session chunks for a given token type.
|
// clearTokenChunks removes all session chunks for a given token type.
|
||||||
// It expires the cookies and removes all stored values to ensure
|
|
||||||
// no token data remains after logout or token invalidation.
|
|
||||||
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
|
||||||
for _, session := range chunks {
|
for _, session := range chunks {
|
||||||
session.Options.MaxAge = -1
|
session.Options.MaxAge = -1
|
||||||
@@ -348,15 +343,13 @@ func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*session
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthenticated returns whether the current session is authenticated.
|
// GetAuthenticated returns whether the current session is authenticated.
|
||||||
// Returns true if the user has successfully completed OIDC authentication
|
|
||||||
// and the session hasn't expired, false otherwise.
|
|
||||||
func (sd *SessionData) GetAuthenticated() bool {
|
func (sd *SessionData) GetAuthenticated() bool {
|
||||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||||
if !auth {
|
if !auth {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check session expiration
|
// Check session expiration.
|
||||||
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
@@ -365,21 +358,21 @@ func (sd *SessionData) GetAuthenticated() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
// SetAuthenticated updates the session's authentication status and rotates session ID.
|
||||||
// This should be called after successful OIDC authentication or during logout.
|
// Returns an error if generating a new session ID fails.
|
||||||
// Session ID rotation helps prevent session fixation attacks.
|
func (sd *SessionData) SetAuthenticated(value bool) error {
|
||||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
|
||||||
if value {
|
if value {
|
||||||
// Generate new session ID and set creation time
|
id, err := generateSecureRandomString(32)
|
||||||
sd.mainSession.ID = generateSecureRandomString(32)
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate secure session id: %w", err)
|
||||||
|
}
|
||||||
|
sd.mainSession.ID = id
|
||||||
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
sd.mainSession.Values["created_at"] = time.Now().Unix()
|
||||||
}
|
}
|
||||||
sd.mainSession.Values["authenticated"] = value
|
sd.mainSession.Values["authenticated"] = value
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken retrieves the complete access token from the session.
|
// GetAccessToken retrieves the complete access token from the session.
|
||||||
// If the token was split into chunks due to size limitations, it will
|
|
||||||
// automatically reassemble the complete token from all chunks.
|
|
||||||
// Returns an empty string if no token is found.
|
|
||||||
func (sd *SessionData) GetAccessToken() string {
|
func (sd *SessionData) GetAccessToken() string {
|
||||||
token, _ := sd.accessSession.Values["token"].(string)
|
token, _ := sd.accessSession.Values["token"].(string)
|
||||||
if token != "" {
|
if token != "" {
|
||||||
@@ -390,7 +383,7 @@ func (sd *SessionData) GetAccessToken() string {
|
|||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reassemble token from chunks
|
// Reassemble token from chunks.
|
||||||
if len(sd.accessTokenChunks) == 0 {
|
if len(sd.accessTokenChunks) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -414,45 +407,23 @@ func (sd *SessionData) GetAccessToken() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetAccessToken stores the access token in the session.
|
// SetAccessToken stores the access token in the session.
|
||||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
|
||||||
// multiple cookie chunks to handle large tokens while staying within
|
|
||||||
// browser cookie size limits. Any existing token or chunks are cleared
|
|
||||||
// before setting the new token.
|
|
||||||
// expireAccessTokenChunks expires any existing access token chunk cookies
|
|
||||||
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
|
||||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
||||||
if err != nil || session.IsNew {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Expire the cookie
|
|
||||||
session.Options.MaxAge = -1
|
|
||||||
session.Values = make(map[interface{}]interface{})
|
|
||||||
// Save expired cookie
|
|
||||||
if err := session.Save(sd.request, w); err != nil {
|
|
||||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sd *SessionData) SetAccessToken(token string) {
|
func (sd *SessionData) SetAccessToken(token string) {
|
||||||
// Expire any existing chunk cookies first
|
// Expire any existing chunk cookies first.
|
||||||
if sd.request != nil {
|
if sd.request != nil {
|
||||||
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called
|
sd.expireAccessTokenChunks(nil) // Will be saved when Save() is called.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear and prepare chunks map for new token
|
// Clear and prepare chunks map for new token.
|
||||||
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
sd.accessTokenChunks = make(map[int]*sessions.Session)
|
||||||
|
|
||||||
// Compress token
|
// Compress token.
|
||||||
compressed := compressToken(token)
|
compressed := compressToken(token)
|
||||||
|
|
||||||
if len(compressed) <= maxCookieSize {
|
if len(compressed) <= maxCookieSize {
|
||||||
sd.accessSession.Values["token"] = compressed
|
sd.accessSession.Values["token"] = compressed
|
||||||
sd.accessSession.Values["compressed"] = true
|
sd.accessSession.Values["compressed"] = true
|
||||||
} else {
|
} else {
|
||||||
// Split compressed token into chunks
|
// Split compressed token into chunks.
|
||||||
sd.accessSession.Values["token"] = ""
|
sd.accessSession.Values["token"] = ""
|
||||||
sd.accessSession.Values["compressed"] = true
|
sd.accessSession.Values["compressed"] = true
|
||||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||||
@@ -466,9 +437,6 @@ func (sd *SessionData) SetAccessToken(token string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRefreshToken retrieves the complete refresh token from the session.
|
// GetRefreshToken retrieves the complete refresh token from the session.
|
||||||
// If the token was split into chunks due to size limitations, it will
|
|
||||||
// automatically reassemble the complete token from all chunks.
|
|
||||||
// Returns an empty string if no token is found.
|
|
||||||
func (sd *SessionData) GetRefreshToken() string {
|
func (sd *SessionData) GetRefreshToken() string {
|
||||||
token, _ := sd.refreshSession.Values["token"].(string)
|
token, _ := sd.refreshSession.Values["token"].(string)
|
||||||
if token != "" {
|
if token != "" {
|
||||||
@@ -479,7 +447,7 @@ func (sd *SessionData) GetRefreshToken() string {
|
|||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reassemble token from chunks
|
// Reassemble token from chunks.
|
||||||
if len(sd.refreshTokenChunks) == 0 {
|
if len(sd.refreshTokenChunks) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -503,45 +471,23 @@ func (sd *SessionData) GetRefreshToken() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetRefreshToken stores the refresh token in the session.
|
// SetRefreshToken stores the refresh token in the session.
|
||||||
// If the token exceeds maxCookieSize, it is automatically compressed and split into
|
|
||||||
// multiple cookie chunks to handle large tokens while staying within
|
|
||||||
// browser cookie size limits. Any existing token or chunks are cleared
|
|
||||||
// before setting the new token.
|
|
||||||
// expireRefreshTokenChunks expires any existing refresh token chunk cookies
|
|
||||||
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
|
||||||
session, err := sd.manager.store.Get(sd.request, sessionName)
|
|
||||||
if err != nil || session.IsNew {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
// Expire the cookie
|
|
||||||
session.Options.MaxAge = -1
|
|
||||||
session.Values = make(map[interface{}]interface{})
|
|
||||||
// Save expired cookie
|
|
||||||
if err := session.Save(sd.request, w); err != nil {
|
|
||||||
sd.manager.logger.Errorf("Failed to save expired cookie: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (sd *SessionData) SetRefreshToken(token string) {
|
func (sd *SessionData) SetRefreshToken(token string) {
|
||||||
// Expire any existing chunk cookies first
|
// Expire any existing chunk cookies first.
|
||||||
if sd.request != nil {
|
if sd.request != nil {
|
||||||
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called
|
sd.expireRefreshTokenChunks(nil) // Will be saved when Save() is called.
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear and prepare chunks map for new token
|
// Clear and prepare chunks map for new token.
|
||||||
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
sd.refreshTokenChunks = make(map[int]*sessions.Session)
|
||||||
|
|
||||||
// Compress token
|
// Compress token.
|
||||||
compressed := compressToken(token)
|
compressed := compressToken(token)
|
||||||
|
|
||||||
if len(compressed) <= maxCookieSize {
|
if len(compressed) <= maxCookieSize {
|
||||||
sd.refreshSession.Values["token"] = compressed
|
sd.refreshSession.Values["token"] = compressed
|
||||||
sd.refreshSession.Values["compressed"] = true
|
sd.refreshSession.Values["compressed"] = true
|
||||||
} else {
|
} else {
|
||||||
// Split compressed token into chunks
|
// Split compressed token into chunks.
|
||||||
sd.refreshSession.Values["token"] = ""
|
sd.refreshSession.Values["token"] = ""
|
||||||
sd.refreshSession.Values["compressed"] = true
|
sd.refreshSession.Values["compressed"] = true
|
||||||
chunks := splitIntoChunks(compressed, maxCookieSize)
|
chunks := splitIntoChunks(compressed, maxCookieSize)
|
||||||
@@ -554,13 +500,43 @@ func (sd *SessionData) SetRefreshToken(token string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// expireAccessTokenChunks expires any existing access token chunk cookies.
|
||||||
|
func (sd *SessionData) expireAccessTokenChunks(w http.ResponseWriter) {
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
|
||||||
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||||
|
if err != nil || session.IsNew {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
session.Values = make(map[interface{}]interface{})
|
||||||
|
if w != nil {
|
||||||
|
if err := session.Save(sd.request, w); err != nil {
|
||||||
|
sd.manager.logger.Errorf("failed to save expired access token cookie: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expireRefreshTokenChunks expires any existing refresh token chunk cookies.
|
||||||
|
func (sd *SessionData) expireRefreshTokenChunks(w http.ResponseWriter) {
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
|
||||||
|
session, err := sd.manager.store.Get(sd.request, sessionName)
|
||||||
|
if err != nil || session.IsNew {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
session.Options.MaxAge = -1
|
||||||
|
session.Values = make(map[interface{}]interface{})
|
||||||
|
if w != nil {
|
||||||
|
if err := session.Save(sd.request, w); err != nil {
|
||||||
|
sd.manager.logger.Errorf("failed to save expired refresh token cookie: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// splitIntoChunks splits a string into chunks of specified size.
|
// splitIntoChunks splits a string into chunks of specified size.
|
||||||
// This is used internally to handle large tokens that exceed cookie size limits.
|
|
||||||
// Parameters:
|
|
||||||
// - s: The string to split
|
|
||||||
// - chunkSize: Maximum size of each chunk
|
|
||||||
//
|
|
||||||
// Returns an array of string chunks, each no larger than chunkSize.
|
|
||||||
func splitIntoChunks(s string, chunkSize int) []string {
|
func splitIntoChunks(s string, chunkSize int) []string {
|
||||||
var chunks []string
|
var chunks []string
|
||||||
for len(s) > 0 {
|
for len(s) > 0 {
|
||||||
@@ -576,64 +552,45 @@ func splitIntoChunks(s string, chunkSize int) []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCSRF retrieves the CSRF token from the session.
|
// GetCSRF retrieves the CSRF token from the session.
|
||||||
// This token is used to prevent cross-site request forgery attacks
|
|
||||||
// by ensuring requests originate from the authenticated user.
|
|
||||||
// Returns an empty string if no CSRF token is found.
|
|
||||||
func (sd *SessionData) GetCSRF() string {
|
func (sd *SessionData) GetCSRF() string {
|
||||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||||
return csrf
|
return csrf
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCSRF stores a new CSRF token in the session.
|
// SetCSRF stores a new CSRF token in the session.
|
||||||
// This should be called when initiating authentication to generate
|
|
||||||
// a new token for the authentication flow.
|
|
||||||
func (sd *SessionData) SetCSRF(token string) {
|
func (sd *SessionData) SetCSRF(token string) {
|
||||||
sd.mainSession.Values["csrf"] = token
|
sd.mainSession.Values["csrf"] = token
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNonce retrieves the nonce value from the session.
|
// GetNonce retrieves the nonce value from the session.
|
||||||
// The nonce is used to prevent replay attacks in the OIDC flow
|
|
||||||
// by ensuring the token received matches the authentication request.
|
|
||||||
// Returns an empty string if no nonce is found.
|
|
||||||
func (sd *SessionData) GetNonce() string {
|
func (sd *SessionData) GetNonce() string {
|
||||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||||
return nonce
|
return nonce
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNonce stores a new nonce value in the session.
|
// SetNonce stores a new nonce value in the session.
|
||||||
// This should be called when initiating authentication to generate
|
|
||||||
// a new nonce for the OIDC authentication flow.
|
|
||||||
func (sd *SessionData) SetNonce(nonce string) {
|
func (sd *SessionData) SetNonce(nonce string) {
|
||||||
sd.mainSession.Values["nonce"] = nonce
|
sd.mainSession.Values["nonce"] = nonce
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEmail retrieves the authenticated user's email address from the session.
|
// GetEmail retrieves the authenticated user's email address from the session.
|
||||||
// The email is typically extracted from the OIDC ID token claims.
|
|
||||||
// Returns an empty string if no email is found.
|
|
||||||
func (sd *SessionData) GetEmail() string {
|
func (sd *SessionData) GetEmail() string {
|
||||||
email, _ := sd.mainSession.Values["email"].(string)
|
email, _ := sd.mainSession.Values["email"].(string)
|
||||||
return email
|
return email
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEmail stores the user's email address in the session.
|
// SetEmail stores the user's email address in the session.
|
||||||
// This should be called after successful authentication when
|
|
||||||
// processing the OIDC ID token claims.
|
|
||||||
func (sd *SessionData) SetEmail(email string) {
|
func (sd *SessionData) SetEmail(email string) {
|
||||||
sd.mainSession.Values["email"] = email
|
sd.mainSession.Values["email"] = email
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetIncomingPath retrieves the original request path that triggered
|
// GetIncomingPath retrieves the original request path that triggered the authentication flow.
|
||||||
// the authentication flow. This is used to redirect the user back
|
|
||||||
// to their intended destination after successful authentication.
|
|
||||||
// Returns an empty string if no path was stored.
|
|
||||||
func (sd *SessionData) GetIncomingPath() string {
|
func (sd *SessionData) GetIncomingPath() string {
|
||||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetIncomingPath stores the original request path that triggered
|
// SetIncomingPath stores the original request path that triggered the authentication flow.
|
||||||
// the authentication flow. This should be called before redirecting
|
|
||||||
// to the OIDC provider to remember where to send the user afterward.
|
|
||||||
func (sd *SessionData) SetIncomingPath(path string) {
|
func (sd *SessionData) SetIncomingPath(path string) {
|
||||||
sd.mainSession.Values["incoming_path"] = path
|
sd.mainSession.Values["incoming_path"] = path
|
||||||
}
|
}
|
||||||
|
|||||||
+2
-8
@@ -5,14 +5,8 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Initialize random seed
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateRandomString creates a random string of specified length
|
// generateRandomString creates a random string of specified length
|
||||||
func generateRandomString(length int) string {
|
func generateRandomString(length int) string {
|
||||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
@@ -83,7 +77,7 @@ func TestCookiePrefix(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||||
session, err := sm.GetSession(req)
|
session, err := sm.GetSession(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to get session: %v", err)
|
t.Fatalf("Failed to get session: %v", err)
|
||||||
@@ -117,7 +111,7 @@ func TestTokenRefreshCleanup(t *testing.T) {
|
|||||||
req := httptest.NewRequest("GET", "/test", nil)
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
sm := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
sm, _ := NewSessionManager("0123456789abcdef0123456789abcdef", true, NewLogger("debug"))
|
||||||
session, err := sm.GetSession(req)
|
session, err := sm.GetSession(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to get session: %v", err)
|
t.Fatalf("Failed to get session: %v", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user