diff --git a/.traefik.yml b/.traefik.yml index ea605f9..9da7166 100644 --- a/.traefik.yml +++ b/.traefik.yml @@ -4,7 +4,7 @@ type: middleware import: github.com/lukaszraczylo/traefikoidc summary: | - WIP [do not use, yet] Middleware adding OIDC authentication to traefik. + Middleware adding OIDC authentication to traefik routes. testData: providerURL: https://accounts.google.com @@ -18,3 +18,5 @@ testData: - profile sessionEncryptionKey: potato-secret forceHTTPS: false + logLevel: debug + rateLimit: 100 diff --git a/README.md b/README.md index 70913ca..ffb9e36 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,7 @@ ## Traefik OIDC middleware -WIP warning! -This middleware is under active development. - -This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy. +This middleware is under active development - things should NOT break, but they might. +This middleware is supposed to replace the need for the forward-auth and oauth2-proxy when using traefik as a reverse proxy to support the OIDC authentication. ### Configuration options @@ -98,4 +96,6 @@ http: - profile sessionEncryptionKey: potato-secret forceHTTPS: false + logLevel: info + rateLimit: 100 # 100 requests per minute ``` diff --git a/helpers.go b/helpers.go index 9b5f3d6..a1f72bd 100644 --- a/helpers.go +++ b/helpers.go @@ -60,6 +60,7 @@ func (t *TraefikOidc) exchangeCodeForToken(ctx context.Context, code, redirectUR func (t *TraefikOidc) handleLogout(rw http.ResponseWriter, req *http.Request) { session, err := t.store.Get(req, cookieName) + t.logger.Debugf("Logging out user") if err != nil { handleError(rw, "Session error", http.StatusInternalServerError, t.logger) return diff --git a/jwk.go b/jwk.go index 816fdfe..8c0a218 100644 --- a/jwk.go +++ b/jwk.go @@ -14,135 +14,135 @@ import ( ) type JWK struct { - Kty string `json:"kty"` - Kid string `json:"kid"` - Use string `json:"use"` - N string `json:"n"` - E string `json:"e"` - Alg string `json:"alg"` + Kty string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"use"` + N string `json:"n"` + E string `json:"e"` + Alg string `json:"alg"` } type JWKSet struct { - Keys []JWK `json:"keys"` + Keys []JWK `json:"keys"` } type JWKCache struct { - jwks *JWKSet - expiresAt time.Time - mutex sync.RWMutex + jwks *JWKSet + expiresAt time.Time + mutex sync.RWMutex } -func (c *JWKCache) GetJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { - c.mutex.RLock() - if c.jwks != nil && time.Now().Before(c.expiresAt) { - defer c.mutex.RUnlock() - return c.jwks, nil - } - c.mutex.RUnlock() +func (c *JWKCache) GetJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { + c.mutex.RLock() + if c.jwks != nil && time.Now().Before(c.expiresAt) { + defer c.mutex.RUnlock() + return c.jwks, nil + } + c.mutex.RUnlock() - c.mutex.Lock() - defer c.mutex.Unlock() + c.mutex.Lock() + defer c.mutex.Unlock() - if c.jwks != nil && time.Now().Before(c.expiresAt) { - return c.jwks, nil - } + if c.jwks != nil && time.Now().Before(c.expiresAt) { + return c.jwks, nil + } - jwks, err := fetchJWKS(jwksURL, httpClient) - if err != nil { - return nil, err - } + jwks, err := fetchJWKS(jwksURL, httpClient) + if err != nil { + return nil, err + } - c.jwks = jwks - c.expiresAt = time.Now().Add(1 * time.Hour) + c.jwks = jwks + c.expiresAt = time.Now().Add(1 * time.Hour) - return jwks, nil + return jwks, nil } -func fetchJWKS(jwksURL string, httpClient HTTPClient) (*JWKSet, error) { - resp, err := httpClient.Get(jwksURL) - if err != nil { - return nil, fmt.Errorf("failed to fetch JWKS: %w", err) - } - defer resp.Body.Close() +func fetchJWKS(jwksURL string, httpClient *http.Client) (*JWKSet, error) { + resp, err := httpClient.Get(jwksURL) + if err != nil { + return nil, fmt.Errorf("failed to fetch JWKS: %w", err) + } + defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode) - } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to fetch JWKS: unexpected status code %d", resp.StatusCode) + } - var jwks JWKSet - if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { - return nil, fmt.Errorf("failed to decode JWKS: %w", err) - } + var jwks JWKSet + if err := json.NewDecoder(resp.Body).Decode(&jwks); err != nil { + return nil, fmt.Errorf("failed to decode JWKS: %w", err) + } - return &jwks, nil + return &jwks, nil } func verifyNonce(tokenNonce, expectedNonce string) error { - if tokenNonce != expectedNonce { - return fmt.Errorf("invalid nonce") - } - return nil + if tokenNonce != expectedNonce { + return fmt.Errorf("invalid nonce") + } + return nil } func verifyAudience(tokenAudience, expectedAudience string) error { - if tokenAudience != expectedAudience { - return fmt.Errorf("invalid audience") - } - return nil + if tokenAudience != expectedAudience { + return fmt.Errorf("invalid audience") + } + return nil } func verifyTokenTimes(issuedAt, expiration int64, allowedClockSkew time.Duration) error { - now := time.Now().Unix() - if now < issuedAt-int64(allowedClockSkew.Seconds()) { - return fmt.Errorf("token used before issued") - } - if now > expiration+int64(allowedClockSkew.Seconds()) { - return fmt.Errorf("token is expired") - } - return nil + now := time.Now().Unix() + if now < issuedAt-int64(allowedClockSkew.Seconds()) { + return fmt.Errorf("token used before issued") + } + if now > expiration+int64(allowedClockSkew.Seconds()) { + return fmt.Errorf("token is expired") + } + return nil } func verifyIssuer(tokenIssuer, expectedIssuer string) error { - if tokenIssuer != expectedIssuer { - return fmt.Errorf("invalid issuer") - } - return nil + if tokenIssuer != expectedIssuer { + return fmt.Errorf("invalid issuer") + } + return nil } func validateClaims(claims map[string]interface{}) error { - requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} - for _, claim := range requiredClaims { - if _, ok := claims[claim]; !ok { - return fmt.Errorf("missing required claim: %s", claim) - } - } - return nil + requiredClaims := []string{"sub", "iss", "aud", "exp", "iat"} + for _, claim := range requiredClaims { + if _, ok := claims[claim]; !ok { + return fmt.Errorf("missing required claim: %s", claim) + } + } + return nil } func jwkToPEM(jwk *JWK) ([]byte, error) { - n, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) - } - e, err := base64.RawURLEncoding.DecodeString(jwk.E) - if err != nil { - return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err) - } + n, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'n' parameter: %w", err) + } + e, err := base64.RawURLEncoding.DecodeString(jwk.E) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK 'e' parameter: %w", err) + } - publicKey := &rsa.PublicKey{ - N: new(big.Int).SetBytes(n), - E: int(new(big.Int).SetBytes(e).Int64()), - } + publicKey := &rsa.PublicKey{ + N: new(big.Int).SetBytes(n), + E: int(new(big.Int).SetBytes(e).Int64()), + } - publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return nil, fmt.Errorf("failed to marshal public key: %w", err) - } + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal public key: %w", err) + } - publicKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: publicKeyBytes, - }) + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) - return publicKeyPEM, nil + return publicKeyPEM, nil } diff --git a/jwt.go b/jwt.go index 2e9a23a..153bb4b 100644 --- a/jwt.go +++ b/jwt.go @@ -132,6 +132,7 @@ func decodeSegment(seg string) (map[string]interface{}, error) { } func (t *TraefikOidc) verifyAndCacheToken(token string) error { + t.logger.Debugf("Verifying token") if !t.limiter.Allow() { return fmt.Errorf("rate limit exceeded") } @@ -141,6 +142,7 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error { } if _, exists := t.tokenCache.Get(token); exists { + t.logger.Debugf("Token is valid and cached") return nil // Token is valid and cached } @@ -160,6 +162,7 @@ func (t *TraefikOidc) verifyAndCacheToken(token string) error { } func (t *TraefikOidc) verifyJWTSignatureAndClaims(jwt *JWT, token string) error { + t.logger.Debugf("Verifying JWT signature and claims") jwks, err := t.jwkCache.GetJWKS(t.jwksURL, t.httpClient) if err != nil { return fmt.Errorf("failed to get JWKS: %w", err) diff --git a/main.go b/main.go index 7312049..e8ca97b 100644 --- a/main.go +++ b/main.go @@ -33,8 +33,9 @@ type TraefikOidc struct { forceHTTPS bool scheme string tokenCache *TokenCache - httpClient HTTPClient - logger Logger + httpClient *http.Client + logger *Logger + redirectURL string } type ProviderMetadata struct { @@ -54,7 +55,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h SameSite: http.SameSiteLaxMode, } - metadata, err := discoverProviderMetadata(config.ProviderURL, &http.Client{}) + metadata, err := discoverProviderMetadata(config.ProviderURL, http.Client{}) if err != nil { return nil, fmt.Errorf("failed to discover provider metadata: %w", err) } @@ -75,16 +76,17 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h authURL: metadata.AuthURL, tokenURL: metadata.TokenURL, scopes: config.Scopes, - limiter: rate.NewLimiter(rate.Every(time.Second), 100), + limiter: rate.NewLimiter(rate.Every(time.Second), config.RateLimit), tokenCache: NewTokenCache(), httpClient: &http.Client{}, logger: NewLogger(config.LogLevel), + redirectURL: "", } t.startTokenCleanup() return t, nil } -func discoverProviderMetadata(providerURL string, httpClient HTTPClient) (*ProviderMetadata, error) { +func discoverProviderMetadata(providerURL string, httpClient http.Client) (*ProviderMetadata, error) { wellKnownURL := strings.TrimSuffix(providerURL, "/") + "/.well-known/openid-configuration" resp, err := httpClient.Get(wellKnownURL) if err != nil { @@ -110,12 +112,14 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.URL.Path == t.logoutURLPath { t.handleLogout(rw, req) - http.Error(rw, "Logged out", http.StatusOK) + http.Error(rw, "Logged out", http.StatusForbidden) return } - redirectURL := buildFullURL(t.scheme, host, t.redirURLPath) - t.logger.Infof("Final redirect URL: %s", redirectURL) + if t.redirectURL == "" { + t.redirectURL = buildFullURL(t.scheme, host, t.redirURLPath) + t.logger.Debugf("Redirect URL updated to: %s", t.redirectURL) + } session, err := t.store.Get(req, cookieName) if err != nil { @@ -125,7 +129,7 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if req.URL.Path == t.redirURLPath { - t.logger.Infof("Handling callback, URL: %s", req.URL.String()) + t.logger.Debugf("Handling callback, URL: %s", req.URL.String()) authSuccess, originalPath := t.handleCallback(rw, req) if authSuccess { http.Redirect(rw, req, originalPath, http.StatusFound) @@ -136,12 +140,13 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if t.isUserAuthenticated(session) { + t.logger.Debugf("User is authenticated, serving content") t.next.ServeHTTP(rw, req) return } // User is not authenticated or session has expired, start the auth process - t.initiateAuthentication(rw, req, session, redirectURL) + t.initiateAuthentication(rw, req, session, t.redirectURL) } func (t *TraefikOidc) determineScheme(req *http.Request) string { @@ -195,7 +200,7 @@ func (t *TraefikOidc) isUserAuthenticated(session *sessions.Session) bool { } if time.Now().Unix() > int64(exp) { - t.logger.Infof("Session has expired") + t.logger.Debugf("Session has expired") return false } @@ -208,7 +213,7 @@ func (t *TraefikOidc) initiateAuthentication(rw http.ResponseWriter, req *http.R csrfToken := uuid.New().String() session.Values["csrf"] = csrfToken session.Values["incoming_path"] = req.URL.Path - t.logger.Infof("Setting CSRF token: %s", csrfToken) + t.logger.Debugf("Setting CSRF token: %s", csrfToken) if err := session.Save(req, rw); err != nil { t.logger.Errorf("Failed to save session: %v", err) @@ -247,6 +252,7 @@ func (t *TraefikOidc) startTokenCleanup() { ticker := time.NewTicker(5 * time.Minute) go func() { for range ticker.C { + t.logger.Debug("Cleaning up token cache") t.tokenCache.Cleanup() t.tokenBlacklist.Cleanup() } diff --git a/settings.go b/settings.go index 9598bc9..b59d359 100644 --- a/settings.go +++ b/settings.go @@ -2,6 +2,8 @@ package traefikoidc import ( "fmt" + "io" + "log" "net/http" "os" ) @@ -20,17 +22,28 @@ type Config struct { LogLevel string `json:"logLevel"` SessionEncryptionKey string `json:"sessionEncryptionKey"` ForceHTTPS bool `json:"forceHTTPS"` + RateLimit int `json:"rateLimit"` } func CreateConfig() *Config { - c := &Config{ - Scopes: []string{"openid", "profile", "email"}, - LogLevel: "info", + c := &Config{} + + if c.Scopes == nil { + c.Scopes = []string{"openid", "profile", "email"} + } + + if c.LogLevel == "" { + c.LogLevel = "info" } if c.LogoutURL == "" { c.LogoutURL = c.CallbackURL + "/logout" } + + if c.RateLimit == 0 { + c.RateLimit = 100 + } + return c } @@ -53,47 +66,56 @@ func (c *Config) Validate() error { return nil } -type defaultLogger struct { - level string +type Logger struct { + logError *log.Logger + logInfo *log.Logger + logDebug *log.Logger } -func NewLogger(level string) Logger { - return &defaultLogger{level: level} -} +func NewLogger(logLevel string) *Logger { + logError := log.New(io.Discard, "ERROR: TraefikOidcPlugin: ", log.Ldate|log.Ltime) + logInfo := log.New(io.Discard, "INFO: TraefikOidcPlugin: ", log.Ldate|log.Ltime) + logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime) -func (l *defaultLogger) Info(args ...interface{}) { - if l.level == "info" || l.level == "debug" { - fmt.Println(append([]interface{}{"INFO:"}, args...)...) + logError.SetOutput(os.Stderr) + logInfo.SetOutput(os.Stdout) + + if logLevel == "debug" { + logDebug.SetOutput(os.Stdout) + } + + return &Logger{ + logError: logError, + logInfo: logInfo, + logDebug: logDebug, } } -func (l *defaultLogger) Infof(format string, args ...interface{}) { - if l.level == "info" || l.level == "debug" { - fmt.Printf("INFO: "+format+"\n", args...) - } +func (l *Logger) Info(format string, args ...interface{}) { + l.logInfo.Printf(format, args...) } -func (l *defaultLogger) Error(args ...interface{}) { - fmt.Fprintln(os.Stderr, append([]interface{}{"ERROR:"}, args...)...) +func (l *Logger) Debug(format string, args ...interface{}) { + l.logDebug.Printf(format, args...) } -func (l *defaultLogger) Errorf(format string, args ...interface{}) { - fmt.Fprintf(os.Stderr, "ERROR: "+format+"\n", args...) +func (l *Logger) Error(format string, args ...interface{}) { + l.logError.Printf(format, args...) } -type HTTPClient interface { - Get(url string) (*http.Response, error) - Do(req *http.Request) (*http.Response, error) +func (l *Logger) Infof(format string, args ...interface{}) { + l.logInfo.Printf(format, args...) } -type Logger interface { - Info(args ...interface{}) - Infof(format string, args ...interface{}) - Error(args ...interface{}) - Errorf(format string, args ...interface{}) +func (l *Logger) Debugf(format string, args ...interface{}) { + l.logDebug.Printf(format, args...) } -func handleError(w http.ResponseWriter, message string, code int, logger Logger) { +func (l *Logger) Errorf(format string, args ...interface{}) { + l.logError.Printf(format, args...) +} + +func handleError(w http.ResponseWriter, message string, code int, logger *Logger) { logger.Errorf(message) http.Error(w, message, code) }