mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 24ecf00053 | |||
| e338992f84 | |||
| a9a596031b | |||
| 23afcad2ba | |||
| d06f9fcf90 |
@@ -13,7 +13,6 @@ testData:
|
||||
clientSecret: secret
|
||||
callbackURL: /oauth2/callback
|
||||
logoutURL: /oauth2/logout
|
||||
postLogoutRedirectURI: /oidc/different-logout # If not provided it will redirect to the "/" URL
|
||||
scopes: # If not provided, default scopes will be used (openid, email, profile)
|
||||
- openid
|
||||
- email
|
||||
|
||||
@@ -4,10 +4,6 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
|
||||
|
||||
Middleware has been tested with Auth0 and Logto.
|
||||
|
||||
### Traefik version compatibility
|
||||
|
||||
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
|
||||
|
||||
### Configuration options
|
||||
|
||||
Middleware currently supports following scenarios:
|
||||
@@ -19,35 +15,6 @@ Middleware currently supports following scenarios:
|
||||
|
||||
#### How to configure...
|
||||
|
||||
##### Keeping secrets secret
|
||||
|
||||
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
|
||||
|
||||
```
|
||||
apiVersion: traefik.io/v1alpha1
|
||||
kind: Middleware
|
||||
metadata:
|
||||
name: oidc-with-open-urls
|
||||
namespace: traefik
|
||||
spec:
|
||||
plugin:
|
||||
traefikoidc:
|
||||
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
|
||||
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
|
||||
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
|
||||
sessionEncryptionKey: vvv
|
||||
callbackURL: /cool-oidc/callback
|
||||
logoutURL: /cool-oidc/logout
|
||||
postLogoutRedirectURI: /my-website/you-have-logged-out # Optional post logout URL redirection
|
||||
scopes:
|
||||
- openid
|
||||
- email
|
||||
- profile
|
||||
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
|
||||
- /login # covers /login, /login/me, /login/reminder etc.
|
||||
- /my-public-data
|
||||
```
|
||||
|
||||
##### Excluded URLs with open access
|
||||
|
||||
```
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CacheItem represents an item in the cache
|
||||
type CacheItem struct {
|
||||
Value interface{}
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache is a simple in-memory cache
|
||||
type Cache struct {
|
||||
items map[string]CacheItem
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCache creates a new Cache
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
items: make(map[string]CacheItem),
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds an item to the cache
|
||||
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
c.items[key] = CacheItem{
|
||||
Value: value,
|
||||
ExpiresAt: time.Now().Add(expiration),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves an item from the cache
|
||||
func (c *Cache) Get(key string) (interface{}, bool) {
|
||||
c.mutex.RLock()
|
||||
defer c.mutex.RUnlock()
|
||||
item, found := c.items[key]
|
||||
if !found {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
return nil, false
|
||||
}
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Delete removes an item from the cache
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
delete(c.items, key)
|
||||
}
|
||||
|
||||
// Cleanup removes expired items from the cache
|
||||
func (c *Cache) Cleanup() {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for key, item := range c.items {
|
||||
if now.After(item.ExpiresAt) {
|
||||
delete(c.items, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,11 @@
|
||||
module github.com/lukaszraczylo/traefikoidc
|
||||
|
||||
go 1.23
|
||||
|
||||
toolchain go1.23.1
|
||||
go 1.22.2
|
||||
|
||||
require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/sessions v1.3.0
|
||||
golang.org/x/time v0.7.0
|
||||
golang.org/x/time v0.5.0
|
||||
)
|
||||
|
||||
require github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
|
||||
@@ -6,5 +6,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
|
||||
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
|
||||
+154
-211
@@ -6,27 +6,16 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func newSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// generateNonce generates a random nonce
|
||||
func generateNonce() (string, error) {
|
||||
nonceBytes := make([]byte, 32)
|
||||
_, err := rand.Read(nonceBytes)
|
||||
@@ -36,8 +25,14 @@ func generateNonce() (string, error) {
|
||||
return base64.URLEncoding.EncodeToString(nonceBytes), nil
|
||||
}
|
||||
|
||||
// exchangeTokens exchanges a code or refresh token for tokens
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
if scheme == "" {
|
||||
scheme = "http"
|
||||
}
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
|
||||
data := url.Values{
|
||||
"grant_type": {grantType},
|
||||
"client_id": {t.clientID},
|
||||
@@ -63,20 +58,14 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResponse TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
return &tokenResponse, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TokenResponse represents the response from the token endpoint
|
||||
type TokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
@@ -85,35 +74,103 @@ type TokenResponse struct {
|
||||
TokenType string `json:"token_type"`
|
||||
}
|
||||
|
||||
// getNewTokenWithRefreshToken refreshes the token using the refresh token
|
||||
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token: %w", err)
|
||||
}
|
||||
|
||||
t.logger.Debugf("Token response: %+v", tokenResponse)
|
||||
newAccessToken, ok := result["access_token"].(string)
|
||||
if !ok || newAccessToken == "" {
|
||||
return nil, fmt.Errorf("no access_token field in token response")
|
||||
}
|
||||
|
||||
return tokenResponse, nil
|
||||
rawIDToken, ok := result["id_token"].(string)
|
||||
if !ok || rawIDToken == "" {
|
||||
return nil, fmt.Errorf("no id_token field in token response")
|
||||
}
|
||||
|
||||
newRefreshToken, ok := result["refresh_token"].(string)
|
||||
if !ok || newRefreshToken == "" {
|
||||
return nil, fmt.Errorf("no refresh_token field in token response")
|
||||
}
|
||||
|
||||
response := &TokenResponse{
|
||||
IDToken: rawIDToken,
|
||||
AccessToken: newAccessToken,
|
||||
ExpiresIn: int(result["expires_in"].(float64)),
|
||||
TokenType: result["token_type"].(string),
|
||||
}
|
||||
|
||||
// The refresh token might not be returned if it hasn't changed
|
||||
if newRefreshToken != refreshToken {
|
||||
response.RefreshToken = newRefreshToken
|
||||
} else {
|
||||
response.RefreshToken = refreshToken
|
||||
}
|
||||
|
||||
t.logger.Debug("Token response: %+v", response)
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// handleExpiredToken handles the case when a token has expired
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
t.logger.Debugf("Logging out user")
|
||||
if err != nil {
|
||||
handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken, ok := session.Values["id_token"].(string); ok {
|
||||
err := t.RevokeTokenWithProvider(idToken)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
t.RevokeToken(idToken)
|
||||
}
|
||||
|
||||
session.Options = defaultSessionOptions
|
||||
// Clear the session
|
||||
session.Options.MaxAge = -1
|
||||
session.Values = make(map[interface{}]interface{})
|
||||
err = session.Save(req, rw)
|
||||
if err != nil {
|
||||
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(rw, "Logged out", http.StatusForbidden)
|
||||
}
|
||||
|
||||
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
|
||||
// Clear the existing session
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to clear session: %v", err)
|
||||
session.Options.MaxAge = -1
|
||||
for k := range session.Values {
|
||||
delete(session.Values, k)
|
||||
}
|
||||
|
||||
// Set new values
|
||||
session.Values["csrf"] = uuid.New().String()
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Values["nonce"], _ = generateNonce()
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Save the session before initiating authentication
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Initialize new authentication
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
// Initiate a new authentication flow
|
||||
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
|
||||
}
|
||||
|
||||
// handleCallback handles the callback from the OIDC provider
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request, redirectURL string) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Session error: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
@@ -122,36 +179,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
|
||||
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
|
||||
|
||||
// Check for errors in the query parameters
|
||||
if req.URL.Query().Get("error") != "" {
|
||||
errorDescription := req.URL.Query().Get("error_description")
|
||||
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
|
||||
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state parameter matches the session's CSRF token
|
||||
state := req.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
t.logger.Error("No state in callback")
|
||||
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
csrfToken := session.GetCSRF()
|
||||
if csrfToken == "" {
|
||||
t.logger.Error("CSRF token missing in session")
|
||||
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if state != csrfToken {
|
||||
t.logger.Error("State parameter does not match CSRF token in session")
|
||||
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
code := req.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.logger.Error("No code in callback")
|
||||
@@ -159,49 +186,27 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := t.exchangeCodeForTokenFunc(code, redirectURL)
|
||||
token, err := t.exchangeCodeForTokenFunc(code)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to exchange code for token: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify and process tokens
|
||||
if err := t.verifyToken(tokenResponse.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify id_token: %v", err)
|
||||
idToken, ok := token["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Error("No id_token in token response")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := t.extractClaimsFunc(tokenResponse.IDToken)
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce
|
||||
nonceClaim, ok := claims["nonce"].(string)
|
||||
if !ok || nonceClaim == "" {
|
||||
t.logger.Error("Nonce claim missing in id_token")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sessionNonce := session.GetNonce()
|
||||
if sessionNonce == "" {
|
||||
t.logger.Error("Nonce not found in session")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if nonceClaim != sessionNonce {
|
||||
t.logger.Error("Nonce claim does not match session nonce")
|
||||
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Process email
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" || !t.isAllowedDomain(email) {
|
||||
t.logger.Errorf("Invalid or disallowed email: %s", email)
|
||||
@@ -209,29 +214,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request,
|
||||
return
|
||||
}
|
||||
|
||||
// Update session with new values
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail(email)
|
||||
session.SetAccessToken(tokenResponse.IDToken)
|
||||
session.SetRefreshToken(tokenResponse.RefreshToken)
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["email"] = email
|
||||
session.Values["id_token"] = idToken
|
||||
session.Options = defaultSessionOptions
|
||||
|
||||
// Save session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to original path or root
|
||||
redirectPath := "/"
|
||||
if incomingPath := session.GetIncomingPath(); incomingPath != "" && incomingPath != t.redirURLPath {
|
||||
redirectPath = incomingPath
|
||||
}
|
||||
|
||||
http.Redirect(rw, req, redirectPath, http.StatusFound)
|
||||
t.logger.Debugf("Authentication successful. User email: %s", email)
|
||||
http.Redirect(rw, req, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
// extractClaims extracts claims from a JWT token
|
||||
func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
@@ -251,27 +248,28 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// TokenBlacklist maintains a blacklist of tokens
|
||||
type UsedTokens struct {
|
||||
tokens map[string]bool
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type TokenBlacklist struct {
|
||||
blacklist map[string]time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewTokenBlacklist creates a new TokenBlacklist
|
||||
func NewTokenBlacklist() *TokenBlacklist {
|
||||
return &TokenBlacklist{
|
||||
blacklist: make(map[string]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a token to the blacklist
|
||||
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
tb.blacklist[tokenID] = expiration
|
||||
}
|
||||
|
||||
// IsBlacklisted checks if a token is blacklisted
|
||||
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
tb.mutex.RLock()
|
||||
defer tb.mutex.RUnlock()
|
||||
@@ -279,7 +277,6 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
|
||||
return exists && time.Now().Before(expiration)
|
||||
}
|
||||
|
||||
// Cleanup removes expired tokens from the blacklist
|
||||
func (tb *TokenBlacklist) Cleanup() {
|
||||
tb.mutex.Lock()
|
||||
defer tb.mutex.Unlock()
|
||||
@@ -291,127 +288,73 @@ func (tb *TokenBlacklist) Cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCache caches tokens
|
||||
type TokenCache struct {
|
||||
cache *Cache
|
||||
cache map[string]*TokenInfo
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
type TokenInfo struct {
|
||||
Token string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// NewTokenCache creates a new TokenCache
|
||||
func NewTokenCache() *TokenCache {
|
||||
return &TokenCache{
|
||||
cache: NewCache(),
|
||||
cache: make(map[string]*TokenInfo),
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets a token in the cache
|
||||
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
|
||||
token = "t-" + token
|
||||
tc.cache.Set(token, claims, expiration)
|
||||
func (tc *TokenCache) Set(token string, expiresAt time.Time) {
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
|
||||
}
|
||||
|
||||
// Get retrieves a token from the cache
|
||||
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
|
||||
token = "t-" + token
|
||||
value, found := tc.cache.Get(token)
|
||||
if !found {
|
||||
return nil, false
|
||||
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
|
||||
tc.mutex.RLock()
|
||||
defer tc.mutex.RUnlock()
|
||||
info, exists := tc.cache[token]
|
||||
if exists && time.Now().Before(info.ExpiresAt) {
|
||||
return info, true
|
||||
}
|
||||
claims, ok := value.(map[string]interface{})
|
||||
return claims, ok
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Delete removes a token from the cache
|
||||
func (tc *TokenCache) Delete(token string) {
|
||||
token = "t-" + token
|
||||
tc.cache.Delete(token)
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
delete(tc.cache, token)
|
||||
}
|
||||
|
||||
// Cleanup cleans up expired tokens from the cache
|
||||
func (tc *TokenCache) Cleanup() {
|
||||
tc.cache.Cleanup()
|
||||
}
|
||||
|
||||
// exchangeCodeForToken exchanges the authorization code for tokens
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string, redirectURL string) (*TokenResponse, error) {
|
||||
ctx := context.Background()
|
||||
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, redirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
|
||||
}
|
||||
return tokenResponse, nil
|
||||
}
|
||||
|
||||
// createStringMap creates a map from a slice of strings
|
||||
func createStringMap(keys []string) map[string]struct{} {
|
||||
result := make(map[string]struct{})
|
||||
for _, key := range keys {
|
||||
result[key] = struct{}{}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// handleLogout handles the logout request
|
||||
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the access token before clearing session
|
||||
accessToken := session.GetAccessToken()
|
||||
|
||||
// Clear all session data
|
||||
if err := session.Clear(req, rw); err != nil {
|
||||
t.logger.Errorf("Error clearing session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Get the base URL for redirects
|
||||
host := t.determineHost(req)
|
||||
scheme := t.determineScheme(req)
|
||||
baseURL := fmt.Sprintf("%s://%s", scheme, host)
|
||||
|
||||
// Determine post logout redirect URI
|
||||
postLogoutRedirectURI := t.postLogoutRedirectURI
|
||||
if postLogoutRedirectURI == "" {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s/", baseURL)
|
||||
} else if !strings.HasPrefix(postLogoutRedirectURI, "http") {
|
||||
postLogoutRedirectURI = fmt.Sprintf("%s%s", baseURL, postLogoutRedirectURI)
|
||||
}
|
||||
|
||||
// If we have an end session endpoint and an access token, use OIDC end session
|
||||
if t.endSessionURL != "" && accessToken != "" {
|
||||
logoutURL, err := BuildLogoutURL(t.endSessionURL, accessToken, postLogoutRedirectURI)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to build logout URL: %v", err)
|
||||
http.Error(rw, "Logout error", http.StatusInternalServerError)
|
||||
return
|
||||
tc.mutex.Lock()
|
||||
defer tc.mutex.Unlock()
|
||||
now := time.Now()
|
||||
for token, info := range tc.cache {
|
||||
if now.After(info.ExpiresAt) {
|
||||
delete(tc.cache, token)
|
||||
}
|
||||
http.Redirect(rw, req, logoutURL, http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, redirect to post logout URI
|
||||
http.Redirect(rw, req, postLogoutRedirectURI, http.StatusFound)
|
||||
}
|
||||
|
||||
// BuildLogoutURL constructs the OIDC end session URL
|
||||
func BuildLogoutURL(endSessionURL, idToken, postLogoutRedirectURI string) (string, error) {
|
||||
u, err := url.Parse(endSessionURL)
|
||||
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "authorization_code")
|
||||
data.Set("client_id", t.clientID)
|
||||
data.Set("client_secret", t.clientSecret)
|
||||
data.Set("code", code)
|
||||
data.Set("redirect_uri", t.redirectURL)
|
||||
|
||||
resp, err := t.httpClient.PostForm(t.tokenURL, data)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse end session URL: %w", err)
|
||||
return nil, fmt.Errorf("failed to exchange token: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode token response: %v", err)
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
q.Set("id_token_hint", idToken)
|
||||
if postLogoutRedirectURI != "" {
|
||||
// Ensure postLogoutRedirectURI is properly URL encoded
|
||||
q.Set("post_logout_redirect_uri", postLogoutRedirectURI)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
return u.String(), nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -4,19 +4,17 @@ import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWK represents a JSON Web Key
|
||||
type JWK struct {
|
||||
Kty string `json:"kty"`
|
||||
Kid string `json:"kid"`
|
||||
@@ -29,24 +27,20 @@ type JWK struct {
|
||||
Y string `json:"y"`
|
||||
}
|
||||
|
||||
// JWKSet represents a set of JWKs
|
||||
type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
// JWKCache caches the JWKs
|
||||
type JWKCache struct {
|
||||
jwks *JWKSet
|
||||
expiresAt time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// JWKCacheInterface defines the interface for the JWK cache
|
||||
type JWKCacheInterface interface {
|
||||
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
|
||||
}
|
||||
|
||||
// GetJWKS gets the JWKS, either from cache or by fetching it
|
||||
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
c.mutex.RLock()
|
||||
if c.jwks != nil && time.Now().Before(c.expiresAt) {
|
||||
@@ -73,7 +67,6 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
// fetchJWKS fetches the JWKS from the provider
|
||||
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
resp, err := httpClient.Get(jwksURL)
|
||||
if err != nil {
|
||||
@@ -93,61 +86,82 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
|
||||
return &jwks, nil
|
||||
}
|
||||
|
||||
// jwkToPEM converts a JWK to PEM format
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []interface{}:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid 'aud' claim type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jwkToPEM(jwk *JWK) ([]byte, error) {
|
||||
converter, ok := jwkConverters[jwk.Kty]
|
||||
if !ok {
|
||||
switch jwk.Kty {
|
||||
case "RSA":
|
||||
return rsaJWKToPEM(jwk)
|
||||
case "EC":
|
||||
return ecJWKToPEM(jwk)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
|
||||
}
|
||||
return converter(jwk)
|
||||
}
|
||||
|
||||
type jwkToPEMConverter func(*JWK) ([]byte, error)
|
||||
|
||||
var jwkConverters = map[string]jwkToPEMConverter{
|
||||
"RSA": rsaJWKToPEM,
|
||||
"EC": ecJWKToPEM,
|
||||
}
|
||||
|
||||
// rsaJWKToPEM converts an RSA JWK to PEM
|
||||
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
|
||||
}
|
||||
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
|
||||
}
|
||||
|
||||
n := new(big.Int).SetBytes(nBytes)
|
||||
e := new(big.Int).SetBytes(eBytes)
|
||||
|
||||
pubKey := &rsa.PublicKey{
|
||||
N: n,
|
||||
E: int(e.Int64()),
|
||||
publicKey := &rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(n),
|
||||
E: int(new(big.Int).SetBytes(e).Int64()),
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PUBLIC KEY",
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
return pubKeyPEM, nil
|
||||
return publicKeyPEM, nil
|
||||
}
|
||||
|
||||
// ecJWKToPEM converts an EC JWK to PEM
|
||||
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
|
||||
}
|
||||
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
|
||||
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
|
||||
}
|
||||
@@ -164,21 +178,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
|
||||
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
|
||||
}
|
||||
|
||||
pubKey := &ecdsa.PublicKey{
|
||||
publicKey := &ecdsa.PublicKey{
|
||||
Curve: curve,
|
||||
X: new(big.Int).SetBytes(xBytes),
|
||||
Y: new(big.Int).SetBytes(yBytes),
|
||||
X: new(big.Int).SetBytes(x),
|
||||
Y: new(big.Int).SetBytes(y),
|
||||
}
|
||||
|
||||
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
|
||||
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||
}
|
||||
|
||||
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "PUBLIC KEY",
|
||||
Bytes: pubKeyBytes,
|
||||
Bytes: publicKeyBytes,
|
||||
})
|
||||
|
||||
return pubKeyPEM, nil
|
||||
return publicKeyPEM, nil
|
||||
}
|
||||
|
||||
@@ -4,36 +4,29 @@ import (
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// JWT represents a JSON Web Token
|
||||
type JWT struct {
|
||||
Header map[string]interface{}
|
||||
Claims map[string]interface{}
|
||||
Signature []byte
|
||||
Token string
|
||||
Signature string
|
||||
}
|
||||
|
||||
// parseJWT parses a JWT token string into a JWT struct
|
||||
func parseJWT(tokenString string) (*JWT, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
jwt := &JWT{
|
||||
Token: tokenString,
|
||||
}
|
||||
jwt := &JWT{}
|
||||
|
||||
// Decode and unmarshal the header
|
||||
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
@@ -53,17 +46,12 @@ func parseJWT(tokenString string) (*JWT, error) {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
|
||||
}
|
||||
|
||||
// Decode the signature
|
||||
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
|
||||
}
|
||||
jwt.Signature = signatureBytes
|
||||
// Set the signature
|
||||
jwt.Signature = parts[2]
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
|
||||
// Verify verifies the standard claims in the JWT
|
||||
func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
claims := j.Claims
|
||||
|
||||
@@ -107,39 +95,6 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyAudience verifies the audience claim
|
||||
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
|
||||
switch aud := tokenAudience.(type) {
|
||||
case string:
|
||||
if aud != expectedAudience {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
case []interface{}:
|
||||
found := false
|
||||
for _, v := range aud {
|
||||
if str, ok := v.(string); ok && str == expectedAudience {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("invalid audience")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid 'aud' claim type")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuer verifies the issuer claim
|
||||
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
|
||||
if tokenIssuer != expectedIssuer {
|
||||
return fmt.Errorf("invalid issuer")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyExpiration checks if the token has expired
|
||||
func verifyExpiration(expiration float64) error {
|
||||
expirationTime := time.Unix(int64(expiration), 0)
|
||||
if time.Now().After(expirationTime) {
|
||||
@@ -148,43 +103,17 @@ func verifyExpiration(expiration float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyIssuedAt checks if the token was issued in the future
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifySignature verifies the token signature using the provided public key and algorithm
|
||||
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
|
||||
// Split the token into its three parts
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
|
||||
// Decode the signature from the token
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
// Decode the PEM-encoded public key
|
||||
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
|
||||
block, _ := pem.Decode(publicKeyPEM)
|
||||
if block == nil {
|
||||
return fmt.Errorf("failed to parse PEM block containing the public key")
|
||||
}
|
||||
|
||||
// Parse the public key
|
||||
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse public key: %w", err)
|
||||
}
|
||||
|
||||
// Determine the hash function to use based on the algorithm
|
||||
var hashFunc crypto.Hash
|
||||
|
||||
switch alg {
|
||||
@@ -198,42 +127,45 @@ func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error
|
||||
return fmt.Errorf("unsupported algorithm: %s", alg)
|
||||
}
|
||||
|
||||
// Hash the signed content
|
||||
h := hashFunc.New()
|
||||
h.Write([]byte(signedContent))
|
||||
hashed := h.Sum(nil)
|
||||
|
||||
// Verify the signature based on the key type and algorithm
|
||||
switch pubKey := pubKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
// RSA PKCS#1 v1.5 signature
|
||||
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
|
||||
} else if strings.HasPrefix(alg, "PS") {
|
||||
// RSA PSS signature
|
||||
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
}
|
||||
switch pub := pubKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "ES") {
|
||||
// ECDSA signature
|
||||
var r, s big.Int
|
||||
sigLen := len(signature)
|
||||
if sigLen%2 != 0 {
|
||||
return fmt.Errorf("invalid ECDSA signature length")
|
||||
// ECDSA signature handling
|
||||
keyBytes := (pub.Params().BitSize + 7) / 8
|
||||
if len(signature) != 2*keyBytes {
|
||||
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
|
||||
}
|
||||
r.SetBytes(signature[:sigLen/2])
|
||||
s.SetBytes(signature[sigLen/2:])
|
||||
if ecdsa.Verify(pubKey, hashed, &r, &s) {
|
||||
r := new(big.Int).SetBytes(signature[:keyBytes])
|
||||
s := new(big.Int).SetBytes(signature[keyBytes:])
|
||||
|
||||
if ecdsa.Verify(pub, hashed, r, s) {
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unexpected key type for algorithm %s", alg)
|
||||
return fmt.Errorf("invalid ECDSA signature")
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
|
||||
case *rsa.PublicKey:
|
||||
if strings.HasPrefix(alg, "RS") {
|
||||
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("RSA signature verification failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
|
||||
default:
|
||||
return fmt.Errorf("unsupported public key type: %T", pubKey)
|
||||
return fmt.Errorf("unsupported public key type: %T", pub)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyIssuedAt(issuedAt float64) error {
|
||||
issuedAtTime := time.Unix(int64(issuedAt), 0)
|
||||
if time.Now().Before(issuedAtTime) {
|
||||
return fmt.Errorf("token used before issued")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package traefikoidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,29 +11,27 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/sessions"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const ConstSessionTimeout = 86400 // Session timeout in seconds
|
||||
const ConstSessionTimeout = 86400
|
||||
|
||||
// TokenVerifier interface for token verification
|
||||
type TokenVerifier interface {
|
||||
VerifyToken(token string) error
|
||||
}
|
||||
|
||||
// JWTVerifier interface for JWT verification
|
||||
type JWTVerifier interface {
|
||||
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
}
|
||||
|
||||
// TraefikOidc is the main struct for the OIDC middleware
|
||||
type TraefikOidc struct {
|
||||
next http.Handler
|
||||
name string
|
||||
store sessions.Store
|
||||
redirURLPath string
|
||||
logoutURLPath string
|
||||
issuerURL string
|
||||
@@ -51,107 +50,81 @@ type TraefikOidc struct {
|
||||
tokenCache *TokenCache
|
||||
httpClient *http.Client
|
||||
logger *Logger
|
||||
redirectURL string
|
||||
tokenVerifier TokenVerifier
|
||||
jwtVerifier JWTVerifier
|
||||
excludedURLs map[string]struct{}
|
||||
allowedUserDomains map[string]struct{}
|
||||
allowedRolesAndGroups map[string]struct{}
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string, redirectURL string) (*TokenResponse, error)
|
||||
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
|
||||
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
|
||||
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
|
||||
initComplete chan struct{}
|
||||
endSessionURL string
|
||||
baseURL string
|
||||
postLogoutRedirectURI string
|
||||
sessionManager *SessionManager
|
||||
}
|
||||
|
||||
// ProviderMetadata holds OIDC provider metadata
|
||||
type ProviderMetadata struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
EndSessionURL string `json:"end_session_endpoint"`
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
RevokeURL string `json:"revocation_endpoint"`
|
||||
}
|
||||
|
||||
// defaultExcludedURLs are the paths that are excluded from authentication
|
||||
var defaultExcludedURLs = map[string]struct{}{
|
||||
"/favicon": {},
|
||||
}
|
||||
|
||||
var newTicker = time.NewTicker
|
||||
|
||||
var (
|
||||
globalMetadataCache struct {
|
||||
sync.Once
|
||||
metadata *ProviderMetadata
|
||||
err error
|
||||
}
|
||||
)
|
||||
|
||||
// VerifyToken verifies the provided JWT token
|
||||
func (t *TraefikOidc) VerifyToken(token string) error {
|
||||
t.logger.Debugf("Verifying token")
|
||||
|
||||
// Rate limiting
|
||||
t.logger.Debugf("Verifying token: %s", token)
|
||||
if !t.limiter.Allow() {
|
||||
return fmt.Errorf("rate limit exceeded")
|
||||
}
|
||||
|
||||
// Check if token is blacklisted
|
||||
if t.tokenBlacklist.IsBlacklisted(token) {
|
||||
return fmt.Errorf("token is blacklisted")
|
||||
}
|
||||
|
||||
// Check if token is cached
|
||||
if _, exists := t.tokenCache.Get(token); exists {
|
||||
t.logger.Debugf("Token is valid and cached")
|
||||
return nil // Token is valid and cached
|
||||
}
|
||||
|
||||
// Parse the JWT
|
||||
jwt, err := parseJWT(token)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
|
||||
// Verify JWT signature and claims
|
||||
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Cache the token until it expires
|
||||
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
|
||||
now := time.Now()
|
||||
duration := expirationTime.Sub(now)
|
||||
t.tokenCache.Set(token, jwt.Claims, duration)
|
||||
t.tokenCache.Set(token, expirationTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
|
||||
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
|
||||
t.logger.Debugf("Verifying JWT signature and claims")
|
||||
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
|
||||
|
||||
// Get JWKS
|
||||
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get JWKS: %w", err)
|
||||
}
|
||||
|
||||
// Retrieve key ID and algorithm from JWT header
|
||||
kid, ok := jwt.Header["kid"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing key ID in token header")
|
||||
}
|
||||
t.logger.Debugf("Token kid: %s", kid)
|
||||
|
||||
alg, ok := jwt.Header["alg"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing algorithm in token header")
|
||||
}
|
||||
t.logger.Debugf("Token alg: %s", alg)
|
||||
|
||||
// Find the matching key in JWKS
|
||||
var matchingKey *JWK
|
||||
for _, key := range jwks.Keys {
|
||||
if key.Kid == kid {
|
||||
@@ -159,32 +132,48 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchingKey == nil {
|
||||
return fmt.Errorf("no matching public key found for kid: %s", kid)
|
||||
}
|
||||
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
|
||||
|
||||
// Convert JWK to PEM format
|
||||
publicKeyPEM, err := jwkToPEM(matchingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
|
||||
}
|
||||
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
|
||||
|
||||
// Verify the signature
|
||||
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
signedContent := parts[0] + "." + parts[1]
|
||||
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode signature: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
|
||||
t.logger.Errorf("Signature verification failed: %v", err)
|
||||
return fmt.Errorf("signature verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Signature verified successfully")
|
||||
|
||||
// Verify standard claims
|
||||
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
|
||||
return fmt.Errorf("standard claim verification failed: %w", err)
|
||||
}
|
||||
t.logger.Debug("Standard claims verified successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new instance of the OIDC middleware
|
||||
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
|
||||
// Setup HTTP client
|
||||
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
|
||||
store.Options = defaultSessionOptions
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
@@ -195,11 +184,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
return dialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 0,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
}
|
||||
|
||||
var httpClient *http.Client
|
||||
@@ -212,9 +201,15 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
}
|
||||
|
||||
metadata, err := discoverProviderMetadata(config.ProviderURL, httpClient, NewLogger(config.LogLevel))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
|
||||
}
|
||||
|
||||
t := &TraefikOidc{
|
||||
next: next,
|
||||
name: name,
|
||||
store: store,
|
||||
redirURLPath: config.CallbackURL,
|
||||
logoutURLPath: func() string {
|
||||
if config.LogoutURL == "" {
|
||||
@@ -222,36 +217,50 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
}
|
||||
return config.LogoutURL
|
||||
}(),
|
||||
postLogoutRedirectURI: func() string {
|
||||
if config.PostLogoutRedirectURI == "" {
|
||||
return "/"
|
||||
issuerURL: metadata.Issuer,
|
||||
revocationURL: metadata.RevokeURL,
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
jwksURL: metadata.JWKSURL,
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
authURL: metadata.AuthURL,
|
||||
tokenURL: metadata.TokenURL,
|
||||
scopes: config.Scopes,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: func() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, url := range config.ExcludedURLs {
|
||||
m[url] = struct{}{}
|
||||
}
|
||||
return config.PostLogoutRedirectURI
|
||||
return m
|
||||
}(),
|
||||
redirectURL: "",
|
||||
allowedUserDomains: func() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, domain := range config.AllowedUserDomains {
|
||||
m[domain] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
allowedRolesAndGroups: func() map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, roleOrGroup := range config.AllowedRolesAndGroups {
|
||||
m[roleOrGroup] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
tokenBlacklist: NewTokenBlacklist(),
|
||||
jwkCache: &JWKCache{},
|
||||
clientID: config.ClientID,
|
||||
clientSecret: config.ClientSecret,
|
||||
forceHTTPS: config.ForceHTTPS,
|
||||
scopes: config.Scopes,
|
||||
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
|
||||
tokenCache: NewTokenCache(),
|
||||
httpClient: httpClient,
|
||||
logger: NewLogger(config.LogLevel),
|
||||
excludedURLs: createStringMap(config.ExcludedURLs),
|
||||
allowedUserDomains: createStringMap(config.AllowedUserDomains),
|
||||
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
|
||||
initComplete: make(chan struct{}),
|
||||
}
|
||||
|
||||
t.sessionManager = NewSessionManager(config.SessionEncryptionKey, config.ForceHTTPS, t.logger)
|
||||
t.extractClaimsFunc = extractClaims
|
||||
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
|
||||
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
|
||||
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
}
|
||||
t.extractClaimsFunc = extractClaims
|
||||
|
||||
// Add default excluded URLs
|
||||
// add defaultExcludedURLs to excludedURLs
|
||||
for k, v := range defaultExcludedURLs {
|
||||
t.excludedURLs[k] = v
|
||||
}
|
||||
@@ -259,36 +268,9 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
|
||||
t.tokenVerifier = t
|
||||
t.jwtVerifier = t
|
||||
t.startTokenCleanup()
|
||||
go t.initializeMetadata(config.ProviderURL)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// initializeMetadata discovers and initializes the provider metadata
|
||||
func (t *TraefikOidc) initializeMetadata(providerURL string) {
|
||||
globalMetadataCache.Once.Do(func() {
|
||||
t.logger.Debug("Starting global provider metadata discovery")
|
||||
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
|
||||
globalMetadataCache.metadata = metadata
|
||||
globalMetadataCache.err = err
|
||||
})
|
||||
|
||||
if globalMetadataCache.err != nil {
|
||||
t.logger.Errorf("Failed to discover provider metadata: %v", globalMetadataCache.err)
|
||||
} else if globalMetadataCache.metadata != nil {
|
||||
t.logger.Debug("Using cached provider metadata")
|
||||
t.jwksURL = globalMetadataCache.metadata.JWKSURL
|
||||
t.authURL = globalMetadataCache.metadata.AuthURL
|
||||
t.tokenURL = globalMetadataCache.metadata.TokenURL
|
||||
t.issuerURL = globalMetadataCache.metadata.Issuer
|
||||
t.revocationURL = globalMetadataCache.metadata.RevokeURL
|
||||
t.endSessionURL = globalMetadataCache.metadata.EndSessionURL
|
||||
}
|
||||
|
||||
close(t.initComplete)
|
||||
}
|
||||
|
||||
// discoverProviderMetadata fetches the OIDC provider metadata
|
||||
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
|
||||
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
|
||||
|
||||
@@ -302,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if time.Since(start) > totalTimeout {
|
||||
l.Errorf("Timeout exceeded while fetching provider metadata")
|
||||
l.Error("Timeout exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
@@ -314,20 +296,18 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
|
||||
|
||||
lastErr = err
|
||||
|
||||
// Exponential backoff
|
||||
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
|
||||
if delay > maxDelay {
|
||||
delay = maxDelay
|
||||
}
|
||||
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
|
||||
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
l.Errorf("Max retries exceeded while fetching provider metadata")
|
||||
l.Error("Max retries exceeded while fetching provider metadata")
|
||||
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
|
||||
}
|
||||
|
||||
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
|
||||
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
|
||||
resp, err := httpClient.Get(wellKnownURL)
|
||||
if err != nil {
|
||||
@@ -350,134 +330,138 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
|
||||
return &metadata, nil
|
||||
}
|
||||
|
||||
// ServeHTTP is the main handler for the middleware
|
||||
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
select {
|
||||
case <-t.initComplete:
|
||||
if t.issuerURL == "" {
|
||||
t.logger.Debug("OIDC middleware not yet initialized")
|
||||
http.Error(rw, "OIDC middleware not yet initialized", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case <-req.Context().Done():
|
||||
t.logger.Debug("Request cancelled")
|
||||
http.Error(rw, "Request cancelled", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if URL is excluded
|
||||
if t.determineExcludedURL(req.URL.Path) {
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Get session
|
||||
session, err := t.sessionManager.GetSession(req)
|
||||
t.scheme = t.determineScheme(req)
|
||||
defaultSessionOptions.Secure = t.scheme == "https"
|
||||
host := t.determineHost(req)
|
||||
|
||||
if t.redirectURL == "" {
|
||||
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
|
||||
}
|
||||
|
||||
session, err := t.store.Get(req, cookieName)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Error getting session: %v", err)
|
||||
http.Error(rw, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build redirect URL
|
||||
scheme := t.determineScheme(req)
|
||||
host := t.determineHost(req)
|
||||
redirectURL := buildFullURL(scheme, host, t.redirURLPath)
|
||||
t.logger.Debugf("Session contents at start: %+v", session.Values)
|
||||
|
||||
// Handle special URLs
|
||||
if req.URL.Path == t.logoutURLPath {
|
||||
t.handleLogout(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Path == t.redirURLPath {
|
||||
t.handleCallback(rw, req, redirectURL)
|
||||
t.handleCallback(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Check authentication status
|
||||
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
|
||||
|
||||
if expired {
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
return
|
||||
}
|
||||
|
||||
if !authenticated {
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
refreshed := t.refreshToken(rw, req, session)
|
||||
if !refreshed {
|
||||
t.handleExpiredToken(rw, req, session, redirectURL)
|
||||
t.handleExpiredToken(rw, req, session)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Process authenticated request
|
||||
email := session.GetEmail()
|
||||
if email == "" {
|
||||
t.logger.Debug("No email found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(session.GetAccessToken())
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
// authenticated, _ := session.Values["authenticated"].(bool)
|
||||
if authenticated {
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Errorf("No id_token found in session")
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Check allowed roles and groups
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
claims, err := extractClaims(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
email, _ := claims["email"].(string)
|
||||
if email == "" {
|
||||
t.logger.Debugf("No email found in token claims")
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
return
|
||||
}
|
||||
|
||||
if !t.isAllowedDomain(email) {
|
||||
t.logger.Infof("User with email %s is not from an allowed domain", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
groups, roles, err := t.extractGroupsAndRoles(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract groups and roles: %v", err)
|
||||
} else {
|
||||
// Set headers for groups and roles
|
||||
if len(groups) > 0 {
|
||||
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
|
||||
}
|
||||
if len(roles) > 0 {
|
||||
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any of the allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
|
||||
if len(t.allowedRolesAndGroups) > 0 {
|
||||
allowed := false
|
||||
for _, roleOrGroup := range append(groups, roles...) {
|
||||
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
|
||||
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
t.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// Set user information in headers
|
||||
req.Header.Set("X-Forwarded-User", email)
|
||||
|
||||
// Process the request
|
||||
t.next.ServeHTTP(rw, req)
|
||||
t.logger.Debug("User is not authenticated, initiating authentication")
|
||||
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
|
||||
}
|
||||
|
||||
// determineExcludedURL checks if the current request URL is in the excluded list
|
||||
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
|
||||
for excludedURL := range t.excludedURLs {
|
||||
if strings.HasPrefix(currentRequest, excludedURL) {
|
||||
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
|
||||
return true
|
||||
}
|
||||
}
|
||||
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
|
||||
t.logger.Debug("URL is not excluded - got %s", currentRequest)
|
||||
return false
|
||||
}
|
||||
|
||||
// determineScheme determines the scheme (http or https) of the request
|
||||
func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
if t.forceHTTPS {
|
||||
return "https"
|
||||
@@ -491,7 +475,6 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
|
||||
return "http"
|
||||
}
|
||||
|
||||
// determineHost determines the host of the request
|
||||
func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
|
||||
return host
|
||||
@@ -499,35 +482,37 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
|
||||
return req.Host
|
||||
}
|
||||
|
||||
// isUserAuthenticated checks if the user is authenticated
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, bool) {
|
||||
if !session.GetAuthenticated() {
|
||||
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
|
||||
authenticated, _ := session.Values["authenticated"].(bool)
|
||||
t.logger.Debugf("Session authenticated value: %v", authenticated)
|
||||
|
||||
if !authenticated {
|
||||
t.logger.Debug("User is not authenticated according to session")
|
||||
return false, false, false
|
||||
}
|
||||
|
||||
accessToken := session.GetAccessToken()
|
||||
if accessToken == "" {
|
||||
t.logger.Debug("No access token found in session")
|
||||
idToken, ok := session.Values["id_token"].(string)
|
||||
if !ok || idToken == "" {
|
||||
t.logger.Debug("No id_token found in session")
|
||||
return false, false, true // Session is invalid, consider it expired
|
||||
}
|
||||
|
||||
// Verify the token
|
||||
if err := t.verifyToken(accessToken); err != nil {
|
||||
if err := t.verifyToken(idToken); err != nil {
|
||||
t.logger.Errorf("Token verification failed: %v", err)
|
||||
return false, false, true // Token is invalid, consider it expired
|
||||
}
|
||||
|
||||
claims, err := extractClaims(accessToken)
|
||||
claims, err := extractClaims(idToken)
|
||||
if err != nil {
|
||||
t.logger.Errorf("Failed to extract claims: %v", err)
|
||||
return false, false, true
|
||||
return false, false, true // Can't read claims, consider it expired
|
||||
}
|
||||
|
||||
expClaim, ok := claims["exp"].(float64)
|
||||
if !ok {
|
||||
t.logger.Error("Failed to get expiration time from claims")
|
||||
return false, false, true
|
||||
t.logger.Errorf("Failed to get expiration time from claims")
|
||||
return false, false, true // No expiration, consider it expired
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
@@ -535,7 +520,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
|
||||
if now > expTime {
|
||||
t.logger.Debug("Token has expired")
|
||||
return false, false, true
|
||||
return false, false, true // Token has expired
|
||||
}
|
||||
|
||||
gracePeriod := time.Minute * 5
|
||||
@@ -544,42 +529,38 @@ func (t *TraefikOidc) isUserAuthenticated(session *SessionData) (bool, bool, boo
|
||||
return true, true, false // Token will expire soon, needs refresh
|
||||
}
|
||||
|
||||
return true, false, false
|
||||
return true, false, false // Token is valid and not expiring soon
|
||||
}
|
||||
|
||||
// defaultInitiateAuthentication initiates the authentication process
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *SessionData, redirectURL string) {
|
||||
// Generate CSRF token and nonce
|
||||
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
|
||||
csrfToken := uuid.New().String()
|
||||
session.Values["csrf"] = csrfToken
|
||||
session.Values["incoming_path"] = req.URL.Path
|
||||
session.Options = defaultSessionOptions
|
||||
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
|
||||
|
||||
nonce, err := generateNonce()
|
||||
if err != nil {
|
||||
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
session.Values["nonce"] = nonce
|
||||
t.logger.Debugf("Setting nonce: %s", nonce)
|
||||
|
||||
// Set session values
|
||||
session.SetCSRF(csrfToken)
|
||||
session.SetNonce(nonce)
|
||||
session.SetIncomingPath(req.URL.Path)
|
||||
|
||||
// Save the session
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save session: %v", err)
|
||||
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build and redirect to auth URL
|
||||
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
|
||||
http.Redirect(rw, req, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// verifyToken verifies the token using the token verifier
|
||||
func (t *TraefikOidc) verifyToken(token string) error {
|
||||
return t.tokenVerifier.VerifyToken(token)
|
||||
}
|
||||
|
||||
// buildAuthURL constructs the authentication URL
|
||||
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
params := url.Values{}
|
||||
params.Set("client_id", t.clientID)
|
||||
@@ -593,7 +574,6 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
|
||||
return t.authURL + "?" + params.Encode()
|
||||
}
|
||||
|
||||
// startTokenCleanup starts the token cleanup goroutine
|
||||
func (t *TraefikOidc) startTokenCleanup() {
|
||||
ticker := newTicker(1 * time.Minute)
|
||||
go func() {
|
||||
@@ -605,23 +585,26 @@ func (t *TraefikOidc) startTokenCleanup() {
|
||||
}()
|
||||
}
|
||||
|
||||
// RevokeToken adds the token to the blacklist
|
||||
func (t *TraefikOidc) RevokeToken(token string) {
|
||||
// Remove from cache
|
||||
t.tokenCache.Delete(token)
|
||||
|
||||
// Add to blacklist with default expiration
|
||||
expiry := time.Now().Add(24 * time.Hour) // or other appropriate duration
|
||||
t.tokenBlacklist.Add(token, expiry)
|
||||
// Add to blacklist
|
||||
claims, err := extractClaims(token)
|
||||
if err == nil {
|
||||
if exp, ok := claims["exp"].(float64); ok {
|
||||
expTime := time.Unix(int64(exp), 0)
|
||||
t.tokenBlacklist.Add(token, expTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RevokeTokenWithProvider revokes the token with the provider
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
|
||||
t.logger.Debugf("Revoking token with provider")
|
||||
|
||||
data := url.Values{
|
||||
"token": {token},
|
||||
"token_type_hint": {tokenType},
|
||||
"token_type_hint": {"access_token", "refresh_token"},
|
||||
"client_id": {t.clientID},
|
||||
"client_secret": {t.clientSecret},
|
||||
}
|
||||
@@ -652,12 +635,10 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// refreshToken refreshes the user's token
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *SessionData) bool {
|
||||
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
|
||||
t.logger.Debug("Refreshing token")
|
||||
refreshToken := session.GetRefreshToken()
|
||||
if refreshToken == "" {
|
||||
t.logger.Debug("No refresh token found in session")
|
||||
refreshToken, ok := session.Values["refresh_token"].(string)
|
||||
if !ok || refreshToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -667,17 +648,9 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the new access token
|
||||
if err := t.verifyToken(newToken.IDToken); err != nil {
|
||||
t.logger.Errorf("Failed to verify new access token: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Update session with new tokens
|
||||
session.SetAccessToken(newToken.IDToken)
|
||||
session.SetRefreshToken(newToken.RefreshToken)
|
||||
|
||||
// Save the session
|
||||
session.Values["id_token"] = newToken.IDToken
|
||||
session.Values["refresh_token"] = newToken.RefreshToken
|
||||
session.Options = defaultSessionOptions
|
||||
if err := session.Save(req, rw); err != nil {
|
||||
t.logger.Errorf("Failed to save refreshed session: %v", err)
|
||||
return false
|
||||
@@ -686,7 +659,6 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
|
||||
return true
|
||||
}
|
||||
|
||||
// isAllowedDomain checks if the user's email domain is allowed
|
||||
func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
if len(t.allowedUserDomains) == 0 {
|
||||
return true // If no domains are specified, all are allowed
|
||||
@@ -702,7 +674,6 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
// extractGroupsAndRoles extracts groups and roles from the id_token
|
||||
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
|
||||
claims, err := t.extractClaimsFunc(idToken)
|
||||
if err != nil {
|
||||
@@ -712,48 +683,37 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
|
||||
var groups []string
|
||||
var roles []string
|
||||
|
||||
// Extract groups with type checking
|
||||
if groupsClaim, exists := claims["groups"]; exists {
|
||||
groupsSlice, ok := groupsClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("groups claim is not an array")
|
||||
}
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
// Check for groups claim
|
||||
if groupsClaim, ok := claims["groups"]; ok {
|
||||
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
|
||||
for _, group := range groupsSlice {
|
||||
if groupStr, ok := group.(string); ok {
|
||||
t.logger.Debugf("Found group: %s", groupStr)
|
||||
groups = append(groups, groupStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract roles with type checking
|
||||
if rolesClaim, exists := claims["roles"]; exists {
|
||||
rolesSlice, ok := rolesClaim.([]interface{})
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("roles claim is not an array")
|
||||
}
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debugf("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
if len(groups) == 0 {
|
||||
t.logger.Debug("No groups found in groups claim, checking roles claim")
|
||||
}
|
||||
|
||||
// Check for roles claim
|
||||
if rolesClaim, ok := claims["roles"]; ok {
|
||||
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
|
||||
for _, role := range rolesSlice {
|
||||
if roleStr, ok := role.(string); ok {
|
||||
t.logger.Debug("Found role: %s", roleStr)
|
||||
roles = append(roles, roleStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(roles) == 0 {
|
||||
t.logger.Debug("No roles found in roles claim")
|
||||
}
|
||||
|
||||
return groups, roles, nil
|
||||
}
|
||||
|
||||
// buildFullURL constructs a full URL from scheme, host and path
|
||||
func buildFullURL(scheme, host, path string) string {
|
||||
// If the path is already a full URL, return it as-is
|
||||
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
|
||||
return path
|
||||
}
|
||||
|
||||
// Ensure the path starts with a forward slash
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s%s", scheme, host, path)
|
||||
}
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
|
||||
func BenchmarkOIDCMiddleware(b *testing.B) {
|
||||
// Setup test environment
|
||||
|
||||
ts := &TestSuite{}
|
||||
ts.Setup()
|
||||
ts.token = "valid.jwt.token"
|
||||
|
||||
// Define the handler with OIDC middleware
|
||||
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
// Create test server
|
||||
server := httptest.NewServer(ts.tOidc.next)
|
||||
defer server.Close()
|
||||
|
||||
// Prepare HTTP client
|
||||
client := &http.Client{}
|
||||
|
||||
// Reset timer to exclude setup time
|
||||
b.ResetTimer()
|
||||
|
||||
// Run benchmark
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Create new request
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Set necessary headers or cookies
|
||||
req.Header.Set("Authorization", "Bearer "+ts.token)
|
||||
|
||||
// Send the request
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Close response body
|
||||
resp.Body.Close()
|
||||
|
||||
// Check response status code
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
+84
-1043
File diff suppressed because it is too large
Load Diff
-196
@@ -1,196 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
mainCookieName = "_raczylo_oidc" // Main session cookie
|
||||
accessTokenCookie = "_raczylo_oidc_access" // Access token cookie
|
||||
refreshTokenCookie = "_raczylo_oidc_refresh" // Refresh token cookie
|
||||
)
|
||||
|
||||
// SessionManager handles multiple session cookies
|
||||
type SessionManager struct {
|
||||
store sessions.Store
|
||||
forceHTTPS bool
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewSessionManager creates a new session manager
|
||||
func NewSessionManager(encryptionKey string, forceHTTPS bool, logger *Logger) *SessionManager {
|
||||
return &SessionManager{
|
||||
store: sessions.NewCookieStore([]byte(encryptionKey)),
|
||||
forceHTTPS: forceHTTPS,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// getSessionOptions returns session options based on scheme
|
||||
func (sm *SessionManager) getSessionOptions(isSecure bool) *sessions.Options {
|
||||
return &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: isSecure || sm.forceHTTPS,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSession retrieves all session data
|
||||
func (sm *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
|
||||
mainSession, err := sm.store.Get(r, mainCookieName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get main session: %w", err)
|
||||
}
|
||||
|
||||
accessSession, err := sm.store.Get(r, accessTokenCookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get access token session: %w", err)
|
||||
}
|
||||
|
||||
refreshSession, err := sm.store.Get(r, refreshTokenCookie)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get refresh token session: %w", err)
|
||||
}
|
||||
|
||||
sessionData := &SessionData{
|
||||
manager: sm,
|
||||
mainSession: mainSession,
|
||||
accessSession: accessSession,
|
||||
refreshSession: refreshSession,
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
// SessionData holds all session information
|
||||
type SessionData struct {
|
||||
manager *SessionManager
|
||||
mainSession *sessions.Session
|
||||
accessSession *sessions.Session
|
||||
refreshSession *sessions.Session
|
||||
}
|
||||
|
||||
// Save saves all session data
|
||||
func (sd *SessionData) Save(r *http.Request, w http.ResponseWriter) error {
|
||||
isSecure := strings.HasPrefix(r.URL.Scheme, "https") || sd.manager.forceHTTPS
|
||||
|
||||
// Set options for all sessions
|
||||
sd.mainSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.accessSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
sd.refreshSession.Options = sd.manager.getSessionOptions(isSecure)
|
||||
|
||||
if err := sd.mainSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save main session: %w", err)
|
||||
}
|
||||
if err := sd.accessSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save access token session: %w", err)
|
||||
}
|
||||
if err := sd.refreshSession.Save(r, w); err != nil {
|
||||
return fmt.Errorf("failed to save refresh token session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear clears all session data
|
||||
func (sd *SessionData) Clear(r *http.Request, w http.ResponseWriter) error {
|
||||
// Clear and expire all sessions
|
||||
sd.mainSession.Options.MaxAge = -1
|
||||
sd.accessSession.Options.MaxAge = -1
|
||||
sd.refreshSession.Options.MaxAge = -1
|
||||
|
||||
for k := range sd.mainSession.Values {
|
||||
delete(sd.mainSession.Values, k)
|
||||
}
|
||||
for k := range sd.accessSession.Values {
|
||||
delete(sd.accessSession.Values, k)
|
||||
}
|
||||
for k := range sd.refreshSession.Values {
|
||||
delete(sd.refreshSession.Values, k)
|
||||
}
|
||||
|
||||
return sd.Save(r, w)
|
||||
}
|
||||
|
||||
// GetAuthenticated returns authentication status
|
||||
func (sd *SessionData) GetAuthenticated() bool {
|
||||
auth, _ := sd.mainSession.Values["authenticated"].(bool)
|
||||
return auth
|
||||
}
|
||||
|
||||
// SetAuthenticated sets authentication status
|
||||
func (sd *SessionData) SetAuthenticated(value bool) {
|
||||
sd.mainSession.Values["authenticated"] = value
|
||||
}
|
||||
|
||||
// GetAccessToken returns the access token
|
||||
func (sd *SessionData) GetAccessToken() string {
|
||||
token, _ := sd.accessSession.Values["token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
// SetAccessToken sets the access token
|
||||
func (sd *SessionData) SetAccessToken(token string) {
|
||||
sd.accessSession.Values["token"] = token
|
||||
}
|
||||
|
||||
// GetRefreshToken returns the refresh token
|
||||
func (sd *SessionData) GetRefreshToken() string {
|
||||
token, _ := sd.refreshSession.Values["token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
// SetRefreshToken sets the refresh token
|
||||
func (sd *SessionData) SetRefreshToken(token string) {
|
||||
sd.refreshSession.Values["token"] = token
|
||||
}
|
||||
|
||||
// GetCSRF returns the CSRF token
|
||||
func (sd *SessionData) GetCSRF() string {
|
||||
csrf, _ := sd.mainSession.Values["csrf"].(string)
|
||||
return csrf
|
||||
}
|
||||
|
||||
// SetCSRF sets the CSRF token
|
||||
func (sd *SessionData) SetCSRF(token string) {
|
||||
sd.mainSession.Values["csrf"] = token
|
||||
}
|
||||
|
||||
// GetNonce returns the nonce
|
||||
func (sd *SessionData) GetNonce() string {
|
||||
nonce, _ := sd.mainSession.Values["nonce"].(string)
|
||||
return nonce
|
||||
}
|
||||
|
||||
// SetNonce sets the nonce
|
||||
func (sd *SessionData) SetNonce(nonce string) {
|
||||
sd.mainSession.Values["nonce"] = nonce
|
||||
}
|
||||
|
||||
// GetEmail returns the user's email
|
||||
func (sd *SessionData) GetEmail() string {
|
||||
email, _ := sd.mainSession.Values["email"].(string)
|
||||
return email
|
||||
}
|
||||
|
||||
// SetEmail sets the user's email
|
||||
func (sd *SessionData) SetEmail(email string) {
|
||||
sd.mainSession.Values["email"] = email
|
||||
}
|
||||
|
||||
// GetIncomingPath returns the original incoming path
|
||||
func (sd *SessionData) GetIncomingPath() string {
|
||||
path, _ := sd.mainSession.Values["incoming_path"].(string)
|
||||
return path
|
||||
}
|
||||
|
||||
// SetIncomingPath sets the original incoming path
|
||||
func (sd *SessionData) SetIncomingPath(path string) {
|
||||
sd.mainSession.Values["incoming_path"] = path
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
logger := NewLogger("info")
|
||||
manager := NewSessionManager("test-secret-key", false, logger)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
session, err := manager.GetSession(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get session: %v", err)
|
||||
}
|
||||
|
||||
// Test setting and getting values
|
||||
session.SetAuthenticated(true)
|
||||
session.SetEmail("test@example.com")
|
||||
session.SetAccessToken("test.access.token")
|
||||
session.SetRefreshToken("test.refresh.token")
|
||||
|
||||
if err := session.Save(req, rr); err != nil {
|
||||
t.Fatalf("Failed to save session: %v", err)
|
||||
}
|
||||
|
||||
// Verify cookies are set
|
||||
cookies := rr.Result().Cookies()
|
||||
if len(cookies) != 3 {
|
||||
t.Errorf("Expected 3 cookies, got %d", len(cookies))
|
||||
}
|
||||
|
||||
// Create a new request with the cookies
|
||||
newReq := httptest.NewRequest("GET", "/test", nil)
|
||||
for _, cookie := range cookies {
|
||||
newReq.AddCookie(cookie)
|
||||
}
|
||||
|
||||
// Get the session again and verify values
|
||||
newSession, err := manager.GetSession(newReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get new session: %v", err)
|
||||
}
|
||||
|
||||
if !newSession.GetAuthenticated() {
|
||||
t.Error("Authentication status not preserved")
|
||||
}
|
||||
if email := newSession.GetEmail(); email != "test@example.com" {
|
||||
t.Errorf("Expected email test@example.com, got %s", email)
|
||||
}
|
||||
if token := newSession.GetAccessToken(); token != "test.access.token" {
|
||||
t.Errorf("Expected access token test.access.token, got %s", token)
|
||||
}
|
||||
if token := newSession.GetRefreshToken(); token != "test.refresh.token" {
|
||||
t.Errorf("Expected refresh token test.refresh.token, got %s", token)
|
||||
}
|
||||
}
|
||||
+10
-14
@@ -6,13 +6,14 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const (
|
||||
cookieName = "_raczylo_oidc"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the OIDC middleware
|
||||
type Config struct {
|
||||
ProviderURL string `json:"providerURL"`
|
||||
RevocationURL string `json:"revocationURL"`
|
||||
@@ -28,12 +29,17 @@ type Config struct {
|
||||
ExcludedURLs []string `json:"excludedURLs"`
|
||||
AllowedUserDomains []string `json:"allowedUserDomains"`
|
||||
AllowedRolesAndGroups []string `json:"allowedRolesAndGroups"`
|
||||
OIDCEndSessionURL string `json:"oidcEndSessionURL"`
|
||||
PostLogoutRedirectURI string `json:"postLogoutRedirectURI"`
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
// CreateConfig creates a new Config with default values
|
||||
var defaultSessionOptions = &sessions.Options{
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: ConstSessionTimeout,
|
||||
Path: "/",
|
||||
}
|
||||
|
||||
func CreateConfig() *Config {
|
||||
c := &Config{}
|
||||
|
||||
@@ -56,7 +62,6 @@ func CreateConfig() *Config {
|
||||
return c
|
||||
}
|
||||
|
||||
// Validate validates the Config
|
||||
func (c *Config) Validate() error {
|
||||
if c.ProviderURL == "" {
|
||||
return fmt.Errorf("providerURL is required")
|
||||
@@ -76,14 +81,12 @@ func (c *Config) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Logger is a simple logger with different levels
|
||||
type Logger struct {
|
||||
logError *log.Logger
|
||||
logInfo *log.Logger
|
||||
logDebug *log.Logger
|
||||
}
|
||||
|
||||
// NewLogger creates a new Logger
|
||||
func NewLogger(logLevel string) *Logger {
|
||||
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
|
||||
@@ -103,37 +106,30 @@ func NewLogger(logLevel string) *Logger {
|
||||
}
|
||||
}
|
||||
|
||||
// Info logs an info message
|
||||
func (l *Logger) Info(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debug logs a debug message
|
||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs an error message
|
||||
func (l *Logger) Error(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs an info message
|
||||
func (l *Logger) Infof(format string, args ...interface{}) {
|
||||
l.logInfo.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Debugf logs a debug message
|
||||
func (l *Logger) Debugf(format string, args ...interface{}) {
|
||||
l.logDebug.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs an error message
|
||||
func (l *Logger) Errorf(format string, args ...interface{}) {
|
||||
l.logError.Printf(format, args...)
|
||||
}
|
||||
|
||||
// handleError writes an error message to the response and logs it
|
||||
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
|
||||
logger.Error(message)
|
||||
http.Error(w, message, code)
|
||||
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
Copyright 2009 The Go Authors.
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google LLC nor the names of its
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
|
||||
+14
-3
@@ -99,9 +99,8 @@ func (lim *Limiter) Tokens() float64 {
|
||||
// bursts of at most b tokens.
|
||||
func NewLimiter(r Limit, b int) *Limiter {
|
||||
return &Limiter{
|
||||
limit: r,
|
||||
burst: b,
|
||||
tokens: float64(b),
|
||||
limit: r,
|
||||
burst: b,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -345,6 +344,18 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
|
||||
tokens: n,
|
||||
timeToAct: t,
|
||||
}
|
||||
} else if lim.limit == 0 {
|
||||
var ok bool
|
||||
if lim.burst >= n {
|
||||
ok = true
|
||||
lim.burst -= n
|
||||
}
|
||||
return Reservation{
|
||||
ok: ok,
|
||||
lim: lim,
|
||||
tokens: lim.burst,
|
||||
timeToAct: t,
|
||||
}
|
||||
}
|
||||
|
||||
t, tokens := lim.advance(t)
|
||||
|
||||
Vendored
+1
-1
@@ -7,6 +7,6 @@ github.com/gorilla/securecookie
|
||||
# github.com/gorilla/sessions v1.3.0
|
||||
## explicit; go 1.20
|
||||
github.com/gorilla/sessions
|
||||
# golang.org/x/time v0.7.0
|
||||
# golang.org/x/time v0.5.0
|
||||
## explicit; go 1.18
|
||||
golang.org/x/time/rate
|
||||
|
||||
Reference in New Issue
Block a user