mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
feat(oidcgate): auth/start/callback/logout endpoint handlers
This commit is contained in:
@@ -0,0 +1,67 @@
|
||||
package main
|
||||
|
||||
import "net/http"
|
||||
|
||||
// sentinelPath is the synthetic request path used when delegating /oauth2/auth
|
||||
// and /oauth2/start into the traefikoidc middleware. It must NOT collide with
|
||||
// callbackURL, logoutURL, /health*, or any plausible excludedURLs entry —
|
||||
// the underscores and double-prefixing make accidental matches near-impossible.
|
||||
const sentinelPath = "/__oidcgate_protected__"
|
||||
|
||||
// newAuthHandler builds the /oauth2/auth (silent probe) handler.
|
||||
// Rewrites the request path to sentinelPath, wraps the ResponseWriter to
|
||||
// convert the middleware's 302→IdP into 401, and delegates.
|
||||
func newAuthHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
ic := newAuthInterceptor(rw)
|
||||
defer ic.Finalize()
|
||||
r2 := cloneAndRewrite(req, sentinelPath)
|
||||
next.ServeHTTP(ic, r2)
|
||||
})
|
||||
}
|
||||
|
||||
// newStartHandler builds the /oauth2/start (visible sign-in) handler.
|
||||
// Rewrites the path to sentinelPath, forwards any ?rd= query as
|
||||
// X-Forwarded-Uri so the middleware (with TrustForwardedURI=true) captures
|
||||
// the right post-login redirect target, then delegates. The middleware's
|
||||
// natural 302→IdP flows through unchanged.
|
||||
func newStartHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
r2 := cloneAndRewrite(req, sentinelPath)
|
||||
if rd := req.URL.Query().Get("rd"); rd != "" && r2.Header.Get("X-Forwarded-Uri") == "" {
|
||||
r2.Header.Set("X-Forwarded-Uri", rd)
|
||||
}
|
||||
next.ServeHTTP(rw, r2)
|
||||
})
|
||||
}
|
||||
|
||||
// newCallbackHandler builds the IdP callback endpoint.
|
||||
// Rewrites the request path to the configured callbackURL so the middleware's
|
||||
// path-match at the top of ServeHTTP triggers the callback flow.
|
||||
func newCallbackHandler(next http.Handler, callbackURL string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
r2 := cloneAndRewrite(req, callbackURL)
|
||||
next.ServeHTTP(rw, r2)
|
||||
})
|
||||
}
|
||||
|
||||
// newLogoutHandler builds the logout endpoint.
|
||||
// Rewrites the request path to the configured logoutURL so the middleware's
|
||||
// path-match at the top of ServeHTTP triggers the logout flow.
|
||||
func newLogoutHandler(next http.Handler, logoutURL string) http.Handler {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
r2 := cloneAndRewrite(req, logoutURL)
|
||||
next.ServeHTTP(rw, r2)
|
||||
})
|
||||
}
|
||||
|
||||
// cloneAndRewrite returns a shallow clone of req with URL.Path set to newPath.
|
||||
// The query string is preserved verbatim — middleware logic for code/state
|
||||
// extraction reads URL.Query() which still works.
|
||||
func cloneAndRewrite(req *http.Request, newPath string) *http.Request {
|
||||
r2 := req.Clone(req.Context())
|
||||
u := *req.URL
|
||||
u.Path = newPath
|
||||
r2.URL = &u
|
||||
return r2
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// stubMiddleware lets us test endpoint wiring without spinning up a full
|
||||
// traefikoidc instance. Each test injects the behavior it wants.
|
||||
type stubMiddleware struct {
|
||||
calls []stubCall
|
||||
fn func(rw http.ResponseWriter, req *http.Request)
|
||||
}
|
||||
|
||||
type stubCall struct {
|
||||
path string
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (s *stubMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
s.calls = append(s.calls, stubCall{path: req.URL.Path, header: req.Header.Clone()})
|
||||
if s.fn != nil {
|
||||
s.fn(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth_RewritesToSentinel_AndConverts302To401(t *testing.T) {
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Location", "https://idp.example/authorize?state=abc")
|
||||
rw.Header().Add("Set-Cookie", "_oidc_state=abc; Path=/")
|
||||
rw.WriteHeader(http.StatusFound)
|
||||
},
|
||||
}
|
||||
h := newAuthHandler(stub)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil)
|
||||
req.Header.Set("X-Forwarded-Uri", "/protected/page")
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status: want 401, got %d", rec.Code)
|
||||
}
|
||||
if len(stub.calls) != 1 || stub.calls[0].path != sentinelPath {
|
||||
t.Fatalf("middleware path: want %q, got %v", sentinelPath, stub.calls)
|
||||
}
|
||||
if rec.Header().Get("X-Auth-Redirect") == "" {
|
||||
t.Error("X-Auth-Redirect should carry Location")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuth_AuthenticatedReturnsHeadersAnd200(t *testing.T) {
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) {
|
||||
// Middleware would stamp X-Forwarded-User on req then call next.
|
||||
req.Header.Set("X-Forwarded-User", "alice")
|
||||
newSuccessHandler().ServeHTTP(rw, req)
|
||||
},
|
||||
}
|
||||
h := newAuthHandler(stub)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/auth", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("status: want 200, got %d", rec.Code)
|
||||
}
|
||||
if got := rec.Header().Get("X-Forwarded-User"); got != "alice" {
|
||||
t.Errorf("X-Forwarded-User mirrored: want alice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_DelegatesWithSentinel_NoInterception(t *testing.T) {
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Location", "https://idp.example/authorize")
|
||||
rw.WriteHeader(http.StatusFound)
|
||||
},
|
||||
}
|
||||
h := newStartHandler(stub)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusFound {
|
||||
t.Fatalf("start: 302 must flow through, got %d", rec.Code)
|
||||
}
|
||||
if stub.calls[0].path != sentinelPath {
|
||||
t.Fatalf("start path rewrite: want %q, got %q", sentinelPath, stub.calls[0].path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStart_ForwardsRdAsXForwardedURI(t *testing.T) {
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusFound) },
|
||||
}
|
||||
h := newStartHandler(stub)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/start?rd=/back/here", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
if got := stub.calls[0].header.Get("X-Forwarded-Uri"); got != "/back/here" {
|
||||
t.Fatalf("?rd should become X-Forwarded-Uri: want /back/here, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallback_RewritesToConfiguredCallbackURL(t *testing.T) {
|
||||
var seenPath, seenQuery string
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) {
|
||||
seenPath = req.URL.Path
|
||||
seenQuery = req.URL.RawQuery
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
},
|
||||
}
|
||||
h := newCallbackHandler(stub, "/oauth2/callback")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/oauth2/callback?code=abc&state=xyz", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if seenPath != "/oauth2/callback" {
|
||||
t.Fatalf("callback path: want /oauth2/callback, got %q", seenPath)
|
||||
}
|
||||
if seenQuery != "code=abc&state=xyz" {
|
||||
t.Fatalf("callback query must survive rewrite: want code=abc&state=xyz, got %q", seenQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout_RewritesToConfiguredLogoutURL(t *testing.T) {
|
||||
var seenPath string
|
||||
stub := &stubMiddleware{
|
||||
fn: func(rw http.ResponseWriter, req *http.Request) {
|
||||
seenPath = req.URL.Path
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
},
|
||||
}
|
||||
h := newLogoutHandler(stub, "/oauth2/logout")
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/oauth2/logout", nil)
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if seenPath != "/oauth2/logout" {
|
||||
t.Fatalf("logout path: want /oauth2/logout, got %q", seenPath)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user