Refactor codebase for clarity and consistency.

This commit is contained in:
2024-07-24 23:46:27 +01:00
parent 6de1ccbd17
commit 88c566ee9a
5 changed files with 456 additions and 481 deletions
+141 -191
View File
@@ -5,7 +5,6 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -15,256 +14,207 @@ import (
) )
func generateNonce() (string, error) { func generateNonce() (string, error) {
nonceBytes := make([]byte, 32) nonceBytes := make([]byte, 32)
_, err := rand.Read(nonceBytes) _, err := rand.Read(nonceBytes)
if err != nil { if err != nil {
return "", fmt.Errorf("could not generate nonce") return "", fmt.Errorf("could not generate nonce: %w", err)
} }
return base64.URLEncoding.EncodeToString(nonceBytes), nil return base64.URLEncoding.EncodeToString(nonceBytes), nil
} }
func assembleRedirectURL(scheme, host, path string) string { func buildFullURL(scheme, host, path string) string {
if scheme == "" { if scheme == "" {
// infoLogger.Println("Scheme is empty, defaulting to http") scheme = "http"
scheme = "http" }
} return fmt.Sprintf("%s://%s%s", scheme, host, path)
return scheme + "://" + host + path
} }
func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code string, redirectURL string) (map[string]interface{}, error) { func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectURL string) (map[string]interface{}, error) {
data := url.Values{} data := url.Values{
data.Set("grant_type", "authorization_code") "grant_type": {"authorization_code"},
data.Set("code", code) "code": {code},
data.Set("client_id", t.clientID) "client_id": {t.clientID},
data.Set("client_secret", t.clientSecret) "client_secret": {t.clientSecret},
data.Set("redirect_uri", redirectURL) // Use the full redirect URL "redirect_uri": {redirectURL},
}
// infoLogger.Printf("Exchanging code for token with redirect_uri: %s", redirectURL) req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req, err := http.NewRequestWithContext(ctx, "POST", t.tokenURL, strings.NewReader(data.Encode())) resp, err := t.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to exchange code for token: %w", err)
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") defer resp.Body.Close()
resp, err := http.DefaultClient.Do(req) var result map[string]interface{}
if err != nil { if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err return nil, fmt.Errorf("failed to decode token response: %w", err)
} }
defer resp.Body.Close()
var result map[string]interface{} return result, nil
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
// infoLogger.Printf("Token response: %+v", result)
return result, nil
} }
func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) { func (t *TraefikOidc) handleCallback(rw http.ResponseWriter, req *http.Request) (bool, string) {
ctx := req.Context() ctx := req.Context()
session, err := t.store.Get(req, cookie_name) session, err := t.store.Get(req, cookieName)
if err != nil { if err != nil {
// infoLogger.Printf("Error getting session: %v", err) handleError(rw, "Session error", http.StatusInternalServerError)
http.Error(rw, "Session error", http.StatusInternalServerError) return false, ""
return false, "" }
}
// infoLogger.Printf("Session values: %+v", session.Values) callbackState := req.URL.Query().Get("state")
sessionState, ok := session.Values["csrf"].(string)
if !ok || callbackState != sessionState {
handleError(rw, "Invalid state parameter", http.StatusBadRequest)
return false, ""
}
callbackState := req.URL.Query().Get("state") code := req.URL.Query().Get("code")
sessionState, ok := session.Values["csrf"].(string) redirectURL := buildFullURL(t.scheme, req.Host, t.redirURLPath)
// infoLogger.Printf("Callback state: %s, Session state: %s, Match: %v", callbackState, sessionState, ok && callbackState == sessionState) oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL)
if err != nil {
handleError(rw, "Failed to exchange token", http.StatusInternalServerError)
return false, ""
}
if !ok || callbackState != sessionState { rawIDToken, ok := oauth2Token["id_token"].(string)
// infoLogger.Printf("Invalid state parameter: callback=%s, session=%s", callbackState, sessionState) if !ok {
http.Error(rw, "Invalid state parameter", http.StatusBadRequest) handleError(rw, "No id_token field in oauth2 token", http.StatusInternalServerError)
return false, "" return false, ""
} }
code := req.URL.Query().Get("code") if err := t.verifyToken(rawIDToken); err != nil {
redirectURL := assembleRedirectURL(t.scheme, req.Host, t.redirURLPath) handleError(rw, "Failed to verify token", http.StatusUnauthorized)
oauth2Token, err := t.exchangeCodeForToken(ctx, code, redirectURL) return false, ""
if err != nil { }
// infoLogger.Printf("Failed to exchange token: %v", err)
http.Error(rw, "Failed to exchange token", http.StatusInternalServerError)
return false, ""
}
rawIDToken, ok := oauth2Token["id_token"].(string) claims, err := extractClaims(rawIDToken)
if !ok { if err != nil {
// infoLogger.Printf("No id_token field in oauth2 token") handleError(rw, "Failed to extract claims", http.StatusInternalServerError)
http.Error(rw, "No id_token field in oauth2 token", http.StatusInternalServerError) return false, ""
return false, "" }
}
if err := t.verifyToken(rawIDToken); err != nil { email, _ := claims["email"].(string)
// infoLogger.Printf("Token verification failed: %v", err)
http.Error(rw, "Failed to verify token", http.StatusUnauthorized)
return false, ""
}
// infoLogger.Printf("Token verification successful")
claims, err := extractClaims(rawIDToken) session.Values["authenticated"] = true
if err != nil { session.Values["id_token"] = rawIDToken
// infoLogger.Printf("Failed to extract claims: %v", err) session.Values["email"] = email
http.Error(rw, "Failed to extract claims", http.StatusInternalServerError) if err := session.Save(req, rw); err != nil {
return false, "" handleError(rw, "Failed to save session", http.StatusInternalServerError)
} return false, ""
}
email, _ := claims["email"].(string) originalPath, ok := session.Values["incoming_path"].(string)
if !ok {
originalPath = "/"
}
delete(session.Values, "incoming_path")
session.Values["authenticated"] = true return true, originalPath
session.Values["id_token"] = rawIDToken
session.Values["email"] = email
if err := session.Save(req, rw); err != nil {
// infoLogger.Printf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session", http.StatusInternalServerError)
return false, ""
}
// infoLogger.Printf("User %s authenticated\n", email)
originalPath, ok := session.Values["incoming_path"].(string)
if !ok {
originalPath = "/"
}
delete(session.Values, "incoming_path")
return true, originalPath
} }
func extractClaims(tokenString string) (map[string]interface{}, error) { func extractClaims(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".") parts := strings.Split(tokenString, ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, errors.New("invalid token format") return nil, fmt.Errorf("invalid token format")
} }
payload, err := base64.RawURLEncoding.DecodeString(parts[1]) payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode token payload: %w", err)
} }
var claims map[string]interface{} var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil { if err := json.Unmarshal(payload, &claims); err != nil {
return nil, err return nil, fmt.Errorf("failed to unmarshal claims: %w", err)
} }
return claims, nil return claims, nil
}
func verifyToken(token string, publicKey []byte) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, errors.New("invalid token format")
}
payloadJson, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
var claims map[string]interface{}
err = json.Unmarshal(payloadJson, &claims)
if err != nil {
return nil, err
}
if exp, ok := claims["exp"].(float64); ok {
if time.Now().Unix() > int64(exp) {
return nil, errors.New("token expired")
}
}
// Placeholder for signature verification
// err = verifySignature(parts[0]+"."+parts[1], parts[2], publicKey)
// if err != nil {
// return nil, err
// }
return claims, nil
} }
type UsedTokens struct { type UsedTokens struct {
tokens map[string]bool tokens map[string]bool
mutex sync.RWMutex mutex sync.RWMutex
} }
type TokenBlacklist struct { type TokenBlacklist struct {
blacklist map[string]time.Time blacklist map[string]time.Time
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewTokenBlacklist() *TokenBlacklist { func NewTokenBlacklist() *TokenBlacklist {
return &TokenBlacklist{ return &TokenBlacklist{
blacklist: make(map[string]time.Time), blacklist: make(map[string]time.Time),
} }
} }
func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) { func (tb *TokenBlacklist) Add(tokenID string, expiration time.Time) {
tb.mutex.Lock() tb.mutex.Lock()
defer tb.mutex.Unlock() defer tb.mutex.Unlock()
tb.blacklist[tokenID] = expiration tb.blacklist[tokenID] = expiration
} }
func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool { func (tb *TokenBlacklist) IsBlacklisted(tokenID string) bool {
tb.mutex.RLock() tb.mutex.RLock()
defer tb.mutex.RUnlock() defer tb.mutex.RUnlock()
expiration, exists := tb.blacklist[tokenID] expiration, exists := tb.blacklist[tokenID]
return exists && time.Now().Before(expiration) return exists && time.Now().Before(expiration)
} }
func (tb *TokenBlacklist) Cleanup() { func (tb *TokenBlacklist) Cleanup() {
tb.mutex.Lock() tb.mutex.Lock()
defer tb.mutex.Unlock() defer tb.mutex.Unlock()
now := time.Now() now := time.Now()
for tokenID, expiration := range tb.blacklist { for tokenID, expiration := range tb.blacklist {
if now.After(expiration) { if now.After(expiration) {
delete(tb.blacklist, tokenID) delete(tb.blacklist, tokenID)
} }
} }
} }
type TokenCache struct { type TokenCache struct {
cache map[string]*TokenInfo cache map[string]*TokenInfo
mutex sync.RWMutex mutex sync.RWMutex
} }
type TokenInfo struct { type TokenInfo struct {
Token string Token string
ExpiresAt time.Time ExpiresAt time.Time
} }
func NewTokenCache() *TokenCache { func NewTokenCache() *TokenCache {
return &TokenCache{ return &TokenCache{
cache: make(map[string]*TokenInfo), cache: make(map[string]*TokenInfo),
} }
} }
func (tc *TokenCache) Set(token string, expiresAt time.Time) { func (tc *TokenCache) Set(token string, expiresAt time.Time) {
tc.mutex.Lock() tc.mutex.Lock()
defer tc.mutex.Unlock() defer tc.mutex.Unlock()
tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt} tc.cache[token] = &TokenInfo{Token: token, ExpiresAt: expiresAt}
} }
func (tc *TokenCache) Get(token string) (*TokenInfo, bool) { func (tc *TokenCache) Get(token string) (*TokenInfo, bool) {
tc.mutex.RLock() tc.mutex.RLock()
defer tc.mutex.RUnlock() defer tc.mutex.RUnlock()
info, exists := tc.cache[token] info, exists := tc.cache[token]
if exists && time.Now().Before(info.ExpiresAt) { if exists && time.Now().Before(info.ExpiresAt) {
return info, true return info, true
} }
return nil, false return nil, false
} }
func (tc *TokenCache) Cleanup() { func (tc *TokenCache) Cleanup() {
tc.mutex.Lock() tc.mutex.Lock()
defer tc.mutex.Unlock() defer tc.mutex.Unlock()
now := time.Now() now := time.Now()
for token, info := range tc.cache { for token, info := range tc.cache {
if now.After(info.ExpiresAt) { if now.After(info.ExpiresAt) {
delete(tc.cache, token) delete(tc.cache, token)
} }
} }
} }
+91 -92
View File
@@ -6,7 +6,6 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"fmt" "fmt"
"math/big" "math/big"
"net/http" "net/http"
@@ -15,135 +14,135 @@ import (
) )
type JWK struct { type JWK struct {
Kty string `json:"kty"` Kty string `json:"kty"`
Kid string `json:"kid"` Kid string `json:"kid"`
Use string `json:"use"` Use string `json:"use"`
N string `json:"n"` N string `json:"n"`
E string `json:"e"` E string `json:"e"`
Alg string `json:"alg"` Alg string `json:"alg"`
} }
type JWKSet struct { type JWKSet struct {
Keys []JWK `json:"keys"` Keys []JWK `json:"keys"`
} }
type JWKCache struct { type JWKCache struct {
jwks *JWKSet jwks *JWKSet
expiresAt time.Time expiresAt time.Time
mutex sync.RWMutex mutex sync.RWMutex
} }
func (c *JWKCache) GetJWKS(jwksURL string) (*JWKSet, error) { func (c *JWKCache) GetJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) {
c.mutex.RLock() c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) { if c.jwks != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
return c.jwks, nil return c.jwks, nil
} }
c.mutex.RUnlock() c.mutex.RUnlock()
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if c.jwks != nil && time.Now().Before(c.expiresAt) { if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil return c.jwks, nil
} }
jwks, err := fetchJWKS(jwksURL) jwks, err := fetchJWKS(jwksURL, httpClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.jwks = jwks c.jwks = jwks
c.expiresAt = time.Now().Add(1 * time.Hour) c.expiresAt = time.Now().Add(1 * time.Hour)
return jwks, nil return jwks, nil
} }
func fetchJWKS(jwksURL string) (*JWKSet, error) { func fetchJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) {
resp, err := http.Get(jwksURL) resp, err := httpClient.Get(jwksURL)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, errors.New("failed to fetch JWKS") return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode)
} }
var jwks JWKSet var jwks JWKSet
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, err return nil, fmt.Errorf("failed to decode JWKS: %w", err)
} }
return &jwks, nil return &jwks, nil
} }
func verifyNonce(tokenNonce, expectedNonce string) error { func verifyNonce(tokenNonce, expectedNonce string) error {
if tokenNonce != expectedNonce { if tokenNonce != expectedNonce {
return errors.New("invalid nonce") return fmt.Errorf("invalid nonce")
} }
return nil return nil
} }
func verifyAudience(tokenAudience, expectedAudience string) error { func verifyAudience(tokenAudience, expectedAudience string) error {
if tokenAudience != expectedAudience { if tokenAudience != expectedAudience {
return errors.New("invalid audience") return fmt.Errorf("invalid audience")
} }
return nil return nil
} }
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error { func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
now := time.Now().Unix() now := time.Now().Unix()
if now < issuedAt-int64(allowedClockSkew.Seconds()) { if now < issuedAt-int64(allowedClockSkew.Seconds()) {
return errors.New("token used before issued") return fmt.Errorf("token used before issued")
} }
if now > expiration+int64(allowedClockSkew.Seconds()) { if now > expiration+int64(allowedClockSkew.Seconds()) {
return errors.New("token is expired") return fmt.Errorf("token is expired")
} }
return nil return nil
} }
func verifyIssuer(tokenIssuer, expectedIssuer string) error { func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer { if tokenIssuer != expectedIssuer {
return errors.New("invalid issuer") return fmt.Errorf("invalid issuer")
} }
return nil return nil
} }
func validateClaims(claims map[string]interface{}) error { func validateClaims(claims map[string]interface{}) error {
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
for _, claim := range requiredClaims { for _, claim := range requiredClaims {
if _, ok := claims[claim]; !ok { if _, ok := claims[claim]; !ok {
return fmt.Errorf("missing required claim: %s", claim) return fmt.Errorf("missing required claim: %s", claim)
} }
} }
return nil return nil
} }
func jwkToPEM(jwk *JWK) ([]byte, error) { func jwkToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N) n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
} }
e, err := base64.RawURLEncoding.DecodeString(jwk.E) e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
} }
publicKey := &rsa.PublicKey{ publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n), N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()), E: int(new(big.Int).SetBytes(e).Int64()),
} }
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to marshal public key: %w", err)
} }
publicKeyPEM := pem.EncodeToMemory(&pem.Block{ publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes, Bytes: publicKeyBytes,
}) })
return publicKeyPEM, nil return publicKeyPEM, nil
} }
+82 -79
View File
@@ -8,122 +8,125 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors" "fmt"
"strings" "strings"
"time" "time"
) )
type JWT struct { type JWT struct {
Header map[string]interface{} Header map[string]interface{}
Claims map[string]interface{} Claims map[string]interface{}
Signature string Signature string
} }
func parseJWT(token string) (*JWT, error) { func parseJWT(token string) (*JWT, error) {
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
if len(parts) != 3 { if len(parts) != 3 {
return nil, errors.New("invalid token format") return nil, fmt.Errorf("invalid token format")
} }
header, err := decodeSegment(parts[0]) header, err := decodeSegment(parts[0])
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode header: %w", err)
} }
claims, err := decodeSegment(parts[1]) claims, err := decodeSegment(parts[1])
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode claims: %w", err)
} }
return &JWT{ return &JWT{
Header: header, Header: header,
Claims: claims, Claims: claims,
Signature: parts[2], Signature: parts[2],
}, nil }, nil
} }
func (j *JWT) Verify(issuerURL, clientID string) error { func (j *JWT) Verify(issuerURL, clientID string) error {
claims := j.Claims claims := j.Claims
if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil { if err := verifyIssuer(claims["iss"].(string), issuerURL); err != nil {
return err return err
} }
if err := verifyAudience(claims["aud"].(string), clientID); err != nil { if err := verifyAudience(claims["aud"].(string), clientID); err != nil {
return err return err
} }
if err := verifyExpiration(claims["exp"].(float64)); err != nil { if err := verifyExpiration(claims["exp"].(float64)); err != nil {
return err return err
} }
if err := verifyIssuedAt(claims["iat"].(float64)); err != nil { if err := verifyIssuedAt(claims["iat"].(float64)); err != nil {
return err return err
} }
return nil return nil
} }
func verifyExpiration(expiration float64) error { func verifyExpiration(expiration float64) error {
expirationTime := time.Unix(int64(expiration), 0) expirationTime := time.Unix(int64(expiration), 0)
if time.Now().After(expirationTime) { if time.Now().After(expirationTime) {
return errors.New("token has expired") return fmt.Errorf("token has expired")
} }
return nil return nil
} }
func verifySignature(token string, publicKeyPEM []byte) error { func verifySignature(token string, publicKeyPEM []byte) error {
parts := strings.Split(token, ".") parts := strings.Split(token, ".")
if len(parts) != 3 { if len(parts) != 3 {
return errors.New("invalid token format") return fmt.Errorf("invalid token format")
} }
block, _ := pem.Decode(publicKeyPEM) block, _ := pem.Decode(publicKeyPEM)
if block == nil { if block == nil {
return errors.New("failed to parse PEM block containing the public key") return fmt.Errorf("failed to parse PEM block containing the public key")
} }
pub, err := x509.ParsePKIXPublicKey(block.Bytes) pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to parse public key: %w", err)
} }
rsaPublicKey, ok := pub.(*rsa.PublicKey) rsaPublicKey, ok := pub.(*rsa.PublicKey)
if !ok { if !ok {
return errors.New("not an RSA public key") return fmt.Errorf("not an RSA public key")
} }
signedContent := parts[0] + "." + parts[1] signedContent := parts[0] + "." + parts[1]
signature, _ := base64.RawURLEncoding.DecodeString(parts[2]) signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return fmt.Errorf("failed to decode signature: %w", err)
}
hash := sha256.Sum256([]byte(signedContent)) hash := sha256.Sum256([]byte(signedContent))
err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature) err = rsa.VerifyPKCS1v15(rsaPublicKey, crypto.SHA256, hash[:], signature)
if err != nil { if err != nil {
return errors.New("invalid token signature") return fmt.Errorf("invalid token signature: %w", err)
} }
return nil return nil
} }
func verifyIssuedAt(issuedAt float64) error { func verifyIssuedAt(issuedAt float64) error {
issuedAtTime := time.Unix(int64(issuedAt), 0) issuedAtTime := time.Unix(int64(issuedAt), 0)
if time.Now().Before(issuedAtTime) { if time.Now().Before(issuedAtTime) {
return errors.New("token used before issued") return fmt.Errorf("token used before issued")
} }
return nil return nil
} }
func decodeSegment(seg string) (map[string]interface{}, error) { func decodeSegment(seg string) (map[string]interface{}, error) {
data, err := base64.RawURLEncoding.DecodeString(seg) data, err := base64.RawURLEncoding.DecodeString(seg)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to decode segment: %w", err)
} }
var result map[string]interface{} var result map[string]interface{}
err = json.Unmarshal(data, &result) err = json.Unmarshal(data, &result)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to unmarshal segment: %w", err)
} }
return result, nil return result, nil
} }
+81 -114
View File
@@ -3,10 +3,7 @@ package traefikoidc
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@@ -17,14 +14,10 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
var (
infoLogger = log.New(io.Discard, "INFO: traefikoidc: ", log.Ldate|log.Ltime)
)
type TraefikOidc struct { type TraefikOidc struct {
next http.Handler next http.Handler
name string name string
store *sessions.CookieStore store sessions.Store
redirURLPath string redirURLPath string
issuerURL string issuerURL string
jwkCache *JWKCache jwkCache *JWKCache
@@ -39,6 +32,8 @@ type TraefikOidc struct {
forceHTTPS bool forceHTTPS bool
scheme string scheme string
tokenCache *TokenCache tokenCache *TokenCache
httpClient HTTPClient
logger Logger
} }
type ProviderMetadata struct { type ProviderMetadata struct {
@@ -58,10 +53,11 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
metadata, err := discoverProviderMetadata(config.ProviderURL) metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to discover provider metadata: %v", err) return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
} }
logger := NewLogger(config.LogLevel)
t := &TraefikOidc{ t := &TraefikOidc{
next: next, next: next,
@@ -80,17 +76,19 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
scopes: config.Scopes, scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), 100), limiter: rate.NewLimiter(rate.Every(time.Second), 100),
tokenCache: NewTokenCache(), tokenCache: NewTokenCache(),
httpClient: &http.Client{},
logger: logger,
} }
t.startTokenCleanup() t.startTokenCleanup()
return t, nil return t, nil
} }
func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) { func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
resp, err := http.Get(wellKnownURL) resp, err := httpClient.Get(wellKnownURL)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to fetch provider metadata: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
@@ -100,13 +98,47 @@ func discoverProviderMetadata(providerURL string) (*ProviderMetadata, error) {
var metadata ProviderMetadata var metadata ProviderMetadata
if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil {
return nil, err return nil, fmt.Errorf("failed to decode provider metadata: %w", err)
} }
return &metadata, nil return &metadata, nil
} }
func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
t.scheme = t.determineScheme(req)
host := t.determineHost(req)
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Infof("Final redirect URL: %s", redirectURL)
session, err := t.store.Get(req, cookieName)
if err != nil {
t.logger.Errorf("Error getting session: %v", err)
http.Error(rw, "Session error", http.StatusInternalServerError)
return
}
if req.URL.Path == t.redirURLPath {
t.logger.Infof("Handling callback, URL: %s", req.URL.String())
authSuccess, originalPath := t.handleCallback(rw, req)
if authSuccess {
http.Redirect(rw, req, originalPath, http.StatusFound)
return
}
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
return
}
if t.isUserAuthenticated(session) {
t.next.ServeHTTP(rw, req)
return
}
// User is not authenticated, start the auth process
t.initiateAuthentication(rw, req, session, redirectURL)
}
func (t *TraefikOidc) determineScheme(req *http.Request) string {
scheme := req.URL.Scheme scheme := req.URL.Scheme
if scheme == "" { if scheme == "" {
scheme = req.Header.Get("X-Forwarded-Proto") scheme = req.Header.Get("X-Forwarded-Proto")
@@ -121,8 +153,10 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if t.forceHTTPS { if t.forceHTTPS {
scheme = "https" scheme = "https"
} }
t.scheme = scheme return scheme
}
func (t *TraefikOidc) determineHost(req *http.Request) string {
host := req.URL.Host host := req.URL.Host
if host == "" { if host == "" {
host = req.Header.Get("X-Forwarded-Host") host = req.Header.Get("X-Forwarded-Host")
@@ -130,74 +164,36 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if host == "" { if host == "" {
host = req.Host host = req.Host
} }
return host
}
// infoLogger.Printf("Scheme: %s, Host: %s, Path: %s", scheme, host, t.redirURLPath) func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool {
// infoLogger.Printf("X-Forwarded-Proto: %s", req.Header.Get("X-Forwarded-Proto"))
// infoLogger.Printf("X-Forwarded-Host: %s", req.Header.Get("X-Forwarded-Host"))
redirectURL := assembleRedirectURL(t.scheme, host, t.redirURLPath)
// infoLogger.Printf("Final redirect URL: %s", redirectURL)
session, err := t.store.Get(req, cookie_name)
if err != nil {
// infoLogger.Printf("Error getting session: %v", err)
http.Error(rw, "Session error: "+err.Error(), http.StatusInternalServerError)
return
}
if req.URL.Path == t.redirURLPath {
// infoLogger.Printf("Handling callback, URL: %s", req.URL.String())
authSuccess, originalPath := t.handleCallback(rw, req)
if authSuccess {
http.Redirect(rw, req, originalPath, http.StatusFound)
return
}
// If auth was not successful, return an error instead of re-authenticating
http.Error(rw, "Authentication failed", http.StatusUnauthorized)
return
}
authenticated, _ := session.Values["authenticated"].(bool) authenticated, _ := session.Values["authenticated"].(bool)
if authenticated { if authenticated {
idToken, ok := session.Values["id_token"].(string) idToken, ok := session.Values["id_token"].(string)
if !ok || idToken == "" { if !ok || idToken == "" {
http.Error(rw, "Invalid session", http.StatusUnauthorized) return false
return
} }
return t.verifyToken(idToken) == nil
if err := t.verifyToken(idToken); err != nil {
http.Error(rw, "Invalid token", http.StatusUnauthorized)
return
}
// Proceed with the request
t.next.ServeHTTP(rw, req)
return
} }
return false
}
// User is not authenticated, start the auth process func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.Request, session *sessions.Session, redirectURL string) {
csrfToken := uuid.New().String() csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path session.Values["incoming_path"] = req.URL.Path
// infoLogger.Printf("Setting CSRF token: %s", csrfToken) t.logger.Infof("Setting CSRF token: %s", csrfToken)
err = session.Save(req, rw)
if err != nil {
// infoLogger.Printf("Failed to save session: %v", err)
http.Error(rw, "Failed to save session: "+err.Error(), http.StatusInternalServerError)
return
}
// Verify the session was saved correctly if err := session.Save(req, rw); err != nil {
verifySession, _ := t.store.Get(req, cookie_name) t.logger.Errorf("Failed to save session: %v", err)
savedCSRF, ok := verifySession.Values["csrf"].(string) http.Error(rw, "Failed to save session", http.StatusInternalServerError)
if !ok || savedCSRF != csrfToken {
// infoLogger.Printf("Failed to save CSRF token. Saved: %s, Expected: %s", savedCSRF, csrfToken)
http.Error(rw, "Failed to save CSRF token", http.StatusInternalServerError)
return return
} }
nonce, err := generateNonce() nonce, err := generateNonce()
if err != nil { if err != nil {
http.Error(rw, "Failed to generate nonce: "+err.Error(), http.StatusInternalServerError) http.Error(rw, "Failed to generate nonce", http.StatusInternalServerError)
return return
} }
@@ -205,22 +201,9 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
http.Redirect(rw, req, authURL, http.StatusFound) http.Redirect(rw, req, authURL, http.StatusFound)
} }
func (t *TraefikOidc) isUserAuthenticated(req *http.Request) bool {
session, err := t.store.Get(req, cookie_name)
if err != nil {
return false
}
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
return false
}
return true
}
func (t *TraefikOidc) verifyToken(token string) error { func (t *TraefikOidc) verifyToken(token string) error {
if !t.limiter.Allow() { if !t.limiter.Allow() {
return errors.New("rate limit exceeded") return fmt.Errorf("rate limit exceeded")
} }
if _, exists := t.tokenCache.Get(token); exists { if _, exists := t.tokenCache.Get(token); exists {
@@ -229,17 +212,17 @@ func (t *TraefikOidc) verifyToken(token string) error {
jwt, err := parseJWT(token) jwt, err := parseJWT(token)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to parse JWT: %w", err)
} }
jwks, err := t.jwkCache.GetJWKS(t.jwksURL) jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to get JWKS: %w", err)
} }
kid, ok := jwt.Header["kid"].(string) kid, ok := jwt.Header["kid"].(string)
if !ok { if !ok {
return errors.New("missing key ID in token header") return fmt.Errorf("missing key ID in token header")
} }
var publicKeyPEM []byte var publicKeyPEM []byte
@@ -247,38 +230,22 @@ func (t *TraefikOidc) verifyToken(token string) error {
if key.Kid == kid { if key.Kid == kid {
publicKeyPEM, err = jwkToPEM(&key) publicKeyPEM, err = jwkToPEM(&key)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to convert JWK to PEM: %w", err)
} }
break break
} }
} }
if publicKeyPEM == nil { if publicKeyPEM == nil {
return errors.New("unable to find matching public key") return fmt.Errorf("unable to find matching public key")
} }
if err := verifySignature(token, publicKeyPEM); err != nil { if err := verifySignature(token, publicKeyPEM); err != nil {
return err return fmt.Errorf("signature verification failed: %w", err)
}
if err := verifyAudience(jwt.Claims["aud"].(string), t.clientID); err != nil {
return err
} }
if err := jwt.Verify(t.issuerURL, t.clientID); err != nil { if err := jwt.Verify(t.issuerURL, t.clientID); err != nil {
return err return fmt.Errorf("JWT verification failed: %w", err)
}
if err := verifyTokenTimes(
int64(jwt.Claims["iat"].(float64)),
int64(jwt.Claims["exp"].(float64)),
5*time.Minute, // Allowed clock skew
); err != nil {
return err
}
if err := validateClaims(jwt.Claims); err != nil {
return err
} }
expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0) expirationTime := time.Unix(int64(jwt.Claims["exp"].(float64)), 0)
@@ -288,17 +255,16 @@ func (t *TraefikOidc) verifyToken(token string) error {
} }
func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string { func (t *TraefikOidc) buildAuthURL(redirectURL, state, nonce string) string {
params := url.Values{} params := url.Values{
params.Add("client_id", t.clientID) "client_id": {t.clientID},
params.Add("response_type", "code") "response_type": {"code"},
params.Add("redirect_uri", redirectURL) "redirect_uri": {redirectURL},
params.Add("scope", strings.Join(t.scopes, " ")) "scope": {strings.Join(t.scopes, " ")},
params.Add("state", state) "state": {state},
params.Add("nonce", nonce) "nonce": {nonce},
}
authURL := t.authURL + "?" + params.Encode() return fmt.Sprintf("%s?%s", t.authURL, params.Encode())
// infoLogger.Printf("Built auth URL: %s", authURL)
return authURL
} }
func (t *TraefikOidc) startTokenCleanup() { func (t *TraefikOidc) startTokenCleanup() {
@@ -306,6 +272,7 @@ func (t *TraefikOidc) startTokenCleanup() {
go func() { go func() {
for range ticker.C { for range ticker.C {
t.tokenCache.Cleanup() t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup()
} }
}() }()
} }
+61 -5
View File
@@ -1,10 +1,13 @@
package traefikoidc package traefikoidc
import "os" import (
"fmt"
"net/http"
"os"
)
// constants
const ( const (
cookie_name = "_raczylo_oidc" cookieName = "_raczylo_oidc"
) )
type Config struct { type Config struct {
@@ -19,6 +22,59 @@ type Config struct {
} }
func CreateConfig() *Config { func CreateConfig() *Config {
infoLogger.SetOutput(os.Stdout) return &Config{
return &Config{} Scopes: []string{"openid", "profile", "email"},
LogLevel: "info",
}
}
func (c *Config) Validate() error {
if c.ProviderURL == "" {
return fmt.Errorf("providerURL is required")
}
if c.CallbackURL == "" {
return fmt.Errorf("callbackURL is required")
}
if c.ClientID == "" {
return fmt.Errorf("clientID is required")
}
if c.ClientSecret == "" {
return fmt.Errorf("clientSecret is required")
}
if c.SessionEncryptionKey == "" {
return fmt.Errorf("sessionEncryptionKey is required")
}
return nil
}
type defaultLogger struct {
level string
}
func NewLogger(level string) Logger {
return &defaultLogger{level: level}
}
func (l *defaultLogger) Infof(format string, args ...interface{}) {
if l.level == "info" || l.level == "debug" {
fmt.Printf("INFO: "+format+"\n", args...)
}
}
func (l *defaultLogger) Errorf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...)
}
type HTTPClient interface {
Get(url string) (*http.Response, error)
Do(req *http.Request) (*http.Response, error)
}
type Logger interface {
Infof(format string, args ...interface{})
Errorf(format string, args ...interface{})
}
func handleError(w http.ResponseWriter, message string, code int) {
http.Error(w, message, code)
} }