Compare commits

..

30 Commits

Author SHA1 Message Date
lukaszraczylo 7e3dc46b6e Improve initial fetch of the provider metadata until successful. 2025-01-06 12:19:11 +00:00
lukaszraczylo 147aa0b169 Fix the issue #16
Removed global metadata cache and sync.Once
Each middleware instance now handles its own metadata initialization
Added tests to verify multiple instances work correctly
The changes ensure that:

Each route gets its own properly initialized middleware instance
Metadata is fetched and set correctly for each instance
No shared state between instances that could cause conflicts
Each instance can handle requests independently
The added test verifies this by creating multiple middleware instances with different routes and confirming they all initialize and function correctly. The test specifically checks that:

Each instance initializes successfully
Each instance gets its own metadata configuration
Each instance can handle requests independently
Callback URLs are correctly set per route
2025-01-06 11:23:12 +00:00
lukaszraczylo eecb7dfc92 Improve test coverage 2025-01-06 11:01:20 +00:00
lukaszraczylo a8d65688c4 Improve documentation. 2025-01-06 10:44:49 +00:00
lukaszraczylo bef4212c57 Add support for the large tokens, which exceed the standard 4096 limit for cookie. 2024-12-11 12:55:16 +00:00
lukaszraczylo 1fee2f9e9a fixup! Re-introduce user roles separation with additional tests. 2024-12-11 09:11:34 +00:00
lukaszraczylo 11bc6f3e31 Re-introduce user roles separation with additional tests. 2024-12-11 09:08:50 +00:00
lukaszraczylo 2b7af88ff9 Move session management into session manager. Split the cookies to avoid the 4k limit ( resolves issue: #15 ) 2024-12-10 10:19:35 +00:00
lukaszraczylo 01ee7c4dc8 Improve cookie setting. 2024-12-10 10:19:35 +00:00
lukaszraczylo a6fa4d8789 Downgrade gorilla sessions preventing the publishing by traefik hub temporarily. 2024-12-10 10:19:34 +00:00
lukaszraczylo 8101fb2bf6 Clean up dependencies. 2024-11-06 11:51:20 +00:00
lukaszraczylo 8ca669105b Fix OIDC logout issue, improve test coverage, load provider once. 2024-11-06 11:33:29 +00:00
lukaszraczylo 555164160d Update dependencies. 2024-11-06 11:33:06 +00:00
lukaszraczylo 3fe537d38f Add ability to verify default ECDSA keys provided by logto as well. 2024-11-06 11:33:06 +00:00
lukaszraczylo 31de2c63b2 Revert "Update go mod dependencies."
This reverts commit dedbdf63c3.
2024-11-06 11:33:04 +00:00
lukaszraczylo 7dd9205277 Update go mod dependencies. 2024-11-06 11:33:04 +00:00
lukaszraczylo f3598e4ab8 Add simple benchmark to track the allocations and speed for future improvements. 2024-11-06 11:33:03 +00:00
lukaszraczylo 218165d365 Cleanup and optimise the code. 2024-11-06 11:33:03 +00:00
lukaszraczylo dc4c4824cd Add support for more algorithms. 2024-11-06 11:33:03 +00:00
lukaszraczylo 345c0c4a11 Abstract filling up maps. 2024-11-06 11:32:37 +00:00
lukaszraczylo da4f97de04 Fix the bug with user not being redirected to originally requested URL post authentication. 2024-11-06 11:32:36 +00:00
lukaszraczylo ce916f3ca3 Update documentation - setting secrets in kubernetes. 2024-11-06 11:32:36 +00:00
lukaszraczylo 6f2cf65d49 Fix the tests hanging on the open channel. 2024-11-06 11:32:36 +00:00
lukaszraczylo 78b9d611f0 Improvement - startup time.
Previous implementations blocked the traefik startup until OIDC plugin was loaded.
This caused chicken-or-egg issue when called OIDC endpoint was hosted by the same traefik as well,
generating rather ridiculous situation when traefik couldn't come up because plugin tried to call the
discovery endpoint which was hosted by the same traefik.

This version resolves the issue allowing for quickstart and lazy loading of the provider metadata.
Disadvantage is - until discovery is done, the plugin will not provide any access to the client.
2024-11-06 11:32:36 +00:00
lukaszraczylo 2bb1debeb3 First step in improvement of caching mechanism. 2024-11-06 11:32:36 +00:00
lukaszraczylo 93b49b6d17 Add support for roles and groups. 2024-11-06 11:32:35 +00:00
lukaszraczylo 7a53da6080 Update tests and additional fixups. 2024-11-06 11:32:35 +00:00
lukaszraczylo 66e08755c1 Update the tests to handle nonce 2024-11-06 11:32:35 +00:00
lukaszraczylo d6fd3467c3 Support additional verification of the token to ensure OIDC compliance 2024-11-06 11:32:35 +00:00
lukaszraczylo 6196a72a8e Update dependencies. 2024-11-06 11:32:34 +00:00
22 changed files with 3107 additions and 662 deletions
+1
View File
@@ -13,6 +13,7 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
+1
View File
@@ -38,6 +38,7 @@ spec:
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
+44 -22
View File
@@ -5,69 +5,91 @@ import (
"time"
)
// CacheItem represents an item in the cache
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
Value interface{}
ExpiresAt int64 // Changed to int64 for faster comparisons
// Value is the cached data of any type
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
ExpiresAt time.Time
}
// Cache is a simple in-memory cache
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
}
// NewCache creates a new Cache
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
// Removed defer for slightly better performance
defer c.mutex.Unlock()
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration).UnixNano(), // Store as UnixNano for faster comparisons
ExpiresAt: time.Now().Add(expiration),
}
c.mutex.Unlock()
}
// Get retrieves an item from the cache
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
if time.Now().UnixNano() > item.ExpiresAt {
c.mutex.RUnlock()
// Use a separate goroutine to delete expired items to avoid blocking
go c.Delete(key)
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
return nil, false
}
c.mutex.RUnlock()
return item.Value, true
}
// Delete removes an item from the cache
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.mutex.Unlock()
}
// Cleanup removes expired items from the cache
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Cleanup() {
c.mutex.Lock()
now := time.Now().UnixNano()
defer c.mutex.Unlock()
now := time.Now()
for key, item := range c.items {
if now > item.ExpiresAt {
if now.After(item.ExpiresAt) {
delete(c.items, key)
}
}
c.mutex.Unlock()
}
+306
View File
@@ -0,0 +1,306 @@
package traefikoidc
import (
"reflect"
"testing"
"time"
)
func TestCache(t *testing.T) {
t.Run("Basic Set and Get", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Test Set
cache.Set(key, value, expiration)
// Test Get
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value {
t.Errorf("Expected value %v, got %v", value, got)
}
})
t.Run("Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 10 * time.Millisecond
// Set with short expiration
cache.Set(key, value, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be expired")
}
})
t.Run("Delete", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Set and then delete
cache.Set(key, value, expiration)
cache.Delete(key)
// Should not find deleted key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be deleted")
}
})
t.Run("Cleanup", func(t *testing.T) {
cache := NewCache()
// Add multiple items with different expirations
cache.Set("expired1", "value1", 10*time.Millisecond)
cache.Set("expired2", "value2", 10*time.Millisecond)
cache.Set("valid", "value3", 1*time.Second)
// Wait for some items to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Check expired items are removed
_, found1 := cache.Get("expired1")
_, found2 := cache.Get("expired2")
_, found3 := cache.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid item to remain in cache")
}
})
t.Run("Concurrent Access", func(t *testing.T) {
cache := NewCache()
done := make(chan bool)
// Start multiple goroutines to access cache concurrently
for i := 0; i < 10; i++ {
go func(id int) {
key := "key"
value := "value"
expiration := 1 * time.Second
// Perform multiple operations
cache.Set(key, value, expiration)
cache.Get(key)
cache.Delete(key)
cache.Cleanup()
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
})
t.Run("Zero Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with zero expiration
cache.Set(key, value, 0)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with zero expiration to be immediately expired")
}
})
t.Run("Negative Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with negative expiration
cache.Set(key, value, -1*time.Second)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with negative expiration to be immediately expired")
}
})
t.Run("Update Existing Key", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value1 := "value1"
value2 := "value2"
expiration := 1 * time.Second
// Set initial value
cache.Set(key, value1, expiration)
// Update value
cache.Set(key, value2, expiration)
// Check updated value
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value2 {
t.Errorf("Expected updated value %v, got %v", value2, got)
}
})
t.Run("Different Value Types", func(t *testing.T) {
cache := NewCache()
expiration := 1 * time.Second
// Test with different value types
testCases := []struct {
key string
value interface{}
}{
{"string", "test"},
{"int", 42},
{"float", 3.14},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"a": 1, "b": 2}},
{"struct", struct{ Name string }{"test"}},
}
for _, tc := range testCases {
t.Run(tc.key, func(t *testing.T) {
cache.Set(tc.key, tc.value, expiration)
got, found := cache.Get(tc.key)
if !found {
t.Error("Expected to find key in cache")
}
// Use reflect.DeepEqual for comparing complex types like slices and maps
if !reflect.DeepEqual(got, tc.value) {
t.Errorf("Expected value %v, got %v", tc.value, got)
}
})
}
})
}
func TestTokenCache(t *testing.T) {
t.Run("Basic Operations", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}
expiration := 1 * time.Second
// Test Set and Get
tc.Set(token, claims, expiration)
gotClaims, found := tc.Get(token)
if !found {
t.Error("Expected to find token in cache")
}
if len(gotClaims) != len(claims) {
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
}
for k, v := range claims {
if gotClaims[k] != v {
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
}
}
// Test Delete
tc.Delete(token)
_, found = tc.Get(token)
if found {
t.Error("Expected token to be deleted")
}
})
t.Run("Expiration", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 10 * time.Millisecond
// Set with short expiration
tc.Set(token, claims, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired token
_, found := tc.Get(token)
if found {
t.Error("Expected token to be expired")
}
})
t.Run("Cleanup", func(t *testing.T) {
tc := NewTokenCache()
// Add multiple tokens with different expirations
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
// Wait for some tokens to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
tc.Cleanup()
// Check expired tokens are removed
_, found1 := tc.Get("expired1")
_, found2 := tc.Get("expired2")
_, found3 := tc.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid token to remain in cache")
}
})
t.Run("Token Prefix", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 1 * time.Second
// Set token
tc.Set(token, claims, expiration)
// Verify internal storage uses prefix
_, found := tc.cache.Get("t-" + token)
if !found {
t.Error("Expected to find prefixed token in underlying cache")
}
})
}
+1 -1
View File
@@ -6,7 +6,7 @@ toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.4.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
)
-4
View File
@@ -6,9 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+236 -150
View File
@@ -13,28 +13,66 @@ import (
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
// generateNonce generates a random nonce
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
// the authentication request.
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
if _, err := rand.Read(nonceBytes); err != nil {
_, err := rand.Read(nonceBytes)
if err != nil {
return "", fmt.Errorf("could not generate nonce: %w", err)
}
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
type TokenResponse struct {
// IDToken is the OIDC ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth 2.0 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the lifetime in seconds of the access token
ExpiresIn int `json:"expires_in"`
// TokenType is the type of token, typically "Bearer"
TokenType string `json:"token_type"`
}
// exchangeTokens exchanges a code or refresh token for tokens
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
// It supports both authorization code and refresh token grant types.
// Parameters:
// - ctx: Context for the HTTP request
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
@@ -42,15 +80,14 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
"client_secret": {t.clientSecret},
}
switch grantType {
case "authorization_code":
if grantType == "authorization_code" {
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
case "refresh_token":
} else if grantType == "refresh_token" {
data.Set("refresh_token", codeOrToken)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.tokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -75,16 +112,8 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
return &tokenResponse, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
// This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
@@ -93,76 +122,25 @@ func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenRe
}
t.logger.Debugf("Token response: %+v", tokenResponse)
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
// Revoke tokens if available
for _, tokenType := range []string{"refresh_token", "access_token"} {
if token, ok := session.Values[tokenType].(string); ok && token != "" {
if err := t.RevokeTokenWithProvider(token, tokenType); err != nil {
t.logger.Errorf("Failed to revoke %s: %v", tokenType, err)
}
t.RevokeToken(token)
}
delete(session.Values, tokenType)
}
// Remove other session values
delete(session.Values, "id_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = &sessions.Options{MaxAge: -1, Path: "/", HttpOnly: true, Secure: true}
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
if session == nil {
t.logger.Error("Session is nil in handleExpiredToken")
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
// Clear the existing session
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -171,21 +149,36 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
if errParam := req.URL.Query().Get("error"); errParam != "" {
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", errParam, errorDescription)
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate CSRF state
state := req.URL.Query().Get("state")
csrfToken, ok := session.Values["csrf"].(string)
if !ok || state == "" || csrfToken == "" || state != csrfToken {
t.logger.Error("Invalid state parameter or CSRF token")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -193,41 +186,49 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if err := t.verifyToken(idToken); err != nil {
// Verify tokens and claims
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(idToken)
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
sessionNonce, ok2 := session.Values["nonce"].(string)
if !ok || !ok2 || nonceClaim == "" || sessionNonce == "" || nonceClaim != sessionNonce {
t.logger.Error("Invalid nonce")
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Validate user's email domain
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -235,14 +236,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
delete(session.Values, "csrf")
delete(session.Values, "nonce")
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
@@ -250,17 +248,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to original path or root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims extracts claims from a JWT token
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -280,56 +278,77 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
blacklist sync.Map
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{}
}
func (tb *TokenBlacklist) Add(token string, expiration time.Time) {
tb.blacklist.Store(token, expiration)
}
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
if exp, ok := tb.blacklist.Load(token); ok {
return time.Now().Before(exp.(time.Time))
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
return false
}
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
tb.blacklist.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
tb.blacklist.Delete(key)
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
}
return true
})
}
}
// TokenCache caches tokens
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
type TokenCache struct {
// cache is the underlying cache implementation
cache *Cache
}
// NewTokenCache creates a new TokenCache
// NewTokenCache creates a new TokenCache instance.
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set sets a token in the cache
// Set stores a token's claims in the cache with an expiration time.
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
tc.cache.Set("t-"+token, claims, expiration)
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token from the cache
// Get retrieves a token's claims from the cache.
// Returns the claims and a boolean indicating if the token was found.
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
value, found := tc.cache.Get("t-" + token)
token = "t-" + token
value, found := tc.cache.Get(token)
if !found {
return nil, false
}
@@ -337,31 +356,98 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache
// Delete removes a token from the cache.
func (tc *TokenCache) Delete(token string) {
tc.cache.Delete("t-" + token)
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup cleans up expired tokens from the cache
// Cleanup removes expired tokens from the cache.
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
// exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
// createStringMap creates a map from a slice of strings.
// Used for efficient lookups in allowed domains and roles.
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{}, len(keys))
result := make(map[string]struct{})
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
// handleLogout manages the OIDC logout process.
// It clears the session and redirects either to the OIDC provider's
// end session endpoint (if available) or to the configured post-logout URL.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
accessToken := session.GetAccessToken()
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
}
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
// Parameters:
// - endSessionURL: The OIDC provider's end session endpoint
// - idToken: The ID token to be invalidated
// - postLogoutRedirectURI: Where to redirect after logout completes
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
}
+99 -45
View File
@@ -4,48 +4,86 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
X string `json:"x"`
Y string `json:"y"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
type JWKCache struct {
jwks *JWKSet
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
expiresAt time.Time
mutex sync.RWMutex
// mutex protects concurrent access to the cache
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -57,7 +95,6 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check locking pattern
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
@@ -73,7 +110,14 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -93,7 +137,9 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JWK to PEM format
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -109,7 +155,9 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
@@ -120,15 +168,31 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(nBytes),
E: int(new(big.Int).SetBytes(eBytes).Int64()),
N: n,
E: int(e.Int64()),
}
return marshalPublicKey(pubKey)
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC JWK to PEM
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
@@ -139,9 +203,16 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
curve, err := getCurve(jwk.Crv)
if err != nil {
return nil, err
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
pubKey := &ecdsa.PublicKey{
@@ -150,32 +221,15 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
Y: new(big.Int).SetBytes(yBytes),
}
return marshalPublicKey(pubKey)
}
// getCurve returns the elliptic curve based on the JWK curve parameter
func getCurve(crv string) (elliptic.Curve, error) {
switch crv {
case "P-256":
return elliptic.P256(), nil
case "P-384":
return elliptic.P384(), nil
case "P-521":
return elliptic.P521(), nil
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", crv)
}
}
// marshalPublicKey marshals a public key to PEM format
func marshalPublicKey(pubKey interface{}) ([]byte, error) {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
}
return pem.EncodeToMemory(&pem.Block{
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
}), nil
})
return pubKeyPEM, nil
}
+191 -125
View File
@@ -4,222 +4,288 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"time"
)
var (
ErrInvalidJWTFormat = errors.New("invalid JWT format")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrTokenExpired = errors.New("token has expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrMissingClaim = errors.New("missing claim")
ErrInvalidClaimType = errors.New("invalid claim type")
ErrUnsupportedAlgorithm = errors.New("unsupported algorithm")
ErrInvalidSignature = errors.New("invalid signature")
)
// JWT represents a JSON Web Token
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Signature []byte
Token string
// Token is the original JWT string
Token string
}
// parseJWT parses a JWT token string into a JWT struct
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("%w: expected 3 parts, got %d", ErrInvalidJWTFormat, len(parts))
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
jwt := &JWT{Token: tokenString}
if err := decodeJSONPart(parts[0], &jwt.Header); err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
jwt := &JWT{
Token: tokenString,
}
if err := decodeJSONPart(parts[1], &jwt.Claims); err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var err error
jwt.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, fmt.Errorf("failed to decode signature: %w", err)
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
}
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
}
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
return jwt, nil
}
func decodeJSONPart(part string, target interface{}) error {
bytes, err := base64.RawURLEncoding.DecodeString(part)
if err != nil {
return err
}
return json.Unmarshal(bytes, target)
}
// Verify verifies the standard claims in the JWT
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future
// - issued at time (iat) is in the past
// - subject (sub) is present and not empty
// Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error {
if err := verifyIssuer(j.Claims["iss"], issuerURL); err != nil {
claims := j.Claims
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
}
if err := verifyIssuer(iss, issuerURL); err != nil {
return err
}
if err := verifyAudience(j.Claims["aud"], clientID); err != nil {
aud, ok := claims["aud"]
if !ok {
return fmt.Errorf("missing 'aud' claim")
}
if err := verifyAudience(aud, clientID); err != nil {
return err
}
if err := verifyExpiration(j.Claims["exp"]); err != nil {
exp, ok := claims["exp"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'exp' claim")
}
if err := verifyExpiration(exp); err != nil {
return err
}
if err := verifyIssuedAt(j.Claims["iat"]); err != nil {
iat, ok := claims["iat"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'iat' claim")
}
if err := verifyIssuedAt(iat); err != nil {
return err
}
if sub, ok := j.Claims["sub"].(string); !ok || sub == "" {
return fmt.Errorf("%w: sub", ErrMissingClaim)
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
}
return nil
}
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return ErrInvalidAudience
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
return nil
found = true
break
}
}
return ErrInvalidAudience
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("%w: aud", ErrInvalidClaimType)
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer interface{}, expectedIssuer string) error {
iss, ok := tokenIssuer.(string)
if !ok {
return fmt.Errorf("%w: iss", ErrMissingClaim)
}
if iss != expectedIssuer {
return ErrInvalidIssuer
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
func verifyExpiration(expiration interface{}) error {
exp, ok := expiration.(float64)
if !ok {
return fmt.Errorf("%w: exp", ErrInvalidClaimType)
}
if time.Now().After(time.Unix(int64(exp), 0)) {
return ErrTokenExpired
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time.
// Parameters:
// - expiration: The expiration timestamp from the token
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
}
return nil
}
func verifyIssuedAt(issuedAt interface{}) error {
iat, ok := issuedAt.(float64)
if !ok {
return fmt.Errorf("%w: iat", ErrInvalidClaimType)
}
if time.Now().Before(time.Unix(int64(iat), 0)) {
return ErrTokenUsedBeforeIssued
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, which could
// indicate clock skew or a malicious token.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return ErrInvalidJWTFormat
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
pubKey, err := parsePublicKey(publicKeyPEM)
if err != nil {
return err
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
hashFunc, err := getHashFunc(alg)
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return err
return fmt.Errorf("failed to parse public key: %w", err)
}
hashed := hashFunc.New().Sum([]byte(signedContent))
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
return verifyRSASignature(pubKey, hashFunc, hashed, signature, alg)
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
case *ecdsa.PublicKey:
return verifyECDSASignature(pubKey, hashed, signature)
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
}
func parsePublicKey(publicKeyPEM []byte) (interface{}, error) {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse PEM block containing the public key")
}
return x509.ParsePKIXPublicKey(block.Bytes)
}
func getHashFunc(alg string) (crypto.Hash, error) {
switch alg {
case "RS256", "PS256", "ES256":
return crypto.SHA256, nil
case "RS384", "PS384", "ES384":
return crypto.SHA384, nil
case "RS512", "PS512", "ES512":
return crypto.SHA512, nil
default:
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
}
func verifyRSASignature(pubKey *rsa.PublicKey, hashFunc crypto.Hash, hashed, signature []byte, alg string) error {
if strings.HasPrefix(alg, "RS") {
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
}
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
func verifyECDSASignature(pubKey *ecdsa.PublicKey, hashed, signature []byte) error {
sigLen := len(signature)
if sigLen%2 != 0 {
return errors.New("invalid ECDSA signature length")
}
r, s := new(big.Int), new(big.Int)
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, r, s) {
return nil
}
return ErrInvalidSignature
}
+174 -138
View File
@@ -10,11 +10,9 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
@@ -34,7 +32,6 @@ type JWTVerifier interface {
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -53,26 +50,29 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
@@ -175,9 +175,6 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
@@ -190,7 +187,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ExpectContinueTimeout: 0,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
@@ -209,7 +206,6 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -217,6 +213,12 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
}
return config.PostLogoutRedirectURI
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
@@ -233,9 +235,10 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
@@ -254,20 +257,43 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
}
close(t.initComplete)
})
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
// discoverProviderMetadata fetches the OIDC provider metadata
@@ -337,96 +363,76 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Debug("OIDC middleware not yet initialized")
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
return
}
// Process the request as normal
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
case <-time.After(30 * time.Second):
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
return
}
// Check if the URL is excluded from authentication
// Check if URL is excluded
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
// Get session
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
// Handle logout URL
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
t.handleCallback(rw, req, redirectURL)
return
}
// Check if the user is authenticated
// Check authentication status
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session)
t.handleExpiredToken(rw, req, session, redirectURL)
return
}
}
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
// Process authenticated request
email := session.GetEmail()
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
t.logger.Debug("No email found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
@@ -436,11 +442,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles := t.extractGroupsAndRoles(claims)
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
@@ -449,6 +454,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
@@ -459,13 +465,15 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
}
@@ -483,16 +491,16 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
switch {
case t.forceHTTPS:
if t.forceHTTPS {
return "https"
case req.Header.Get(headerXForwardedProto) != "":
return req.Header.Get(headerXForwardedProto)
case req.TLS != nil:
return "https"
default:
return "http"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
return "https"
}
return "http"
}
// determineHost determines the host of the request
@@ -504,37 +512,34 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(idToken); err != nil {
if err := t.verifyToken(accessToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(idToken)
claims, err := extractClaims(accessToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true // Can't read claims, consider it expired
return false, false, true
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
}
now := time.Now().Unix()
@@ -542,7 +547,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true // Token has expired
return false, false, true
}
gracePeriod := time.Minute * 5
@@ -551,26 +556,23 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false // Token is valid and not expiring soon
return true, false, false
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
// Save the session
if err := session.Save(req, rw); err != nil {
@@ -579,7 +581,7 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
return
}
// Build the authentication URL
// Build and redirect to auth URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
@@ -620,14 +622,9 @@ func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
}
// RevokeTokenWithProvider revokes the token with the provider
@@ -668,10 +665,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -682,16 +679,17 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
// Verify the new access token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
t.logger.Errorf("Failed to verify new access token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
@@ -703,33 +701,71 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true
return true // If no domains are specified, all are allowed
}
atIndex := strings.LastIndex(email, "@")
if atIndex == -1 {
return false
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false // Invalid email format
}
domain := email[atIndex+1:]
domain := parts[1]
_, ok := t.allowedUserDomains[domain]
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(claims map[string]interface{}) ([]string, []string) {
groups := extractStringSlice(claims, "groups")
roles := extractStringSlice(claims, "roles")
return groups, roles
}
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
func extractStringSlice(claims map[string]interface{}, key string) []string {
if slice, ok := claims[key].([]interface{}); ok {
result := make([]string, 0, len(slice))
for _, item := range slice {
if str, ok := item.(string); ok {
result = append(result, str)
var groups []string
var roles []string
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
return result
}
return nil
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
+901 -100
View File
File diff suppressed because it is too large Load Diff
+456
View File
@@ -0,0 +1,456 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
)
// Cookie names and configuration constants used for session management
const (
// mainCookieName is the name of the main session cookie that stores authentication state
// and basic user information like email and CSRF tokens
mainCookieName = "_raczylo_oidc"
// accessTokenCookie is the name of the cookie that stores the OIDC access token
// This may be split into multiple cookies if the token is large
accessTokenCookie = "_raczylo_oidc_access"
// refreshTokenCookie is the name of the cookie that stores the OIDC refresh token
// This may be split into multiple cookies if the token is large
refreshTokenCookie = "_raczylo_oidc_refresh"
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
// 3. Calculation:
// - Let x be the chunk size
// - After encryption: x + 28 bytes
// - After base64: ((x + 28) * 4/3) bytes
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
)
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
type SessionManager struct {
// store is the underlying session store for cookie management
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
forceHTTPS bool
// logger provides structured logging capabilities
logger *Logger
}
// NewSessionManager creates a new session manager with the specified configuration.
// Parameters:
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
// - Using SameSite=Lax for CSRF protection
// - Set with appropriate timeout and path settings
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// GetSession retrieves all session data for the current request.
// It loads the main session and token sessions, including any chunked token data,
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
if err != nil {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
request: r,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
}
// Retrieve chunked access token sessions
sessionData.accessTokenChunks = sm.getTokenChunkSessions(r, accessTokenCookie)
// Retrieve chunked refresh token sessions
sessionData.refreshTokenChunks = sm.getTokenChunkSessions(r, refreshTokenCookie)
return sessionData, nil
}
// getTokenChunkSessions retrieves all session chunks for a given token type.
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// Returns a map of chunk index to session, used for handling large tokens
// that exceed single cookie size limits.
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string) map[int]*sessions.Session {
chunks := make(map[int]*sessions.Session)
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
return chunks
}
// SessionData holds all session information for an authenticated user.
// It manages multiple session cookies to handle the main session state
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
manager *SessionManager
// request is the current HTTP request associated with this session
request *http.Request
// mainSession stores authentication state and basic user info
mainSession *sessions.Session
// accessSession stores the primary access token cookie
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
refreshTokenChunks map[int]*sessions.Session
}
// Save persists all session data to cookies in the HTTP response.
// It saves the main session, token sessions, and any token chunks,
// applying appropriate security options to each cookie. All cookies
// are saved with consistent security settings based on the request scheme.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token chunk session: %w", err)
}
}
// Save refresh token chunks
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
}
}
return nil
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
return sd.Save(r, w)
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
}
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication,
// false otherwise or if the authentication status cannot be determined.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
}
// SetAuthenticated updates the session's authentication status.
// This should be called after successful OIDC authentication or during logout.
func (sd *SessionData) SetAuthenticated(value bool) {
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.accessTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.accessTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
sd.accessTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.accessSession.Values["token"] = token
} else {
// Split token into chunks
sd.accessSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.accessTokenChunks[i] = session
}
}
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
return token
}
// Reassemble token from chunks
if len(sd.refreshTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.refreshTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
return strings.Join(chunks, "")
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
sd.refreshTokenChunks = make(map[int]*sessions.Session)
if len(token) <= maxCookieSize {
sd.refreshSession.Values["token"] = token
} else {
// Split token into chunks
sd.refreshSession.Values["token"] = ""
chunks := splitIntoChunks(token, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.refreshTokenChunks[i] = session
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
if len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
+129
View File
@@ -0,0 +1,129 @@
package traefikoidc
import (
"net/http/httptest"
"strings"
"testing"
)
// TestSessionManager tests the SessionManager functionality
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
// Recalculate expected cookies based on new maxCookieSize
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved")
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Calculate number of chunks for access token
accessChunks := len(splitIntoChunks(accessToken, maxCookieSize))
if accessChunks > 1 {
count += accessChunks
}
// Calculate number of chunks for refresh token
refreshChunks := len(splitIntoChunks(refreshToken, maxCookieSize))
if refreshChunks > 1 {
count += refreshChunks
}
return count
}
+144 -47
View File
@@ -5,51 +5,93 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"github.com/gorilla/sessions"
"strings"
)
const (
cookieName = "_raczylo_oidc"
)
const (
headerXForwardedProto = "X-Forwarded-Proto"
headerXForwardedHost = "X-Forwarded-Host"
headerXForwardedUser = "X-Forwarded-User"
headerXUserGroups = "X-User-Groups"
headerXUserRoles = "X-User-Roles"
)
// Config holds the configuration for the OIDC middleware
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHTTPS"`
RateLimit int `json:"rateLimit"`
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
// ProviderURL is the base URL of the OIDC provider (required)
// Example: https://accounts.google.com
ProviderURL string `json:"providerURL"`
// RevocationURL is the endpoint for revoking tokens (optional)
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
// LogoutURL is the path for handling logout requests (optional)
// If not provided, it will be set to CallbackURL + "/logout"
LogoutURL string `json:"logoutURL"`
// ClientID is the OAuth 2.0 client identifier (required)
ClientID string `json:"clientID"`
// ClientSecret is the OAuth 2.0 client secret (required)
ClientSecret string `json:"clientSecret"`
// Scopes defines the OAuth 2.0 scopes to request (optional)
// Defaults to ["openid", "profile", "email"] if not provided
Scopes []string `json:"scopes"`
// LogLevel sets the logging verbosity (optional)
// Valid values: "debug", "info", "error"
// Default: "info"
LogLevel string `json:"logLevel"`
// SessionEncryptionKey is used to encrypt session data (required)
// Must be a secure random string
SessionEncryptionKey string `json:"sessionEncryptionKey"`
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
// Default: false
ForceHTTPS bool `json:"forceHTTPS"`
// RateLimit sets the maximum number of requests per second (optional)
// Default: 100
RateLimit int `json:"rateLimit"`
// ExcludedURLs lists paths that bypass authentication (optional)
// Example: ["/health", "/metrics"]
ExcludedURLs []string `json:"excludedURLs"`
// AllowedUserDomains restricts access to specific email domains (optional)
// Example: ["company.com", "subsidiary.com"]
AllowedUserDomains []string `json:"allowedUserDomains"`
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
// Example: ["admin", "developer"]
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
HTTPClient *http.Client
// OIDCEndSessionURL is the provider's end session endpoint (optional)
// If not provided, it will be discovered from provider metadata
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
// Default: "/"
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
// CreateConfig creates a new Config with default values
// CreateConfig creates a new Config with sensible default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
func CreateConfig() *Config {
c := &Config{}
@@ -72,14 +114,22 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
// Validate performs validation checks on the Config.
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid URL")
}
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
@@ -89,25 +139,58 @@ func (c *Config) Validate() error {
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < 32 {
return fmt.Errorf("sessionEncryptionKey must be at least 32 characters long")
}
if c.RateLimit < 0 {
return fmt.Errorf("rateLimit must be non-negative")
}
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
return nil
}
// Logger is a simple logger with different levels
// isValidURL checks if the provided string is a valid URL
func isValidURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme != "" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
func isValidLogLevel(level string) bool {
return level == "debug" || level == "info" || level == "error"
}
// Logger provides structured logging capabilities with different severity levels.
// It supports error, info, and debug levels with appropriate output streams
// and formatting for each level.
type Logger struct {
// logError handles error-level messages, writing to stderr
logError *log.Logger
logInfo *log.Logger
// logInfo handles informational messages, writing to stdout
logInfo *log.Logger
// logDebug handles debug-level messages, writing to stdout when debug is enabled
logDebug *log.Logger
}
// NewLogger creates a new Logger
// NewLogger creates a new Logger with the specified log level.
// The log level determines which messages are output:
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
logInfo.SetOutput(os.Stdout)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
@@ -119,37 +202,51 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
// Info logs an informational message.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
// Debug logs a debug message.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
// Error logs an error message.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
// Infof logs an informational message using Printf formatting.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
// Debugf logs a debug message using Printf formatting.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
// Errorf logs an error message using Printf formatting.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
// handleError writes an error message to both the HTTP response and the error log.
// It ensures consistent error handling across the middleware by logging the error
// and sending an appropriate HTTP response to the client.
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
+362
View File
@@ -0,0 +1,362 @@
package traefikoidc
import (
"bytes"
"log"
"net/http"
"testing"
)
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
// Check default scopes
expectedScopes := []string{"openid", "profile", "email"}
if len(config.Scopes) != len(expectedScopes) {
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
}
for i, scope := range expectedScopes {
if config.Scopes[i] != scope {
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
}
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
})
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
config *Config
expectedError string
}{
{
name: "Empty Config",
config: &Config{},
expectedError: "providerURL is required",
},
{
name: "Missing CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
},
expectedError: "callbackURL is required",
},
{
name: "Missing ClientID",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
},
expectedError: "clientID is required",
},
{
name: "Missing ClientSecret",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
},
expectedError: "clientSecret is required",
},
{
name: "Missing SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
},
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
},
{
name: "Invalid CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "callback", // Missing leading slash
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "callbackURL must start with /",
},
{
name: "Short SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "short",
},
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
},
expectedError: "rateLimit must be non-negative",
},
{
name: "Invalid LogLevel",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "invalid",
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Valid Config",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
},
expectedError: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
}
}
})
}
}
func TestLogger(t *testing.T) {
// Capture log output
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
name: "Debug Level",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut == "" {
t.Error("Expected debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Info Level",
logLevel: "info",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Error Level",
logLevel: "error",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut != "" {
t.Error("Did not expect info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Printf Methods",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debugf("debug %s", "formatted")
l.Infof("info %s", "formatted")
l.Errorf("error %s", "formatted")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
t.Error("Expected formatted debug message")
}
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
t.Error("Expected formatted info message")
}
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
t.Error("Expected formatted error message")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset buffers
debugBuf.Reset()
infoBuf.Reset()
errorBuf.Reset()
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
if tc.logLevel == "debug" {
logger.logDebug.SetOutput(&debugBuf)
}
// Run test
tc.testFunc(logger)
// Check results
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
})
}
}
func TestHandleError(t *testing.T) {
// Create a test logger with captured output
var errorBuf bytes.Buffer
logger := &Logger{
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
}
logger.logError.SetOutput(&errorBuf)
// Create a test response recorder
rr := &testResponseRecorder{
headers: make(map[string][]string),
}
// Test error handling
message := "test error message"
code := 400
handleError(rr, message, code, logger)
// Check response code
if rr.statusCode != code {
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
}
// Check response body
expectedBody := message + "\n"
if rr.body != expectedBody {
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
}
// Check error was logged
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
t.Error("Error message was not logged")
}
}
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
}
func (r *testResponseRecorder) Header() http.Header {
return r.headers
}
func (r *testResponseRecorder) Write(b []byte) (int, error) {
r.body = string(b)
return len(b), nil
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+1 -5
View File
@@ -1,7 +1,4 @@
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
# sessions
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -77,7 +74,6 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+9 -12
View File
@@ -1,6 +1,5 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
@@ -9,15 +8,13 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
}
}
+21
View File
@@ -0,0 +1,21 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+5 -10
View File
@@ -1,11 +1,8 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.11
// +build !go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -16,9 +13,7 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
MaxAge int
Secure bool
HttpOnly bool
}
+23
View File
@@ -0,0 +1,23 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -4,8 +4,8 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
## explicit; go 1.18