Compare commits

...

12 Commits

20 changed files with 843 additions and 439 deletions
+4
View File
@@ -4,6 +4,10 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
Middleware has been tested with Auth0 and Logto.
### Traefik version compatibility
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
### Configuration options
Middleware currently supports following scenarios:
+7
View File
@@ -5,22 +5,26 @@ import (
"time"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
}
// Cache is a simple in-memory cache
type Cache struct {
items map[string]CacheItem
mutex sync.RWMutex
}
// NewCache creates a new Cache
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -30,6 +34,7 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
}
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
@@ -44,12 +49,14 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
+5 -3
View File
@@ -1,11 +1,13 @@
module github.com/lukaszraczylo/traefikoidc
go 1.22.2
go 1.23
toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.5.0
github.com/gorilla/sessions v1.4.0
golang.org/x/time v0.7.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
+4
View File
@@ -6,5 +6,9 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+145 -81
View File
@@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
@@ -16,6 +17,7 @@ import (
"github.com/gorilla/sessions"
)
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -25,6 +27,7 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
@@ -32,7 +35,8 @@ func buildFullURL(scheme, host, path string) string {
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -58,14 +62,20 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return result, nil
return &tokenResponse, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
@@ -74,47 +84,20 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
newAccessToken, ok := result["access_token"].(string)
if !ok || newAccessToken == "" {
return nil, fmt.Errorf("no access_token field in token response")
}
t.logger.Debugf("Token response: %+v", tokenResponse)
rawIDToken, ok := result["id_token"].(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token field in token response")
}
newRefreshToken, ok := result["refresh_token"].(string)
if !ok || newRefreshToken == "" {
return nil, fmt.Errorf("no refresh_token field in token response")
}
response := &TokenResponse{
IDToken: rawIDToken,
AccessToken: newAccessToken,
ExpiresIn: int(result["expires_in"].(float64)),
TokenType: result["token_type"].(string),
}
// The refresh token might not be returned if it hasn't changed
if newRefreshToken != refreshToken {
response.RefreshToken = newRefreshToken
} else {
response.RefreshToken = refreshToken
}
t.logger.Debug("Token response: %+v", response)
return response, nil
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
@@ -123,28 +106,41 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
err := t.RevokeTokenWithProvider(idToken)
if err != nil {
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
return
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
}
t.RevokeToken(idToken)
t.RevokeToken(refreshToken)
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
t.RevokeToken(accessToken)
}
// Remove tokens from session
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
http.Error(rw, "Logged out", http.StatusForbidden)
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
// Clear the existing session
session.Options.MaxAge = -1
@@ -169,6 +165,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
@@ -179,6 +176,34 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -186,20 +211,29 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
token, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
idToken, ok := token["id_token"].(string)
if !ok || idToken == "" {
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
@@ -207,6 +241,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -214,11 +268,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
@@ -226,16 +286,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
}
t.logger.Debugf("Authentication successful. User email: %s", email)
http.Redirect(rw, req, func() string {
if path, ok := session.Values["incoming_path"].(string); ok {
t.logger.Debug("Redirecting to incoming path from original request: %s", path)
return path
}
t.logger.Debug("Redirecting to root path as no incoming path found")
return "/"
}(), http.StatusFound)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims extracts claims from a JWT token
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -255,23 +316,27 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -279,6 +344,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -290,26 +356,25 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache caches tokens
type TokenCache struct {
cache *Cache
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
// NewTokenCache creates a new TokenCache
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
@@ -320,33 +385,32 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup cleans up expired tokens from the cache
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.clientID)
data.Set("client_secret", t.clientSecret)
data.Set("code", code)
data.Set("redirect_uri", t.redirectURL)
resp, err := t.httpClient.PostForm(t.tokenURL, data)
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %v", err)
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %v", err)
}
return result, nil
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
+45 -59
View File
@@ -4,17 +4,19 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
@@ -27,20 +29,24 @@ type JWK struct {
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
type JWKCache struct {
jwks *JWKSet
expiresAt time.Time
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -67,6 +73,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -86,82 +93,61 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// jwkToPEM converts a JWK to PEM format
func jwkToPEM(jwk *JWK) ([]byte, error) {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
converter, ok := jwkConverters[jwk.Kty]
if !ok {
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return converter(jwk)
}
type jwkToPEMConverter func(*JWK) ([]byte, error)
var jwkConverters = map[string]jwkToPEMConverter{
"RSA": rsaJWKToPEM,
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()),
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return publicKeyPEM, nil
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC JWK to PEM
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
@@ -178,21 +164,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
publicKey := &ecdsa.PublicKey{
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
Bytes: pubKeyBytes,
})
return publicKeyPEM, nil
return pubKeyPEM, nil
}
+107 -39
View File
@@ -4,29 +4,36 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
Signature string
Signature []byte
Token string
}
// parseJWT parses a JWT token string into a JWT struct
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
jwt := &JWT{}
jwt := &JWT{
Token: tokenString,
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -46,12 +53,17 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Set the signature
jwt.Signature = parts[2]
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
return jwt, nil
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
@@ -95,6 +107,39 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
@@ -103,17 +148,43 @@ func verifyExpiration(expiration float64) error {
return nil
}
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
// verifySignature verifies the token signature using the provided public key and algorithm
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
@@ -127,45 +198,42 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte
return fmt.Errorf("unsupported algorithm: %s", alg)
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pub, hashed, r, s) {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
default:
return fmt.Errorf("unsupported public key type: %T", pub)
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
}
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
+148 -142
View File
@@ -2,7 +2,6 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -19,16 +18,19 @@ import (
"golang.org/x/time/rate"
)
const ConstSessionTimeout = 86400
const ConstSessionTimeout = 86400 // Session timeout in seconds
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// JWTVerifier interface for JWT verification
type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
name string
@@ -58,12 +60,13 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
@@ -72,36 +75,45 @@ type ProviderMetadata struct {
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token: %s", token)
t.logger.Debugf("Verifying token")
// Rate limiting
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
@@ -110,26 +122,27 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return nil
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
t.logger.Debugf("Verifying JWT signature and claims")
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Retrieve key ID and algorithm from JWT header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
// Find the matching key in JWKS
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
@@ -137,48 +150,35 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
// Convert JWK to PEM format
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
// Verify the signature
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -217,47 +217,29 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: func() map[string]struct{} {
m := make(map[string]struct{})
for _, url := range config.ExcludedURLs {
m[url] = struct{}{}
}
return m
}(),
redirectURL: "",
allowedUserDomains: func() map[string]struct{} {
m := make(map[string]struct{})
for _, domain := range config.AllowedUserDomains {
m[domain] = struct{}{}
}
return m
}(),
allowedRolesAndGroups: func() map[string]struct{} {
m := make(map[string]struct{})
for _, roleOrGroup := range config.AllowedRolesAndGroups {
m[roleOrGroup] = struct{}{}
}
return m
}(),
initComplete: make(chan struct{}),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// add defaultExcludedURLs to excludedURLs
// Add default excluded URLs
for k, v := range defaultExcludedURLs {
t.excludedURLs[k] = v
}
@@ -266,14 +248,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.jwtVerifier = t
t.startTokenCleanup()
go t.initializeMetadata(config.ProviderURL)
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Error("Failed to discover provider metadata: %v", err)
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
@@ -286,6 +270,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
})
}
// discoverProviderMetadata fetches the OIDC provider metadata
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -299,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if time.Since(start) > totalTimeout {
l.Error("Timeout exceeded while fetching provider metadata")
l.Errorf("Timeout exceeded while fetching provider metadata")
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
}
@@ -311,18 +296,20 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
lastErr = err
// Exponential backoff
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
if delay > maxDelay {
delay = maxDelay
}
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
time.Sleep(delay)
}
l.Error("Max retries exceeded while fetching provider metadata")
l.Errorf("Max retries exceeded while fetching provider metadata")
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -345,6 +332,7 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -360,20 +348,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Check if the URL is excluded from authentication
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
@@ -383,16 +375,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle logout URL
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
return
}
// Check if the user is authenticated
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
@@ -413,84 +408,80 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// authenticated, _ := session.Values["authenticated"].(bool)
if authenticated {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
t.logger.Debug("User is not authenticated, initiating authentication")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
t.logger.Debug("URL is not excluded - got %s", currentRequest)
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
return false
}
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
@@ -504,6 +495,7 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
return "http"
}
// determineHost determines the host of the request
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -511,6 +503,7 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
@@ -561,13 +554,16 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
@@ -576,20 +572,24 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build the authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -603,6 +603,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
go func() {
@@ -614,6 +615,7 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
@@ -628,12 +630,13 @@ func (t *TraefikOidc) RevokeToken(token string) {
}
}
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
// RevokeTokenWithProvider revokes the token with the provider
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
data := url.Values{
"token": {token},
"token_type_hint": {"access_token", "refresh_token"},
"token_type_hint": {tokenType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
@@ -664,10 +667,12 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
return nil
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -677,6 +682,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
@@ -688,6 +700,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
@@ -703,6 +716,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
@@ -724,25 +738,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
}
}
if len(groups) == 0 {
t.logger.Debug("No groups found in groups claim, checking roles claim")
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debug("Found role: %s", roleStr)
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
if len(roles) == 0 {
t.logger.Debug("No roles found in roles claim")
}
return groups, roles, nil
}
+57
View File
@@ -0,0 +1,57 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
func BenchmarkOIDCMiddleware(b *testing.B) {
// Setup test environment
ts := &TestSuite{}
ts.Setup()
ts.token = "valid.jwt.token"
// Define the handler with OIDC middleware
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create test server
server := httptest.NewServer(ts.tOidc.next)
defer server.Close()
// Prepare HTTP client
client := &http.Client{}
// Reset timer to exclude setup time
b.ResetTimer()
// Run benchmark
for i := 0; i < b.N; i++ {
// Create new request
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
b.Fatal(err)
}
// Set necessary headers or cookies
req.Header.Set("Authorization", "Bearer "+ts.token)
// Send the request
resp, err := client.Do(req)
if err != nil {
b.Fatal(err)
}
// Close response body
resp.Body.Close()
// Check response status code
if resp.StatusCode != http.StatusOK {
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
}
}
}
+273 -36
View File
@@ -72,6 +72,7 @@ func (ts *TestSuite) Setup() {
"iat": time.Now().Unix(),
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
@@ -79,33 +80,34 @@ func (ts *TestSuite) Setup() {
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
exchangeCodeForTokenFunc: ts.exchangeCodeForTokenFunc,
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
ts.tOidc.tokenVerifier = ts.tOidc
ts.tOidc.jwtVerifier = ts.tOidc
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": ts.token,
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
}
@@ -453,61 +455,149 @@ func TestHandleCallback(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (map[string]interface{}, error)
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
}{
{
name: "Success",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Exchange Code Error",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing ID Token",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{}, nil
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Disallowed Email",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@disallowed.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusForbidden,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tests {
@@ -519,6 +609,8 @@ func TestHandleCallback(t *testing.T) {
logger: NewLogger("info"),
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
}
// Create request and response recorder
@@ -527,6 +619,9 @@ func TestHandleCallback(t *testing.T) {
// Create session
session, _ := tOidc.store.New(req, cookieName)
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
session.Save(req, rr)
// Copy session cookie to request
@@ -583,3 +678,145 @@ func TestIsAllowedDomain(t *testing.T) {
})
}
}
func TestOIDCHandler(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.token = "valid.jwt.token"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
blacklist bool
rateLimit bool
cacheToken bool
}{
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with invalid nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims without nonce
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with mismatched nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour))
}
if tc.rateLimit {
// Exceed rate limit
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
}
if tc.cacheToken {
// Cache the token with dummy claims
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
"empty": "claim",
}, 60)
}
})
}
}
+12
View File
@@ -14,6 +14,7 @@ const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
@@ -40,6 +41,7 @@ var defaultSessionOptions = &sessions.Options{
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{}
@@ -62,6 +64,7 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
@@ -81,12 +84,14 @@ func (c *Config) Validate() error {
return nil
}
// Logger is a simple logger with different levels
type Logger struct {
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
@@ -106,30 +111,37 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+5 -1
View File
@@ -1,4 +1,7 @@
# sessions
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -74,6 +77,7 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+12 -9
View File
@@ -1,5 +1,6 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
@@ -8,13 +9,15 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
}
}
-21
View File
@@ -1,21 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+10 -5
View File
@@ -1,8 +1,11 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -13,7 +16,9 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
}
-23
View File
@@ -1,23 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
+3 -14
View File
@@ -99,8 +99,9 @@ func (lim *Limiter) Tokens() float64 {
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
limit: r,
burst: b,
tokens: float64(b),
}
}
@@ -344,18 +345,6 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
tokens: n,
timeToAct: t,
}
} else if lim.limit == 0 {
var ok bool
if lim.burst >= n {
ok = true
lim.burst -= n
}
return Reservation{
ok: ok,
lim: lim,
tokens: lim.burst,
timeToAct: t,
}
}
t, tokens := lim.advance(t)
+3 -3
View File
@@ -4,9 +4,9 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
github.com/gorilla/sessions
# golang.org/x/time v0.5.0
# golang.org/x/time v0.7.0
## explicit; go 1.18
golang.org/x/time/rate