Fix logging, add additional settings for to the middleware.

This commit is contained in:
2024-07-25 18:00:41 +01:00
parent 622b11f586
commit 4b99a4c5fa
7 changed files with 170 additions and 136 deletions
+3 -1
View File
@@ -4,7 +4,7 @@ type: middleware
import: github.com/lukaszraczylo/traefikoidc import: github.com/lukaszraczylo/traefikoidc
summary: | summary: |
WIP [do not use, yet] Middleware adding OIDC authentication to traefik. Middleware adding OIDC authentication to traefik routes.
testData: testData:
providerURL: https://accounts.google.com providerURL: https://accounts.google.com
@@ -18,3 +18,5 @@ testData:
- profile - profile
sessionEncryptionKey: potato-secret sessionEncryptionKey: potato-secret
forceHTTPS: false forceHTTPS: false
logLevel: debug
rateLimit: 100
+4 -4
View File
@@ -1,9 +1,7 @@
## Traefik OIDC middleware ## Traefik OIDC middleware
WIP warning! This middleware is under active development - things should NOT break, but they might.
This middleware is under active development. This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication.
This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy.
### Configuration options ### Configuration options
@@ -98,4 +96,6 @@ http:
- profile - profile
sessionEncryptionKey: potato-secret sessionEncryptionKey: potato-secret
forceHTTPS: false forceHTTPS: false
logLevel: info
rateLimit: 100 # 100 requests per minute
``` ```
+1
View File
@@ -60,6 +60,7 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR
func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) {
session, err := t.store.Get(req, cookieName) session, err := t.store.Get(req, cookieName)
t.logger.Debugf("Logging out user")
if err != nil { if err != nil {
handleError(rw, "Session error", http.StatusInternalServerError, t.logger) handleError(rw, "Session error", http.StatusInternalServerError, t.logger)
return return
+91 -91
View File
@@ -14,135 +14,135 @@ import (
) )
type JWK struct { type JWK struct {
Kty string `json:"kty"` Kty string `json:"kty"`
Kid string `json:"kid"` Kid string `json:"kid"`
Use string `json:"use"` Use string `json:"use"`
N string `json:"n"` N string `json:"n"`
E string `json:"e"` E string `json:"e"`
Alg string `json:"alg"` Alg string `json:"alg"`
} }
type JWKSet struct { type JWKSet struct {
Keys []JWK `json:"keys"` Keys []JWK `json:"keys"`
} }
type JWKCache struct { type JWKCache struct {
jwks *JWKSet jwks *JWKSet
expiresAt time.Time expiresAt time.Time
mutex sync.RWMutex mutex sync.RWMutex
} }
func (c *JWKCache) GetJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
c.mutex.RLock() c.mutex.RLock()
if c.jwks != nil && time.Now().Before(c.expiresAt) { if c.jwks != nil && time.Now().Before(c.expiresAt) {
defer c.mutex.RUnlock() defer c.mutex.RUnlock()
return c.jwks, nil return c.jwks, nil
} }
c.mutex.RUnlock() c.mutex.RUnlock()
c.mutex.Lock() c.mutex.Lock()
defer c.mutex.Unlock() defer c.mutex.Unlock()
if c.jwks != nil && time.Now().Before(c.expiresAt) { if c.jwks != nil && time.Now().Before(c.expiresAt) {
return c.jwks, nil return c.jwks, nil
} }
jwks, err := fetchJWKS(jwksURL, httpClient) jwks, err := fetchJWKS(jwksURL, httpClient)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.jwks = jwks c.jwks = jwks
c.expiresAt = time.Now().Add(1 * time.Hour) c.expiresAt = time.Now().Add(1 * time.Hour)
return jwks, nil return jwks, nil
} }
func fetchJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) {
resp, err := httpClient.Get(jwksURL) resp, err := httpClient.Get(jwksURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to fetch JWKS: %w", err) return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode) return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode)
} }
var jwks JWKSet var jwks JWKSet
if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil {
return nil, fmt.Errorf("failed to decode JWKS: %w", err) return nil, fmt.Errorf("failed to decode JWKS: %w", err)
} }
return &jwks, nil return &jwks, nil
} }
func verifyNonce(tokenNonce, expectedNonce string) error { func verifyNonce(tokenNonce, expectedNonce string) error {
if tokenNonce != expectedNonce { if tokenNonce != expectedNonce {
return fmt.Errorf("invalid nonce") return fmt.Errorf("invalid nonce")
} }
return nil return nil
} }
func verifyAudience(tokenAudience, expectedAudience string) error { func verifyAudience(tokenAudience, expectedAudience string) error {
if tokenAudience != expectedAudience { if tokenAudience != expectedAudience {
return fmt.Errorf("invalid audience") return fmt.Errorf("invalid audience")
} }
return nil return nil
} }
func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error { func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error {
now := time.Now().Unix() now := time.Now().Unix()
if now < issuedAt-int64(allowedClockSkew.Seconds()) { if now < issuedAt-int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token used before issued") return fmt.Errorf("token used before issued")
} }
if now > expiration+int64(allowedClockSkew.Seconds()) { if now > expiration+int64(allowedClockSkew.Seconds()) {
return fmt.Errorf("token is expired") return fmt.Errorf("token is expired")
} }
return nil return nil
} }
func verifyIssuer(tokenIssuer, expectedIssuer string) error { func verifyIssuer(tokenIssuer, expectedIssuer string) error {
if tokenIssuer != expectedIssuer { if tokenIssuer != expectedIssuer {
return fmt.Errorf("invalid issuer") return fmt.Errorf("invalid issuer")
} }
return nil return nil
} }
func validateClaims(claims map[string]interface{}) error { func validateClaims(claims map[string]interface{}) error {
requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"}
for _, claim := range requiredClaims { for _, claim := range requiredClaims {
if _, ok := claims[claim]; !ok { if _, ok := claims[claim]; !ok {
return fmt.Errorf("missing required claim: %s", claim) return fmt.Errorf("missing required claim: %s", claim)
} }
} }
return nil return nil
} }
func jwkToPEM(jwk *JWK) ([]byte, error) { func jwkToPEM(jwk *JWK) ([]byte, error) {
n, err := base64.RawURLEncoding.DecodeString(jwk.N) n, err := base64.RawURLEncoding.DecodeString(jwk.N)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err)
} }
e, err := base64.RawURLEncoding.DecodeString(jwk.E) e, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err) return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err)
} }
publicKey := &rsa.PublicKey{ publicKey := &rsa.PublicKey{
N: new(big.Int).SetBytes(n), N: new(big.Int).SetBytes(n),
E: int(new(big.Int).SetBytes(e).Int64()), E: int(new(big.Int).SetBytes(e).Int64()),
} }
publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to marshal public key: %w", err) return nil, fmt.Errorf("failed to marshal public key: %w", err)
} }
publicKeyPEM := pem.EncodeToMemory(&pem.Block{ publicKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PUBLIC KEY", Type: "RSA PUBLIC KEY",
Bytes: publicKeyBytes, Bytes: publicKeyBytes,
}) })
return publicKeyPEM, nil return publicKeyPEM, nil
} }
+3
View File
@@ -132,6 +132,7 @@ func decodeSegment(seg string) (map[string]interface{}, error) {
} }
func (t *TraefikOidc) verifyAndCacheToken(token string) error { func (t *TraefikOidc) verifyAndCacheToken(token string) error {
t.logger.Debugf("Verifying token")
if !t.limiter.Allow() { if !t.limiter.Allow() {
return fmt.Errorf("rate limit exceeded") return fmt.Errorf("rate limit exceeded")
} }
@@ -141,6 +142,7 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error {
} }
if _, exists := t.tokenCache.Get(token); exists { if _, exists := t.tokenCache.Get(token); exists {
t.logger.Debugf("Token is valid and cached")
return nil // Token is valid and cached return nil // Token is valid and cached
} }
@@ -160,6 +162,7 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error {
} }
func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error { func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error {
t.logger.Debugf("Verifying JWT signature and claims")
jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient)
if err != nil { if err != nil {
return fmt.Errorf("failed to get JWKS: %w", err) return fmt.Errorf("failed to get JWKS: %w", err)
+18 -12
View File
@@ -33,8 +33,9 @@ type TraefikOidc struct {
forceHTTPS bool forceHTTPS bool
scheme string scheme string
tokenCache *TokenCache tokenCache *TokenCache
httpClient HTTPClient httpClient *http.Client
logger Logger logger *Logger
redirectURL string
} }
type ProviderMetadata struct { type ProviderMetadata struct {
@@ -54,7 +55,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
SameSite: http.SameSiteLaxMode, SameSite: http.SameSiteLaxMode,
} }
metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{}) metadata, err := discoverProviderMetadata(config.ProviderURL, http.Client{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to discover provider metadata: %w", err) return nil, fmt.Errorf("failed to discover provider metadata: %w", err)
} }
@@ -75,16 +76,17 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
authURL: metadata.AuthURL, authURL: metadata.AuthURL,
tokenURL: metadata.TokenURL, tokenURL: metadata.TokenURL,
scopes: config.Scopes, scopes: config.Scopes,
limiter: rate.NewLimiter(rate.Every(time.Second), 100), limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit),
tokenCache: NewTokenCache(), tokenCache: NewTokenCache(),
httpClient: &http.Client{}, httpClient: &http.Client{},
logger: NewLogger(config.LogLevel), logger: NewLogger(config.LogLevel),
redirectURL: "",
} }
t.startTokenCleanup() t.startTokenCleanup()
return t, nil return t, nil
} }
func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) { func discoverProviderMetadata(providerURL string, httpClient http.Client) (*ProviderMetadata, error) {
wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration"
resp, err := httpClient.Get(wellKnownURL) resp, err := httpClient.Get(wellKnownURL)
if err != nil { if err != nil {
@@ -110,12 +112,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == t.logoutURLPath { if req.URL.Path == t.logoutURLPath {
t.handleLogout(rw, req) t.handleLogout(rw, req)
http.Error(rw, "Logged out", http.StatusOK) http.Error(rw, "Logged out", http.StatusForbidden)
return return
} }
redirectURL := buildFullURL(t.scheme, host, t.redirURLPath) if t.redirectURL == "" {
t.logger.Infof("Final redirect URL: %s", redirectURL) t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath)
t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL)
}
session, err := t.store.Get(req, cookieName) session, err := t.store.Get(req, cookieName)
if err != nil { if err != nil {
@@ -125,7 +129,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
if req.URL.Path == t.redirURLPath { if req.URL.Path == t.redirURLPath {
t.logger.Infof("Handling callback, URL: %s", req.URL.String()) t.logger.Debugf("Handling callback, URL: %s", req.URL.String())
authSuccess, originalPath := t.handleCallback(rw, req) authSuccess, originalPath := t.handleCallback(rw, req)
if authSuccess { if authSuccess {
http.Redirect(rw, req, originalPath, http.StatusFound) http.Redirect(rw, req, originalPath, http.StatusFound)
@@ -136,12 +140,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
if t.isUserAuthenticated(session) { if t.isUserAuthenticated(session) {
t.logger.Debugf("User is authenticated, serving content")
t.next.ServeHTTP(rw, req) t.next.ServeHTTP(rw, req)
return return
} }
// User is not authenticated or session has expired, start the auth process // User is not authenticated or session has expired, start the auth process
t.initiateAuthentication(rw, req, session, redirectURL) t.initiateAuthentication(rw, req, session, t.redirectURL)
} }
func (t *TraefikOidc) determineScheme(req *http.Request) string { func (t *TraefikOidc) determineScheme(req *http.Request) string {
@@ -195,7 +200,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool {
} }
if time.Now().Unix() > int64(exp) { if time.Now().Unix() > int64(exp) {
t.logger.Infof("Session has expired") t.logger.Debugf("Session has expired")
return false return false
} }
@@ -208,7 +213,7 @@ func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.R
csrfToken := uuid.New().String() csrfToken := uuid.New().String()
session.Values["csrf"] = csrfToken session.Values["csrf"] = csrfToken
session.Values["incoming_path"] = req.URL.Path session.Values["incoming_path"] = req.URL.Path
t.logger.Infof("Setting CSRF token: %s", csrfToken) t.logger.Debugf("Setting CSRF token: %s", csrfToken)
if err := session.Save(req, rw); err != nil { if err := session.Save(req, rw); err != nil {
t.logger.Errorf("Failed to save session: %v", err) t.logger.Errorf("Failed to save session: %v", err)
@@ -247,6 +252,7 @@ func (t *TraefikOidc) startTokenCleanup() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
go func() { go func() {
for range ticker.C { for range ticker.C {
t.logger.Debug("Cleaning up token cache")
t.tokenCache.Cleanup() t.tokenCache.Cleanup()
t.tokenBlacklist.Cleanup() t.tokenBlacklist.Cleanup()
} }
+50 -28
View File
@@ -2,6 +2,8 @@ package traefikoidc
import ( import (
"fmt" "fmt"
"io"
"log"
"net/http" "net/http"
"os" "os"
) )
@@ -20,17 +22,28 @@ type Config struct {
LogLevel string `json:"logLevel"` LogLevel string `json:"logLevel"`
SessionEncryptionKey string `json:"sessionEncryptionKey"` SessionEncryptionKey string `json:"sessionEncryptionKey"`
ForceHTTPS bool `json:"forceHTTPS"` ForceHTTPS bool `json:"forceHTTPS"`
RateLimit int `json:"rateLimit"`
} }
func CreateConfig() *Config { func CreateConfig() *Config {
c := &Config{ c := &Config{}
Scopes: []string{"openid", "profile", "email"},
LogLevel: "info", if c.Scopes == nil {
c.Scopes = []string{"openid", "profile", "email"}
}
if c.LogLevel == "" {
c.LogLevel = "info"
} }
if c.LogoutURL == "" { if c.LogoutURL == "" {
c.LogoutURL = c.CallbackURL + "/logout" c.LogoutURL = c.CallbackURL + "/logout"
} }
if c.RateLimit == 0 {
c.RateLimit = 100
}
return c return c
} }
@@ -53,47 +66,56 @@ func (c *Config) Validate() error {
return nil return nil
} }
type defaultLogger struct { type Logger struct {
level string logError *log.Logger
logInfo *log.Logger
logDebug *log.Logger
} }
func NewLogger(level string) Logger { func NewLogger(logLevel string) *Logger {
return &defaultLogger{level: level} logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
} logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
func (l *defaultLogger) Info(args ...interface{}) { logError.SetOutput(os.Stderr)
if l.level == "info" || l.level == "debug" { logInfo.SetOutput(os.Stdout)
fmt.Println(append([]interface{}{"INFO:"}, args...)...)
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
return &Logger{
logError: logError,
logInfo: logInfo,
logDebug: logDebug,
} }
} }
func (l *defaultLogger) Infof(format string, args ...interface{}) { func (l *Logger) Info(format string, args ...interface{}) {
if l.level == "info" || l.level == "debug" { l.logInfo.Printf(format, args...)
fmt.Printf("INFO: "+format+"\n", args...)
}
} }
func (l *defaultLogger) Error(args ...interface{}) { func (l *Logger) Debug(format string, args ...interface{}) {
fmt.Fprintln(os.Stderr, append([]interface{}{"ERROR:"}, args...)...) l.logDebug.Printf(format, args...)
} }
func (l *defaultLogger) Errorf(format string, args ...interface{}) { func (l *Logger) Error(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...) l.logError.Printf(format, args...)
} }
type HTTPClient interface { func (l *Logger) Infof(format string, args ...interface{}) {
Get(url string) (*http.Response, error) l.logInfo.Printf(format, args...)
Do(req *http.Request) (*http.Response, error)
} }
type Logger interface { func (l *Logger) Debugf(format string, args ...interface{}) {
Info(args ...interface{}) l.logDebug.Printf(format, args...)
Infof(format string, args ...interface{})
Error(args ...interface{})
Errorf(format string, args ...interface{})
} }
func handleError(w http.ResponseWriter, message string, code int, logger Logger) { func (l *Logger) Errorf(format string, args ...interface{}) {
l.logError.Printf(format, args...)
}
func handleError(w http.ResponseWriter, message string, code int, logger *Logger) {
logger.Errorf(message) logger.Errorf(message)
http.Error(w, message, code) http.Error(w, message, code)
} }