Compare commits

..

5 Commits

17 changed files with 591 additions and 2112 deletions
-1
View File
@@ -13,7 +13,6 @@ testData:
clientSecret: secret
callbackURL: /oauth2/callback
logoutURL: /oauth2/logout
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
scopes: # If not provided, default scopes will be used (openid, email, profile)
- openid
- email
-33
View File
@@ -4,10 +4,6 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
Middleware has been tested with Auth0 and Logto.
### Traefik version compatibility
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
### Configuration options
Middleware currently supports following scenarios:
@@ -19,35 +15,6 @@ Middleware currently supports following scenarios:
#### How to configure...
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```
##### Excluded URLs with open access
```
-69
View File
@@ -1,69 +0,0 @@
package traefikoidc
import (
"sync"
"time"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
}
// Cache is a simple in-memory cache
type Cache struct {
items map[string]CacheItem
mutex sync.RWMutex
}
// NewCache creates a new Cache
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
}
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
return nil, false
}
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
return nil, false
}
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
delete(c.items, key)
}
}
}
+2 -4
View File
@@ -1,13 +1,11 @@
module github.com/lukaszraczylo/traefikoidc
go 1.23
toolchain go1.23.1
go 1.22.2
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.7.0
golang.org/x/time v0.5.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
+2 -2
View File
@@ -6,5 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+154 -211
View File
@@ -6,27 +6,16 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
)
func newSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -36,8 +25,14 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -63,20 +58,14 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return &tokenResponse, nil
return result, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
@@ -85,35 +74,103 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
t.logger.Debugf("Token response: %+v", tokenResponse)
newAccessToken, ok := result["access_token"].(string)
if !ok || newAccessToken == "" {
return nil, fmt.Errorf("no access_token field in token response")
}
return tokenResponse, nil
rawIDToken, ok := result["id_token"].(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token field in token response")
}
newRefreshToken, ok := result["refresh_token"].(string)
if !ok || newRefreshToken == "" {
return nil, fmt.Errorf("no refresh_token field in token response")
}
response := &TokenResponse{
IDToken: rawIDToken,
AccessToken: newAccessToken,
ExpiresIn: int(result["expires_in"].(float64)),
TokenType: result["token_type"].(string),
}
// The refresh token might not be returned if it hasn't changed
if newRefreshToken != refreshToken {
response.RefreshToken = newRefreshToken
} else {
response.RefreshToken = refreshToken
}
t.logger.Debug("Token response: %+v", response)
return response, nil
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
err := t.RevokeTokenWithProvider(idToken)
if err != nil {
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
return
}
t.RevokeToken(idToken)
}
session.Options = defaultSessionOptions
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
http.Error(rw, "Logged out", http.StatusForbidden)
}
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
// Clear the existing session
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Failed to clear session: %v", err)
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
// Set new values
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initialize new authentication
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
session, err := t.sessionManager.GetSession(req)
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Session error: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
@@ -122,36 +179,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken := session.GetCSRF()
if csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Exchange code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -159,49 +186,27 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
token, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify and process tokens
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
idToken, ok := token["id_token"].(string)
if !ok || idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify nonce
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce := session.GetNonce()
if sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Process email
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -209,29 +214,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
return
}
// Update session with new values
session.SetAuthenticated(true)
session.SetEmail(email)
session.SetAccessToken(tokenResponse.IDToken)
session.SetRefreshToken(tokenResponse.RefreshToken)
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Options = defaultSessionOptions
// Save session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Redirect to original path or root
redirectPath := "/"
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
redirectPath = incomingPath
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
t.logger.Debugf("Authentication successful. User email: %s", email)
http.Redirect(rw, req, "/", http.StatusFound)
}
// extractClaims extracts claims from a JWT token
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -251,27 +248,28 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
type UsedTokens struct {
tokens map[string]bool
mutex sync.RWMutex
}
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -279,7 +277,6 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -291,127 +288,73 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache caches tokens
type TokenCache struct {
cache *Cache
cache map[string]*TokenInfo
mutex sync.RWMutex
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
// NewTokenCache creates a new TokenCache
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
cache: make(map[string]*TokenInfo),
}
}
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
tc.mutex.Lock()
defer tc.mutex.Unlock()
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
if !found {
return nil, false
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
tc.mutex.RLock()
defer tc.mutex.RUnlock()
info, exists := tc.cache[token]
if exists && time.Now().Before(info.ExpiresAt) {
return info, true
}
claims, ok := value.(map[string]interface{})
return claims, ok
return nil, false
}
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
tc.mutex.Lock()
defer tc.mutex.Unlock()
delete(tc.cache, token)
}
// Cleanup cleans up expired tokens from the cache
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
// handleLogout handles the logout request
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.sessionManager.GetSession(req)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the access token before clearing session
accessToken := session.GetAccessToken()
// Clear all session data
if err := session.Clear(req, rw); err != nil {
t.logger.Errorf("Error clearing session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Get the base URL for redirects
host := t.determineHost(req)
scheme := t.determineScheme(req)
baseURL := fmt.Sprintf("%s://%s", scheme, host)
// Determine post logout redirect URI
postLogoutRedirectURI := t.postLogoutRedirectURI
if postLogoutRedirectURI == "" {
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
}
// If we have an end session endpoint and an access token, use OIDC end session
if t.endSessionURL != "" && accessToken != "" {
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
if err != nil {
t.logger.Errorf("Failed to build logout URL: %v", err)
http.Error(rw, "Logout error", http.StatusInternalServerError)
return
tc.mutex.Lock()
defer tc.mutex.Unlock()
now := time.Now()
for token, info := range tc.cache {
if now.After(info.ExpiresAt) {
delete(tc.cache, token)
}
http.Redirect(rw, req, logoutURL, http.StatusFound)
return
}
// Otherwise, redirect to post logout URI
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
}
// BuildLogoutURL constructs the OIDC end session URL
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
u, err := url.Parse(endSessionURL)
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.clientID)
data.Set("client_secret", t.clientSecret)
data.Set("code", code)
data.Set("redirect_uri", t.redirectURL)
resp, err := t.httpClient.PostForm(t.tokenURL, data)
if err != nil {
return "", fmt.Errorf("failed to parse end session URL: %w", err)
return nil, fmt.Errorf("failed to exchange token: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %v", err)
}
q := u.Query()
q.Set("id_token_hint", idToken)
if postLogoutRedirectURI != "" {
// Ensure postLogoutRedirectURI is properly URL encoded
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
}
u.RawQuery = q.Encode()
return u.String(), nil
return result, nil
}
+59 -45
View File
@@ -4,19 +4,17 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
@@ -29,24 +27,20 @@ type JWK struct {
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
type JWKCache struct {
jwks *JWKSet
expiresAt time.Time
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -73,7 +67,6 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -93,61 +86,82 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
// jwkToPEM converts a JWK to PEM format
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return converter(jwk)
}
type jwkToPEMConverter func(*JWK) ([]byte, error)
var jwkConverters = map[string]jwkToPEMConverter{
"RSA": rsaJWKToPEM,
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
}
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
})
return pubKeyPEM, nil
return publicKeyPEM, nil
}
// ecJWKToPEM converts an EC JWK to PEM
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
@@ -164,21 +178,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
pubKey := &ecdsa.PublicKey{
publicKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
Bytes: publicKeyBytes,
})
return pubKeyPEM, nil
return publicKeyPEM, nil
}
+36 -104
View File
@@ -4,36 +4,29 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
Signature []byte
Token string
Signature string
}
// parseJWT parses a JWT token string into a JWT struct
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
jwt := &JWT{
Token: tokenString,
}
jwt := &JWT{}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -53,17 +46,12 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
// Set the signature
jwt.Signature = parts[2]
return jwt, nil
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
@@ -107,39 +95,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
@@ -148,43 +103,17 @@ func verifyExpiration(expiration float64) error {
return nil
}
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
// verifySignature verifies the token signature using the provided public key and algorithm
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
@@ -198,42 +127,45 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pub, hashed, r, s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
return fmt.Errorf("unsupported public key type: %T", pub)
}
}
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
+227 -267
View File
@@ -2,6 +2,7 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -10,29 +11,27 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/sessions"
"golang.org/x/time/rate"
)
const ConstSessionTimeout = 86400 // Session timeout in seconds
const ConstSessionTimeout = 86400
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// JWTVerifier interface for JWT verification
type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
name string
store sessions.Store
redirURLPath string
logoutURLPath string
issuerURL string
@@ -51,107 +50,81 @@ type TraefikOidc struct {
tokenCache *TokenCache
httpClient *http.Client
logger *Logger
redirectURL string
tokenVerifier TokenVerifier
jwtVerifier JWTVerifier
excludedURLs map[string]struct{}
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initComplete chan struct{}
endSessionURL string
baseURL string
postLogoutRedirectURI string
sessionManager *SessionManager
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
EndSessionURL string `json:"end_session_endpoint"`
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
var (
globalMetadataCache struct {
sync.Once
metadata *ProviderMetadata
err error
}
)
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token")
// Rate limiting
t.logger.Debugf("Verifying token: %s", token)
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
t.tokenCache.Set(token, jwt.Claims, duration)
t.tokenCache.Set(token, expirationTime)
return nil
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT signature and claims")
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Retrieve key ID and algorithm from JWT header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
// Find the matching key in JWKS
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
@@ -159,32 +132,48 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
// Convert JWK to PEM format
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
// Verify the signature
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
// Setup HTTP client
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -195,11 +184,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
return dialer.DialContext(ctx, network, addr)
},
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 0,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConnsPerHost: 10,
}
var httpClient *http.Client
@@ -212,9 +201,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
}
metadata, err := discoverProviderMetadata(config.ProviderURL, httpClient, NewLogger(config.LogLevel))
if err != nil {
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
}
t := &TraefikOidc{
next: next,
name: name,
store: store,
redirURLPath: config.CallbackURL,
logoutURLPath: func() string {
if config.LogoutURL == "" {
@@ -222,36 +217,50 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
postLogoutRedirectURI: func() string {
if config.PostLogoutRedirectURI == "" {
return "/"
issuerURL: metadata.Issuer,
revocationURL: metadata.RevokeURL,
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
jwksURL: metadata.JWKSURL,
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
authURL: metadata.AuthURL,
tokenURL: metadata.TokenURL,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: func() map[string]struct{} {
m := make(map[string]struct{})
for _, url := range config.ExcludedURLs {
m[url] = struct{}{}
}
return config.PostLogoutRedirectURI
return m
}(),
redirectURL: "",
allowedUserDomains: func() map[string]struct{} {
m := make(map[string]struct{})
for _, domain := range config.AllowedUserDomains {
m[domain] = struct{}{}
}
return m
}(),
allowedRolesAndGroups: func() map[string]struct{} {
m := make(map[string]struct{})
for _, roleOrGroup := range config.AllowedRolesAndGroups {
m[roleOrGroup] = struct{}{}
}
return m
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
t.extractClaimsFunc = extractClaims
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
t.extractClaimsFunc = extractClaims
// Add default excluded URLs
// add defaultExcludedURLs to excludedURLs
for k, v := range defaultExcludedURLs {
t.excludedURLs[k] = v
}
@@ -259,36 +268,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.tokenVerifier = t
t.jwtVerifier = t
t.startTokenCleanup()
go t.initializeMetadata(config.ProviderURL)
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
globalMetadataCache.Once.Do(func() {
t.logger.Debug("Starting global provider metadata discovery")
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
globalMetadataCache.metadata = metadata
globalMetadataCache.err = err
})
if globalMetadataCache.err != nil {
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
} else if globalMetadataCache.metadata != nil {
t.logger.Debug("Using cached provider metadata")
t.jwksURL = globalMetadataCache.metadata.JWKSURL
t.authURL = globalMetadataCache.metadata.AuthURL
t.tokenURL = globalMetadataCache.metadata.TokenURL
t.issuerURL = globalMetadataCache.metadata.Issuer
t.revocationURL = globalMetadataCache.metadata.RevokeURL
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
}
close(t.initComplete)
}
// discoverProviderMetadata fetches the OIDC provider metadata
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -302,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if time.Since(start) > totalTimeout {
l.Errorf("Timeout exceeded while fetching provider metadata")
l.Error("Timeout exceeded while fetching provider metadata")
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
}
@@ -314,20 +296,18 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
lastErr = err
// Exponential backoff
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
if delay > maxDelay {
delay = maxDelay
}
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
time.Sleep(delay)
}
l.Errorf("Max retries exceeded while fetching provider metadata")
l.Error("Max retries exceeded while fetching provider metadata")
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -350,134 +330,138 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
if t.issuerURL == "" {
t.logger.Debug("OIDC middleware not yet initialized")
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
return
}
case <-req.Context().Done():
t.logger.Debug("Request cancelled")
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
return
}
// Check if URL is excluded
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Get session
session, err := t.sessionManager.GetSession(req)
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
// Build redirect URL
scheme := t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle special URLs
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req, redirectURL)
t.handleCallback(rw, req)
return
}
// Check authentication status
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
if !authenticated {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if needsRefresh {
refreshed := t.refreshToken(rw, req, session)
if !refreshed {
t.handleExpiredToken(rw, req, session, redirectURL)
t.handleExpiredToken(rw, req, session)
return
}
}
// Process authenticated request
email := session.GetEmail()
if email == "" {
t.logger.Debug("No email found in session")
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
// authenticated, _ := session.Values["authenticated"].(bool)
if authenticated {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
// Check allowed roles and groups
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
return
}
// Set user information in headers
req.Header.Set("X-Forwarded-User", email)
// Process the request
t.next.ServeHTTP(rw, req)
t.logger.Debug("User is not authenticated, initiating authentication")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
t.logger.Debug("URL is not excluded - got %s", currentRequest)
return false
}
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
@@ -491,7 +475,6 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
return "http"
}
// determineHost determines the host of the request
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -499,35 +482,37 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
if !session.GetAuthenticated() {
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
if !authenticated {
t.logger.Debug("User is not authenticated according to session")
return false, false, false
}
accessToken := session.GetAccessToken()
if accessToken == "" {
t.logger.Debug("No access token found in session")
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Debug("No id_token found in session")
return false, false, true // Session is invalid, consider it expired
}
// Verify the token
if err := t.verifyToken(accessToken); err != nil {
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Token verification failed: %v", err)
return false, false, true // Token is invalid, consider it expired
}
claims, err := extractClaims(accessToken)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
return false, false, true
return false, false, true // Can't read claims, consider it expired
}
expClaim, ok := claims["exp"].(float64)
if !ok {
t.logger.Error("Failed to get expiration time from claims")
return false, false, true
t.logger.Errorf("Failed to get expiration time from claims")
return false, false, true // No expiration, consider it expired
}
now := time.Now().Unix()
@@ -535,7 +520,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
if now > expTime {
t.logger.Debug("Token has expired")
return false, false, true
return false, false, true // Token has expired
}
gracePeriod := time.Minute * 5
@@ -544,42 +529,38 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
return true, true, false // Token will expire soon, needs refresh
}
return true, false, false
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
// Generate CSRF token and nonce
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return
}
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Set session values
session.SetCSRF(csrfToken)
session.SetNonce(nonce)
session.SetIncomingPath(req.URL.Path)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build and redirect to auth URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -593,7 +574,6 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
go func() {
@@ -605,23 +585,26 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
// Add to blacklist with default expiration
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
t.tokenBlacklist.Add(token, expiry)
// Add to blacklist
claims, err := extractClaims(token)
if err == nil {
if exp, ok := claims["exp"].(float64); ok {
expTime := time.Unix(int64(exp), 0)
t.tokenBlacklist.Add(token, expTime)
}
}
}
// RevokeTokenWithProvider revokes the token with the provider
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
t.logger.Debugf("Revoking token with provider")
data := url.Values{
"token": {token},
"token_type_hint": {tokenType},
"token_type_hint": {"access_token", "refresh_token"},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
@@ -652,12 +635,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
return nil
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken := session.GetRefreshToken()
if refreshToken == "" {
t.logger.Debug("No refresh token found in session")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
return false
}
@@ -667,17 +648,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new access token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new access token: %v", err)
return false
}
// Update session with new tokens
session.SetAccessToken(newToken.IDToken)
session.SetRefreshToken(newToken.RefreshToken)
// Save the session
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save refreshed session: %v", err)
return false
@@ -686,7 +659,6 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
@@ -702,7 +674,6 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
@@ -712,48 +683,37 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
var groups []string
var roles []string
// Extract groups with type checking
if groupsClaim, exists := claims["groups"]; exists {
groupsSlice, ok := groupsClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("groups claim is not an array")
}
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
}
// Extract roles with type checking
if rolesClaim, exists := claims["roles"]; exists {
rolesSlice, ok := rolesClaim.([]interface{})
if !ok {
return nil, nil, fmt.Errorf("roles claim is not an array")
}
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
if len(groups) == 0 {
t.logger.Debug("No groups found in groups claim, checking roles claim")
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debug("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
if len(roles) == 0 {
t.logger.Debug("No roles found in roles claim")
}
return groups, roles, nil
}
// buildFullURL constructs a full URL from scheme, host and path
func buildFullURL(scheme, host, path string) string {
// If the path is already a full URL, return it as-is
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return path
}
// Ensure the path starts with a forward slash
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
-57
View File
@@ -1,57 +0,0 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
func BenchmarkOIDCMiddleware(b *testing.B) {
// Setup test environment
ts := &TestSuite{}
ts.Setup()
ts.token = "valid.jwt.token"
// Define the handler with OIDC middleware
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create test server
server := httptest.NewServer(ts.tOidc.next)
defer server.Close()
// Prepare HTTP client
client := &http.Client{}
// Reset timer to exclude setup time
b.ResetTimer()
// Run benchmark
for i := 0; i < b.N; i++ {
// Create new request
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
b.Fatal(err)
}
// Set necessary headers or cookies
req.Header.Set("Authorization", "Bearer "+ts.token)
// Send the request
resp, err := client.Do(req)
if err != nil {
b.Fatal(err)
}
// Close response body
resp.Body.Close()
// Check response status code
if resp.StatusCode != http.StatusOK {
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
}
}
}
+84 -1043
View File
File diff suppressed because it is too large Load Diff
-196
View File
@@ -1,196 +0,0 @@
package traefikoidc
import (
"fmt"
"net/http"
"strings"
"github.com/gorilla/sessions"
)
const (
mainCookieName = "_raczylo_oidc" // Main session cookie
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
)
// SessionManager handles multiple session cookies
type SessionManager struct {
store sessions.Store
forceHTTPS bool
logger *Logger
}
// NewSessionManager creates a new session manager
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
return &SessionManager{
store: sessions.NewCookieStore([]byte(encryptionKey)),
forceHTTPS: forceHTTPS,
logger: logger,
}
}
// getSessionOptions returns session options based on scheme
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
return &sessions.Options{
HttpOnly: true,
Secure: isSecure || sm.forceHTTPS,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
}
// GetSession retrieves all session data
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
mainSession, err := sm.store.Get(r, mainCookieName)
if err != nil {
return nil, fmt.Errorf("failed to get main session: %w", err)
}
accessSession, err := sm.store.Get(r, accessTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get access token session: %w", err)
}
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
if err != nil {
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
}
sessionData := &SessionData{
manager: sm,
mainSession: mainSession,
accessSession: accessSession,
refreshSession: refreshSession,
}
return sessionData, nil
}
// SessionData holds all session information
type SessionData struct {
manager *SessionManager
mainSession *sessions.Session
accessSession *sessions.Session
refreshSession *sessions.Session
}
// Save saves all session data
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
// Set options for all sessions
sd.mainSession.Options = sd.manager.getSessionOptions(isSecure)
sd.accessSession.Options = sd.manager.getSessionOptions(isSecure)
sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure)
if err := sd.mainSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save main session: %w", err)
}
if err := sd.accessSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save access token session: %w", err)
}
if err := sd.refreshSession.Save(r, w); err != nil {
return fmt.Errorf("failed to save refresh token session: %w", err)
}
return nil
}
// Clear clears all session data
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
// Clear and expire all sessions
sd.mainSession.Options.MaxAge = -1
sd.accessSession.Options.MaxAge = -1
sd.refreshSession.Options.MaxAge = -1
for k := range sd.mainSession.Values {
delete(sd.mainSession.Values, k)
}
for k := range sd.accessSession.Values {
delete(sd.accessSession.Values, k)
}
for k := range sd.refreshSession.Values {
delete(sd.refreshSession.Values, k)
}
return sd.Save(r, w)
}
// GetAuthenticated returns authentication status
func (sd *SessionData) GetAuthenticated() bool {
auth, _ := sd.mainSession.Values["authenticated"].(bool)
return auth
}
// SetAuthenticated sets authentication status
func (sd *SessionData) SetAuthenticated(value bool) {
sd.mainSession.Values["authenticated"] = value
}
// GetAccessToken returns the access token
func (sd *SessionData) GetAccessToken() string {
token, _ := sd.accessSession.Values["token"].(string)
return token
}
// SetAccessToken sets the access token
func (sd *SessionData) SetAccessToken(token string) {
sd.accessSession.Values["token"] = token
}
// GetRefreshToken returns the refresh token
func (sd *SessionData) GetRefreshToken() string {
token, _ := sd.refreshSession.Values["token"].(string)
return token
}
// SetRefreshToken sets the refresh token
func (sd *SessionData) SetRefreshToken(token string) {
sd.refreshSession.Values["token"] = token
}
// GetCSRF returns the CSRF token
func (sd *SessionData) GetCSRF() string {
csrf, _ := sd.mainSession.Values["csrf"].(string)
return csrf
}
// SetCSRF sets the CSRF token
func (sd *SessionData) SetCSRF(token string) {
sd.mainSession.Values["csrf"] = token
}
// GetNonce returns the nonce
func (sd *SessionData) GetNonce() string {
nonce, _ := sd.mainSession.Values["nonce"].(string)
return nonce
}
// SetNonce sets the nonce
func (sd *SessionData) SetNonce(nonce string) {
sd.mainSession.Values["nonce"] = nonce
}
// GetEmail returns the user's email
func (sd *SessionData) GetEmail() string {
email, _ := sd.mainSession.Values["email"].(string)
return email
}
// SetEmail sets the user's email
func (sd *SessionData) SetEmail(email string) {
sd.mainSession.Values["email"] = email
}
// GetIncomingPath returns the original incoming path
func (sd *SessionData) GetIncomingPath() string {
path, _ := sd.mainSession.Values["incoming_path"].(string)
return path
}
// SetIncomingPath sets the original incoming path
func (sd *SessionData) SetIncomingPath(path string) {
sd.mainSession.Values["incoming_path"] = path
}
-60
View File
@@ -1,60 +0,0 @@
package traefikoidc
import (
"net/http/httptest"
"testing"
)
func TestSessionManager(t *testing.T) {
logger := NewLogger("info")
manager := NewSessionManager("test-secret-key", false, logger)
req := httptest.NewRequest("GET", "/test", nil)
rr := httptest.NewRecorder()
session, err := manager.GetSession(req)
if err != nil {
t.Fatalf("Failed to get session: %v", err)
}
// Test setting and getting values
session.SetAuthenticated(true)
session.SetEmail("test@example.com")
session.SetAccessToken("test.access.token")
session.SetRefreshToken("test.refresh.token")
if err := session.Save(req, rr); err != nil {
t.Fatalf("Failed to save session: %v", err)
}
// Verify cookies are set
cookies := rr.Result().Cookies()
if len(cookies) != 3 {
t.Errorf("Expected 3 cookies, got %d", len(cookies))
}
// Create a new request with the cookies
newReq := httptest.NewRequest("GET", "/test", nil)
for _, cookie := range cookies {
newReq.AddCookie(cookie)
}
// Get the session again and verify values
newSession, err := manager.GetSession(newReq)
if err != nil {
t.Fatalf("Failed to get new session: %v", err)
}
if !newSession.GetAuthenticated() {
t.Error("Authentication status not preserved")
}
if email := newSession.GetEmail(); email != "test@example.com" {
t.Errorf("Expected email test@example.com, got %s", email)
}
if token := newSession.GetAccessToken(); token != "test.access.token" {
t.Errorf("Expected access token test.access.token, got %s", token)
}
if token := newSession.GetRefreshToken(); token != "test.refresh.token" {
t.Errorf("Expected refresh token test.refresh.token, got %s", token)
}
}
+10 -14
View File
@@ -6,13 +6,14 @@ import (
"log"
"net/http"
"os"
"github.com/gorilla/sessions"
)
const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
@@ -28,12 +29,17 @@ type Config struct {
ExcludedURLs []string `json:"excludedURLs"`
AllowedUserDomains []string `json:"allowedUserDomains"`
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
HTTPClient *http.Client
}
// CreateConfig creates a new Config with default values
var defaultSessionOptions = &sessions.Options{
HttpOnly: true,
Secure: false,
SameSite: http.SameSiteLaxMode,
MaxAge: ConstSessionTimeout,
Path: "/",
}
func CreateConfig() *Config {
c := &Config{}
@@ -56,7 +62,6 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
@@ -76,14 +81,12 @@ func (c *Config) Validate() error {
return nil
}
// Logger is a simple logger with different levels
type Logger struct {
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
@@ -103,37 +106,30 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
+2 -2
View File
@@ -1,4 +1,4 @@
Copyright 2009 The Go Authors.
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google LLC nor the names of its
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
+14 -3
View File
@@ -99,9 +99,8 @@ func (lim *Limiter) Tokens() float64 {
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
tokens: float64(b),
limit: r,
burst: b,
}
}
@@ -345,6 +344,18 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
tokens: n,
timeToAct: t,
}
} else if lim.limit == 0 {
var ok bool
if lim.burst >= n {
ok = true
lim.burst -= n
}
return Reservation{
ok: ok,
lim: lim,
tokens: lim.burst,
timeToAct: t,
}
}
t, tokens := lim.advance(t)
+1 -1
View File
@@ -7,6 +7,6 @@ github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
github.com/gorilla/sessions
# golang.org/x/time v0.7.0
# golang.org/x/time v0.5.0
## explicit; go 1.18
golang.org/x/time/rate