Compare commits

..

16 Commits

Author SHA1 Message Date
lukaszraczylo ab36f10a70 Draft. 2024-10-13 18:21:13 +01:00
lukaszraczylo 4972b21373 Improve speed of the cache module. 2024-10-13 18:00:43 +01:00
lukaszraczylo 0be45f06a5 Merge pull request #10 from lukaszraczylo/code-improvements
Code improvements
2024-10-09 09:34:45 +01:00
lukaszraczylo 2cc89f0f31 fixup! Update dependencies. 2024-10-09 09:33:20 +01:00
lukaszraczylo b9928f2f0c Update dependencies. 2024-10-09 09:32:25 +01:00
lukaszraczylo 2e1a3a9320 Add ability to verify default ECDSA keys provided by logto as well. 2024-10-09 09:30:42 +01:00
lukaszraczylo 9dabd0e5cf Revert "Update go mod dependencies."
This reverts commit dedbdf63c3.
2024-10-09 09:11:07 +01:00
lukaszraczylo dedbdf63c3 Update go mod dependencies. 2024-10-09 09:07:25 +01:00
lukaszraczylo af032c6cd3 Add simple benchmark to track the allocations and speed for future improvements. 2024-10-08 14:41:43 +01:00
lukaszraczylo 9938cff053 fixup! Cleanup and optimise the code. 2024-10-08 14:26:26 +01:00
lukaszraczylo 7a404ef76f Cleanup and optimise the code. 2024-10-08 14:14:47 +01:00
lukaszraczylo 63922f362f fixup! Add support for more algorithms. 2024-10-07 16:07:07 +01:00
lukaszraczylo 2de9297ab6 Add support for more algorithms. 2024-10-07 16:01:07 +01:00
lukaszraczylo 971c84f762 Abstract filling up maps. 2024-10-07 15:56:24 +01:00
lukaszraczylo d2a0d2167e Fix the bug with user not being redirected to originally requested URL post authentication. 2024-10-05 09:33:56 +01:00
lukaszraczylo c46d958397 Update documentation - setting secrets in kubernetes. 2024-10-04 17:15:43 +01:00
20 changed files with 1008 additions and 630 deletions
+32
View File
@@ -4,6 +4,10 @@ This middleware is supposed to replace the need for the forward-auth and oauth2-
Middleware has been tested with Auth0 and Logto.
### Traefik version compatibility
Code follows closely the current traefik helm chart versions. If plugin fails to load - it's time to update to the latest version of the traefik helm chart.
### Configuration options
Middleware currently supports following scenarios:
@@ -15,6 +19,34 @@ Middleware currently supports following scenarios:
#### How to configure...
##### Keeping secrets secret
This works ONLY in kubernetes environments. Don't forget to create secret traefik-middleware-oidc with fields ISSUER, CLIENT_ID and SECRET keys.
```
apiVersion: traefik.io/v1alpha1
kind: Middleware
metadata:
name: oidc-with-open-urls
namespace: traefik
spec:
plugin:
traefikoidc:
providerURL: urn:k8s:secret:traefik-middleware-oidc:ISSUER
clientID: urn:k8s:secret:traefik-middleware-oidc:CLIENT_ID
clientSecret: urn:k8s:secret:traefik-middleware-oidc:SECRET
sessionEncryptionKey: vvv
callbackURL: /cool-oidc/callback
logoutURL: /cool-oidc/logout
scopes:
- openid
- email
- profile
excludedURLs: # Determines the list of URLs which are NOT a subject to authentication
- /login # covers /login, /login/me, /login/reminder etc.
- /my-public-data
```
##### Excluded URLs with open access
```
+21 -10
View File
@@ -5,58 +5,69 @@ import (
"time"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
ExpiresAt int64 // Changed to int64 for faster comparisons
}
// Cache is a simple in-memory cache
type Cache struct {
items map[string]CacheItem
mutex sync.RWMutex
}
// NewCache creates a new Cache
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Removed defer for slightly better performance
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: time.Now().Add(expiration).UnixNano(), // Store as UnixNano for faster comparisons
}
c.mutex.Unlock()
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
if time.Now().UnixNano() > item.ExpiresAt {
c.mutex.RUnlock()
// Use a separate goroutine to delete expired items to avoid blocking
go c.Delete(key)
return nil, false
}
c.mutex.RUnlock()
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.mutex.Unlock()
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
now := time.Now().UnixNano()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
if now > item.ExpiresAt {
delete(c.items, key)
}
}
c.mutex.Unlock()
}
+5 -3
View File
@@ -1,11 +1,13 @@
module github.com/lukaszraczylo/traefikoidc
go 1.22.2
go 1.23
toolchain go1.23.1
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.3.0
golang.org/x/time v0.5.0
github.com/gorilla/sessions v1.4.0
golang.org/x/time v0.7.0
)
require github.com/gorilla/securecookie v1.1.2 // indirect
+4
View File
@@ -6,5 +6,9 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFzg=
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
+137 -115
View File
@@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
@@ -16,15 +17,16 @@ import (
"github.com/gorilla/sessions"
)
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
if err != nil {
if _, err := rand.Read(nonceBytes); err != nil {
return "", fmt.Errorf("could not generate nonce: %w", err)
}
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
@@ -32,21 +34,23 @@ func buildFullURL(scheme, host, path string) string {
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
if grantType == "authorization_code" {
switch grantType {
case "authorization_code":
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
} else if grantType == "refresh_token" {
case "refresh_token":
data.Set("refresh_token", codeOrToken)
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -58,14 +62,20 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return result, nil
return &tokenResponse, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
@@ -74,47 +84,20 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
newAccessToken, ok := result["access_token"].(string)
if !ok || newAccessToken == "" {
return nil, fmt.Errorf("no access_token field in token response")
}
t.logger.Debugf("Token response: %+v", tokenResponse)
rawIDToken, ok := result["id_token"].(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token field in token response")
}
newRefreshToken, ok := result["refresh_token"].(string)
if !ok || newRefreshToken == "" {
return nil, fmt.Errorf("no refresh_token field in token response")
}
response := &TokenResponse{
IDToken: rawIDToken,
AccessToken: newAccessToken,
ExpiresIn: int(result["expires_in"].(float64)),
TokenType: result["token_type"].(string),
}
// The refresh token might not be returned if it hasn't changed
if newRefreshToken != refreshToken {
response.RefreshToken = newRefreshToken
} else {
response.RefreshToken = refreshToken
}
t.logger.Debug("Token response: %+v", response)
return response, nil
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
@@ -123,31 +106,41 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
err := t.RevokeTokenWithProvider(idToken)
if err != nil {
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
return
// Revoke tokens if available
for _, tokenType := range []string{"refresh_token", "access_token"} {
if token, ok := session.Values[tokenType].(string); ok && token != "" {
if err := t.RevokeTokenWithProvider(token, tokenType); err != nil {
t.logger.Errorf("Failed to revoke %s: %v", tokenType, err)
}
t.RevokeToken(token)
}
t.RevokeToken(idToken)
delete(session.Values, tokenType)
}
session.Options = defaultSessionOptions
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
// Remove other session values
delete(session.Values, "id_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = &sessions.Options{MaxAge: -1, Path: "/", HttpOnly: true, Secure: true}
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
http.Error(rw, "Logged out", http.StatusForbidden)
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
if session == nil {
t.logger.Error("Session is nil in handleExpiredToken")
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
@@ -156,19 +149,18 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
@@ -179,6 +171,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
if errParam := req.URL.Query().Get("error"); errParam != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", errParam, errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
state := req.URL.Query().Get("state")
csrfToken, ok := session.Values["csrf"].(string)
if !ok || state == "" || csrfToken == "" || state != csrfToken {
t.logger.Error("Invalid state parameter or CSRF token")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -186,20 +193,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
token, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
idToken, ok := token["id_token"].(string)
if !ok || idToken == "" {
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
@@ -207,6 +220,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
nonceClaim, ok := claims["nonce"].(string)
sessionNonce, ok2 := session.Values["nonce"].(string)
if !ok || !ok2 || nonceClaim == "" || sessionNonce == "" || nonceClaim != sessionNonce {
t.logger.Error("Invalid nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -217,7 +238,11 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Options = defaultSessionOptions
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
delete(session.Values, "csrf")
delete(session.Values, "nonce")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
@@ -226,9 +251,16 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
}
t.logger.Debugf("Authentication successful. User email: %s", email)
http.Redirect(rw, req, "/", http.StatusFound)
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims extracts claims from a JWT token
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -248,64 +280,56 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
blacklist sync.Map
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
return &TokenBlacklist{}
}
func (tb *TokenBlacklist) Add(token string, expiration time.Time) {
tb.blacklist.Store(token, expiration)
}
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
if exp, ok := tb.blacklist.Load(token); ok {
return time.Now().Before(exp.(time.Time))
}
}
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
return false
}
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
tb.blacklist.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
tb.blacklist.Delete(key)
}
}
return true
})
}
// TokenCache caches tokens
type TokenCache struct {
cache *Cache
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
// NewTokenCache creates a new TokenCache
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
tc.cache.Set("t-"+token, claims, expiration)
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
value, found := tc.cache.Get("t-" + token)
if !found {
return nil, false
}
@@ -313,33 +337,31 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
tc.cache.Delete("t-" + token)
}
// Cleanup cleans up expired tokens from the cache
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.clientID)
data.Set("client_secret", t.clientSecret)
data.Set("code", code)
data.Set("redirect_uri", t.redirectURL)
resp, err := t.httpClient.PostForm(t.tokenURL, data)
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %v", err)
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %v", err)
}
return result, nil
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{}, len(keys))
for _, key := range keys {
result[key] = struct{}{}
}
return result
}
+57 -74
View File
@@ -15,6 +15,7 @@ import (
"time"
)
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
@@ -27,20 +28,24 @@ type JWK struct {
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
type JWKCache struct {
jwks *JWKSet
expiresAt time.Time
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -52,6 +57,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check locking pattern
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
@@ -67,6 +73,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -86,113 +93,89 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// jwkToPEM converts a JWK to PEM format
func jwkToPEM(jwk *JWK) ([]byte, error) {
switch jwk.Kty {
case "RSA":
return rsaJWKToPEM(jwk)
case "EC":
return ecJWKToPEM(jwk)
default:
converter, ok := jwkConverters[jwk.Kty]
if !ok {
return nil, fmt.Errorf("unsupported key type: %s", jwk.Kty)
}
return converter(jwk)
}
type jwkToPEMConverter func(*JWK) ([]byte, error)
var jwkConverters = map[string]jwkToPEMConverter{
"RSA": rsaJWKToPEM,
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()),
pubKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(nBytes),
E: int(new(big.Int).SetBytes(eBytes).Int64()),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
})
return publicKeyPEM, nil
return marshalPublicKey(pubKey)
}
// ecJWKToPEM converts an EC JWK to PEM
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
curve, err := getCurve(jwk.Crv)
if err != nil {
return nil, err
}
publicKey := &ecdsa.PublicKey{
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
return marshalPublicKey(pubKey)
}
// getCurve returns the elliptic curve based on the JWK curve parameter
func getCurve(crv string) (elliptic.Curve, error) {
switch crv {
case "P-256":
return elliptic.P256(), nil
case "P-384":
return elliptic.P384(), nil
case "P-521":
return elliptic.P521(), nil
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", crv)
}
}
// marshalPublicKey marshals a public key to PEM format
func marshalPublicKey(pubKey interface{}) ([]byte, error) {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
})
return publicKeyPEM, nil
Bytes: pubKeyBytes,
}), nil
}
+178 -124
View File
@@ -8,164 +8,218 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"time"
)
var (
ErrInvalidJWTFormat = errors.New("invalid JWT format")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrTokenExpired = errors.New("token has expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrMissingClaim = errors.New("missing claim")
ErrInvalidClaimType = errors.New("invalid claim type")
ErrUnsupportedAlgorithm = errors.New("unsupported algorithm")
ErrInvalidSignature = errors.New("invalid signature")
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
Signature string
Signature []byte
Token string
}
// parseJWT parses a JWT token string into a JWT struct
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
return nil, fmt.Errorf("%w: expected 3 parts, got %d", ErrInvalidJWTFormat, len(parts))
}
jwt := &JWT{}
jwt := &JWT{Token: tokenString}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err := decodeJSONPart(parts[0], &jwt.Header); err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
if err := decodeJSONPart(parts[1], &jwt.Claims); err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var err error
jwt.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
}
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Set the signature
jwt.Signature = parts[2]
return jwt, nil
}
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
}
if err := verifyIssuer(iss, issuerURL); err != nil {
return err
}
aud, ok := claims["aud"]
if !ok {
return fmt.Errorf("missing 'aud' claim")
}
if err := verifyAudience(aud, clientID); err != nil {
return err
}
exp, ok := claims["exp"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'exp' claim")
}
if err := verifyExpiration(exp); err != nil {
return err
}
iat, ok := claims["iat"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'iat' claim")
}
if err := verifyIssuedAt(iat); err != nil {
return err
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
}
return nil
}
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
}
return nil
}
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
func decodeJSONPart(part string, target interface{}) error {
bytes, err := base64.RawURLEncoding.DecodeString(part)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
return err
}
return json.Unmarshal(bytes, target)
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
if err := verifyIssuer(j.Claims["iss"], issuerURL); err != nil {
return err
}
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
if err := verifyAudience(j.Claims["aud"], clientID); err != nil {
return err
}
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
if err := verifyExpiration(j.Claims["exp"]); err != nil {
return err
}
switch pub := pubKey.(type) {
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature handling
keyBytes := (pub.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if err := verifyIssuedAt(j.Claims["iat"]); err != nil {
return err
}
if ecdsa.Verify(pub, hashed, r, s) {
if sub, ok := j.Claims["sub"].(string); !ok || sub == "" {
return fmt.Errorf("%w: sub", ErrMissingClaim)
}
return nil
}
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return ErrInvalidAudience
}
case []interface{}:
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
return nil
}
return fmt.Errorf("invalid ECDSA signature")
}
return fmt.Errorf("algorithm %s is not compatible with ECDSA public key", alg)
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
err := rsa.VerifyPKCS1v15(pub, hashFunc, hashed, signature)
if err != nil {
return fmt.Errorf("RSA signature verification failed: %w", err)
}
return nil
}
return fmt.Errorf("algorithm %s is not compatible with RSA public key", alg)
return ErrInvalidAudience
default:
return fmt.Errorf("unsupported public key type: %T", pub)
}
}
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
return fmt.Errorf("%w: aud", ErrInvalidClaimType)
}
return nil
}
func verifyIssuer(tokenIssuer interface{}, expectedIssuer string) error {
iss, ok := tokenIssuer.(string)
if !ok {
return fmt.Errorf("%w: iss", ErrMissingClaim)
}
if iss != expectedIssuer {
return ErrInvalidIssuer
}
return nil
}
func verifyExpiration(expiration interface{}) error {
exp, ok := expiration.(float64)
if !ok {
return fmt.Errorf("%w: exp", ErrInvalidClaimType)
}
if time.Now().After(time.Unix(int64(exp), 0)) {
return ErrTokenExpired
}
return nil
}
func verifyIssuedAt(issuedAt interface{}) error {
iat, ok := issuedAt.(float64)
if !ok {
return fmt.Errorf("%w: iat", ErrInvalidClaimType)
}
if time.Now().Before(time.Unix(int64(iat), 0)) {
return ErrTokenUsedBeforeIssued
}
return nil
}
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return ErrInvalidJWTFormat
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
pubKey, err := parsePublicKey(publicKeyPEM)
if err != nil {
return err
}
hashFunc, err := getHashFunc(alg)
if err != nil {
return err
}
hashed := hashFunc.New().Sum([]byte(signedContent))
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
return verifyRSASignature(pubKey, hashFunc, hashed, signature, alg)
case *ecdsa.PublicKey:
return verifyECDSASignature(pubKey, hashed, signature)
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
}
func parsePublicKey(publicKeyPEM []byte) (interface{}, error) {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse PEM block containing the public key")
}
return x509.ParsePKIXPublicKey(block.Bytes)
}
func getHashFunc(alg string) (crypto.Hash, error) {
switch alg {
case "RS256", "PS256", "ES256":
return crypto.SHA256, nil
case "RS384", "PS384", "ES384":
return crypto.SHA384, nil
case "RS512", "PS512", "ES512":
return crypto.SHA512, nil
default:
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
}
func verifyRSASignature(pubKey *rsa.PublicKey, hashFunc crypto.Hash, hashed, signature []byte, alg string) error {
if strings.HasPrefix(alg, "RS") {
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
}
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
func verifyECDSASignature(pubKey *ecdsa.PublicKey, hashed, signature []byte) error {
sigLen := len(signature)
if sigLen%2 != 0 {
return errors.New("invalid ECDSA signature length")
}
r, s := new(big.Int), new(big.Int)
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, r, s) {
return nil
}
return ErrInvalidSignature
}
+176 -189
View File
@@ -2,7 +2,6 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -19,16 +18,19 @@ import (
"golang.org/x/time/rate"
)
const ConstSessionTimeout = 86400
const ConstSessionTimeout = 86400 // Session timeout in seconds
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// JWTVerifier interface for JWT verification
type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
name string
@@ -58,12 +60,13 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
@@ -72,36 +75,45 @@ type ProviderMetadata struct {
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token: %s", token)
t.logger.Debugf("Verifying token")
// Rate limiting
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
@@ -110,26 +122,27 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return nil
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
t.logger.Debugf("Verifying JWT signature and claims")
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Retrieve key ID and algorithm from JWT header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
// Find the matching key in JWKS
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
@@ -137,48 +150,35 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
// Convert JWK to PEM format
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
// Verify the signature
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -217,47 +217,29 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: func() map[string]struct{} {
m := make(map[string]struct{})
for _, url := range config.ExcludedURLs {
m[url] = struct{}{}
}
return m
}(),
redirectURL: "",
allowedUserDomains: func() map[string]struct{} {
m := make(map[string]struct{})
for _, domain := range config.AllowedUserDomains {
m[domain] = struct{}{}
}
return m
}(),
allowedRolesAndGroups: func() map[string]struct{} {
m := make(map[string]struct{})
for _, roleOrGroup := range config.AllowedRolesAndGroups {
m[roleOrGroup] = struct{}{}
}
return m
}(),
initComplete: make(chan struct{}),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(),
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// add defaultExcludedURLs to excludedURLs
// Add default excluded URLs
for k, v := range defaultExcludedURLs {
t.excludedURLs[k] = v
}
@@ -266,14 +248,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.jwtVerifier = t
t.startTokenCleanup()
go t.initializeMetadata(config.ProviderURL)
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Error("Failed to discover provider metadata: %v", err)
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
@@ -286,6 +270,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
})
}
// discoverProviderMetadata fetches the OIDC provider metadata
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -299,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if time.Since(start) > totalTimeout {
l.Error("Timeout exceeded while fetching provider metadata")
l.Errorf("Timeout exceeded while fetching provider metadata")
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
}
@@ -311,18 +296,20 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
lastErr = err
// Exponential backoff
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
if delay > maxDelay {
delay = maxDelay
}
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
time.Sleep(delay)
}
l.Error("Max retries exceeded while fetching provider metadata")
l.Errorf("Max retries exceeded while fetching provider metadata")
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -345,6 +332,7 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -360,20 +348,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Check if the URL is excluded from authentication
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
@@ -383,16 +375,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle logout URL
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
return
}
// Check if the user is authenticated
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
@@ -413,97 +408,94 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// authenticated, _ := session.Values["authenticated"].(bool)
if authenticated {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
t.logger.Debug("User is not authenticated, initiating authentication")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles := t.extractGroupsAndRoles(claims)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
t.logger.Debug("URL is not excluded - got %s", currentRequest)
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
return false
}
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
switch {
case t.forceHTTPS:
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
case req.Header.Get(headerXForwardedProto) != "":
return req.Header.Get(headerXForwardedProto)
case req.TLS != nil:
return "https"
default:
return "http"
}
return "http"
}
// determineHost determines the host of the request
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -511,6 +503,7 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
@@ -561,13 +554,16 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
@@ -576,20 +572,24 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build the authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -603,6 +603,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
go func() {
@@ -614,6 +615,7 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
@@ -628,12 +630,13 @@ func (t *TraefikOidc) RevokeToken(token string) {
}
}
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
// RevokeTokenWithProvider revokes the token with the provider
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
data := url.Values{
"token": {token},
"token_type_hint": {"access_token", "refresh_token"},
"token_type_hint": {tokenType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
@@ -664,10 +667,12 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
return nil
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -677,6 +682,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
@@ -688,61 +700,36 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
return true
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false // Invalid email format
atIndex := strings.LastIndex(email, "@")
if atIndex == -1 {
return false
}
domain := parts[1]
domain := email[atIndex+1:]
_, ok := t.allowedUserDomains[domain]
return ok
}
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
var groups []string
var roles []string
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
}
if len(groups) == 0 {
t.logger.Debug("No groups found in groups claim, checking roles claim")
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debug("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
if len(roles) == 0 {
t.logger.Debug("No roles found in roles claim")
}
return groups, roles, nil
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(claims map[string]interface{}) ([]string, []string) {
groups := extractStringSlice(claims, "groups")
roles := extractStringSlice(claims, "roles")
return groups, roles
}
func extractStringSlice(claims map[string]interface{}, key string) []string {
if slice, ok := claims[key].([]interface{}); ok {
result := make([]string, 0, len(slice))
for _, item := range slice {
if str, ok := item.(string); ok {
result = append(result, str)
}
}
return result
}
return nil
}
+57
View File
@@ -0,0 +1,57 @@
package traefikoidc
import (
"net/http"
"net/http/httptest"
"testing"
)
// BenchmarkOIDCMiddleware benchmarks the OIDC middleware's ability to handle concurrent requests.
func BenchmarkOIDCMiddleware(b *testing.B) {
// Setup test environment
ts := &TestSuite{}
ts.Setup()
ts.token = "valid.jwt.token"
// Define the handler with OIDC middleware
ts.tOidc.next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Create test server
server := httptest.NewServer(ts.tOidc.next)
defer server.Close()
// Prepare HTTP client
client := &http.Client{}
// Reset timer to exclude setup time
b.ResetTimer()
// Run benchmark
for i := 0; i < b.N; i++ {
// Create new request
req, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
b.Fatal(err)
}
// Set necessary headers or cookies
req.Header.Set("Authorization", "Bearer "+ts.token)
// Send the request
resp, err := client.Do(req)
if err != nil {
b.Fatal(err)
}
// Close response body
resp.Body.Close()
// Check response status code
if resp.StatusCode != http.StatusOK {
b.Errorf("Unexpected status code: got %v, want %v", resp.StatusCode, http.StatusOK)
}
}
}
+285 -36
View File
@@ -72,6 +72,7 @@ func (ts *TestSuite) Setup() {
"iat": time.Now().Unix(),
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
@@ -79,33 +80,34 @@ func (ts *TestSuite) Setup() {
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
exchangeCodeForTokenFunc: ts.exchangeCodeForTokenFunc,
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
ts.tOidc.tokenVerifier = ts.tOidc
ts.tOidc.jwtVerifier = ts.tOidc
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": ts.token,
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
}
@@ -165,6 +167,18 @@ func TestVerifyToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.mockJWKCache.JWKS = &JWKSet{
Keys: []JWK{
{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
},
},
}
tests := []struct {
name string
token string
@@ -453,61 +467,149 @@ func TestHandleCallback(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (map[string]interface{}, error)
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
}{
{
name: "Success",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Exchange Code Error",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing ID Token",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{}, nil
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Disallowed Email",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@disallowed.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusForbidden,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tests {
@@ -519,6 +621,8 @@ func TestHandleCallback(t *testing.T) {
logger: NewLogger("info"),
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
}
// Create request and response recorder
@@ -527,6 +631,9 @@ func TestHandleCallback(t *testing.T) {
// Create session
session, _ := tOidc.store.New(req, cookieName)
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
session.Save(req, rr)
// Copy session cookie to request
@@ -583,3 +690,145 @@ func TestIsAllowedDomain(t *testing.T) {
})
}
}
func TestOIDCHandler(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.token = "valid.jwt.token"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
blacklist bool
rateLimit bool
cacheToken bool
}{
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with invalid nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims without nonce
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with mismatched nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour))
}
if tc.rateLimit {
// Exceed rate limit
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
}
if tc.cacheToken {
// Cache the token with dummy claims
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
"empty": "claim",
}, 60)
}
})
}
}
+20
View File
@@ -14,6 +14,15 @@ const (
cookieName = "_raczylo_oidc"
)
const (
headerXForwardedProto = "X-Forwarded-Proto"
headerXForwardedHost = "X-Forwarded-Host"
headerXForwardedUser = "X-Forwarded-User"
headerXUserGroups = "X-User-Groups"
headerXUserRoles = "X-User-Roles"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
@@ -40,6 +49,7 @@ var defaultSessionOptions = &sessions.Options{
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{}
@@ -62,6 +72,7 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
@@ -81,12 +92,14 @@ func (c *Config) Validate() error {
return nil
}
// Logger is a simple logger with different levels
type Logger struct {
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
@@ -106,30 +119,37 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)
+1 -1
View File
@@ -1,4 +1,4 @@
Copyright (c) 2023 The Gorilla Authors. All rights reserved.
Copyright (c) 2024 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
+5 -1
View File
@@ -1,4 +1,7 @@
# sessions
# Gorilla Sessions
> [!IMPORTANT]
> The latest version of this repository requires go 1.23 because of the new partitioned attribute. The last version that is compatible with older versions of go is v1.3.0.
![testing](https://github.com/gorilla/sessions/actions/workflows/test.yml/badge.svg)
[![codecov](https://codecov.io/github/gorilla/sessions/branch/main/graph/badge.svg)](https://codecov.io/github/gorilla/sessions)
@@ -74,6 +77,7 @@ Other implementations of the `sessions.Store` interface:
- [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
- [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
- [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
- [github.com/danielepintore/gorilla-sessions-mysql](https://github.com/danielepintore/gorilla-sessions-mysql) - MySQL
- [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
- [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
- [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
+12 -9
View File
@@ -1,5 +1,6 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
@@ -8,13 +9,15 @@ import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
Partitioned: options.Partitioned,
SameSite: options.SameSite,
}
}
-21
View File
@@ -1,21 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// newCookieFromOptions returns an http.Cookie with the options set.
func newCookieFromOptions(name, value string, options *Options) *http.Cookie {
return &http.Cookie{
Name: name,
Value: value,
Path: options.Path,
Domain: options.Domain,
MaxAge: options.MaxAge,
Secure: options.Secure,
HttpOnly: options.HttpOnly,
SameSite: options.SameSite,
}
}
+10 -5
View File
@@ -1,8 +1,11 @@
//go:build !go1.11
// +build !go1.11
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
@@ -13,7 +16,9 @@ type Options struct {
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
MaxAge int
Secure bool
HttpOnly bool
Partitioned bool
SameSite http.SameSite
}
-23
View File
@@ -1,23 +0,0 @@
//go:build go1.11
// +build go1.11
package sessions
import "net/http"
// Options stores configuration for a session or session store.
//
// Fields are a subset of http.Cookie fields.
type Options struct {
Path string
Domain string
// MaxAge=0 means no Max-Age attribute specified and the cookie will be
// deleted after the browser session ends.
// MaxAge<0 means delete cookie immediately.
// MaxAge>0 means Max-Age attribute present and given in seconds.
MaxAge int
Secure bool
HttpOnly bool
// Defaults to http.SameSiteDefaultMode
SameSite http.SameSite
}
+2 -2
View File
@@ -1,4 +1,4 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Copyright 2009 The Go Authors.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -10,7 +10,7 @@ notice, this list of conditions and the following disclaimer.
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
* Neither the name of Google LLC nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
+3 -14
View File
@@ -99,8 +99,9 @@ func (lim *Limiter) Tokens() float64 {
// bursts of at most b tokens.
func NewLimiter(r Limit, b int) *Limiter {
return &Limiter{
limit: r,
burst: b,
limit: r,
burst: b,
tokens: float64(b),
}
}
@@ -344,18 +345,6 @@ func (lim *Limiter) reserveN(t time.Time, n int, maxFutureReserve time.Duration)
tokens: n,
timeToAct: t,
}
} else if lim.limit == 0 {
var ok bool
if lim.burst >= n {
ok = true
lim.burst -= n
}
return Reservation{
ok: ok,
lim: lim,
tokens: lim.burst,
timeToAct: t,
}
}
t, tokens := lim.advance(t)
+3 -3
View File
@@ -4,9 +4,9 @@ github.com/google/uuid
# github.com/gorilla/securecookie v1.1.2
## explicit; go 1.20
github.com/gorilla/securecookie
# github.com/gorilla/sessions v1.3.0
## explicit; go 1.20
# github.com/gorilla/sessions v1.4.0
## explicit; go 1.23
github.com/gorilla/sessions
# golang.org/x/time v0.5.0
# golang.org/x/time v0.7.0
## explicit; go 1.18
golang.org/x/time/rate