diff --git a/cmd/oidcgate/endpoints.go b/cmd/oidcgate/endpoints.go new file mode 100644 index 0000000..a577910 --- /dev/null +++ b/cmd/oidcgate/endpoints.go @@ -0,0 +1,67 @@ +package main + +import "net/http" + +// sentinelPath is the synthetic request path used when delegating /oauth2/auth +// and /oauth2/start into the traefikoidc middleware. It must NOT collide with +// callbackURL, logoutURL, /health*, or any plausible excludedURLs entry — +// the underscores and double-prefixing make accidental matches near-impossible. +const sentinelPath = "/__oidcgate_protected__" + +// newAuthHandler builds the /oauth2/auth (silent probe) handler. +// Rewrites the request path to sentinelPath, wraps the ResponseWriter to +// convert the middleware's 302→IdP into 401, and delegates. +func newAuthHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + ic := newAuthInterceptor(rw) + defer ic.Finalize() + r2 := cloneAndRewrite(req, sentinelPath) + next.ServeHTTP(ic, r2) + }) +} + +// newStartHandler builds the /oauth2/start (visible sign-in) handler. +// Rewrites the path to sentinelPath, forwards any ?rd= query as +// X-Forwarded-Uri so the middleware (with TrustForwardedURI=true) captures +// the right post-login redirect target, then delegates. The middleware's +// natural 302→IdP flows through unchanged. +func newStartHandler(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + r2 := cloneAndRewrite(req, sentinelPath) + if rd := req.URL.Query().Get("rd"); rd != "" && r2.Header.Get("X-Forwarded-Uri") == "" { + r2.Header.Set("X-Forwarded-Uri", rd) + } + next.ServeHTTP(rw, r2) + }) +} + +// newCallbackHandler builds the IdP callback endpoint. +// Rewrites the request path to the configured callbackURL so the middleware's +// path-match at the top of ServeHTTP triggers the callback flow. +func newCallbackHandler(next http.Handler, callbackURL string) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + r2 := cloneAndRewrite(req, callbackURL) + next.ServeHTTP(rw, r2) + }) +} + +// newLogoutHandler builds the logout endpoint. +// Rewrites the request path to the configured logoutURL so the middleware's +// path-match at the top of ServeHTTP triggers the logout flow. +func newLogoutHandler(next http.Handler, logoutURL string) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + r2 := cloneAndRewrite(req, logoutURL) + next.ServeHTTP(rw, r2) + }) +} + +// cloneAndRewrite returns a shallow clone of req with URL.Path set to newPath. +// The query string is preserved verbatim — middleware logic for code/state +// extraction reads URL.Query() which still works. +func cloneAndRewrite(req *http.Request, newPath string) *http.Request { + r2 := req.Clone(req.Context()) + u := *req.URL + u.Path = newPath + r2.URL = &u + return r2 +} diff --git a/cmd/oidcgate/endpoints_test.go b/cmd/oidcgate/endpoints_test.go new file mode 100644 index 0000000..c5c1e02 --- /dev/null +++ b/cmd/oidcgate/endpoints_test.go @@ -0,0 +1,150 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// stubMiddleware lets us test endpoint wiring without spinning up a full +// traefikoidc instance. Each test injects the behavior it wants. +type stubMiddleware struct { + calls []stubCall + fn func(rw http.ResponseWriter, req *http.Request) +} + +type stubCall struct { + path string + header http.Header +} + +func (s *stubMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + s.calls = append(s.calls, stubCall{path: req.URL.Path, header: req.Header.Clone()}) + if s.fn != nil { + s.fn(rw, req) + } +} + +func TestAuth_RewritesToSentinel_AndConverts302To401(t *testing.T) { + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Location", "https://idp.example/authorize?state=abc") + rw.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/") + rw.WriteHeader(http.StatusFound) + }, + } + h := newAuthHandler(stub) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil) + req.Header.Set("X-Forwarded-Uri", "/protected/page") + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: want 401, got %d", rec.Code) + } + if len(stub.calls) != 1 || stub.calls[0].path != sentinelPath { + t.Fatalf("middleware path: want %q, got %v", sentinelPath, stub.calls) + } + if rec.Header().Get("X-Auth-Redirect") == "" { + t.Error("X-Auth-Redirect should carry Location") + } +} + +func TestAuth_AuthenticatedReturnsHeadersAnd200(t *testing.T) { + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { + // Middleware would stamp X-Forwarded-User on req then call next. + req.Header.Set("X-Forwarded-User", "alice") + newSuccessHandler().ServeHTTP(rw, req) + }, + } + h := newAuthHandler(stub) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil) + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status: want 200, got %d", rec.Code) + } + if got := rec.Header().Get("X-Forwarded-User"); got != "alice" { + t.Errorf("X-Forwarded-User mirrored: want alice, got %q", got) + } +} + +func TestStart_DelegatesWithSentinel_NoInterception(t *testing.T) { + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("Location", "https://idp.example/authorize") + rw.WriteHeader(http.StatusFound) + }, + } + h := newStartHandler(stub) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back", nil) + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusFound { + t.Fatalf("start: 302 must flow through, got %d", rec.Code) + } + if stub.calls[0].path != sentinelPath { + t.Fatalf("start path rewrite: want %q, got %q", sentinelPath, stub.calls[0].path) + } +} + +func TestStart_ForwardsRdAsXForwardedURI(t *testing.T) { + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusFound) }, + } + h := newStartHandler(stub) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back/here", nil) + h.ServeHTTP(rec, req) + if got := stub.calls[0].header.Get("X-Forwarded-Uri"); got != "/back/here" { + t.Fatalf("?rd should become X-Forwarded-Uri: want /back/here, got %q", got) + } +} + +func TestCallback_RewritesToConfiguredCallbackURL(t *testing.T) { + var seenPath, seenQuery string + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { + seenPath = req.URL.Path + seenQuery = req.URL.RawQuery + rw.WriteHeader(http.StatusOK) + }, + } + h := newCallbackHandler(stub, "/oauth2/callback") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/oauth2/callback?code=abc&state=xyz", nil) + h.ServeHTTP(rec, req) + + if seenPath != "/oauth2/callback" { + t.Fatalf("callback path: want /oauth2/callback, got %q", seenPath) + } + if seenQuery != "code=abc&state=xyz" { + t.Fatalf("callback query must survive rewrite: want code=abc&state=xyz, got %q", seenQuery) + } +} + +func TestLogout_RewritesToConfiguredLogoutURL(t *testing.T) { + var seenPath string + stub := &stubMiddleware{ + fn: func(rw http.ResponseWriter, req *http.Request) { + seenPath = req.URL.Path + rw.WriteHeader(http.StatusOK) + }, + } + h := newLogoutHandler(stub, "/oauth2/logout") + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/oauth2/logout", nil) + h.ServeHTTP(rec, req) + + if seenPath != "/oauth2/logout" { + t.Fatalf("logout path: want /oauth2/logout, got %q", seenPath) + } +}