feat(oidcgate): response-writer interceptor converts 302->401 for /oauth2/auth

This commit is contained in:
2026-05-19 13:50:03 +01:00
parent 047fea3c75
commit c465fc888b
2 changed files with 148 additions and 0 deletions
+78
View File
@@ -0,0 +1,78 @@
package main
import "net/http"
// authInterceptor wraps a ResponseWriter for the /oauth2/auth endpoint.
// The traefikoidc middleware emits an HTTP 302 to the IdP authorize URL
// when a request is unauthenticated, but nginx auth_request and similar
// silent-probe contracts cannot follow redirects. authInterceptor buffers
// the header/body and, at Finalize() time:
//
// - if status was 302 or 303 (redirect class we care about), rewrites
// it to 401, moves the original Location header to X-Auth-Redirect
// (advisory), strips Location, preserves Set-Cookie headers (state,
// PKCE, nonce — the browser will carry them into the next request),
// and writes an empty body.
// - otherwise: passes through verbatim.
type authInterceptor struct {
inner http.ResponseWriter
headers http.Header
status int
body []byte
wroteHeader bool
}
func newAuthInterceptor(inner http.ResponseWriter) *authInterceptor {
return &authInterceptor{
inner: inner,
headers: http.Header{},
status: http.StatusOK,
}
}
func (w *authInterceptor) Header() http.Header { return w.headers }
func (w *authInterceptor) WriteHeader(status int) {
if w.wroteHeader {
return
}
w.status = status
w.wroteHeader = true
}
func (w *authInterceptor) Write(b []byte) (int, error) { //nolint:unparam // signature mandated by http.ResponseWriter
if !w.wroteHeader {
w.WriteHeader(http.StatusOK)
}
w.body = append(w.body, b...)
return len(b), nil
}
// Finalize flushes the buffered response, applying the 302/303 → 401 rewrite.
// Must be called exactly once after the wrapped handler returns.
func (w *authInterceptor) Finalize() {
switch w.status {
case http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
// Move Location → X-Auth-Redirect, strip Location, force 401, drop body.
if loc := w.headers.Get("Location"); loc != "" {
w.headers.Set("X-Auth-Redirect", loc)
w.headers.Del("Location")
}
copyHeaders(w.inner.Header(), w.headers)
w.inner.WriteHeader(http.StatusUnauthorized)
return
}
copyHeaders(w.inner.Header(), w.headers)
w.inner.WriteHeader(w.status)
if len(w.body) > 0 {
_, _ = w.inner.Write(w.body)
}
}
func copyHeaders(dst, src http.Header) {
for k, vs := range src {
for _, v := range vs {
dst.Add(k, v)
}
}
}
+70
View File
@@ -0,0 +1,70 @@
package main
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestInterceptor_302BecomesNot401(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "https://idp.example/authorize?state=abc")
w.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/; HttpOnly")
w.Header().Add("Set-Cookie", "_oidc_pkce=xyz; Path=/; HttpOnly")
w.WriteHeader(http.StatusFound)
_, _ = w.Write([]byte("ignored body"))
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: want 401, got %d", rec.Code)
}
if got := rec.Header().Get("X-Auth-Redirect"); got != "https://idp.example/authorize?state=abc" {
t.Errorf("X-Auth-Redirect: want preserved Location, got %q", got)
}
if got := rec.Header().Get("Location"); got != "" {
t.Errorf("Location must be stripped on 401, got %q", got)
}
cookies := rec.Header().Values("Set-Cookie")
if len(cookies) != 2 {
t.Fatalf("Set-Cookie count: want 2, got %d (%v)", len(cookies), cookies)
}
if body := strings.TrimSpace(rec.Body.String()); body != "" {
t.Errorf("body must be empty on 401, got %q", body)
}
}
func TestInterceptor_NonRedirectPassthrough(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("X-Forwarded-User", "alice")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
w.Finalize()
if rec.Code != http.StatusOK {
t.Fatalf("status: want 200, got %d", rec.Code)
}
if got := rec.Header().Get("X-Forwarded-User"); got != "alice" {
t.Errorf("X-Forwarded-User: want preserved, got %q", got)
}
if !strings.Contains(rec.Body.String(), "ok") {
t.Errorf("body: want 'ok' preserved, got %q", rec.Body.String())
}
}
func TestInterceptor_303SeeOtherAlsoIntercepted(t *testing.T) {
rec := httptest.NewRecorder()
w := newAuthInterceptor(rec)
w.Header().Set("Location", "/elsewhere")
w.WriteHeader(http.StatusSeeOther)
w.Finalize()
if rec.Code != http.StatusUnauthorized {
t.Fatalf("303 should be intercepted to 401, got %d", rec.Code)
}
}