Compare commits

..

5 Commits

19 changed files with 658 additions and 4128 deletions
-1
View File
@@ -13,7 +13,6 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
-33
View File
@@ -4,10 +4,6 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
Middleware has been tested with Auth0 and Logto.
### Traefik version compatibility
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
### Configuration options
Middleware currently supports following scenarios:
@@ -19,35 +15,6 @@ Middleware currently supports following scenarios:
#### How to configure...
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```
##### Excluded URLs with open access
```
-169
View File
@@ -1,169 +0,0 @@
package traefikoidc
import (
"sync"
"time"
)
// CacheItem represents an item stored in the cache with its associated metadata.
type CacheItem struct {
// Value is the cached data of any type
Value interface{}
// ExpiresAt is the timestamp when this item should be considered expired
// and removed from the cache during cleanup operations
ExpiresAt time.Time
}
// Cache provides a thread-safe in-memory caching mechanism with expiration support.
// It uses a read-write mutex to ensure safe concurrent access to the cached items.
type Cache struct {
// items stores the cached data with string keys
items map[string]CacheItem
// mutex protects concurrent access to the items map
// Use RLock/RUnlock for reads and Lock/Unlock for writes
mutex sync.RWMutex
// maxSize is the maximum number of items allowed in the cache
maxSize int
// accessList maintains the order of item access for eviction
accessList []string
}
// DefaultMaxSize is the default maximum number of items in the cache
const DefaultMaxSize = 1000
// NewCache creates a new empty cache instance.
// The cache is immediately ready for use and is thread-safe.
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
maxSize: DefaultMaxSize,
accessList: make([]string, 0, DefaultMaxSize),
}
}
// Set adds or updates an item in the cache with the specified expiration duration.
// Parameters:
// - key: Unique identifier for the cached item
// - value: The data to cache (can be of any type)
// - expiration: How long the item should remain in the cache
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// If key exists, update it
if _, exists := c.items[key]; exists {
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
}
return
}
// If cache is full, remove oldest item
if len(c.items) >= c.maxSize {
c.evictOldest()
}
// Add new item
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
}
c.accessList = append(c.accessList, key)
}
// Get retrieves an item from the cache if it exists and hasn't expired.
// Parameters:
// - key: The identifier of the item to retrieve
// Returns:
// - value: The cached data (nil if not found or expired)
// - found: true if the item was found and is valid, false otherwise
// Thread-safe: Uses read locking to ensure safe concurrent access.
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
item, found := c.items[key]
c.mutex.RUnlock()
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
c.mutex.Lock()
c.removeItem(key)
c.mutex.Unlock()
return nil, false
}
// Update access order
c.mutex.Lock()
c.updateAccessOrder(key)
c.mutex.Unlock()
return item.Value, true
}
// Delete removes an item from the cache if it exists.
// If the item doesn't exist, this operation is a no-op.
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes all expired items from the cache.
// This should be called periodically to prevent memory leaks from
// expired items that haven't been accessed (and thus not removed during Get operations).
// Thread-safe: Uses write locking to ensure safe concurrent access.
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
var newAccessList []string
for _, key := range c.accessList {
if item, exists := c.items[key]; exists && !now.After(item.ExpiresAt) {
newAccessList = append(newAccessList, key)
} else {
delete(c.items, key)
}
}
c.accessList = newAccessList
}
// evictOldest removes the least recently used item from the cache
func (c *Cache) evictOldest() {
if len(c.accessList) > 0 {
oldest := c.accessList[0]
c.removeItem(oldest)
}
}
// removeItem removes an item from both the cache and access list
func (c *Cache) removeItem(key string) {
delete(c.items, key)
for i, k := range c.accessList {
if k == key {
c.accessList = append(c.accessList[:i], c.accessList[i+1:]...)
break
}
}
}
// updateAccessOrder moves the accessed key to the end of the access list
func (c *Cache) updateAccessOrder(key string) {
for i, k := range c.accessList {
if k == key {
c.accessList = append(append(c.accessList[:i], c.accessList[i+1:]...), key)
break
}
}
}
-306
View File
@@ -1,306 +0,0 @@
package traefikoidc
import (
"reflect"
"testing"
"time"
)
func TestCache(t *testing.T) {
t.Run("Basic Set and Get", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Test Set
cache.Set(key, value, expiration)
// Test Get
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value {
t.Errorf("Expected value %v, got %v", value, got)
}
})
t.Run("Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 10 * time.Millisecond
// Set with short expiration
cache.Set(key, value, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be expired")
}
})
t.Run("Delete", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Set and then delete
cache.Set(key, value, expiration)
cache.Delete(key)
// Should not find deleted key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be deleted")
}
})
t.Run("Cleanup", func(t *testing.T) {
cache := NewCache()
// Add multiple items with different expirations
cache.Set("expired1", "value1", 10*time.Millisecond)
cache.Set("expired2", "value2", 10*time.Millisecond)
cache.Set("valid", "value3", 1*time.Second)
// Wait for some items to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Check expired items are removed
_, found1 := cache.Get("expired1")
_, found2 := cache.Get("expired2")
_, found3 := cache.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid item to remain in cache")
}
})
t.Run("Concurrent Access", func(t *testing.T) {
cache := NewCache()
done := make(chan bool)
// Start multiple goroutines to access cache concurrently
for i := 0; i < 10; i++ {
go func(id int) {
key := "key"
value := "value"
expiration := 1 * time.Second
// Perform multiple operations
cache.Set(key, value, expiration)
cache.Get(key)
cache.Delete(key)
cache.Cleanup()
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
})
t.Run("Zero Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with zero expiration
cache.Set(key, value, 0)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with zero expiration to be immediately expired")
}
})
t.Run("Negative Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with negative expiration
cache.Set(key, value, -1*time.Second)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with negative expiration to be immediately expired")
}
})
t.Run("Update Existing Key", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value1 := "value1"
value2 := "value2"
expiration := 1 * time.Second
// Set initial value
cache.Set(key, value1, expiration)
// Update value
cache.Set(key, value2, expiration)
// Check updated value
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value2 {
t.Errorf("Expected updated value %v, got %v", value2, got)
}
})
t.Run("Different Value Types", func(t *testing.T) {
cache := NewCache()
expiration := 1 * time.Second
// Test with different value types
testCases := []struct {
key string
value interface{}
}{
{"string", "test"},
{"int", 42},
{"float", 3.14},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"a": 1, "b": 2}},
{"struct", struct{ Name string }{"test"}},
}
for _, tc := range testCases {
t.Run(tc.key, func(t *testing.T) {
cache.Set(tc.key, tc.value, expiration)
got, found := cache.Get(tc.key)
if !found {
t.Error("Expected to find key in cache")
}
// Use reflect.DeepEqual for comparing complex types like slices and maps
if !reflect.DeepEqual(got, tc.value) {
t.Errorf("Expected value %v, got %v", tc.value, got)
}
})
}
})
}
func TestTokenCache(t *testing.T) {
t.Run("Basic Operations", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}
expiration := 1 * time.Second
// Test Set and Get
tc.Set(token, claims, expiration)
gotClaims, found := tc.Get(token)
if !found {
t.Error("Expected to find token in cache")
}
if len(gotClaims) != len(claims) {
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
}
for k, v := range claims {
if gotClaims[k] != v {
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
}
}
// Test Delete
tc.Delete(token)
_, found = tc.Get(token)
if found {
t.Error("Expected token to be deleted")
}
})
t.Run("Expiration", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 10 * time.Millisecond
// Set with short expiration
tc.Set(token, claims, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired token
_, found := tc.Get(token)
if found {
t.Error("Expected token to be expired")
}
})
t.Run("Cleanup", func(t *testing.T) {
tc := NewTokenCache()
// Add multiple tokens with different expirations
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
// Wait for some tokens to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
tc.Cleanup()
// Check expired tokens are removed
_, found1 := tc.Get("expired1")
_, found2 := tc.Get("expired2")
_, found3 := tc.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid token to remain in cache")
}
})
t.Run("Token Prefix", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 1 * time.Second
// Set token
tc.Set(token, claims, expiration)
// Verify internal storage uses prefix
_, found := tc.cache.Get("t-" + token)
if !found {
t.Error("Expected to find prefixed token in underlying cache")
}
})
}
+2 -4
View File
@@ -1,13 +1,11 @@
module github.com/lukaszraczylo/traefikoidc
go 1.23
toolchain go1.23.1
go 1.22.2
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
golang.org/x/time v0.5.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
+2 -2
View File
@@ -6,5 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+166 -259
View File
@@ -6,37 +6,16 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
// newSessionOptions creates secure session cookie options.
// Parameters:
// - isSecure: Whether to set the Secure flag on cookies
// Returns session options configured for security with:
// - HttpOnly flag to prevent JavaScript access
// - SameSite=Lax for CSRF protection
// - Appropriate timeout and path settings
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce creates a cryptographically secure random nonce
// for use in the OIDC authentication flow. The nonce is used to
// prevent replay attacks by ensuring the token received matches
// the authentication request.
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -46,34 +25,14 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// TokenResponse represents the response from the OIDC token endpoint.
// It contains the various tokens and metadata returned after successful
// code exchange or token refresh operations.
type TokenResponse struct {
// IDToken is the OIDC ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth 2.0 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth 2.0 refresh token for obtaining new tokens
RefreshToken string `json:"refresh_token"`
// ExpiresIn is the lifetime in seconds of the access token
ExpiresIn int `json:"expires_in"`
// TokenType is the type of token, typically "Bearer"
TokenType string `json:"token_type"`
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
// exchangeTokens performs the OAuth 2.0 token exchange with the OIDC provider.
// It supports both authorization code and refresh token grant types.
// Parameters:
// - ctx: Context for the HTTP request
// - grantType: The OAuth 2.0 grant type ("authorization_code" or "refresh_token")
// - codeOrToken: Either the authorization code or refresh token
// - redirectURL: The callback URL for authorization code grant
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -99,48 +58,119 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return &tokenResponse, nil
return result, nil
}
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken obtains new tokens using a refresh token.
// This is used to refresh access tokens before they expire.
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
t.logger.Debugf("Token response: %+v", tokenResponse)
return tokenResponse, nil
newAccessToken, ok := result["access_token"].(string)
if !ok || newAccessToken == "" {
return nil, fmt.Errorf("no access_token field in token response")
}
rawIDToken, ok := result["id_token"].(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token field in token response")
}
newRefreshToken, ok := result["refresh_token"].(string)
if !ok || newRefreshToken == "" {
return nil, fmt.Errorf("no refresh_token field in token response")
}
response := &TokenResponse{
IDToken: rawIDToken,
AccessToken: newAccessToken,
ExpiresIn: int(result["expires_in"].(float64)),
TokenType: result["token_type"].(string),
}
// The refresh token might not be returned if it hasn't changed
if newRefreshToken != refreshToken {
response.RefreshToken = newRefreshToken
} else {
response.RefreshToken = refreshToken
}
t.logger.Debug("Token response: %+v", response)
return response, nil
}
// handleExpiredToken manages token expiration by clearing the session
// and initiating a new authentication flow.
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
err := t.RevokeTokenWithProvider(idToken)
if err != nil {
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
return
}
t.RevokeToken(idToken)
}
session.Options = defaultSessionOptions
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
http.Error(rw, "Logged out", http.StatusForbidden)
}
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback processes the authentication callback from the OIDC provider.
// It validates the callback parameters, exchanges the authorization code for
// tokens, verifies the tokens, and establishes the user's session.
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -149,36 +179,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the callback
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate CSRF state
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -186,49 +186,27 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
token, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify tokens and claims
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
idToken, ok := token["id_token"].(string)
if !ok || idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce to prevent replay attacks
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Validate user's email domain
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -236,11 +214,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Update session with authentication data
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Options = defaultSessionOptions
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
@@ -248,17 +225,10 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
t.logger.Debugf("Authentication successful. User email: %s", email)
http.Redirect(rw, req, "/", http.StatusFound)
}
// extractClaims parses a JWT token and extracts its claims.
// It handles base64url decoding and JSON parsing of the token payload.
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -278,32 +248,28 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a thread-safe list of revoked tokens.
// It stores tokens with their expiration times and automatically
// removes expired entries during cleanup operations.
type TokenBlacklist struct {
// blacklist maps token IDs to their expiration times
blacklist map[string]time.Time
// mutex protects concurrent access to the blacklist
mutex sync.RWMutex
type UsedTokens struct {
tokens map[string]bool
mutex sync.RWMutex
}
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist instance.
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist with an expiration time.
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is in the blacklist and not expired.
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -311,7 +277,6 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist.
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -323,131 +288,73 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache provides a caching mechanism for validated tokens.
// It stores token claims to avoid repeated validation of the
// same token, improving performance for frequently used tokens.
type TokenCache struct {
// cache is the underlying cache implementation
cache *Cache
cache map[string]*TokenInfo
mutex sync.RWMutex
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
// NewTokenCache creates a new TokenCache instance.
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
cache: make(map[string]*TokenInfo),
}
}
// Set stores a token's claims in the cache with an expiration time.
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
tc.mutex.Lock()
defer tc.mutex.Unlock()
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
}
// Get retrieves a token's claims from the cache.
// Returns the claims and a boolean indicating if the token was found.
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
if !found {
return nil, false
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
tc.mutex.RLock()
defer tc.mutex.RUnlock()
info, exists := tc.cache[token]
if exists && time.Now().Before(info.ExpiresAt) {
return info, true
}
claims, ok := value.(map[string]interface{})
return claims, ok
return nil, false
}
// Delete removes a token from the cache.
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
tc.mutex.Lock()
defer tc.mutex.Unlock()
delete(tc.cache, token)
}
// Cleanup removes expired tokens from the cache.
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges an authorization code for tokens.
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings.
// Used for efficient lookups in allowed domains and roles.
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
// handleLogout manages the OIDC logout process.
// It clears the session and redirects either to the OIDC provider's
// end session endpoint (if available) or to the configured post-logout URL.
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
accessToken := session.GetAccessToken()
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
tc.mutex.Lock()
defer tc.mutex.Unlock()
now := time.Now()
for token, info := range tc.cache {
if now.After(info.ExpiresAt) {
delete(tc.cache, token)
}
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL with appropriate parameters.
// Parameters:
// - endSessionURL: The OIDC provider's end session endpoint
// - idToken: The ID token to be invalidated
// - postLogoutRedirectURI: Where to redirect after logout completes
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.clientID)
data.Set("client_secret", t.clientSecret)
data.Set("code", code)
data.Set("redirect_uri", t.redirectURL)
resp, err := t.httpClient.PostForm(t.tokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
return nil, fmt.Errorf("failed to exchange token: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %v", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
return result, nil
}
+65 -102
View File
@@ -4,86 +4,43 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key as defined in RFC 7517.
// It contains the cryptographic key information used for token verification.
type JWK struct {
// Kty is the key type (e.g., "RSA", "EC")
Kty string `json:"kty"`
// Kid is the unique key identifier
Kid string `json:"kid"`
// Use specifies the intended use of the key (e.g., "sig" for signature)
Use string `json:"use"`
// N is the modulus for RSA keys
N string `json:"n"`
// E is the exponent for RSA keys
E string `json:"e"`
// Alg is the algorithm intended for use with the key
N string `json:"n"`
E string `json:"e"`
Alg string `json:"alg"`
// Crv is the curve for EC keys (e.g., "P-256", "P-384", "P-521")
Crv string `json:"crv"`
// X is the x-coordinate for EC keys
X string `json:"x"`
// Y is the y-coordinate for EC keys
Y string `json:"y"`
X string `json:"x"`
Y string `json:"y"`
}
// JWKSet represents a set of JSON Web Keys as returned by the JWKS endpoint.
// OIDC providers typically expose multiple keys to support key rotation.
type JWKSet struct {
// Keys is the array of JSON Web Keys
Keys []JWK `json:"keys"`
}
// JWKCache provides a thread-safe caching mechanism for JWK sets.
// It caches the keys for a configurable duration to reduce load on the OIDC provider
// while ensuring keys are refreshed periodically to handle key rotation.
type JWKCache struct {
// jwks holds the cached set of JSON Web Keys
jwks *JWKSet
// expiresAt is the timestamp when the cached keys should be refreshed
jwks *JWKSet
expiresAt time.Time
// mutex protects concurrent access to the cache
mutex sync.RWMutex
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for JWK caching operations.
// This interface allows for different caching implementations while
// maintaining consistent behavior in the token verification process.
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS retrieves the JSON Web Key Set, either from cache or by fetching it
// from the OIDC provider. It implements a thread-safe double-checked locking
// pattern to prevent multiple simultaneous fetches of the same keys.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for fetching keys
// Returns:
// - The JSON Web Key Set
// - An error if the keys cannot be retrieved or parsed
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -110,14 +67,6 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS retrieves the JSON Web Key Set from the OIDC provider's JWKS endpoint.
// It handles HTTP communication and JSON parsing of the response.
// Parameters:
// - jwksURL: The URL of the JWKS endpoint
// - httpClient: The HTTP client to use for the request
// Returns:
// - The parsed JSON Web Key Set
// - An error if the request fails or the response is invalid
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -137,68 +86,82 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JSON Web Key to PEM format for use with standard
// cryptographic functions. It supports both RSA and EC keys, delegating
// to the appropriate converter based on the key type.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return converter(jwk)
}
type jwkToPEMConverter func(*JWK) ([]byte, error)
var jwkConverters = map[string]jwkToPEMConverter{
"RSA": rsaJWKToPEM,
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JSON Web Key to PEM format.
// It handles base64url decoding of the modulus and exponent,
// constructs an RSA public key, and encodes it in PEM format.
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
})
return pubKeyPEM, nil
return publicKeyPEM, nil
}
// ecJWKToPEM converts an EC (Elliptic Curve) JSON Web Key to PEM format.
// It supports the P-256, P-384, and P-521 curves as defined in the
// OIDC specification, decoding the x and y coordinates and encoding
// the resulting public key in PEM format.
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
@@ -215,21 +178,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
pubKey := &ecdsa.PublicKey{
publicKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
Bytes: publicKeyBytes,
})
return pubKeyPEM, nil
return publicKeyPEM, nil
}
+40 -276
View File
@@ -4,52 +4,29 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
)
// JWT represents a JSON Web Token as defined in RFC 7519.
// It contains the three parts of a JWT: header, claims (payload),
// and signature, along with the original token string.
type JWT struct {
// Header contains the token metadata (algorithm, key ID, etc.)
Header map[string]interface{}
// Claims contains the token claims (subject, expiration, etc.)
Claims map[string]interface{}
// Signature contains the raw signature bytes
Signature []byte
// Token is the original JWT string
Token string
Header map[string]interface{}
Claims map[string]interface{}
Signature string
}
// parseJWT parses a JWT token string into a JWT struct.
// It validates the token format and decodes the three parts
// (header, claims, signature) using base64url decoding.
// Parameters:
// - tokenString: The raw JWT token string
// Returns:
// - A parsed JWT struct
// - An error if the token format is invalid or parsing fails
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
jwt := &JWT{
Token: tokenString,
}
jwt := &JWT{}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -69,51 +46,14 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
// Set the signature
jwt.Signature = parts[2]
return jwt, nil
}
// Verify validates the standard JWT claims as defined in RFC 7519.
// It checks:
// - issuer (iss) matches the expected issuer URL
// - audience (aud) includes the client ID
// - expiration time (exp) is in the future (with clock skew tolerance)
// - issued at time (iat) is in the past (with clock skew tolerance)
// - not before time (nbf) is in the past (with clock skew tolerance)
// - subject (sub) is present and not empty
// - algorithm matches expected value to prevent algorithm switching attacks
// Returns an error if any validation fails.
func (j *JWT) Verify(issuerURL, clientID string) error {
// Debug logging of validation parameters
fmt.Printf("Validating token against:\nIssuer: %s\nClient ID: %s\n", issuerURL, clientID)
// Debug logging of token header
fmt.Printf("Token header: %+v\n", j.Header)
// Validate algorithm to prevent algorithm switching attacks
alg, ok := j.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing 'alg' header")
}
// List of supported algorithms - should match those in verifySignature
supportedAlgs := map[string]bool{
"RS256": true, "RS384": true, "RS512": true,
"PS256": true, "PS384": true, "PS512": true,
"ES256": true, "ES384": true, "ES512": true,
}
if !supportedAlgs[alg] {
return fmt.Errorf("unsupported algorithm: %s", alg)
}
claims := j.Claims
// Debug logging of all claims
fmt.Printf("Token claims: %+v\n", claims)
iss, ok := claims["iss"].(string)
if !ok {
@@ -147,19 +87,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return err
}
// Validate nbf (not before) claim if present
if nbf, ok := claims["nbf"].(float64); ok {
if err := verifyNotBefore(nbf); err != nil {
return err
}
}
// Validate jti (JWT ID) claim if present
if jti, ok := claims["jti"].(string); ok {
// Could add replay detection here if needed
_ = jti
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
@@ -168,191 +95,25 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience validates the token's audience claim.
// The audience can be either a single string or an array of strings.
// For array audiences, the expected audience must match any one value.
// Parameters:
// - tokenAudience: The audience claim from the token
// - expectedAudience: The expected audience value
// Returns an error if validation fails.
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
// Debug logging
fmt.Printf("Verifying audience:\nToken aud: %+v\nExpected: %s\n",
tokenAudience, expectedAudience)
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
// verifyIssuer validates the token's issuer claim.
// The issuer URL must exactly match the expected issuer.
// Parameters:
// - tokenIssuer: The issuer claim from the token
// - expectedIssuer: The expected issuer URL
// Returns an error if validation fails.
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
// Debug logging
fmt.Printf("Verifying issuer:\nToken iss: %s\nExpected: %s\n",
tokenIssuer, expectedIssuer)
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer (token: %s, expected: %s)",
tokenIssuer, expectedIssuer)
}
return nil
}
// Clock skew tolerance for time-based validations
const clockSkewTolerance = 2 * time.Minute
// verifyExpiration checks if the token's expiration time has passed.
// The expiration time is compared against the current time with clock skew tolerance.
// Parameters:
// - expiration: The expiration timestamp from the token
// Returns an error if the token has expired.
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(clockSkewTolerance)
// Debug logging
fmt.Printf("Token exp: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
expirationTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that expire exactly now
if expirationTime.Equal(now) {
return nil
}
if skewedNow.After(expirationTime) {
return fmt.Errorf("token has expired (exp: %v, now: %v)",
expirationTime.UTC(), now.UTC())
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
}
return nil
}
// verifyIssuedAt validates the token's issued-at time.
// Ensures the token wasn't issued in the future, accounting for clock skew.
// Parameters:
// - issuedAt: The issued-at timestamp from the token
// Returns an error if the token was issued in the future.
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token iat: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
issuedAtTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens issued in the same second as current time
if issuedAtTime.Equal(now) {
return nil
}
if skewedNow.Before(issuedAtTime) {
return fmt.Errorf("token used before issued (iat: %v, now: %v)",
issuedAtTime.UTC(), now.UTC())
}
return nil
}
// verifyNotBefore validates the token's not-before time if present.
// Ensures the token is not used before its valid time period, accounting for clock skew.
// Parameters:
// - notBefore: The not-before timestamp from the token
// Returns an error if the token is not yet valid.
func verifyNotBefore(notBefore float64) error {
notBeforeTime := time.Unix(int64(notBefore), 0)
// Truncate current time to seconds for consistent comparison
now := time.Now().Truncate(time.Second)
skewedNow := now.Add(-clockSkewTolerance)
// Debug logging
fmt.Printf("Token nbf: %v\nCurrent time: %v\nSkewed time: %v\nSkew: %v\n",
notBeforeTime.UTC(),
now.UTC(),
skewedNow.UTC(),
clockSkewTolerance)
// Allow tokens that become valid exactly now
if notBeforeTime.Equal(now) {
return nil
}
if skewedNow.Before(notBeforeTime) {
return fmt.Errorf("token not yet valid (nbf: %v, now: %v)",
notBeforeTime.UTC(), now.UTC())
}
return nil
}
// verifySignature validates the token's cryptographic signature.
// Supports multiple signature algorithms:
// - RSA: RS256, RS384, RS512 (PKCS#1 v1.5)
// - RSA-PSS: PS256, PS384, PS512
// - ECDSA: ES256, ES384, ES512
// Parameters:
// - tokenString: The complete JWT token string
// - publicKeyPEM: The PEM-encoded public key for verification
// - alg: The signature algorithm identifier
// Returns an error if signature verification fails.
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Debug logging
fmt.Printf("Verifying signature with algorithm: %s\n", alg)
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
@@ -366,42 +127,45 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pub, hashed, r, s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
return fmt.Errorf("unsupported public key type: %T", pub)
}
}
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
+232 -286
View File
@@ -2,6 +2,7 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -13,25 +14,24 @@ import (
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
const ConstSessionTimeout = 86400 // Session timeout in seconds
const ConstSessionTimeout = 86400
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// JWTVerifier interface for JWT verification
type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -50,99 +50,81 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
// Rate limiting
t.logger.Debugf("Verifying token: %s", token)
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
t.tokenCache.Set(token, jwt.Claims, duration)
t.tokenCache.Set(token, expirationTime)
return nil
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT signature and claims")
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Retrieve key ID and algorithm from JWT header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
// Find the matching key in JWKS
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
@@ -150,49 +132,63 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
// Convert JWK to PEM format
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
// Verify the signature
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
// Setup HTTP client
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: 15 * time.Second, // Reduced timeout
KeepAlive: 15 * time.Second, // Reduced keepalive
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 5 * time.Second, // Reduced from 10s
ExpectContinueTimeout: 0,
MaxIdleConns: 30, // Reduced from 100
MaxIdleConnsPerHost: 10, // Reduced from 100
IdleConnTimeout: 30 * time.Second, // Reduced from 90s
DisableKeepAlives: false, // Enable connection reuse
MaxConnsPerHost: 50, // Limit max connections
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: 10,
}
var httpClient *http.Client
@@ -200,14 +196,20 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient = config.HTTPClient
} else {
httpClient = &http.Client{
Timeout: time.Second * 15, // Reduced timeout
Timeout: time.Second * 30,
Transport: transport,
}
}
metadata, err := discoverProviderMetadata(config.ProviderURL, httpClient, NewLogger(config.LogLevel))
if err != nil {
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
}
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -215,36 +217,50 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
issuerURL: metadata.Issuer,
revocationURL: metadata.RevokeURL,
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
jwksURL: metadata.JWKSURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
authURL: metadata.AuthURL,
tokenURL: metadata.TokenURL,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: func() map[string]struct{} {
m := make(map[string]struct{})
for _, url := range config.ExcludedURLs {
m[url] = struct{}{}
}
return config.PostLogoutRedirectURI
return m
}(),
redirectURL: "",
allowedUserDomains: func() map[string]struct{} {
m := make(map[string]struct{})
for _, domain := range config.AllowedUserDomains {
m[domain] = struct{}{}
}
return m
}(),
allowedRolesAndGroups: func() map[string]struct{} {
m := make(map[string]struct{})
for _, roleOrGroup := range config.AllowedRolesAndGroups {
m[roleOrGroup] = struct{}{}
}
return m
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
t.extractClaimsFunc = extractClaims
// Add default excluded URLs
// add defaultExcludedURLs to excludedURLs
for k, v := range defaultExcludedURLs {
t.excludedURLs[k] = v
}
@@ -252,53 +268,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.tokenVerifier = t
t.jwtVerifier = t
t.startTokenCleanup()
go t.initializeMetadata(config.ProviderURL)
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.logger.Debug("Starting provider metadata discovery")
// Keep retrying until successful
backoff := time.Second
maxBackoff := 30 * time.Second
for {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v, retrying in %v", err, backoff)
time.Sleep(backoff)
// Exponential backoff with max
backoff *= 2
if backoff > maxBackoff {
backoff = maxBackoff
}
continue
}
if metadata != nil {
t.logger.Debug("Successfully initialized provider metadata")
t.jwksURL = metadata.JWKSURL
t.authURL = metadata.AuthURL
t.tokenURL = metadata.TokenURL
t.issuerURL = metadata.Issuer
t.revocationURL = metadata.RevokeURL
t.endSessionURL = metadata.EndSessionURL
// Only close channel on success
close(t.initComplete)
return
}
t.logger.Error("Received nil metadata, retrying")
time.Sleep(backoff)
}
}
// discoverProviderMetadata fetches the OIDC provider metadata
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -312,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if time.Since(start) > totalTimeout {
l.Errorf("Timeout exceeded while fetching provider metadata")
l.Error("Timeout exceeded while fetching provider metadata")
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
}
@@ -324,20 +296,18 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
lastErr = err
// Exponential backoff
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
if delay > maxDelay {
delay = maxDelay
}
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
time.Sleep(delay)
}
l.Errorf("Max retries exceeded while fetching provider metadata")
l.Error("Max retries exceeded while fetching provider metadata")
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -360,138 +330,138 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Error("OIDC provider metadata initialization failed")
http.Error(rw, "OIDC provider metadata initialization failed - please check provider availability", http.StatusServiceUnavailable)
return
}
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
case <-time.After(30 * time.Second):
t.logger.Error("Timeout waiting for OIDC initialization")
http.Error(rw, "Timeout waiting for OIDC provider initialization - please try again", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Get session
session, err := t.sessionManager.GetSession(req)
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
t.handleCallback(rw, req)
return
}
// Check authentication status
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
}
// Process authenticated request
email := session.GetEmail()
if email == "" {
t.logger.Debug("No email found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
// authenticated, _ := session.Values["authenticated"].(bool)
if authenticated {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
return
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
t.logger.Debug("User is not authenticated, initiating authentication")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
t.logger.Debug("URL is not excluded - got %s", currentRequest)
return false
}
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
@@ -505,7 +475,6 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
return "http"
}
// determineHost determines the host of the request
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -513,35 +482,37 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(accessToken); err != nil {
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(accessToken)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true
return false, false, true // Can't read claims, consider it expired
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
}
now := time.Now().Unix()
@@ -549,7 +520,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true
return false, false, true // Token has expired
}
gracePeriod := time.Minute * 5
@@ -558,42 +529,38 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build and redirect to auth URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -607,7 +574,6 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
go func() {
@@ -619,23 +585,26 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
}
// RevokeTokenWithProvider revokes the token with the provider
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
t.logger.Debugf("Revoking token with provider")
data := url.Values{
"token": {token},
"token_type_hint": {tokenType},
"token_type_hint": {"access_token", "refresh_token"},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
@@ -666,12 +635,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
return nil
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
return false
}
@@ -681,17 +648,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new access token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
return false
}
// Update session with new tokens
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
@@ -700,7 +659,6 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
@@ -716,7 +674,6 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
@@ -726,48 +683,37 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
var groups []string
var roles []string
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
if len(groups) == 0 {
t.logger.Debug("No groups found in groups claim, checking roles claim")
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debug("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
if len(roles) == 0 {
t.logger.Debug("No roles found in roles claim")
}
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
-57
View File
@@ -1,57 +0,0 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
func BenchmarkOIDCMiddleware(b *testing.B) {
// Setup test environment
ts := &TestSuite{}
ts.Setup()
ts.token = "valid.jwt.token"
// Define the handler with OIDC middleware
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create test server
server := httptest.NewServer(ts.tOidc.next)
defer server.Close()
// Prepare HTTP client
client := &http.Client{}
// Reset timer to exclude setup time
b.ResetTimer()
// Run benchmark
for i := 0; i < b.N; i++ {
// Create new request
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
b.Fatal(err)
}
// Set necessary headers or cookies
req.Header.Set("Authorization", "Bearer "+ts.token)
// Send the request
resp, err := client.Do(req)
if err != nil {
b.Fatal(err)
}
// Close response body
resp.Body.Close()
// Check response status code
if resp.StatusCode != http.StatusOK {
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
}
}
}
+86 -1163
View File
File diff suppressed because it is too large Load Diff
-592
View File
@@ -1,592 +0,0 @@
package traefikoidc
import (
"bytes"
"compress/gzip"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/sessions"
)
// generateSecureRandomString creates a cryptographically secure random string of specified length
func generateSecureRandomString(length int) string {
bytes := make([]byte, length)
if _, err := rand.Read(bytes); err != nil {
panic("failed to generate random string")
}
return hex.EncodeToString(bytes)
}
// Cookie names and configuration constants used for session management
var (
// Using random prefixes to make cookie names less predictable
mainCookieName = "_oidc_m_" + generateSecureRandomString(8)
accessTokenCookie = "_oidc_a_" + generateSecureRandomString(8)
refreshTokenCookie = "_oidc_r_" + generateSecureRandomString(8)
)
const (
// maxCookieSize is the maximum size for each cookie chunk.
// This value is calculated to ensure the final cookie size stays within browser limits:
// 1. Browser cookie size limit is typically 4096 bytes
// 2. Cookie content undergoes encryption (adds 28 bytes) and base64 encoding (4/3 ratio)
// 3. Calculation:
// - Let x be the chunk size
// - After encryption: x + 28 bytes
// - After base64: ((x + 28) * 4/3) bytes
// - Must satisfy: ((x + 28) * 4/3) ≤ 4096
// - Solving for x: x ≤ 3044
// 4. We use 2000 as a conservative limit to account for cookie metadata
maxCookieSize = 2000
// absoluteSessionTimeout defines the maximum lifetime of a session
// regardless of activity (24 hours)
absoluteSessionTimeout = 24 * time.Hour
// minEncryptionKeyLength defines the minimum length for the encryption key
minEncryptionKeyLength = 32
)
// compressToken compresses a token using gzip and base64 encodes it
func compressToken(token string) string {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
if _, err := gz.Write([]byte(token)); err != nil {
return token // fallback to uncompressed on error
}
if err := gz.Close(); err != nil {
return token
}
return base64.StdEncoding.EncodeToString(b.Bytes())
}
// decompressToken decompresses a base64 encoded gzipped token
func decompressToken(compressed string) string {
data, err := base64.StdEncoding.DecodeString(compressed)
if err != nil {
return compressed // return as-is if not base64
}
gz, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return compressed
}
defer gz.Close()
decompressed, err := io.ReadAll(gz)
if err != nil {
return compressed
}
return string(decompressed)
}
// SessionManager handles the management of multiple session cookies for OIDC authentication.
// It provides functionality for storing and retrieving authentication state, tokens,
// and other session-related data across multiple cookies to handle large tokens.
type SessionManager struct {
// store is the underlying session store for cookie management
store sessions.Store
// forceHTTPS enforces secure cookie attributes regardless of request scheme
forceHTTPS bool
// logger provides structured logging capabilities
logger *Logger
// sessionPool is a sync.Pool for reusing SessionData objects
sessionPool sync.Pool
}
// NewSessionManager creates a new session manager with the specified configuration.
// Parameters:
// - encryptionKey: Key used to encrypt session data (must be at least 32 bytes)
// - forceHTTPS: When true, forces secure cookie attributes regardless of request scheme
// - logger: Logger instance for recording session-related events
// The manager handles session creation, storage, and cookie security settings.
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
// Validate encryption key length
if len(encryptionKey) < minEncryptionKeyLength {
panic(fmt.Sprintf("encryption key must be at least %d bytes long", minEncryptionKeyLength))
}
sm := &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
// Initialize session pool
sm.sessionPool.New = func() interface{} {
return &SessionData{
manager: sm,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
}
}
return sm
}
// getSessionOptions returns secure session options configured for the current request.
// Parameters:
// - isSecure: Whether the current request is using HTTPS
// The options ensure cookies are:
// - HTTP-only (not accessible via JavaScript)
// - Secure when using HTTPS or when forceHTTPS is enabled
// - Using SameSite=Lax for CSRF protection
// - Set with appropriate timeout and path settings
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: int(absoluteSessionTimeout.Seconds()),
Path: "/",
}
}
// GetSession retrieves all session data for the current request.
// It loads the main session and token sessions, including any chunked token data,
// and combines them into a single SessionData structure for easy access.
// Returns an error if any session component cannot be loaded.
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
// Get session from pool
sessionData := sm.sessionPool.Get().(*SessionData)
sessionData.request = r
var err error
sessionData.mainSession, err = sm.store.Get(r, mainCookieName)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get main session: %w", err)
}
// Check for absolute session timeout
if createdAt, ok := sessionData.mainSession.Values["created_at"].(int64); ok {
if time.Since(time.Unix(createdAt, 0)) > absoluteSessionTimeout {
sessionData.Clear(r, nil) // Clear expired session
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("session expired")
}
}
sessionData.accessSession, err = sm.store.Get(r, accessTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
sessionData.refreshSession, err = sm.store.Get(r, refreshTokenCookie)
if err != nil {
sm.sessionPool.Put(sessionData)
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
// Clear and reuse chunk maps
for k := range sessionData.accessTokenChunks {
delete(sessionData.accessTokenChunks, k)
}
for k := range sessionData.refreshTokenChunks {
delete(sessionData.refreshTokenChunks, k)
}
// Retrieve chunked token sessions
sm.getTokenChunkSessions(r, accessTokenCookie, sessionData.accessTokenChunks)
sm.getTokenChunkSessions(r, refreshTokenCookie, sessionData.refreshTokenChunks)
return sessionData, nil
}
// getTokenChunkSessions retrieves all session chunks for a given token type.
// Parameters:
// - r: The HTTP request
// - baseName: The base name for the token's session cookies
// - chunks: Map to store the chunks in
func (sm *SessionManager) getTokenChunkSessions(r *http.Request, baseName string, chunks map[int]*sessions.Session) {
for i := 0; ; i++ {
sessionName := fmt.Sprintf("%s_%d", baseName, i)
session, err := sm.store.Get(r, sessionName)
if err != nil || session.IsNew {
// No more sessions
break
}
chunks[i] = session
}
}
// SessionData holds all session information for an authenticated user.
// It manages multiple session cookies to handle the main session state
// and potentially large access and refresh tokens that may need to be
// split across multiple cookies due to browser size limitations.
type SessionData struct {
// manager is the SessionManager that created this SessionData
manager *SessionManager
// request is the current HTTP request associated with this session
request *http.Request
// mainSession stores authentication state and basic user info
mainSession *sessions.Session
// accessSession stores the primary access token cookie
accessSession *sessions.Session
// refreshSession stores the primary refresh token cookie
refreshSession *sessions.Session
// accessTokenChunks stores additional chunks of the access token
// when it exceeds the maximum cookie size
accessTokenChunks map[int]*sessions.Session
// refreshTokenChunks stores additional chunks of the refresh token
// when it exceeds the maximum cookie size
refreshTokenChunks map[int]*sessions.Session
}
// Save persists all session data to cookies in the HTTP response.
// It saves the main session, token sessions, and any token chunks,
// applying appropriate security options to each cookie. All cookies
// are saved with consistent security settings based on the request scheme.
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
options := sd.manager.getSessionOptions(isSecure)
sd.mainSession.Options = options
sd.accessSession.Options = options
sd.refreshSession.Options = options
// Save main session
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
// Save access token session
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
// Save refresh token session
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
// Save access token chunks
for _, session := range sd.accessTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token chunk session: %w", err)
}
}
// Save refresh token chunks
for _, session := range sd.refreshTokenChunks {
session.Options = options
if err := session.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token chunk session: %w", err)
}
}
return nil
}
// Clear removes all session data by expiring all cookies and clearing their values.
// This is typically used during logout to ensure all session data is properly cleaned up.
// It handles both main session data and any token chunks that may exist.
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
// Clear chunk sessions
sd.clearTokenChunks(r, sd.accessTokenChunks)
sd.clearTokenChunks(r, sd.refreshTokenChunks)
var err error
if w != nil {
err = sd.Save(r, w)
}
// Return session to pool
sd.manager.sessionPool.Put(sd)
return err
}
// clearTokenChunks removes all session chunks for a given token type.
// It expires the cookies and removes all stored values to ensure
// no token data remains after logout or token invalidation.
func (sd *SessionData) clearTokenChunks(r *http.Request, chunks map[int]*sessions.Session) {
for _, session := range chunks {
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
}
}
// GetAuthenticated returns whether the current session is authenticated.
// Returns true if the user has successfully completed OIDC authentication
// and the session hasn't expired, false otherwise.
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
if !auth {
return false
}
// Check session expiration
createdAt, ok := sd.mainSession.Values["created_at"].(int64)
if !ok {
return false
}
return time.Since(time.Unix(createdAt, 0)) <= absoluteSessionTimeout
}
// SetAuthenticated updates the session's authentication status and rotates session ID.
// This should be called after successful OIDC authentication or during logout.
// Session ID rotation helps prevent session fixation attacks.
func (sd *SessionData) SetAuthenticated(value bool) {
if value {
// Generate new session ID and set creation time
sd.mainSession.ID = generateSecureRandomString(32)
sd.mainSession.Values["created_at"] = time.Now().Unix()
}
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken retrieves the complete access token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
if len(sd.accessTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.accessTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
token = strings.Join(chunks, "")
compressed, _ := sd.accessSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetAccessToken stores the access token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetAccessToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.accessTokenChunks)
sd.accessTokenChunks = make(map[int]*sessions.Session)
// Compress token
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.accessSession.Values["token"] = compressed
sd.accessSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
sd.accessSession.Values["token"] = ""
sd.accessSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", accessTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.accessTokenChunks[i] = session
}
}
}
// GetRefreshToken retrieves the complete refresh token from the session.
// If the token was split into chunks due to size limitations, it will
// automatically reassemble the complete token from all chunks.
// Returns an empty string if no token is found.
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
if token != "" {
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// Reassemble token from chunks
if len(sd.refreshTokenChunks) == 0 {
return ""
}
var chunks []string
for i := 0; ; i++ {
session, ok := sd.refreshTokenChunks[i]
if !ok {
break
}
chunk, _ := session.Values["token_chunk"].(string)
chunks = append(chunks, chunk)
}
token = strings.Join(chunks, "")
compressed, _ := sd.refreshSession.Values["compressed"].(bool)
if compressed {
return decompressToken(token)
}
return token
}
// SetRefreshToken stores the refresh token in the session.
// If the token exceeds maxCookieSize, it is automatically compressed and split into
// multiple cookie chunks to handle large tokens while staying within
// browser cookie size limits. Any existing token or chunks are cleared
// before setting the new token.
func (sd *SessionData) SetRefreshToken(token string) {
// Clear existing chunks
sd.clearTokenChunks(sd.request, sd.refreshTokenChunks)
sd.refreshTokenChunks = make(map[int]*sessions.Session)
// Compress token
compressed := compressToken(token)
if len(compressed) <= maxCookieSize {
sd.refreshSession.Values["token"] = compressed
sd.refreshSession.Values["compressed"] = true
} else {
// Split compressed token into chunks
sd.refreshSession.Values["token"] = ""
sd.refreshSession.Values["compressed"] = true
chunks := splitIntoChunks(compressed, maxCookieSize)
for i, chunk := range chunks {
sessionName := fmt.Sprintf("%s_%d", refreshTokenCookie, i)
session, _ := sd.manager.store.Get(sd.request, sessionName)
session.Values["token_chunk"] = chunk
sd.refreshTokenChunks[i] = session
}
}
}
// splitIntoChunks splits a string into chunks of specified size.
// This is used internally to handle large tokens that exceed cookie size limits.
// Parameters:
// - s: The string to split
// - chunkSize: Maximum size of each chunk
// Returns an array of string chunks, each no larger than chunkSize.
func splitIntoChunks(s string, chunkSize int) []string {
var chunks []string
for len(s) > 0 {
if len(s) > chunkSize {
chunks = append(chunks, s[:chunkSize])
s = s[chunkSize:]
} else {
chunks = append(chunks, s)
break
}
}
return chunks
}
// GetCSRF retrieves the CSRF token from the session.
// This token is used to prevent cross-site request forgery attacks
// by ensuring requests originate from the authenticated user.
// Returns an empty string if no CSRF token is found.
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF stores a new CSRF token in the session.
// This should be called when initiating authentication to generate
// a new token for the authentication flow.
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce retrieves the nonce value from the session.
// The nonce is used to prevent replay attacks in the OIDC flow
// by ensuring the token received matches the authentication request.
// Returns an empty string if no nonce is found.
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce stores a new nonce value in the session.
// This should be called when initiating authentication to generate
// a new nonce for the OIDC authentication flow.
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail retrieves the authenticated user's email address from the session.
// The email is typically extracted from the OIDC ID token claims.
// Returns an empty string if no email is found.
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail stores the user's email address in the session.
// This should be called after successful authentication when
// processing the OIDC ID token claims.
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath retrieves the original request path that triggered
// the authentication flow. This is used to redirect the user back
// to their intended destination after successful authentication.
// Returns an empty string if no path was stored.
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath stores the original request path that triggered
// the authentication flow. This should be called before redirecting
// to the OIDC provider to remember where to send the user afterward.
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
-268
View File
@@ -1,268 +0,0 @@
package traefikoidc
import (
"math/rand"
"net/http/httptest"
"strings"
"testing"
"time"
)
func init() {
// Initialize random seed
rand.Seed(time.Now().UnixNano())
}
// generateRandomString creates a random string of specified length
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[rand.Intn(len(charset))]
}
return string(b)
}
// TestTokenCompression tests the token compression functionality
func TestTokenCompression(t *testing.T) {
tests := []struct {
name string
token string
wantSize int // Expected size after compression (approximate)
}{
{
name: "Short token",
token: "shorttoken",
wantSize: 50, // Base64 encoded gzip has overhead for small content
},
{
name: "Repeating content",
token: strings.Repeat("abcdef", 1000),
wantSize: 100, // Should compress well due to repetition
},
{
name: "Random content",
token: generateRandomString(1000),
wantSize: 2000, // Random content won't compress much
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
compressed := compressToken(tt.token)
decompressed := decompressToken(compressed)
// Only verify compression ratio for non-short tokens
if len(tt.token) > 100 {
compressionRatio := float64(len(compressed)) / float64(len(tt.token))
t.Logf("Compression ratio for %s: %.2f", tt.name, compressionRatio)
if compressionRatio > 1.1 { // Allow up to 10% size increase
t.Errorf("Compression increased size too much: original=%d, compressed=%d, ratio=%.2f",
len(tt.token), len(compressed), compressionRatio)
}
}
// Verify decompression restores original
if decompressed != tt.token {
t.Error("Decompression failed to restore original token")
}
// Verify approximate compression ratio
if len(compressed) > tt.wantSize*2 {
t.Errorf("Compression ratio worse than expected: got=%d, want<%d", len(compressed), tt.wantSize*2)
}
})
}
}
// TestSessionManager tests the SessionManager functionality
func TestSessionManager(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
tests := []struct {
name string
authenticated bool
email string
accessToken string
refreshToken string
expectedCookieCount int
wantCompressed bool // Whether tokens should be compressed
}{
{
name: "Short tokens",
authenticated: true,
email: "test@example.com",
accessToken: "shortaccesstoken",
refreshToken: "shortrefreshtoken",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: true,
},
{
name: "Long tokens exceeding 4096 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 5000),
refreshToken: strings.Repeat("y", 6000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 5000), strings.Repeat("y", 6000)),
wantCompressed: true,
},
{
name: "REALLY long tokens, exceeding 25000 bytes",
authenticated: true,
email: "test@example.com",
accessToken: strings.Repeat("x", 25000),
refreshToken: strings.Repeat("y", 25000),
expectedCookieCount: calculateExpectedCookieCount(strings.Repeat("x", 25000), strings.Repeat("y", 25000)),
wantCompressed: true,
},
{
name: "Unauthenticated session",
authenticated: false,
email: "",
accessToken: "",
refreshToken: "",
expectedCookieCount: 3, // main, access, refresh
wantCompressed: false,
},
{
name: "Random content tokens",
authenticated: true,
email: "test@example.com",
accessToken: generateRandomString(5000),
refreshToken: generateRandomString(5000),
expectedCookieCount: calculateExpectedCookieCount(generateRandomString(5000), generateRandomString(5000)),
wantCompressed: true,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := ts.sessionManager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Set session values
session.SetAuthenticated(tc.authenticated)
session.SetEmail(tc.email)
session.SetAccessToken(tc.accessToken)
session.SetRefreshToken(tc.refreshToken)
// Save session
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set and compression is used when appropriate
cookies := rr.Result().Cookies()
if len(cookies) != tc.expectedCookieCount {
t.Errorf("Expected %d cookies, got %d", tc.expectedCookieCount, len(cookies))
}
// Verify compression is working by checking token sizes
for _, cookie := range cookies {
if strings.Contains(cookie.Name, accessTokenCookie) {
// Get original and stored sizes
originalSize := len(tc.accessToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
// For large tokens, verify some compression occurred
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Access token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 { // Allow some overhead, but should see compression
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
} else if strings.Contains(cookie.Name, refreshTokenCookie) {
originalSize := len(tc.refreshToken)
storedSize := len(cookie.Value)
if originalSize > 100 && tc.wantCompressed {
compressionRatio := float64(storedSize) / float64(originalSize)
t.Logf("Refresh token compression ratio: %.2f (original: %d, stored: %d)",
compressionRatio, originalSize, storedSize)
if compressionRatio > 0.9 {
t.Errorf("Expected compression for large token in cookie %s (ratio: %.2f)",
cookie.Name, compressionRatio)
}
}
}
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := ts.sessionManager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
// Verify session values
if newSession.GetAuthenticated() != tc.authenticated {
t.Errorf("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != tc.email {
t.Errorf("Expected email %s, got %s", tc.email, email)
}
if token := newSession.GetAccessToken(); token != tc.accessToken {
t.Errorf("Access token not preserved: got len=%d, want len=%d", len(token), len(tc.accessToken))
}
if token := newSession.GetRefreshToken(); token != tc.refreshToken {
t.Errorf("Refresh token not preserved: got len=%d, want len=%d", len(token), len(tc.refreshToken))
}
// Verify session pooling by checking if the session is reused
session2, _ := ts.sessionManager.GetSession(newReq)
if session2 == newSession {
t.Error("Session not properly pooled")
}
})
}
}
func calculateExpectedCookieCount(accessToken, refreshToken string) int {
count := 3 // main, access, refresh
// Helper to calculate chunks for compressed token
calculateChunks := func(token string) int {
// Compress token (matching the actual implementation)
compressed := compressToken(token)
// If compressed token fits in one cookie, no additional chunks needed
if len(compressed) <= maxCookieSize {
return 0
}
// Calculate chunks needed for compressed token
return len(splitIntoChunks(compressed, maxCookieSize))
}
// Add chunks for access token if needed
accessChunks := calculateChunks(accessToken)
if accessChunks > 0 {
count += accessChunks
}
// Add chunks for refresh token if needed
refreshChunks := calculateChunks(refreshToken)
if refreshChunks > 0 {
count += refreshChunks
}
return count
}
+48 -207
View File
@@ -5,234 +5,96 @@ import (
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"github.com/gorilla/sessions"
)
// Config holds the configuration for the OIDC middleware.
// It provides all necessary settings to configure OpenID Connect authentication
// with various providers like Auth0, Logto, or any standard OIDC provider.
type Config struct {
// ProviderURL is the base URL of the OIDC provider (required)
// Example: https://accounts.google.com
ProviderURL string `json:"providerURL"`
// RevocationURL is the endpoint for revoking tokens (optional)
// If not provided, it will be discovered from provider metadata
RevocationURL string `json:"revocationURL"`
// CallbackURL is the path where the OIDC provider will redirect after authentication (required)
// Example: /oauth2/callback
CallbackURL string `json:"callbackURL"`
// LogoutURL is the path for handling logout requests (optional)
// If not provided, it will be set to CallbackURL + "/logout"
LogoutURL string `json:"logoutURL"`
// ClientID is the OAuth 2.0 client identifier (required)
ClientID string `json:"clientID"`
// ClientSecret is the OAuth 2.0 client secret (required)
ClientSecret string `json:"clientSecret"`
// Scopes defines the OAuth 2.0 scopes to request (optional)
// Defaults to ["openid", "profile", "email"] if not provided
Scopes []string `json:"scopes"`
// LogLevel sets the logging verbosity (optional)
// Valid values: "debug", "info", "error"
// Default: "info"
LogLevel string `json:"logLevel"`
// SessionEncryptionKey is used to encrypt session data (required)
// Must be a secure random string
SessionEncryptionKey string `json:"sessionEncryptionKey"`
// ForceHTTPS forces the use of HTTPS for all URLs (optional)
// Default: false
ForceHTTPS bool `json:"forceHTTPS"`
// RateLimit sets the maximum number of requests per second (optional)
// Default: 100
RateLimit int `json:"rateLimit"`
// ExcludedURLs lists paths that bypass authentication (optional)
// Example: ["/health", "/metrics"]
ExcludedURLs []string `json:"excludedURLs"`
// AllowedUserDomains restricts access to specific email domains (optional)
// Example: ["company.com", "subsidiary.com"]
AllowedUserDomains []string `json:"allowedUserDomains"`
// AllowedRolesAndGroups restricts access to users with specific roles or groups (optional)
// Example: ["admin", "developer"]
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
// OIDCEndSessionURL is the provider's end session endpoint (optional)
// If not provided, it will be discovered from provider metadata
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
// PostLogoutRedirectURI is the URL to redirect to after logout (optional)
// Default: "/"
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
// HTTPClient allows customizing the HTTP client used for OIDC operations (optional)
HTTPClient *http.Client
}
const (
// DefaultRateLimit defines the default rate limit for requests per second
DefaultRateLimit = 100
// MinRateLimit defines the minimum allowed rate limit to prevent DOS
MinRateLimit = 10
// DefaultLogLevel defines the default logging level
DefaultLogLevel = "info"
// MinSessionEncryptionKeyLength defines the minimum length for session encryption key
MinSessionEncryptionKeyLength = 32
cookieName = "_raczylo_oidc"
)
// CreateConfig creates a new Config with secure default values.
// Default values are set for optional fields:
// - Scopes: ["openid", "profile", "email"]
// - LogLevel: "info"
// - LogoutURL: CallbackURL + "/logout"
// - RateLimit: 100 requests per second
// - PostLogoutRedirectURI: "/"
// - ForceHTTPS: true (for security)
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
CallbackURL string `json:"callbackURL"`
LogoutURL string `json:"logoutURL"`
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
Scopes []string `json:"scopes"`
LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHTTPS"`
RateLimit int `json:"rateLimit"`
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
HTTPClient *http.Client
}
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
func CreateConfig() *Config {
c := &Config{
Scopes: []string{"openid", "profile", "email"},
LogLevel: DefaultLogLevel,
RateLimit: DefaultRateLimit,
ForceHTTPS: true, // Secure by default
c := &Config{}
if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
}
if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout"
}
if c.RateLimit == 0 {
c.RateLimit = 100
}
return c
}
// Validate performs validation checks on the Config.
// It ensures all required fields are set and have valid values.
// Returns an error if any validation check fails.
func (c *Config) Validate() error {
// Validate provider URL
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if !isValidSecureURL(c.ProviderURL) {
return fmt.Errorf("providerURL must be a valid HTTPS URL")
}
// Validate callback URL
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if !strings.HasPrefix(c.CallbackURL, "/") {
return fmt.Errorf("callbackURL must start with /")
}
// Validate client credentials
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
// Validate session encryption key
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
if len(c.SessionEncryptionKey) < MinSessionEncryptionKeyLength {
return fmt.Errorf("sessionEncryptionKey must be at least %d characters long", MinSessionEncryptionKeyLength)
}
// Validate log level
if c.LogLevel != "" && !isValidLogLevel(c.LogLevel) {
return fmt.Errorf("logLevel must be one of: debug, info, error")
}
// Validate excluded URLs
for _, url := range c.ExcludedURLs {
if !strings.HasPrefix(url, "/") {
return fmt.Errorf("excluded URL must start with /: %s", url)
}
if strings.Contains(url, "..") {
return fmt.Errorf("excluded URL must not contain path traversal: %s", url)
}
if strings.Contains(url, "*") {
return fmt.Errorf("excluded URL must not contain wildcards: %s", url)
}
}
// Validate revocation URL if set
if c.RevocationURL != "" && !isValidSecureURL(c.RevocationURL) {
return fmt.Errorf("revocationURL must be a valid HTTPS URL")
}
// Validate end session URL if set
if c.OIDCEndSessionURL != "" && !isValidSecureURL(c.OIDCEndSessionURL) {
return fmt.Errorf("oidcEndSessionURL must be a valid HTTPS URL")
}
// Validate post-logout redirect URI if set
if c.PostLogoutRedirectURI != "" && c.PostLogoutRedirectURI != "/" {
if !isValidSecureURL(c.PostLogoutRedirectURI) && !strings.HasPrefix(c.PostLogoutRedirectURI, "/") {
return fmt.Errorf("postLogoutRedirectURI must be either a valid HTTPS URL or start with /")
}
}
// Validate rate limit
if c.RateLimit < MinRateLimit {
return fmt.Errorf("rateLimit must be at least %d", MinRateLimit)
}
return nil
}
// isValidSecureURL checks if the provided string is a valid HTTPS URL
func isValidSecureURL(s string) bool {
u, err := url.Parse(s)
return err == nil && u.Scheme == "https" && u.Host != ""
}
// isValidLogLevel checks if the provided log level is valid
func isValidLogLevel(level string) bool {
return level == "debug" || level == "info" || level == "error"
}
// Logger provides structured logging capabilities with different severity levels.
// It supports error, info, and debug levels with appropriate output streams
// and formatting for each level.
type Logger struct {
// logError handles error-level messages, writing to stderr
logError *log.Logger
// logInfo handles informational messages, writing to stdout
logInfo *log.Logger
// logDebug handles debug-level messages, writing to stdout when debug is enabled
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger with the specified log level.
// The log level determines which messages are output:
// - "debug": Outputs all messages (debug, info, error)
// - "info": Outputs info and error messages
// - "error": Outputs only error messages
// Error messages are always written to stderr, while info and debug
// messages are written to stdout when enabled.
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
logInfo.SetOutput(os.Stdout)
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
@@ -244,51 +106,30 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an informational message.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an informational message using Printf formatting.
// These messages are intended for general operational information
// and are written to stdout.
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message using Printf formatting.
// These messages are only output when debug level logging is enabled
// and are intended for detailed troubleshooting information.
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message using Printf formatting.
// These messages indicate problems that need attention and are
// always written to stderr regardless of the log level.
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to both the HTTP response and the error log.
// It ensures consistent error handling across the middleware by logging the error
// and sending an appropriate HTTP response to the client.
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
-397
View File
@@ -1,397 +0,0 @@
package traefikoidc
import (
"bytes"
"log"
"net/http"
"testing"
)
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
// Check default scopes
expectedScopes := []string{"openid", "profile", "email"}
if len(config.Scopes) != len(expectedScopes) {
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
}
for i, scope := range expectedScopes {
if config.Scopes[i] != scope {
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
}
}
// Check default log level
if config.LogLevel != DefaultLogLevel {
t.Errorf("Expected default log level '%s', got '%s'", DefaultLogLevel, config.LogLevel)
}
// Check default rate limit
if config.RateLimit != DefaultRateLimit {
t.Errorf("Expected default rate limit %d, got %d", DefaultRateLimit, config.RateLimit)
}
// Check ForceHTTPS default
if !config.ForceHTTPS {
t.Error("Expected ForceHTTPS to be true by default")
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
config.ForceHTTPS = false
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
if config.ForceHTTPS {
t.Error("Custom ForceHTTPS value was overwritten")
}
})
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
config *Config
expectedError string
}{
{
name: "Empty Config",
config: &Config{},
expectedError: "providerURL is required",
},
{
name: "Missing CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
},
expectedError: "callbackURL is required",
},
{
name: "Missing ClientID",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
},
expectedError: "clientID is required",
},
{
name: "Missing ClientSecret",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
},
expectedError: "clientSecret is required",
},
{
name: "Missing SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
},
expectedError: "sessionEncryptionKey is required",
},
{
name: "Non-HTTPS ProviderURL",
config: &Config{
ProviderURL: "http://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid HTTPS URL",
},
{
name: "Invalid CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "callback", // Missing leading slash
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "callbackURL must start with /",
},
{
name: "Short SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "short",
},
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Low RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: 5,
},
expectedError: "rateLimit must be at least 10",
},
{
name: "Invalid LogLevel",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "invalid",
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Non-HTTPS RevocationURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RevocationURL: "http://revoke.com",
},
expectedError: "revocationURL must be a valid HTTPS URL",
},
{
name: "Non-HTTPS OIDCEndSessionURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
OIDCEndSessionURL: "http://endsession.com",
},
expectedError: "oidcEndSessionURL must be a valid HTTPS URL",
},
{
name: "Valid Config",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
RevocationURL: "https://revoke.com",
OIDCEndSessionURL: "https://endsession.com",
},
expectedError: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
}
}
})
}
}
func TestLogger(t *testing.T) {
// Capture log output
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
name: "Debug Level",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut == "" {
t.Error("Expected debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Info Level",
logLevel: "info",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Error Level",
logLevel: "error",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut != "" {
t.Error("Did not expect info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Printf Methods",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debugf("debug %s", "formatted")
l.Infof("info %s", "formatted")
l.Errorf("error %s", "formatted")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
t.Error("Expected formatted debug message")
}
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
t.Error("Expected formatted info message")
}
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
t.Error("Expected formatted error message")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset buffers
debugBuf.Reset()
infoBuf.Reset()
errorBuf.Reset()
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
if tc.logLevel == "debug" {
logger.logDebug.SetOutput(&debugBuf)
}
// Run test
tc.testFunc(logger)
// Check results
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
})
}
}
func TestHandleError(t *testing.T) {
// Create a test logger with captured output
var errorBuf bytes.Buffer
logger := &Logger{
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
}
logger.logError.SetOutput(&errorBuf)
// Create a test response recorder
rr := &testResponseRecorder{
headers: make(map[string][]string),
}
// Test error handling
message := "test error message"
code := 400
handleError(rr, message, code, logger)
// Check response code
if rr.statusCode != code {
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
}
// Check response body
expectedBody := message + "\n"
if rr.body != expectedBody {
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
}
// Check error was logged
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
t.Error("Error message was not logged")
}
}
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
}
func (r *testResponseRecorder) Header() http.Header {
return r.headers
}
func (r *testResponseRecorder) Write(b []byte) (int, error) {
r.body = string(b)
return len(b), nil
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
+2 -2
View File
@@ -1,4 +1,4 @@
Copyright 2009 The Go Authors.
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
+14 -3
View File
@@ -99,9 +99,8 @@ func (lim *Limiter) Tokens() float64 {
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
tokens: float64(b),
limit: r,
burst: b,
}
}
@@ -345,6 +344,18 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
tokens: n,
timeToAct: t,
}
} else if lim.limit == 0 {
var ok bool
if lim.burst >= n {
ok = true
lim.burst -= n
}
return Reservation{
ok: ok,
lim: lim,
tokens: lim.burst,
timeToAct: t,
}
}
t, tokens := lim.advance(t)
+1 -1
View File
@@ -7,6 +7,6 @@ github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
# golang.org/x/time v0.5.0
## explicit; go 1.18
golang.org/x/time/rate