Cleanup and optimise the code.

This commit is contained in:
2024-10-08 14:14:47 +01:00
parent dc4c4824cd
commit 218165d365
7 changed files with 679 additions and 349 deletions
+7
View File
@@ -5,22 +5,26 @@ import (
"time"
)
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
}
// Cache is a simple in-memory cache
type Cache struct {
items map[string]CacheItem
mutex sync.RWMutex
}
// NewCache creates a new Cache
func NewCache() *Cache {
return &Cache{
items: make(map[string]CacheItem),
}
}
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
@@ -30,6 +34,7 @@ func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
}
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
@@ -44,12 +49,14 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
+137 -81
View File
@@ -6,6 +6,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
@@ -16,6 +17,7 @@ import (
"github.com/gorilla/sessions"
)
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
@@ -25,6 +27,7 @@ func generateNonce() (string, error) {
return base64.URLEncoding.EncodeToString(nonceBytes), nil
}
// buildFullURL constructs a full URL from scheme, host, and path
func buildFullURL(scheme, host, path string) string {
if scheme == "" {
scheme = "http"
@@ -32,7 +35,8 @@ func buildFullURL(scheme, host, path string) string {
return fmt.Sprintf("%s://%s%s", scheme, host, path)
}
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (map[string]interface{}, error) {
// exchangeTokens exchanges a code or refresh token for tokens
func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken, redirectURL string) (*TokenResponse, error) {
data := url.Values{
"grant_type": {grantType},
"client_id": {t.clientID},
@@ -58,14 +62,20 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("token endpoint returned status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResponse TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResponse); err != nil {
return nil, fmt.Errorf("failed to decode token response: %w", err)
}
return result, nil
return &tokenResponse, nil
}
// TokenResponse represents the response from the token endpoint
type TokenResponse struct {
IDToken string `json:"id_token"`
AccessToken string `json:"access_token"`
@@ -74,47 +84,20 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}
// getNewTokenWithRefreshToken refreshes the token using the refresh token
func (t *TraefikOidc) getNewTokenWithRefreshToken(refreshToken string) (*TokenResponse, error) {
ctx := context.Background()
result, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
tokenResponse, err := t.exchangeTokens(ctx, "refresh_token", refreshToken, "")
if err != nil {
return nil, fmt.Errorf("failed to refresh token: %w", err)
}
newAccessToken, ok := result["access_token"].(string)
if !ok || newAccessToken == "" {
return nil, fmt.Errorf("no access_token field in token response")
}
t.logger.Debugf("Token response: %+v", tokenResponse)
rawIDToken, ok := result["id_token"].(string)
if !ok || rawIDToken == "" {
return nil, fmt.Errorf("no id_token field in token response")
}
newRefreshToken, ok := result["refresh_token"].(string)
if !ok || newRefreshToken == "" {
return nil, fmt.Errorf("no refresh_token field in token response")
}
response := &TokenResponse{
IDToken: rawIDToken,
AccessToken: newAccessToken,
ExpiresIn: int(result["expires_in"].(float64)),
TokenType: result["token_type"].(string),
}
// The refresh token might not be returned if it hasn't changed
if newRefreshToken != refreshToken {
response.RefreshToken = newRefreshToken
} else {
response.RefreshToken = refreshToken
}
t.logger.Debug("Token response: %+v", response)
return response, nil
return tokenResponse, nil
}
// handleLogout handles the user logout
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
@@ -123,28 +106,41 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
return
}
if idToken, ok := session.Values["id_token"].(string); ok {
err := t.RevokeTokenWithProvider(idToken)
if err != nil {
handleError(rw, "Failed to revoke token", http.StatusInternalServerError, t.logger)
return
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
}
t.RevokeToken(idToken)
t.RevokeToken(refreshToken)
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
t.RevokeToken(accessToken)
}
// Remove tokens from session
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
// Clear the session
session.Options.MaxAge = -1
session.Values = make(map[interface{}]interface{})
err = session.Save(req, rw)
if err != nil {
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
http.Error(rw, "Logged out", http.StatusForbidden)
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
// Clear the existing session
session.Options.MaxAge = -1
@@ -169,6 +165,7 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
// handleCallback handles the callback from the OIDC provider
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName)
if err != nil {
@@ -179,6 +176,34 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -186,20 +211,29 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
token, err := t.exchangeCodeForTokenFunc(code)
tokenResponse, err := t.exchangeCodeForTokenFunc(code)
if err != nil {
t.logger.Errorf("Failed to exchange code for token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
idToken, ok := token["id_token"].(string)
if !ok || idToken == "" {
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
@@ -207,6 +241,26 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -214,11 +268,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
@@ -226,16 +286,17 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
}
t.logger.Debugf("Authentication successful. User email: %s", email)
http.Redirect(rw, req, func() string {
if path, ok := session.Values["incoming_path"].(string); ok {
t.logger.Debug("Redirecting to incoming path from original request: %s", path)
return path
}
t.logger.Debug("Redirecting to root path as no incoming path found")
return "/"
}(), http.StatusFound)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
redirectPath = path
}
http.Redirect(rw, req, redirectPath, http.StatusFound)
}
// extractClaims extracts claims from a JWT token
func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
@@ -255,23 +316,27 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
return claims, nil
}
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
}
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
@@ -279,6 +344,7 @@ func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
@@ -290,26 +356,25 @@ func (tb *TokenBlacklist) Cleanup() {
}
}
// TokenCache caches tokens
type TokenCache struct {
cache *Cache
}
type TokenInfo struct {
Token string
ExpiresAt time.Time
}
// NewTokenCache creates a new TokenCache
func NewTokenCache() *TokenCache {
return &TokenCache{
cache: NewCache(),
}
}
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
@@ -320,37 +385,28 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
return claims, ok
}
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
}
// Cleanup cleans up expired tokens from the cache
func (tc *TokenCache) Cleanup() {
tc.cache.Cleanup()
}
func (t *TraefikOidc) exchangeCodeForToken(code string) (map[string]interface{}, error) {
data := url.Values{}
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.clientID)
data.Set("client_secret", t.clientSecret)
data.Set("code", code)
data.Set("redirect_uri", t.redirectURL)
resp, err := t.httpClient.PostForm(t.tokenURL, data)
// exchangeCodeForToken exchanges the authorization code for tokens
func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error) {
ctx := context.Background()
tokenResponse, err := t.exchangeTokens(ctx, "authorization_code", code, t.redirectURL)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %v", err)
return nil, fmt.Errorf("failed to exchange code for token: %w", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode token response: %v", err)
}
return result, nil
return tokenResponse, nil
}
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
for _, key := range keys {
+35 -53
View File
@@ -4,17 +4,19 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
)
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"`
Kid string `json:"kid"`
@@ -27,20 +29,24 @@ type JWK struct {
Y string `json:"y"`
}
// JWKSet represents a set of JWKs
type JWKSet struct {
Keys []JWK `json:"keys"`
}
// JWKCache caches the JWKs
type JWKCache struct {
jwks *JWKSet
expiresAt time.Time
mutex sync.RWMutex
}
// JWKCacheInterface defines the interface for the JWK cache
type JWKCacheInterface interface {
GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error)
}
// GetJWKS gets the JWKS, either from cache or by fetching it
func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) {
@@ -67,6 +73,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
return jwks, nil
}
// fetchJWKS fetches the JWKS from the provider
func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL)
if err != nil {
@@ -86,36 +93,7 @@ func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
return &jwks, nil
}
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// jwkToPEM converts a JWK to PEM format
func jwkToPEM(jwk *JWK) ([]byte, error) {
converter, ok := jwkConverters[jwk.Kty]
if !ok {
@@ -131,41 +109,45 @@ var jwkConverters = map[string]jwkToPEMConverter{
"EC": ecJWKToPEM,
}
// rsaJWKToPEM converts an RSA JWK to PEM
func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N)
nBytes, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
}
e, err := base64.RawURLEncoding.DecodeString(jwk.E)
eBytes, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()),
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes,
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return publicKeyPEM, nil
return pubKeyPEM, nil
}
// ecJWKToPEM converts an EC JWK to PEM
func ecJWKToPEM(jwk *JWK) ([]byte, error) {
x, err := base64.RawURLEncoding.DecodeString(jwk.X)
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'x' parameter: %w", err)
}
y, err := base64.RawURLEncoding.DecodeString(jwk.Y)
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
@@ -182,21 +164,21 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
}
publicKey := &ecdsa.PublicKey{
pubKey := &ecdsa.PublicKey{
Curve: curve,
X: new(big.Int).SetBytes(x),
Y: new(big.Int).SetBytes(y),
X: new(big.Int).SetBytes(xBytes),
Y: new(big.Int).SetBytes(yBytes),
}
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err)
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
}
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicKeyBytes,
Bytes: pubKeyBytes,
})
return publicKeyPEM, nil
return pubKeyPEM, nil
}
+79 -67
View File
@@ -4,29 +4,36 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha256"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"strings"
"time"
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
Claims map[string]interface{}
Signature string
Signature []byte
Token string
}
// parseJWT parses a JWT token string into a JWT struct
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
}
jwt := &JWT{}
jwt := &JWT{
Token: tokenString,
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
@@ -46,12 +53,17 @@ func parseJWT(tokenString string) (*JWT, error) {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Set the signature
jwt.Signature = parts[2]
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
return jwt, nil
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
@@ -95,6 +107,39 @@ func (j *JWT) Verify(issuerURL, clientID string) error {
return nil
}
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
}
}
if !found {
return fmt.Errorf("invalid audience")
}
default:
return fmt.Errorf("invalid 'aud' claim type")
}
return nil
}
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
}
return nil
}
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
@@ -103,7 +148,24 @@ func verifyExpiration(expiration float64) error {
return nil
}
func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte, alg string) error {
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
// verifySignature verifies the token signature
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
parts := strings.Split(tokenString, ".")
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
@@ -114,69 +176,19 @@ func verifySignature(signedContent string, signature []byte, publicKeyPEM []byte
return fmt.Errorf("failed to parse public key: %w", err)
}
var hash crypto.Hash
var verifyFunc func(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error
switch alg {
case "RS256", "RS384", "RS512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for RS384 and RS512 respectively.
verifyFunc = rsaVerifyPKCS1v15
case "PS256", "PS384", "PS512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for PS384 and PS512 respectively.
verifyFunc = rsaVerifyPSS
case "ES256", "ES384", "ES512":
hash = crypto.SHA256 // SHA384 and SHA512 are used for ES384 and ES512 respectively.
verifyFunc = ecdsaVerify
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
}
h := hash.New()
h := sha256.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
return verifyFunc(pubKey, hashed, signature, hash)
}
func rsaVerifyPKCS1v15(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for RSA: %T", publicKey)
}
return rsa.VerifyPKCS1v15(pubKey, hash, hashed, signature)
}
func rsaVerifyPSS(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for RSA: %T", publicKey)
}
opts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
return rsa.VerifyPSS(pubKey, crypto.SHA256, hashed, signature, opts)
}
func ecdsaVerify(publicKey interface{}, hashed []byte, signature []byte, hash crypto.Hash) error {
pubKey, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return fmt.Errorf("invalid public key type for ECDSA: %T", publicKey)
}
keyBytes := (pubKey.Params().BitSize + 7) / 8
if len(signature) != 2*keyBytes {
return fmt.Errorf("invalid signature length for ECDSA: expected %d bytes, got %d bytes", 2*keyBytes, len(signature))
}
r := new(big.Int).SetBytes(signature[:keyBytes])
s := new(big.Int).SetBytes(signature[keyBytes:])
if ecdsa.Verify(pubKey, hashed, r, s) {
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hashed, signature)
case *ecdsa.PublicKey:
if !ecdsa.VerifyASN1(pubKey, hashed, signature) {
return fmt.Errorf("invalid ECDSA signature")
}
return nil
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
return fmt.Errorf("invalid ECDSA signature")
}
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
}
return nil
}
+136 -112
View File
@@ -2,7 +2,6 @@ package traefikoidc
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
@@ -19,16 +18,19 @@ import (
"golang.org/x/time/rate"
)
const ConstSessionTimeout = 86400
const ConstSessionTimeout = 86400 // Session timeout in seconds
// TokenVerifier interface for token verification
type TokenVerifier interface {
VerifyToken(token string) error
}
// JWTVerifier interface for JWT verification
type JWTVerifier interface {
VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
}
// TraefikOidc is the main struct for the OIDC middleware
type TraefikOidc struct {
next http.Handler
name string
@@ -58,12 +60,13 @@ type TraefikOidc struct {
allowedUserDomains map[string]struct{}
allowedRolesAndGroups map[string]struct{}
initiateAuthenticationFunc func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string)
exchangeCodeForTokenFunc func(code string) (map[string]interface{}, error)
exchangeCodeForTokenFunc func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
initOnce sync.Once
initComplete chan struct{}
}
// ProviderMetadata holds OIDC provider metadata
type ProviderMetadata struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
@@ -72,36 +75,45 @@ type ProviderMetadata struct {
RevokeURL string `json:"revocation_endpoint"`
}
// defaultExcludedURLs are the paths that are excluded from authentication
var defaultExcludedURLs = map[string]struct{}{
"/favicon": {},
}
var newTicker = time.NewTicker
// VerifyToken verifies the provided JWT token
func (t *TraefikOidc) VerifyToken(token string) error {
t.logger.Debugf("Verifying token: %s", token)
t.logger.Debugf("Verifying token")
// Rate limiting
if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded")
}
// Check if token is blacklisted
if t.tokenBlacklist.IsBlacklisted(token) {
return fmt.Errorf("token is blacklisted")
}
// Check if token is cached
if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached
}
// Parse the JWT
jwt, err := parseJWT(token)
if err != nil {
return fmt.Errorf("failed to parse JWT: %w", err)
}
// Verify JWT signature and claims
if err := t.VerifyJWTSignatureAndClaims(jwt, token); err != nil {
return err
}
// Cache the token until it expires
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
now := time.Now()
duration := expirationTime.Sub(now)
@@ -110,26 +122,27 @@ func (t *TraefikOidc) VerifyToken(token string) error {
return nil
}
// VerifyJWTSignatureAndClaims verifies the JWT signature and standard claims
func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT. Header: %+v", jwt.Header)
t.logger.Debugf("Verifying JWT signature and claims")
// Get JWKS
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err)
}
// Retrieve key ID and algorithm from JWT header
kid, ok := jwt.Header["kid"].(string)
if !ok {
return fmt.Errorf("missing key ID in token header")
}
t.logger.Debugf("Token kid: %s", kid)
alg, ok := jwt.Header["alg"].(string)
if !ok {
return fmt.Errorf("missing algorithm in token header")
}
t.logger.Debugf("Token alg: %s", alg)
// Find the matching key in JWKS
var matchingKey *JWK
for _, key := range jwks.Keys {
if key.Kid == kid {
@@ -137,48 +150,35 @@ func (t *TraefikOidc) VerifyJWTSignatureAndClaims(jwt *JWT, token string) error
break
}
}
if matchingKey == nil {
return fmt.Errorf("no matching public key found for kid: %s", kid)
}
t.logger.Debugf("Matching key found. Type: %s, Algorithm: %s", matchingKey.Kty, matchingKey.Alg)
// Convert JWK to PEM format
publicKeyPEM, err := jwkToPEM(matchingKey)
if err != nil {
return fmt.Errorf("failed to convert JWK to PEM: %w", err)
}
t.logger.Debugf("Public key PEM generated. Length: %d", len(publicKeyPEM))
parts := strings.Split(token, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
}
signedContent := parts[0] + "." + parts[1]
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
if err := verifySignature(signedContent, signature, publicKeyPEM, alg); err != nil {
t.logger.Errorf("Signature verification failed: %v", err)
// Verify the signature
if err := verifySignature(token, publicKeyPEM, alg); err != nil {
return fmt.Errorf("signature verification failed: %w", err)
}
t.logger.Debug("Signature verified successfully")
// Verify standard claims
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return fmt.Errorf("standard claim verification failed: %w", err)
}
t.logger.Debug("Standard claims verified successfully")
return nil
}
// New creates a new instance of the OIDC middleware
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
store := sessions.NewCookieStore([]byte(config.SessionEncryptionKey))
store.Options = defaultSessionOptions
// Setup HTTP client
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -217,9 +217,8 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
}
return config.LogoutURL
}(),
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
tokenBlacklist: NewTokenBlacklist(),
jwkCache: &JWKCache{},
clientID: config.ClientID,
clientSecret: config.ClientSecret,
forceHTTPS: config.ForceHTTPS,
@@ -229,17 +228,18 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
httpClient: httpClient,
logger: NewLogger(config.LogLevel),
excludedURLs: createStringMap(config.ExcludedURLs),
redirectURL: "",
allowedUserDomains: createStringMap(config.AllowedUserDomains),
allowedRolesAndGroups: createStringMap(config.AllowedRolesAndGroups),
initComplete: make(chan struct{}),
}
t.initiateAuthenticationFunc = t.defaultInitiateAuthentication
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.extractClaimsFunc = extractClaims
t.exchangeCodeForTokenFunc = t.exchangeCodeForToken
t.initiateAuthenticationFunc = func(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
t.defaultInitiateAuthentication(rw, req, session, redirectURL)
}
// add defaultExcludedURLs to excludedURLs
// Add default excluded URLs
for k, v := range defaultExcludedURLs {
t.excludedURLs[k] = v
}
@@ -248,14 +248,16 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
t.jwtVerifier = t
t.startTokenCleanup()
go t.initializeMetadata(config.ProviderURL)
return t, nil
}
// initializeMetadata discovers and initializes the provider metadata
func (t *TraefikOidc) initializeMetadata(providerURL string) {
t.initOnce.Do(func() {
metadata, err := discoverProviderMetadata(providerURL, t.httpClient, t.logger)
if err != nil {
t.logger.Error("Failed to discover provider metadata: %v", err)
t.logger.Errorf("Failed to discover provider metadata: %v", err)
} else {
t.logger.Debug("Provider metadata discovered successfully")
t.jwksURL = metadata.JWKSURL
@@ -268,6 +270,7 @@ func (t *TraefikOidc) initializeMetadata(providerURL string) {
})
}
// discoverProviderMetadata fetches the OIDC provider metadata
func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Logger) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
@@ -281,7 +284,7 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if time.Since(start) > totalTimeout {
l.Error("Timeout exceeded while fetching provider metadata")
l.Errorf("Timeout exceeded while fetching provider metadata")
return nil, fmt.Errorf("timeout exceeded while fetching provider metadata: %w", lastErr)
}
@@ -293,18 +296,20 @@ func discoverProviderMetadata(providerURL string, httpClient *http.Client, l *Lo
lastErr = err
// Exponential backoff
delay := time.Duration(math.Pow(2, float64(attempt))) * baseDelay
if delay > maxDelay {
delay = maxDelay
}
l.Debug("Failed to fetch provider metadata, retrying in %s", delay)
l.Debugf("Failed to fetch provider metadata, retrying in %s", delay)
time.Sleep(delay)
}
l.Error("Max retries exceeded while fetching provider metadata")
l.Errorf("Max retries exceeded while fetching provider metadata")
return nil, fmt.Errorf("max retries exceeded while fetching provider metadata: %w", lastErr)
}
// fetchMetadata fetches metadata from the well-known OIDC configuration endpoint
func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetadata, error) {
resp, err := httpClient.Get(wellKnownURL)
if err != nil {
@@ -327,6 +332,7 @@ func fetchMetadata(wellKnownURL string, httpClient *http.Client) (*ProviderMetad
return &metadata, nil
}
// ServeHTTP is the main handler for the middleware
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
select {
case <-t.initComplete:
@@ -342,20 +348,24 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
// Check if the URL is excluded from authentication
if t.determineExcludedURL(req.URL.Path) {
t.next.ServeHTTP(rw, req)
return
}
// Determine the scheme (http/https) and host
t.scheme = t.determineScheme(req)
defaultSessionOptions.Secure = t.scheme == "https"
host := t.determineHost(req)
// Build the redirect URL if not already set
if t.redirectURL == "" {
t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
// Get the session
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
@@ -365,16 +375,19 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.logger.Debugf("Session contents at start: %+v", session.Values)
// Handle logout URL
if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req)
return
}
// Handle callback URL
if req.URL.Path == t.redirURLPath {
t.handleCallback(rw, req)
return
}
// Check if the user is authenticated
authenticated, needsRefresh, expired := t.isUserAuthenticated(session)
if expired {
@@ -395,84 +408,80 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
// authenticated, _ := session.Values["authenticated"].(bool)
if authenticated {
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
// At this point, the user is authenticated
idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" {
t.logger.Errorf("No id_token found in session")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
t.logger.Debug("User is not authenticated, initiating authentication")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
claims, err := extractClaims(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
email, _ := claims["email"].(string)
if email == "" {
t.logger.Debugf("No email found in token claims")
t.defaultInitiateAuthentication(rw, req, session, t.redirectURL)
return
}
if !t.isAllowedDomain(email) {
t.logger.Infof("User with email %s is not from an allowed domain", email)
http.Error(rw, fmt.Sprintf("Access denied: Your email domain is not allowed. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
// Set headers for groups and roles
if len(groups) > 0 {
req.Header.Set("X-User-Groups", strings.Join(groups, ","))
}
if len(roles) > 0 {
req.Header.Set("X-User-Roles", strings.Join(roles, ","))
}
}
if len(t.allowedRolesAndGroups) > 0 {
allowed := false
for _, roleOrGroup := range append(groups, roles...) {
if _, ok := t.allowedRolesAndGroups[roleOrGroup]; ok {
allowed = true
break
}
}
if !allowed {
t.logger.Infof("User with email %s does not have any allowed roles or groups", email)
http.Error(rw, fmt.Sprintf("Access denied: You do not have any allowed roles or groups. To log out, visit: %s", t.logoutURLPath), http.StatusForbidden)
return
}
}
req.Header.Set("X-Forwarded-User", email)
t.next.ServeHTTP(rw, req)
}
// determineExcludedURL checks if the current request URL is in the excluded list
func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
for excludedURL := range t.excludedURLs {
if strings.HasPrefix(currentRequest, excludedURL) {
t.logger.Debug("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
t.logger.Debugf("URL is excluded - got %s / excluded hit: %s", currentRequest, excludedURL)
return true
}
}
t.logger.Debug("URL is not excluded - got %s", currentRequest)
t.logger.Debugf("URL is not excluded - got %s", currentRequest)
return false
}
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
return "https"
@@ -486,6 +495,7 @@ func (t *TraefikOidc) determineScheme(req *http.Request) string {
return "http"
}
// determineHost determines the host of the request
func (t *TraefikOidc) determineHost(req *http.Request) string {
if host := req.Header.Get("X-Forwarded-Host"); host != "" {
return host
@@ -493,6 +503,7 @@ func (t *TraefikOidc) determineHost(req *http.Request) string {
return req.Host
}
// isUserAuthenticated checks if the user is authenticated
func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool, bool) {
authenticated, _ := session.Values["authenticated"].(bool)
t.logger.Debugf("Session authenticated value: %v", authenticated)
@@ -543,13 +554,16 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) (bool, bool
return true, false, false // Token is valid and not expiring soon
}
// defaultInitiateAuthentication initiates the authentication process
func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
// Generate CSRF token
csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path
session.Options = defaultSessionOptions
t.logger.Debugf("Setting CSRF token: %s", csrfToken)
// Generate nonce
nonce, err := generateNonce()
if err != nil {
http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
@@ -558,20 +572,24 @@ func (t *TraefikOidc) defaultInitiateAuthentication(rw http.ResponseWriter, req
session.Values["nonce"] = nonce
t.logger.Debugf("Setting nonce: %s", nonce)
// Save the session
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return
}
// Build the authentication URL
authURL := t.buildAuthURL(redirectURL, csrfToken, nonce)
http.Redirect(rw, req, authURL, http.StatusFound)
}
// verifyToken verifies the token using the token verifier
func (t *TraefikOidc) verifyToken(token string) error {
return t.tokenVerifier.VerifyToken(token)
}
// buildAuthURL constructs the authentication URL
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{}
params.Set("client_id", t.clientID)
@@ -585,6 +603,7 @@ func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
return t.authURL + "?" + params.Encode()
}
// startTokenCleanup starts the token cleanup goroutine
func (t *TraefikOidc) startTokenCleanup() {
ticker := newTicker(1 * time.Minute)
go func() {
@@ -596,6 +615,7 @@ func (t *TraefikOidc) startTokenCleanup() {
}()
}
// RevokeToken adds the token to the blacklist
func (t *TraefikOidc) RevokeToken(token string) {
// Remove from cache
t.tokenCache.Delete(token)
@@ -610,12 +630,13 @@ func (t *TraefikOidc) RevokeToken(token string) {
}
}
func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
// RevokeTokenWithProvider revokes the token with the provider
func (t *TraefikOidc) RevokeTokenWithProvider(token, tokenType string) error {
t.logger.Debugf("Revoking token with provider")
data := url.Values{
"token": {token},
"token_type_hint": {"access_token", "refresh_token"},
"token_type_hint": {tokenType},
"client_id": {t.clientID},
"client_secret": {t.clientSecret},
}
@@ -646,10 +667,12 @@ func (t *TraefikOidc) RevokeTokenWithProvider(token string) error {
return nil
}
// refreshToken refreshes the user's token
func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) bool {
t.logger.Debug("Refreshing token")
refreshToken, ok := session.Values["refresh_token"].(string)
if !ok || refreshToken == "" {
t.logger.Debug("No refresh token found in session")
return false
}
@@ -659,6 +682,13 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return false
}
// Verify the new id_token
if err := t.verifyToken(newToken.IDToken); err != nil {
t.logger.Errorf("Failed to verify new id_token: %v", err)
return false
}
// Update session with new tokens
session.Values["id_token"] = newToken.IDToken
session.Values["refresh_token"] = newToken.RefreshToken
session.Options = defaultSessionOptions
@@ -670,6 +700,7 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
return true
}
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
@@ -685,6 +716,7 @@ func (t *TraefikOidc) isAllowedDomain(email string) bool {
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
@@ -706,25 +738,17 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string,
}
}
if len(groups) == 0 {
t.logger.Debug("No groups found in groups claim, checking roles claim")
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debug("Found role: %s", roleStr)
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
if len(roles) == 0 {
t.logger.Debug("No roles found in roles claim")
}
return groups, roles, nil
}
+273 -36
View File
@@ -72,6 +72,7 @@ func (ts *TestSuite) Setup() {
"iat": time.Now().Unix(),
"sub": "test-subject",
"email": "user@example.com",
"nonce": "test-nonce",
})
if err != nil {
ts.t.Fatalf("Failed to create test JWT: %v", err)
@@ -79,33 +80,34 @@ func (ts *TestSuite) Setup() {
// Common TraefikOidc instance
ts.tOidc = &TraefikOidc{
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
exchangeCodeForTokenFunc: ts.exchangeCodeForTokenFunc,
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
issuerURL: "https://test-issuer.com",
clientID: "test-client-id",
clientSecret: "test-client-secret",
jwkCache: ts.mockJWKCache,
jwksURL: "https://test-jwks-url.com",
revocationURL: "https://revocation-endpoint.com",
limiter: rate.NewLimiter(rate.Every(time.Second), 10),
tokenBlacklist: NewTokenBlacklist(),
tokenCache: NewTokenCache(),
logger: NewLogger("info"),
store: sessions.NewCookieStore([]byte("test-secret-key")),
allowedUserDomains: map[string]struct{}{"example.com": {}},
excludedURLs: map[string]struct{}{"/favicon": {}},
httpClient: &http.Client{},
extractClaimsFunc: extractClaims,
initComplete: make(chan struct{}),
}
close(ts.tOidc.initComplete)
ts.tOidc.exchangeCodeForTokenFunc = ts.exchangeCodeForTokenFunc
ts.tOidc.tokenVerifier = ts.tOidc
ts.tOidc.jwtVerifier = ts.tOidc
}
// Helper functions used by TraefikOidc
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": ts.token,
func (ts *TestSuite) exchangeCodeForTokenFunc(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
}
@@ -453,61 +455,149 @@ func TestHandleCallback(t *testing.T) {
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (map[string]interface{}, error)
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
}{
{
name: "Success",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusFound,
},
{
name: "Missing Code",
queryParams: "",
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Exchange Code Error",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return nil, fmt.Errorf("exchange code error")
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing ID Token",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{}, nil
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Disallowed Email",
queryParams: "?code=test-code",
exchangeCodeForToken: func(code string) (map[string]interface{}, error) {
return map[string]interface{}{
"id_token": "test-id-token",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@disallowed.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusForbidden,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
return map[string]interface{}{
"email": "user@example.com",
// Missing nonce
}, nil
},
sessionSetupFunc: func(session *sessions.Session) {
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
expectedStatus: http.StatusInternalServerError,
},
}
for _, tc := range tests {
@@ -519,6 +609,8 @@ func TestHandleCallback(t *testing.T) {
logger: NewLogger("info"),
exchangeCodeForTokenFunc: tc.exchangeCodeForToken,
extractClaimsFunc: tc.extractClaimsFunc,
tokenVerifier: ts.tOidc.tokenVerifier,
jwtVerifier: ts.tOidc.jwtVerifier,
}
// Create request and response recorder
@@ -527,6 +619,9 @@ func TestHandleCallback(t *testing.T) {
// Create session
session, _ := tOidc.store.New(req, cookieName)
if tc.sessionSetupFunc != nil {
tc.sessionSetupFunc(session)
}
session.Save(req, rr)
// Copy session cookie to request
@@ -583,3 +678,145 @@ func TestIsAllowedDomain(t *testing.T) {
})
}
}
func TestOIDCHandler(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.token = "valid.jwt.token"
tests := []struct {
name string
queryParams string
exchangeCodeForToken func(code string) (*TokenResponse, error)
extractClaimsFunc func(tokenString string) (map[string]interface{}, error)
sessionSetupFunc func(session *sessions.Session)
expectedStatus int
blacklist bool
rateLimit bool
cacheToken bool
}{
{
name: "Missing Code",
queryParams: "",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with invalid nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusInternalServerError,
},
{
name: "Missing Nonce in Claims",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims without nonce
return map[string]interface{}{
"email": "user@example.com",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Invalid State Parameter",
queryParams: "?code=test-code&state=invalid-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims
return map[string]interface{}{
"email": "user@example.com",
"nonce": "test-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
{
name: "Nonce Mismatch",
queryParams: "?code=test-code&state=test-csrf-token",
sessionSetupFunc: func(session *sessions.Session) {
// Set CSRF and nonce values in session
session.Values["csrf"] = "test-csrf-token"
session.Values["nonce"] = "test-nonce"
},
exchangeCodeForToken: func(code string) (*TokenResponse, error) {
// Simulate token exchange
return &TokenResponse{
IDToken: ts.token,
RefreshToken: "test-refresh-token",
}, nil
},
extractClaimsFunc: func(tokenString string) (map[string]interface{}, error) {
// Simulate extraction of claims with mismatched nonce
return map[string]interface{}{
"email": "user@example.com",
"nonce": "invalid-nonce",
}, nil
},
expectedStatus: http.StatusBadRequest,
},
}
for _, tc := range tests {
tc := tc // Capture range variable
t.Run(tc.name, func(t *testing.T) {
// Reset token blacklist and cache
ts.tOidc.tokenBlacklist = NewTokenBlacklist()
ts.tOidc.tokenCache = NewTokenCache()
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Second), 10)
// Set up the test case
if tc.blacklist {
ts.tOidc.tokenBlacklist.Add(ts.token, time.Now().Add(1*time.Hour))
}
if tc.rateLimit {
// Exceed rate limit
ts.tOidc.limiter = rate.NewLimiter(rate.Every(time.Hour), 0)
}
if tc.cacheToken {
// Cache the token with dummy claims
ts.tOidc.tokenCache.Set(ts.token, map[string]interface{}{
"empty": "claim",
}, 60)
}
})
}
}
+12
View File
@@ -14,6 +14,7 @@ const (
cookieName = "_raczylo_oidc"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`
RevocationURL string `json:"revocationURL"`
@@ -40,6 +41,7 @@ var defaultSessionOptions = &sessions.Options{
Path: "/",
}
// CreateConfig creates a new Config with default values
func CreateConfig() *Config {
c := &Config{}
@@ -62,6 +64,7 @@ func CreateConfig() *Config {
return c
}
// Validate validates the Config
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
@@ -81,12 +84,14 @@ func (c *Config) Validate() error {
return nil
}
// Logger is a simple logger with different levels
type Logger struct {
logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
}
// NewLogger creates a new Logger
func NewLogger(logLevel string) *Logger {
logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
@@ -106,30 +111,37 @@ func NewLogger(logLevel string) *Logger {
}
}
// Info logs an info message
func (l *Logger) Info(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debug logs a debug message
func (l *Logger) Debug(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Error logs an error message
func (l *Logger) Error(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// Infof logs an info message
func (l *Logger) Infof(format string, args ...interface{}) {
l.logInfo.Printf(format, args...)
}
// Debugf logs a debug message
func (l *Logger) Debugf(format string, args ...interface{}) {
l.logDebug.Printf(format, args...)
}
// Errorf logs an error message
func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
// handleError writes an error message to the response and logs it
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Error(message)
http.Error(w, message, code)