diff --git a/cmd/oidcgate/interceptor.go b/cmd/oidcgate/interceptor.go new file mode 100644 index 0000000..7558b03 --- /dev/null +++ b/cmd/oidcgate/interceptor.go @@ -0,0 +1,78 @@ +package main + +import "net/http" + +// authInterceptor wraps a ResponseWriter for the /oauth2/auth endpoint. +// The traefikoidc middleware emits an HTTP 302 to the IdP authorize URL +// when a request is unauthenticated, but nginx auth_request and similar +// silent-probe contracts cannot follow redirects. authInterceptor buffers +// the header/body and, at Finalize() time: +// +// - if status was 302 or 303 (redirect class we care about), rewrites +// it to 401, moves the original Location header to X-Auth-Redirect +// (advisory), strips Location, preserves Set-Cookie headers (state, +// PKCE, nonce — the browser will carry them into the next request), +// and writes an empty body. +// - otherwise: passes through verbatim. +type authInterceptor struct { + inner http.ResponseWriter + headers http.Header + status int + body []byte + wroteHeader bool +} + +func newAuthInterceptor(inner http.ResponseWriter) *authInterceptor { + return &authInterceptor{ + inner: inner, + headers: http.Header{}, + status: http.StatusOK, + } +} + +func (w *authInterceptor) Header() http.Header { return w.headers } + +func (w *authInterceptor) WriteHeader(status int) { + if w.wroteHeader { + return + } + w.status = status + w.wroteHeader = true +} + +func (w *authInterceptor) Write(b []byte) (int, error) { //nolint:unparam // signature mandated by http.ResponseWriter + if !w.wroteHeader { + w.WriteHeader(http.StatusOK) + } + w.body = append(w.body, b...) + return len(b), nil +} + +// Finalize flushes the buffered response, applying the 302/303 → 401 rewrite. +// Must be called exactly once after the wrapped handler returns. +func (w *authInterceptor) Finalize() { + switch w.status { + case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: + // Move Location → X-Auth-Redirect, strip Location, force 401, drop body. + if loc := w.headers.Get("Location"); loc != "" { + w.headers.Set("X-Auth-Redirect", loc) + w.headers.Del("Location") + } + copyHeaders(w.inner.Header(), w.headers) + w.inner.WriteHeader(http.StatusUnauthorized) + return + } + copyHeaders(w.inner.Header(), w.headers) + w.inner.WriteHeader(w.status) + if len(w.body) > 0 { + _, _ = w.inner.Write(w.body) + } +} + +func copyHeaders(dst, src http.Header) { + for k, vs := range src { + for _, v := range vs { + dst.Add(k, v) + } + } +} diff --git a/cmd/oidcgate/interceptor_test.go b/cmd/oidcgate/interceptor_test.go new file mode 100644 index 0000000..c19b4ba --- /dev/null +++ b/cmd/oidcgate/interceptor_test.go @@ -0,0 +1,70 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestInterceptor_302BecomesNot401(t *testing.T) { + rec := httptest.NewRecorder() + w := newAuthInterceptor(rec) + + w.Header().Set("Location", "https://idp.example/authorize?state=abc") + w.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/; HttpOnly") + w.Header().Add("Set-Cookie", "_oidc_pkce=xyz; Path=/; HttpOnly") + w.WriteHeader(http.StatusFound) + _, _ = w.Write([]byte("ignored body")) + + w.Finalize() + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: want 401, got %d", rec.Code) + } + if got := rec.Header().Get("X-Auth-Redirect"); got != "https://idp.example/authorize?state=abc" { + t.Errorf("X-Auth-Redirect: want preserved Location, got %q", got) + } + if got := rec.Header().Get("Location"); got != "" { + t.Errorf("Location must be stripped on 401, got %q", got) + } + cookies := rec.Header().Values("Set-Cookie") + if len(cookies) != 2 { + t.Fatalf("Set-Cookie count: want 2, got %d (%v)", len(cookies), cookies) + } + if body := strings.TrimSpace(rec.Body.String()); body != "" { + t.Errorf("body must be empty on 401, got %q", body) + } +} + +func TestInterceptor_NonRedirectPassthrough(t *testing.T) { + rec := httptest.NewRecorder() + w := newAuthInterceptor(rec) + + w.Header().Set("X-Forwarded-User", "alice") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + + w.Finalize() + + if rec.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d", rec.Code) + } + if got := rec.Header().Get("X-Forwarded-User"); got != "alice" { + t.Errorf("X-Forwarded-User: want preserved, got %q", got) + } + if !strings.Contains(rec.Body.String(), "ok") { + t.Errorf("body: want 'ok' preserved, got %q", rec.Body.String()) + } +} + +func TestInterceptor_303SeeOtherAlsoIntercepted(t *testing.T) { + rec := httptest.NewRecorder() + w := newAuthInterceptor(rec) + w.Header().Set("Location", "/elsewhere") + w.WriteHeader(http.StatusSeeOther) + w.Finalize() + if rec.Code != http.StatusUnauthorized { + t.Fatalf("303 should be intercepted to 401, got %d", rec.Code) + } +}