Compare commits

..

2 Commits

Author SHA1 Message Date
lukaszraczylo ab36f10a70 Draft. 2024-10-13 18:21:13 +01:00
lukaszraczylo 4972b21373 Improve speed of the cache module. 2024-10-13 18:00:43 +01:00
7 changed files with 261 additions and 322 deletions
+14 -10
View File
@@ -8,7 +8,7 @@ import (
// CacheItem represents an item in the cache
type CacheItem struct {
Value interface{}
ExpiresAt time.Time
ExpiresAt int64 // Changed to int64 for faster comparisons
}
// Cache is a simple in-memory cache
@@ -27,43 +27,47 @@ func NewCache() *Cache {
// Set adds an item to the cache
func (c *Cache) Set(key string, value interface{}, expiration time.Duration) {
c.mutex.Lock()
defer c.mutex.Unlock()
// Removed defer for slightly better performance
c.items[key] = CacheItem{
Value: value,
ExpiresAt: time.Now().Add(expiration),
ExpiresAt: time.Now().Add(expiration).UnixNano(), // Store as UnixNano for faster comparisons
}
c.mutex.Unlock()
}
// Get retrieves an item from the cache
func (c *Cache) Get(key string) (interface{}, bool) {
c.mutex.RLock()
defer c.mutex.RUnlock()
item, found := c.items[key]
if !found {
c.mutex.RUnlock()
return nil, false
}
if time.Now().After(item.ExpiresAt) {
delete(c.items, key)
if time.Now().UnixNano() > item.ExpiresAt {
c.mutex.RUnlock()
// Use a separate goroutine to delete expired items to avoid blocking
go c.Delete(key)
return nil, false
}
c.mutex.RUnlock()
return item.Value, true
}
// Delete removes an item from the cache
func (c *Cache) Delete(key string) {
c.mutex.Lock()
defer c.mutex.Unlock()
delete(c.items, key)
c.mutex.Unlock()
}
// Cleanup removes expired items from the cache
func (c *Cache) Cleanup() {
c.mutex.Lock()
defer c.mutex.Unlock()
now := time.Now()
now := time.Now().UnixNano()
for key, item := range c.items {
if now.After(item.ExpiresAt) {
if now > item.ExpiresAt {
delete(c.items, key)
}
}
c.mutex.Unlock()
}
+48 -97
View File
@@ -20,8 +20,7 @@ import (
// generateNonce generates a random nonce
func generateNonce() (string, error) {
nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes)
if err != nil {
if _, err := rand.Read(nonceBytes); err != nil {
return "", fmt.Errorf("could not generate nonce: %w", err)
}
return base64.URLEncoding.EncodeToString(nonceBytes), nil
@@ -43,14 +42,15 @@ func (t *TraefikOidc) exchangeTokens(ctx context.Context, grantType, codeOrToken
"client_secret": {t.clientSecret},
}
if grantType == "authorization_code" {
switch grantType {
case "authorization_code":
data.Set("code", codeOrToken)
data.Set("redirect_uri", redirectURL)
} else if grantType == "refresh_token" {
case "refresh_token":
data.Set("refresh_token", codeOrToken)
}
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
@@ -107,43 +107,40 @@ func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
}
// Revoke tokens if available
if refreshToken, ok := session.Values["refresh_token"].(string); ok && refreshToken != "" {
if err := t.RevokeTokenWithProvider(refreshToken, "refresh_token"); err != nil {
t.logger.Errorf("Failed to revoke refresh token: %v", err)
for _, tokenType := range []string{"refresh_token", "access_token"} {
if token, ok := session.Values[tokenType].(string); ok && token != "" {
if err := t.RevokeTokenWithProvider(token, tokenType); err != nil {
t.logger.Errorf("Failed to revoke %s: %v", tokenType, err)
}
t.RevokeToken(token)
}
t.RevokeToken(refreshToken)
}
if accessToken, ok := session.Values["access_token"].(string); ok && accessToken != "" {
if err := t.RevokeTokenWithProvider(accessToken, "access_token"); err != nil {
t.logger.Errorf("Failed to revoke access token: %v", err)
}
t.RevokeToken(accessToken)
delete(session.Values, tokenType)
}
// Remove tokens from session
// Remove other session values
delete(session.Values, "id_token")
delete(session.Values, "refresh_token")
delete(session.Values, "access_token")
delete(session.Values, "authenticated")
// Set session options to delete the session
session.Options = defaultSessionOptions
session.Options.MaxAge = -1
session.Options = &sessions.Options{MaxAge: -1, Path: "/", HttpOnly: true, Secure: true}
if err := session.Save(req, rw); err != nil {
handleError(rw, "Failed to save session", http.StatusInternalServerError, t.logger)
return
}
// Redirect or display logout message
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("Logged out successfully"))
}
// handleExpiredToken handles the case when a token has expired
func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Request, session *sessions.Session) {
if session == nil {
t.logger.Error("Session is nil in handleExpiredToken")
http.Error(rw, "Internal server error", http.StatusInternalServerError)
return
}
// Clear the existing session
session.Options.MaxAge = -1
for k := range session.Values {
delete(session.Values, k)
}
@@ -152,16 +149,14 @@ func (t *TraefikOidc) handleExpiredToken(rw http.ResponseWriter, req *http.Reque
session.Values["csrf"] = uuid.New().String()
session.Values["incoming_path"] = req.URL.Path
session.Values["nonce"], _ = generateNonce()
session.Options = defaultSessionOptions
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
// Save the session before initiating authentication
if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err)
http.Error(rw, "Internal Server Error", http.StatusInternalServerError)
return
}
// Initiate a new authentication flow
t.initiateAuthenticationFunc(rw, req, session, t.redirectURL)
}
@@ -176,34 +171,21 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
// Check for errors in the query parameters
if req.URL.Query().Get("error") != "" {
if errParam := req.URL.Query().Get("error"); errParam != "" {
errorDescription := req.URL.Query().Get("error_description")
t.logger.Errorf("Authentication error: %s - %s", req.URL.Query().Get("error"), errorDescription)
t.logger.Errorf("Authentication error: %s - %s", errParam, errorDescription)
http.Error(rw, fmt.Sprintf("Authentication error: %s", errorDescription), http.StatusBadRequest)
return
}
// Validate the state parameter matches the session's CSRF token
state := req.URL.Query().Get("state")
if state == "" {
t.logger.Error("No state in callback")
http.Error(rw, "State parameter missing in callback", http.StatusBadRequest)
return
}
csrfToken, ok := session.Values["csrf"].(string)
if !ok || csrfToken == "" {
t.logger.Error("CSRF token missing in session")
http.Error(rw, "CSRF token missing", http.StatusBadRequest)
return
}
if state != csrfToken {
t.logger.Error("State parameter does not match CSRF token in session")
if !ok || state == "" || csrfToken == "" || state != csrfToken {
t.logger.Error("Invalid state parameter or CSRF token")
http.Error(rw, "Invalid state parameter", http.StatusBadRequest)
return
}
// Proceed to exchange the code for tokens
code := req.URL.Query().Get("code")
if code == "" {
t.logger.Error("No code in callback")
@@ -218,7 +200,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Extract id_token
idToken := tokenResponse.IDToken
if idToken == "" {
t.logger.Error("No id_token in token response")
@@ -226,14 +207,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Verify the id_token
if err := t.verifyToken(idToken); err != nil {
t.logger.Errorf("Failed to verify id_token: %v", err)
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Extract claims from id_token
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
t.logger.Errorf("Failed to extract claims: %v", err)
@@ -241,26 +220,14 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Verify the nonce claim matches the one stored in session
nonceClaim, ok := claims["nonce"].(string)
if !ok || nonceClaim == "" {
t.logger.Error("Nonce claim missing in id_token")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
sessionNonce, ok := session.Values["nonce"].(string)
if !ok || sessionNonce == "" {
t.logger.Error("Nonce not found in session")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
if nonceClaim != sessionNonce {
t.logger.Error("Nonce claim does not match session nonce")
sessionNonce, ok2 := session.Values["nonce"].(string)
if !ok || !ok2 || nonceClaim == "" || sessionNonce == "" || nonceClaim != sessionNonce {
t.logger.Error("Invalid nonce")
http.Error(rw, "Authentication failed", http.StatusInternalServerError)
return
}
// Get the email from claims
email, _ := claims["email"].(string)
if email == "" || !t.isAllowedDomain(email) {
t.logger.Errorf("Invalid or disallowed email: %s", email)
@@ -268,14 +235,12 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
return
}
// Store tokens and authentication status in session
session.Values["authenticated"] = true
session.Values["email"] = email
session.Values["id_token"] = idToken
session.Values["refresh_token"] = tokenResponse.RefreshToken
session.Options = defaultSessionOptions
session.Options = &sessions.Options{MaxAge: 3600, Path: "/", HttpOnly: true, Secure: true}
// Remove CSRF and nonce from session
delete(session.Values, "csrf")
delete(session.Values, "nonce")
@@ -287,7 +252,6 @@ func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request)
t.logger.Debugf("Authentication successful. User email: %s", email)
// Redirect to the original requested path or default to root
redirectPath := "/"
if path, ok := session.Values["incoming_path"].(string); ok && path != t.redirURLPath {
t.logger.Debugf("Redirecting to incoming path from original request: %s", path)
@@ -318,42 +282,32 @@ func extractClaims(tokenString string) (map[string]interface{}, error) {
// TokenBlacklist maintains a blacklist of tokens
type TokenBlacklist struct {
blacklist map[string]time.Time
mutex sync.RWMutex
blacklist sync.Map
}
// NewTokenBlacklist creates a new TokenBlacklist
func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{
blacklist: make(map[string]time.Time),
return &TokenBlacklist{}
}
func (tb *TokenBlacklist) Add(token string, expiration time.Time) {
tb.blacklist.Store(token, expiration)
}
func (tb *TokenBlacklist) IsBlacklisted(token string) bool {
if exp, ok := tb.blacklist.Load(token); ok {
return time.Now().Before(exp.(time.Time))
}
return false
}
// Add adds a token to the blacklist
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock()
defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration
}
// IsBlacklisted checks if a token is blacklisted
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock()
defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration)
}
// Cleanup removes expired tokens from the blacklist
func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock()
defer tb.mutex.Unlock()
now := time.Now()
for tokenID, expiration := range tb.blacklist {
if now.After(expiration) {
delete(tb.blacklist, tokenID)
tb.blacklist.Range(func(key, value interface{}) bool {
if now.After(value.(time.Time)) {
tb.blacklist.Delete(key)
}
}
return true
})
}
// TokenCache caches tokens
@@ -370,14 +324,12 @@ func NewTokenCache() *TokenCache {
// Set sets a token in the cache
func (tc *TokenCache) Set(token string, claims map[string]interface{}, expiration time.Duration) {
token = "t-" + token
tc.cache.Set(token, claims, expiration)
tc.cache.Set("t-"+token, claims, expiration)
}
// Get retrieves a token from the cache
func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
token = "t-" + token
value, found := tc.cache.Get(token)
value, found := tc.cache.Get("t-" + token)
if !found {
return nil, false
}
@@ -387,8 +339,7 @@ func (tc *TokenCache) Get(token string) (map[string]interface{}, bool) {
// Delete removes a token from the cache
func (tc *TokenCache) Delete(token string) {
token = "t-" + token
tc.cache.Delete(token)
tc.cache.Delete("t-" + token)
}
// Cleanup cleans up expired tokens from the cache
@@ -408,7 +359,7 @@ func (t *TraefikOidc) exchangeCodeForToken(code string) (*TokenResponse, error)
// createStringMap creates a map from a slice of strings
func createStringMap(keys []string) map[string]struct{} {
result := make(map[string]struct{})
result := make(map[string]struct{}, len(keys))
for _, key := range keys {
result[key] = struct{}{}
}
+30 -33
View File
@@ -4,13 +4,12 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"math/big"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
"net/http"
"sync"
"time"
@@ -58,6 +57,7 @@ func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, er
c.mutex.Lock()
defer c.mutex.Unlock()
// Double-check locking pattern
if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil
}
@@ -120,25 +120,12 @@ func rsaJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
}
n := new(big.Int).SetBytes(nBytes)
e := new(big.Int).SetBytes(eBytes)
pubKey := &rsa.PublicKey{
N: n,
E: int(e.Int64()),
N: new(big.Int).SetBytes(nBytes),
E: int(new(big.Int).SetBytes(eBytes).Int64()),
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal RSA public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return pubKeyPEM, nil
return marshalPublicKey(pubKey)
}
// ecJWKToPEM converts an EC JWK to PEM
@@ -152,16 +139,9 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
return nil, fmt.Errorf("failed to decode JWK 'y' parameter: %w", err)
}
var curve elliptic.Curve
switch jwk.Crv {
case "P-256":
curve = elliptic.P256()
case "P-384":
curve = elliptic.P384()
case "P-521":
curve = elliptic.P521()
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", jwk.Crv)
curve, err := getCurve(jwk.Crv)
if err != nil {
return nil, err
}
pubKey := &ecdsa.PublicKey{
@@ -170,15 +150,32 @@ func ecJWKToPEM(jwk *JWK) ([]byte, error) {
Y: new(big.Int).SetBytes(yBytes),
}
return marshalPublicKey(pubKey)
}
// getCurve returns the elliptic curve based on the JWK curve parameter
func getCurve(crv string) (elliptic.Curve, error) {
switch crv {
case "P-256":
return elliptic.P256(), nil
case "P-384":
return elliptic.P384(), nil
case "P-521":
return elliptic.P521(), nil
default:
return nil, fmt.Errorf("unsupported elliptic curve: %s", crv)
}
}
// marshalPublicKey marshals a public key to PEM format
func marshalPublicKey(pubKey interface{}) ([]byte, error) {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(pubKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal EC public key: %w", err)
return nil, fmt.Errorf("failed to marshal public key: %w", err)
}
pubKeyPEM := pem.EncodeToMemory(&pem.Block{
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: pubKeyBytes,
})
return pubKeyPEM, nil
}), nil
}
+119 -133
View File
@@ -4,18 +4,29 @@ import (
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"math/big"
"strings"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math/big"
"strings"
"time"
)
var (
ErrInvalidJWTFormat = errors.New("invalid JWT format")
ErrInvalidAudience = errors.New("invalid audience")
ErrInvalidIssuer = errors.New("invalid issuer")
ErrTokenExpired = errors.New("token has expired")
ErrTokenUsedBeforeIssued = errors.New("token used before issued")
ErrMissingClaim = errors.New("missing claim")
ErrInvalidClaimType = errors.New("invalid claim type")
ErrUnsupportedAlgorithm = errors.New("unsupported algorithm")
ErrInvalidSignature = errors.New("invalid signature")
)
// JWT represents a JSON Web Token
type JWT struct {
Header map[string]interface{}
@@ -28,212 +39,187 @@ type JWT struct {
func parseJWT(tokenString string) (*JWT, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT format: expected 3 parts, got %d", len(parts))
return nil, fmt.Errorf("%w: expected 3 parts, got %d", ErrInvalidJWTFormat, len(parts))
}
jwt := &JWT{
Token: tokenString,
jwt := &JWT{Token: tokenString}
if err := decodeJSONPart(parts[0], &jwt.Header); err != nil {
return nil, fmt.Errorf("failed to decode header: %w", err)
}
// Decode and unmarshal the header
headerBytes, err := base64.RawURLEncoding.DecodeString(parts[0])
if err := decodeJSONPart(parts[1], &jwt.Claims); err != nil {
return nil, fmt.Errorf("failed to decode claims: %w", err)
}
var err error
jwt.Signature, err = base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode header: %v", err)
return nil, fmt.Errorf("failed to decode signature: %w", err)
}
if err := json.Unmarshal(headerBytes, &jwt.Header); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal header: %v", err)
}
// Decode and unmarshal the claims
claimsBytes, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode claims: %v", err)
}
if err := json.Unmarshal(claimsBytes, &jwt.Claims); err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to unmarshal claims: %v", err)
}
// Decode the signature
signatureBytes, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, fmt.Errorf("invalid JWT format: failed to decode signature: %v", err)
}
jwt.Signature = signatureBytes
return jwt, nil
}
func decodeJSONPart(part string, target interface{}) error {
bytes, err := base64.RawURLEncoding.DecodeString(part)
if err != nil {
return err
}
return json.Unmarshal(bytes, target)
}
// Verify verifies the standard claims in the JWT
func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing 'iss' claim")
}
if err := verifyIssuer(iss, issuerURL); err != nil {
if err := verifyIssuer(j.Claims["iss"], issuerURL); err != nil {
return err
}
aud, ok := claims["aud"]
if !ok {
return fmt.Errorf("missing 'aud' claim")
}
if err := verifyAudience(aud, clientID); err != nil {
if err := verifyAudience(j.Claims["aud"], clientID); err != nil {
return err
}
exp, ok := claims["exp"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'exp' claim")
}
if err := verifyExpiration(exp); err != nil {
if err := verifyExpiration(j.Claims["exp"]); err != nil {
return err
}
iat, ok := claims["iat"].(float64)
if !ok {
return fmt.Errorf("missing or invalid 'iat' claim")
}
if err := verifyIssuedAt(iat); err != nil {
if err := verifyIssuedAt(j.Claims["iat"]); err != nil {
return err
}
sub, ok := claims["sub"].(string)
if !ok || sub == "" {
return fmt.Errorf("missing or empty 'sub' claim")
if sub, ok := j.Claims["sub"].(string); !ok || sub == "" {
return fmt.Errorf("%w: sub", ErrMissingClaim)
}
return nil
}
// verifyAudience verifies the audience claim
func verifyAudience(tokenAudience interface{}, expectedAudience string) error {
switch aud := tokenAudience.(type) {
case string:
if aud != expectedAudience {
return fmt.Errorf("invalid audience")
return ErrInvalidAudience
}
case []interface{}:
found := false
for _, v := range aud {
if str, ok := v.(string); ok && str == expectedAudience {
found = true
break
return nil
}
}
if !found {
return fmt.Errorf("invalid audience")
}
return ErrInvalidAudience
default:
return fmt.Errorf("invalid 'aud' claim type")
return fmt.Errorf("%w: aud", ErrInvalidClaimType)
}
return nil
}
// verifyIssuer verifies the issuer claim
func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer")
func verifyIssuer(tokenIssuer interface{}, expectedIssuer string) error {
iss, ok := tokenIssuer.(string)
if !ok {
return fmt.Errorf("%w: iss", ErrMissingClaim)
}
if iss != expectedIssuer {
return ErrInvalidIssuer
}
return nil
}
// verifyExpiration checks if the token has expired
func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) {
return fmt.Errorf("token has expired")
func verifyExpiration(expiration interface{}) error {
exp, ok := expiration.(float64)
if !ok {
return fmt.Errorf("%w: exp", ErrInvalidClaimType)
}
if time.Now().After(time.Unix(int64(exp), 0)) {
return ErrTokenExpired
}
return nil
}
// verifyIssuedAt checks if the token was issued in the future
func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) {
return fmt.Errorf("token used before issued")
func verifyIssuedAt(issuedAt interface{}) error {
iat, ok := issuedAt.(float64)
if !ok {
return fmt.Errorf("%w: iat", ErrInvalidClaimType)
}
if time.Now().Before(time.Unix(int64(iat), 0)) {
return ErrTokenUsedBeforeIssued
}
return nil
}
// verifySignature verifies the token signature using the provided public key and algorithm
func verifySignature(tokenString string, publicKeyPEM []byte, alg string) error {
// Split the token into its three parts
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return fmt.Errorf("invalid token format")
return ErrInvalidJWTFormat
}
signedContent := parts[0] + "." + parts[1]
// Decode the signature from the token
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
// Decode the PEM-encoded public key
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return fmt.Errorf("failed to parse PEM block containing the public key")
}
// Parse the public key
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
pubKey, err := parsePublicKey(publicKeyPEM)
if err != nil {
return fmt.Errorf("failed to parse public key: %w", err)
return err
}
// Determine the hash function to use based on the algorithm
var hashFunc crypto.Hash
switch alg {
case "RS256", "PS256", "ES256":
hashFunc = crypto.SHA256
case "RS384", "PS384", "ES384":
hashFunc = crypto.SHA384
case "RS512", "PS512", "ES512":
hashFunc = crypto.SHA512
default:
return fmt.Errorf("unsupported algorithm: %s", alg)
hashFunc, err := getHashFunc(alg)
if err != nil {
return err
}
// Hash the signed content
h := hashFunc.New()
h.Write([]byte(signedContent))
hashed := h.Sum(nil)
hashed := hashFunc.New().Sum([]byte(signedContent))
// Verify the signature based on the key type and algorithm
switch pubKey := pubKey.(type) {
case *rsa.PublicKey:
if strings.HasPrefix(alg, "RS") {
// RSA PKCS#1 v1.5 signature
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
// RSA PSS signature
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
return verifyRSASignature(pubKey, hashFunc, hashed, signature, alg)
case *ecdsa.PublicKey:
if strings.HasPrefix(alg, "ES") {
// ECDSA signature
var r, s big.Int
sigLen := len(signature)
if sigLen%2 != 0 {
return fmt.Errorf("invalid ECDSA signature length")
}
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, &r, &s) {
return nil
} else {
return fmt.Errorf("invalid ECDSA signature")
}
} else {
return fmt.Errorf("unexpected key type for algorithm %s", alg)
}
return verifyECDSASignature(pubKey, hashed, signature)
default:
return fmt.Errorf("unsupported public key type: %T", pubKey)
}
}
func parsePublicKey(publicKeyPEM []byte) (interface{}, error) {
block, _ := pem.Decode(publicKeyPEM)
if block == nil {
return nil, errors.New("failed to parse PEM block containing the public key")
}
return x509.ParsePKIXPublicKey(block.Bytes)
}
func getHashFunc(alg string) (crypto.Hash, error) {
switch alg {
case "RS256", "PS256", "ES256":
return crypto.SHA256, nil
case "RS384", "PS384", "ES384":
return crypto.SHA384, nil
case "RS512", "PS512", "ES512":
return crypto.SHA512, nil
default:
return 0, fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
}
func verifyRSASignature(pubKey *rsa.PublicKey, hashFunc crypto.Hash, hashed, signature []byte, alg string) error {
if strings.HasPrefix(alg, "RS") {
return rsa.VerifyPKCS1v15(pubKey, hashFunc, hashed, signature)
} else if strings.HasPrefix(alg, "PS") {
return rsa.VerifyPSS(pubKey, hashFunc, hashed, signature, nil)
}
return fmt.Errorf("%w: %s", ErrUnsupportedAlgorithm, alg)
}
func verifyECDSASignature(pubKey *ecdsa.PublicKey, hashed, signature []byte) error {
sigLen := len(signature)
if sigLen%2 != 0 {
return errors.New("invalid ECDSA signature length")
}
r, s := new(big.Int), new(big.Int)
r.SetBytes(signature[:sigLen/2])
s.SetBytes(signature[sigLen/2:])
if ecdsa.Verify(pubKey, hashed, r, s) {
return nil
}
return ErrInvalidSignature
}
+30 -49
View File
@@ -436,7 +436,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}
groups, roles, err := t.extractGroupsAndRoles(idToken)
groups, roles := t.extractGroupsAndRoles(claims)
if err != nil {
t.logger.Errorf("Failed to extract groups and roles: %v", err)
} else {
@@ -483,16 +483,16 @@ func (t *TraefikOidc) determineExcludedURL(currentRequest string) bool {
// determineScheme determines the scheme (http or https) of the request
func (t *TraefikOidc) determineScheme(req *http.Request) string {
if t.forceHTTPS {
switch {
case t.forceHTTPS:
return "https"
}
if scheme := req.Header.Get("X-Forwarded-Proto"); scheme != "" {
return scheme
}
if req.TLS != nil {
case req.Header.Get(headerXForwardedProto) != "":
return req.Header.Get(headerXForwardedProto)
case req.TLS != nil:
return "https"
default:
return "http"
}
return "http"
}
// determineHost determines the host of the request
@@ -703,52 +703,33 @@ func (t *TraefikOidc) refreshToken(rw http.ResponseWriter, req *http.Request, se
// isAllowedDomain checks if the user's email domain is allowed
func (t *TraefikOidc) isAllowedDomain(email string) bool {
if len(t.allowedUserDomains) == 0 {
return true // If no domains are specified, all are allowed
return true
}
parts := strings.Split(email, "@")
if len(parts) != 2 {
return false // Invalid email format
atIndex := strings.LastIndex(email, "@")
if atIndex == -1 {
return false
}
domain := parts[1]
domain := email[atIndex+1:]
_, ok := t.allowedUserDomains[domain]
return ok
}
// extractGroupsAndRoles extracts groups and roles from the id_token
func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) {
claims, err := t.extractClaimsFunc(idToken)
if err != nil {
return nil, nil, fmt.Errorf("failed to extract claims: %w", err)
}
var groups []string
var roles []string
// Check for groups claim
if groupsClaim, ok := claims["groups"]; ok {
if groupsSlice, ok := groupsClaim.([]interface{}); ok {
for _, group := range groupsSlice {
if groupStr, ok := group.(string); ok {
t.logger.Debugf("Found group: %s", groupStr)
groups = append(groups, groupStr)
}
}
}
}
// Check for roles claim
if rolesClaim, ok := claims["roles"]; ok {
if rolesSlice, ok := rolesClaim.([]interface{}); ok {
for _, role := range rolesSlice {
if roleStr, ok := role.(string); ok {
t.logger.Debugf("Found role: %s", roleStr)
roles = append(roles, roleStr)
}
}
}
}
return groups, roles, nil
func (t *TraefikOidc) extractGroupsAndRoles(claims map[string]interface{}) ([]string, []string) {
groups := extractStringSlice(claims, "groups")
roles := extractStringSlice(claims, "roles")
return groups, roles
}
func extractStringSlice(claims map[string]interface{}, key string) []string {
if slice, ok := claims[key].([]interface{}); ok {
result := make([]string, 0, len(slice))
for _, item := range slice {
if str, ok := item.(string); ok {
result = append(result, str)
}
}
return result
}
return nil
}
+12
View File
@@ -167,6 +167,18 @@ func TestVerifyToken(t *testing.T) {
ts := &TestSuite{t: t}
ts.Setup()
ts.mockJWKCache.JWKS = &JWKSet{
Keys: []JWK{
{
Kty: "RSA",
Kid: "test-key-id",
Alg: "RS256",
N: base64.RawURLEncoding.EncodeToString(ts.rsaPublicKey.N.Bytes()),
E: base64.RawURLEncoding.EncodeToString(bigIntToBytes(big.NewInt(int64(ts.rsaPublicKey.E)))),
},
},
}
tests := []struct {
name string
token string
+8
View File
@@ -14,6 +14,14 @@ const (
cookieName = "_raczylo_oidc"
)
const (
headerXForwardedProto = "X-Forwarded-Proto"
headerXForwardedHost = "X-Forwarded-Host"
headerXForwardedUser = "X-Forwarded-User"
headerXUserGroups = "X-User-Groups"
headerXUserRoles = "X-User-Roles"
)
// Config holds the configuration for the OIDC middleware
type Config struct {
ProviderURL string `json:"providerURL"`