diff --git a/main_servehttp_test.go b/main_servehttp_test.go index cf4cace..3ca87d1 100644 --- a/main_servehttp_test.go +++ b/main_servehttp_test.go @@ -79,34 +79,186 @@ func TestServeHTTP_ExcludedURLs(t *testing.T) { } } -// TestServeHTTP_EventStream tests the event-stream bypass functionality +// TestServeHTTP_EventStream tests the event-stream (SSE) bypass: the +// handshake must skip the OIDC redirect dance (clients can't follow it +// mid-stream) but it must STILL require an authenticated session, otherwise +// any caller could reach the backend by setting Accept: text/event-stream. func TestServeHTTP_EventStream(t *testing.T) { - nextCalled := false - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - nextCalled = true - w.WriteHeader(http.StatusOK) + sessionManager := createTestSessionManager(t) + + newOidc := func(next http.Handler) *TraefikOidc { + oidc := &TraefikOidc{ + next: next, + logger: NewLogger("debug"), + initComplete: make(chan struct{}), + sessionManager: sessionManager, + firstRequestReceived: true, + metadataRefreshStarted: true, + issuerURL: "https://provider.example.com", + } + close(oidc.initComplete) + return oidc + } + + t.Run("unauthenticated_request_is_rejected", func(t *testing.T) { + nextCalled := false + oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/events", nil) + req.Header.Set("Accept", "text/event-stream") + rw := httptest.NewRecorder() + + oidc.ServeHTTP(rw, req) + + if rw.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for unauthenticated SSE request, got %d", rw.Code) + } + if nextCalled { + t.Error("backend handler must NOT be called for unauthenticated SSE bypass") + } }) - oidc := &TraefikOidc{ - next: next, - logger: NewLogger("debug"), - initComplete: make(chan struct{}), - sessionManager: createTestSessionManager(t), - firstRequestReceived: true, - metadataRefreshStarted: true, - issuerURL: "https://provider.example.com", + t.Run("authenticated_request_bypasses_to_backend", func(t *testing.T) { + nextCalled := false + var forwardedUser string + oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + forwardedUser = r.Header.Get("X-Forwarded-User") + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/events", nil) + req.Header.Set("Accept", "text/event-stream") + + // Build an authenticated session and inject its cookies onto req. + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("failed to create test session: %v", err) + } + session.SetEmail("user@example.com") + if err := session.SetAuthenticated(true); err != nil { + t.Fatalf("failed to mark session authenticated: %v", err) + } + setupRW := httptest.NewRecorder() + if err := session.Save(req, setupRW); err != nil { + t.Fatalf("failed to save session: %v", err) + } + for _, c := range setupRW.Result().Cookies() { + req.AddCookie(c) + } + + rw := httptest.NewRecorder() + oidc.ServeHTTP(rw, req) + + if !nextCalled { + t.Fatal("expected authenticated SSE request to be forwarded to backend") + } + if forwardedUser != "user@example.com" { + t.Errorf("expected X-Forwarded-User=user@example.com, got %q", forwardedUser) + } + }) +} + +// TestServeHTTP_WebSocketUpgrade mirrors the SSE behavior: WebSocket +// handshake bypasses the OIDC redirect (clients can't follow it) but the +// session must already be authenticated, otherwise the backend is exposed +// to any caller setting `Connection: Upgrade` + `Upgrade: websocket`. +func TestServeHTTP_WebSocketUpgrade(t *testing.T) { + sessionManager := createTestSessionManager(t) + + newOidc := func(next http.Handler) *TraefikOidc { + oidc := &TraefikOidc{ + next: next, + logger: NewLogger("debug"), + initComplete: make(chan struct{}), + sessionManager: sessionManager, + firstRequestReceived: true, + metadataRefreshStarted: true, + issuerURL: "https://provider.example.com", + } + close(oidc.initComplete) + return oidc } - close(oidc.initComplete) - req := httptest.NewRequest("GET", "/events", nil) - req.Header.Set("Accept", "text/event-stream") - rw := httptest.NewRecorder() + t.Run("unauthenticated_upgrade_is_rejected", func(t *testing.T) { + nextCalled := false + oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + })) - oidc.ServeHTTP(rw, req) + req := httptest.NewRequest("GET", "/ws", nil) + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + rw := httptest.NewRecorder() - if !nextCalled { - t.Error("expected event-stream request to bypass OIDC") - } + oidc.ServeHTTP(rw, req) + + if rw.Code != http.StatusUnauthorized { + t.Errorf("expected 401 for unauthenticated WS upgrade, got %d", rw.Code) + } + if nextCalled { + t.Error("backend handler must NOT be called for unauthenticated WS bypass") + } + }) + + t.Run("authenticated_upgrade_bypasses_to_backend", func(t *testing.T) { + nextCalled := false + var forwardedUser string + oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nextCalled = true + forwardedUser = r.Header.Get("X-Forwarded-User") + })) + + req := httptest.NewRequest("GET", "/ws", nil) + // Mixed-case + multi-token Connection header to exercise parsing. + req.Header.Set("Connection", "keep-alive, Upgrade") + req.Header.Set("Upgrade", "WebSocket") + + session, err := sessionManager.GetSession(req) + if err != nil { + t.Fatalf("failed to create test session: %v", err) + } + session.SetEmail("ws-user@example.com") + if err := session.SetAuthenticated(true); err != nil { + t.Fatalf("failed to mark session authenticated: %v", err) + } + setupRW := httptest.NewRecorder() + if err := session.Save(req, setupRW); err != nil { + t.Fatalf("failed to save session: %v", err) + } + for _, c := range setupRW.Result().Cookies() { + req.AddCookie(c) + } + + rw := httptest.NewRecorder() + oidc.ServeHTTP(rw, req) + + if !nextCalled { + t.Fatal("expected authenticated WS handshake to be forwarded to backend") + } + if forwardedUser != "ws-user@example.com" { + t.Errorf("expected X-Forwarded-User=ws-user@example.com, got %q", forwardedUser) + } + }) + + t.Run("plain_http_does_not_bypass", func(t *testing.T) { + // Sanity: requests without Upgrade headers must NOT hit the WS + // bypass branch (otherwise the new code path could short-circuit + // normal authentication). + oidc := newOidc(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("backend must not be called for unauthenticated plain HTTP") + })) + req := httptest.NewRequest("GET", "/ws", nil) + req.Header.Set("Connection", "keep-alive") + rw := httptest.NewRecorder() + oidc.ServeHTTP(rw, req) + if rw.Code == http.StatusOK { + t.Errorf("expected redirect or 401 for plain HTTP without auth, got 200") + } + }) } // TestServeHTTP_InitializationTimeout tests initialization timeout handling diff --git a/middleware.go b/middleware.go index 7ee012d..fa636f8 100644 --- a/middleware.go +++ b/middleware.go @@ -14,21 +14,40 @@ import ( ) // bypassReason describes why a request is being forwarded without OIDC auth. -// It is only used for logging and to decide whether extra SSE-specific work +// It is only used for logging and to decide whether extra side-effects // (propagating the user header from an existing session) should run. const ( - bypassReasonExcluded = "excluded-url" - bypassReasonSSE = "sse" + bypassReasonExcluded = "excluded-url" + bypassReasonSSE = "sse" + bypassReasonWebSocket = "websocket" ) +// isWebSocketUpgrade reports whether req is a WebSocket upgrade handshake +// (RFC 6455). The middleware can only see the handshake; once Traefik +// completes the upgrade it forwards frames directly, so we never re-process +// per-frame traffic. We bypass auth on the handshake the same way we do for +// SSE, because browser WebSocket clients cannot follow an OIDC redirect. +func isWebSocketUpgrade(req *http.Request) bool { + if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") { + return false + } + for _, token := range strings.Split(req.Header.Get("Connection"), ",") { + if strings.EqualFold(strings.TrimSpace(token), "upgrade") { + return true + } + } + return false +} + // shouldBypassAuth decides whether a request must skip OIDC authentication // entirely. It returns (true, reason) when either the request path matches a -// configured excluded URL or the Accept header asks for a text/event-stream -// response (SSE). The reason lets ServeHTTP apply any side-effects that are -// unique to the bypass kind (e.g. propagating user headers for SSE). +// configured excluded URL, the Accept header asks for a text/event-stream +// response (SSE), or the request is a WebSocket upgrade handshake. The +// reason lets ServeHTTP apply any side-effects that are unique to the bypass +// kind (e.g. propagating user headers). // -// This must be called BEFORE waiting on t.initComplete so excluded and SSE -// traffic is never blocked by a slow/broken provider. +// This must be called BEFORE waiting on t.initComplete so excluded, SSE and +// WebSocket traffic is never blocked by a slow/broken provider. func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) { if t.determineExcludedURL(req.URL.Path) { return true, bypassReasonExcluded @@ -36,38 +55,55 @@ func (t *TraefikOidc) shouldBypassAuth(req *http.Request) (bool, string) { if strings.Contains(req.Header.Get("Accept"), "text/event-stream") { return true, bypassReasonSSE } + if isWebSocketUpgrade(req) { + return true, bypassReasonWebSocket + } return false, "" } -// applySSEUserHeaders attempts to copy the authenticated user's identity from -// an existing session onto the outgoing SSE request so downstream services -// can still see who the user is. Failures are logged (not silenced) because -// they indicate either a corrupt cookie or a misconfigured session manager -// and are useful for debugging, but they never block the bypass itself. -func (t *TraefikOidc) applySSEUserHeaders(req *http.Request) { +// applyBypassUserHeaders enforces authentication on SSE / WebSocket bypass +// requests and, on success, copies the authenticated user's identity onto +// the outgoing request so downstream services can see who the user is. +// +// Returns true when the request carries a valid authenticated session and +// the bypass should proceed. Returns false when no usable session is +// present; callers must then reject the request (typically with 401) to +// prevent unauthenticated traffic from reaching the backend just by setting +// `Accept: text/event-stream` or sending a WebSocket upgrade. +// +// The check is cookie-only: the session cookie is sealed by our encryption +// key, so the authenticated flag cannot be forged. We do NOT run full token +// signature verification here so that SSE/WS keeps working when the OIDC +// provider is briefly unavailable for JWK fetches. +func (t *TraefikOidc) applyBypassUserHeaders(req *http.Request, reason string) bool { if t.sessionManager == nil { - return + return false } session, err := t.sessionManager.GetSession(req) if err != nil { - // Intentionally not fatal: SSE requests bypass auth, we just lose the - // forwarded-user header for this request. - t.logger.Debugf("SSE bypass: unable to load session for user header propagation: %v", err) - return + t.logger.Debugf("%s bypass: unable to load session: %v", reason, err) + return false } defer session.returnToPoolSafely() + if !session.GetAuthenticated() { + t.logger.Debugf("%s bypass: rejecting request without authenticated session", reason) + return false + } + email := session.GetEmail() if email == "" { - return + t.logger.Debugf("%s bypass: rejecting request, session has no user identifier", reason) + return false } req.Header.Set("X-Forwarded-User", email) if !t.minimalHeaders { req.Header.Set("X-Auth-Request-User", email) } - t.logger.Debugf("SSE bypass: forwarded user %s from session", email) + t.logger.Debugf("%s bypass: forwarded user %s from session", reason, email) + return true } // ServeHTTP implements the main middleware logic for processing HTTP requests. @@ -124,16 +160,32 @@ func (t *TraefikOidc) ServeHTTP(rw http.ResponseWriter, req *http.Request) { t.firstRequestMutex.Unlock() } - // Evaluate auth-bypass once, before waiting for initialization. Excluded URLs - // and SSE requests must not block on provider init. For SSE we additionally - // attempt to forward the user identity from an existing session (best - // effort) so downstream handlers still see X-Forwarded-User. + // Evaluate auth-bypass once, before waiting for initialization. Excluded + // URLs, SSE and WebSocket upgrade requests must not block on provider + // init. For SSE/WebSocket we ALSO require an authenticated session + // (cookie-only check, no JWK fetch) and otherwise return 401 — clients + // of in-flight streams can't follow an OIDC redirect, so forwarding + // unauthenticated traffic would silently expose the backend. if bypass, reason := t.shouldBypassAuth(req); bypass { t.logger.Debugf("Bypassing OIDC for %s (%s)", req.URL.Path, reason) - if reason == bypassReasonSSE { - t.applySSEUserHeaders(req) + switch reason { + case bypassReasonExcluded: + // Operator-declared excluded URLs forward unconditionally. + t.next.ServeHTTP(rw, req) + case bypassReasonSSE, bypassReasonWebSocket: + // Skip the OIDC redirect dance (clients can't follow it + // mid-stream) but still require an authenticated session. + // Otherwise an unauthenticated client could hit the backend + // just by setting Accept: text/event-stream or sending a + // WebSocket upgrade. + if !t.applyBypassUserHeaders(req, reason) { + t.sendErrorResponse(rw, req, "Authentication required", http.StatusUnauthorized) + return + } + t.next.ServeHTTP(rw, req) + default: + t.next.ServeHTTP(rw, req) } - t.next.ServeHTTP(rw, req) return } @@ -386,31 +438,52 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http } } - tokenForClaims := session.GetIDToken() - if tokenForClaims == "" { - tokenForClaims = session.GetAccessToken() - if tokenForClaims == "" && len(t.allowedRolesAndGroups) > 0 { - t.logger.Error("No token available but roles/groups checks are required") - // Reset redirect count to prevent loops when token is missing + // Resolve ID-token claims at most once per request. SessionData caches + // the parsed claims keyed on the raw ID token, so concurrent dashboard + // panel requests on the same session don't repeatedly base64-decode and + // JSON-unmarshal the same JWT (a real cost under the yaegi interpreter + // that hosts Traefik plugins). idClaims is reused below by the + // header-templates branch. + idToken := session.GetIDToken() + var ( + idClaims map[string]interface{} + idClaimsErr error + ) + if idToken != "" { + idClaims, idClaimsErr = session.GetIDTokenClaims(t.extractClaimsFunc) + } + + // Choose which claims drive groups/roles extraction. Prefer the ID + // token (cached) and fall back to the access token if there is no ID + // token in the session — matching the prior behavior for opaque + // ID-token providers. + var ( + groupClaims map[string]interface{} + groupClaimsErr error + ) + if idToken != "" { + groupClaims, groupClaimsErr = idClaims, idClaimsErr + } else if accessToken := session.GetAccessToken(); accessToken != "" { + groupClaims, groupClaimsErr = t.extractClaimsFunc(accessToken) + } else if len(t.allowedRolesAndGroups) > 0 { + t.logger.Error("No token available but roles/groups checks are required") + session.ResetRedirectCount() + t.defaultInitiateAuthentication(rw, req, session, redirectURL) + return + } + + var groups, roles []string + + if groupClaimsErr == nil && groupClaims != nil { + var err error + groups, roles, err = t.extractGroupsAndRolesFromClaims(groupClaims) + if err != nil && len(t.allowedRolesAndGroups) > 0 { + t.logger.Errorf("Failed to extract groups and roles: %v", err) session.ResetRedirectCount() t.defaultInitiateAuthentication(rw, req, session, redirectURL) return } - } - - // Initialize empty slices - var groups, roles []string - - if tokenForClaims != "" { - var err error - groups, roles, err = t.extractGroupsAndRoles(tokenForClaims) - if err != nil && len(t.allowedRolesAndGroups) > 0 { - t.logger.Errorf("Failed to extract groups and roles: %v", err) - // Reset redirect count to prevent loops when claim extraction fails - session.ResetRedirectCount() - t.defaultInitiateAuthentication(rw, req, session, redirectURL) - return - } else if err == nil { + if err == nil { if len(groups) > 0 { req.Header.Set("X-User-Groups", strings.Join(groups, ",")) } @@ -442,41 +515,40 @@ func (t *TraefikOidc) processAuthorizedRequest(rw http.ResponseWriter, req *http if !t.minimalHeaders { req.Header.Set("X-Auth-Request-Redirect", req.URL.RequestURI()) req.Header.Set("X-Auth-Request-User", email) - if idToken := session.GetIDToken(); idToken != "" { + if idToken != "" { req.Header.Set("X-Auth-Request-Token", idToken) } } if len(t.headerTemplates) > 0 { - // Reuse claims parsed earlier in this request if the ID token has not - // changed. Saves an unnecessary JWT parse on every authenticated - // request that uses headerTemplates. - claims, err := session.GetIDTokenClaims(t.extractClaimsFunc) - if err != nil { - t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", err) + if idClaimsErr != nil { + t.logger.Errorf("Failed to extract claims from ID Token for template headers: %v", idClaimsErr) } else { + // idClaims may be nil when no ID token is present; templates + // referencing .Claims.* will simply produce empty values, which + // matches the prior behavior. templateData := map[string]interface{}{ "AccessToken": session.GetAccessToken(), - "IDToken": session.GetIDToken(), + "IDToken": idToken, "RefreshToken": session.GetRefreshToken(), - "Claims": claims, + "Claims": idClaims, } for headerName, tmpl := range t.headerTemplates { var buf bytes.Buffer - if err := tmpl.Execute(&buf, templateData); err != nil { t.logger.Errorf("Failed to execute template for header %s: %v", headerName, err) continue } headerValue := buf.String() - req.Header.Set(headerName, headerValue) - t.logger.Debugf("Set templated header %s = %s", headerName, headerValue) } - session.MarkDirty() - t.logger.Debugf("Session marked dirty after templated header processing.") + // NOTE: templates only mutate request headers (not session state), + // so we deliberately do NOT MarkDirty / Save here. Previously every + // authenticated request with header templates re-encrypted and + // rewrote all session cookies, which was a measurable CPU and + // Set-Cookie tax on dashboards that poll many panels per second. } } diff --git a/token_manager.go b/token_manager.go index f0e5e9f..a20fcaf 100644 --- a/token_manager.go +++ b/token_manager.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/url" + "runtime" "strings" "time" ) @@ -1193,9 +1194,14 @@ func (t *TraefikOidc) startTokenCleanup() { sessionManager := t.sessionManager logger := t.logger + // Only use the fast cleanup interval when actually running under `go test`. + // runtime.Compiler == "yaegi" makes isTestMode() return true in production + // (Traefik interprets the plugin via yaegi), which would otherwise pin this + // ticker to 20 Hz on a real cluster despite tokenCache.Cleanup and + // jwkCache.Cleanup both being no-ops there. cleanupInterval := 1 * time.Minute - if isTestMode() { - cleanupInterval = 50 * time.Millisecond // Fast interval for tests + if isTestMode() && runtime.Compiler != "yaegi" { + cleanupInterval = 50 * time.Millisecond } // Create cleanup function @@ -1237,25 +1243,27 @@ func (t *TraefikOidc) startTokenCleanup() { } // extractGroupsAndRoles extracts group and role information from token claims. -// It parses the 'groups' and 'roles' claims from the ID token and validates their format. -// Parameters: -// - idToken: The ID token containing claims to extract. +// It parses the configured group/role claims from the supplied ID token. // -// Returns: -// - groups: Array of group names from the 'groups' claim. -// - roles: Array of role names from the 'roles' claim. -// - An error if claim extraction fails or if the 'groups' or 'roles' claims are present -// but not arrays of strings. +// Most callers should prefer extractGroupsAndRolesFromClaims when claims have +// already been parsed for the request (e.g. via SessionData.GetIDTokenClaims), +// to avoid re-parsing the JWT. func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, error) { claims, err := t.extractClaimsFunc(idToken) if err != nil { return nil, nil, fmt.Errorf("failed to extract claims: %w", err) } + return t.extractGroupsAndRolesFromClaims(claims) +} +// extractGroupsAndRolesFromClaims extracts group and role information from +// already-parsed claims. Hot path: callers that have a cached claims map (such +// as SessionData.GetIDTokenClaims) should use this to skip a redundant +// base64+JSON decode of the JWT on every authenticated request. +func (t *TraefikOidc) extractGroupsAndRolesFromClaims(claims map[string]interface{}) ([]string, []string, error) { var groups []string var roles []string - // Extract groups using configurable claim name (defaults to "groups") if groupsClaim, exists := claims[t.groupClaimName]; exists { groupsSlice, ok := groupsClaim.([]interface{}) if !ok { @@ -1271,7 +1279,6 @@ func (t *TraefikOidc) extractGroupsAndRoles(idToken string) ([]string, []string, } } - // Extract roles using configurable claim name (defaults to "roles") if rolesClaim, exists := claims[t.roleClaimName]; exists { rolesSlice, ok := rolesClaim.([]interface{}) if !ok {