mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
9cbca4c4fb
patch-release
The refresh path in token_manager.go hardcoded the "email" claim when
extracting the user identifier from a refreshed ID token, ignoring the
configured userIdentifierClaim. Keycloak users without an email claim
(using sub or another identifier) were kicked out on refresh even
though their initial login worked.
The callback path (auth_flow.go:226-239) already honored
userIdentifierClaim with "sub" fallback; PR #100 (commit a316a98)
added that support but missed the refresh path.
Mirror the callback logic in refreshToken so both paths behave the same.
Cleanup: rename Get/SetEmail to Get/SetUserIdentifier on SessionData
to match the actual semantics. The slot already stored the configured
identifier (email, sub, oid, upn, preferred_username), only the API
name was misleading. Storage key "email" → "user_identifier" and
combinedSessionPayload field E (json:"e") → Ui (json:"ui").
Compat note: existing user sessions invalidate on upgrade — every active
user re-authenticates once after deploying this change.
499 lines
12 KiB
Go
499 lines
12 KiB
Go
package traefikoidc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// TestFramework provides a unified testing framework for the OIDC middleware
|
|
type TestFramework struct {
|
|
t *testing.T
|
|
server *httptest.Server
|
|
oidc *TraefikOidc
|
|
config *Config
|
|
mocks *TestMocks
|
|
fixtures *TestFixtures
|
|
privateKey *rsa.PrivateKey
|
|
publicKey *rsa.PublicKey
|
|
cleanup []func()
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// TestMocks contains all mock implementations
|
|
type TestMocks struct {
|
|
JWKCache *MockJWKCache
|
|
TokenVerifier *MockTokenVerifier
|
|
TokenExchanger *MockTokenExchanger
|
|
JWTVerifier *MockJWTVerifier
|
|
HTTPClient *http.Client
|
|
Provider interface{}
|
|
}
|
|
|
|
// TestFixtures contains reusable test data
|
|
type TestFixtures struct {
|
|
ValidJWT string
|
|
ExpiredJWT string
|
|
InvalidJWT string
|
|
RefreshToken string
|
|
AccessToken string
|
|
IDToken string
|
|
Claims map[string]interface{}
|
|
UserEmail string
|
|
UserSub string
|
|
ClientID string
|
|
ClientSecret string
|
|
ProviderURL string
|
|
CallbackURL string
|
|
EncryptionKey string
|
|
Nonce string
|
|
State string
|
|
CodeVerifier string
|
|
CodeChallenge string
|
|
AuthCode string
|
|
}
|
|
|
|
// NewTestFramework creates a new test framework instance
|
|
func NewTestFramework(t *testing.T) *TestFramework {
|
|
privateKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
tf := &TestFramework{
|
|
t: t,
|
|
privateKey: privateKey,
|
|
publicKey: &privateKey.PublicKey,
|
|
mocks: &TestMocks{},
|
|
fixtures: generateTestFixtures(),
|
|
cleanup: make([]func(), 0),
|
|
}
|
|
|
|
// Register cleanup
|
|
t.Cleanup(tf.Cleanup)
|
|
|
|
return tf
|
|
}
|
|
|
|
// generateTestFixtures creates standard test data
|
|
func generateTestFixtures() *TestFixtures {
|
|
return &TestFixtures{
|
|
UserEmail: "test@example.com",
|
|
UserSub: "test-user-123",
|
|
ClientID: "test-client-id",
|
|
ClientSecret: "test-client-secret",
|
|
ProviderURL: "https://provider.example.com",
|
|
CallbackURL: "/callback",
|
|
EncryptionKey: "test-encryption-key-32-bytes-long!!",
|
|
Nonce: "test-nonce-123",
|
|
State: "test-state-456",
|
|
AuthCode: "test-auth-code",
|
|
RefreshToken: "test-refresh-token",
|
|
AccessToken: "test-access-token",
|
|
Claims: map[string]interface{}{
|
|
"email": "test@example.com",
|
|
"sub": "test-user-123",
|
|
"name": "Test User",
|
|
"iat": time.Now().Unix(),
|
|
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
|
},
|
|
}
|
|
}
|
|
|
|
// SetupOIDC creates a configured OIDC middleware instance for testing
|
|
func (tf *TestFramework) SetupOIDC(customConfig ...*Config) *TraefikOidc {
|
|
tf.mu.Lock()
|
|
defer tf.mu.Unlock()
|
|
|
|
config := tf.GetDefaultConfig()
|
|
if len(customConfig) > 0 && customConfig[0] != nil {
|
|
config = customConfig[0]
|
|
}
|
|
|
|
tf.config = config
|
|
|
|
// Create OIDC instance
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("authenticated"))
|
|
})
|
|
|
|
oidc, err := New(context.Background(), nextHandler, config, "test")
|
|
if err != nil {
|
|
tf.t.Fatalf("Failed to create OIDC middleware: %v", err)
|
|
}
|
|
|
|
tf.oidc = oidc.(*TraefikOidc)
|
|
|
|
// Override with mocks if configured
|
|
if tf.mocks.TokenVerifier != nil {
|
|
tf.oidc.tokenVerifier = tf.mocks.TokenVerifier
|
|
}
|
|
if tf.mocks.TokenExchanger != nil {
|
|
tf.oidc.tokenExchanger = tf.mocks.TokenExchanger
|
|
}
|
|
|
|
tf.AddCleanup(func() {
|
|
if tf.oidc != nil {
|
|
tf.oidc.Close()
|
|
}
|
|
})
|
|
|
|
return tf.oidc
|
|
}
|
|
|
|
// SetupMockProvider creates a mock OIDC provider server
|
|
func (tf *TestFramework) SetupMockProvider() *httptest.Server {
|
|
tf.mu.Lock()
|
|
defer tf.mu.Unlock()
|
|
|
|
mux := http.NewServeMux()
|
|
|
|
// Well-known configuration endpoint
|
|
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
|
metadata := map[string]interface{}{
|
|
"issuer": tf.fixtures.ProviderURL,
|
|
"authorization_endpoint": tf.fixtures.ProviderURL + "/authorize",
|
|
"token_endpoint": tf.fixtures.ProviderURL + "/token",
|
|
"jwks_uri": tf.fixtures.ProviderURL + "/jwks",
|
|
"userinfo_endpoint": tf.fixtures.ProviderURL + "/userinfo",
|
|
"end_session_endpoint": tf.fixtures.ProviderURL + "/logout",
|
|
}
|
|
json.NewEncoder(w).Encode(metadata)
|
|
})
|
|
|
|
// JWKS endpoint
|
|
mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) {
|
|
jwks := tf.GenerateJWKS()
|
|
json.NewEncoder(w).Encode(jwks)
|
|
})
|
|
|
|
// Token endpoint
|
|
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
|
response := map[string]interface{}{
|
|
"access_token": tf.fixtures.AccessToken,
|
|
"refresh_token": tf.fixtures.RefreshToken,
|
|
"id_token": tf.GenerateJWT(tf.fixtures.Claims),
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
}
|
|
json.NewEncoder(w).Encode(response)
|
|
})
|
|
|
|
// UserInfo endpoint
|
|
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
|
json.NewEncoder(w).Encode(tf.fixtures.Claims)
|
|
})
|
|
|
|
server := httptest.NewServer(mux)
|
|
tf.server = server
|
|
tf.fixtures.ProviderURL = server.URL
|
|
|
|
tf.AddCleanup(server.Close)
|
|
|
|
return server
|
|
}
|
|
|
|
// GetDefaultConfig returns a default test configuration
|
|
func (tf *TestFramework) GetDefaultConfig() *Config {
|
|
return &Config{
|
|
ProviderURL: tf.fixtures.ProviderURL,
|
|
ClientID: tf.fixtures.ClientID,
|
|
ClientSecret: tf.fixtures.ClientSecret,
|
|
CallbackURL: tf.fixtures.CallbackURL,
|
|
SessionEncryptionKey: tf.fixtures.EncryptionKey,
|
|
LogLevel: "debug",
|
|
ForceHTTPS: false,
|
|
Scopes: []string{"openid", "email", "profile"},
|
|
RateLimit: 100,
|
|
}
|
|
}
|
|
|
|
// GenerateJWT creates a test JWT with the given claims
|
|
func (tf *TestFramework) GenerateJWT(claims map[string]interface{}) string {
|
|
tokenString, _ := createTestJWT(tf.privateKey, "RS256", "test-key", claims)
|
|
return tokenString
|
|
}
|
|
|
|
// GenerateExpiredJWT creates an expired JWT for testing
|
|
func (tf *TestFramework) GenerateExpiredJWT() string {
|
|
claims := make(map[string]interface{})
|
|
for k, v := range tf.fixtures.Claims {
|
|
claims[k] = v
|
|
}
|
|
claims["exp"] = time.Now().Add(-1 * time.Hour).Unix()
|
|
return tf.GenerateJWT(claims)
|
|
}
|
|
|
|
// GenerateInvalidJWT creates an invalid JWT for testing
|
|
func (tf *TestFramework) GenerateInvalidJWT() string {
|
|
return "invalid.jwt.token"
|
|
}
|
|
|
|
// GenerateJWKS creates a JWKS response
|
|
func (tf *TestFramework) GenerateJWKS() map[string]interface{} {
|
|
n := base64.RawURLEncoding.EncodeToString(tf.publicKey.N.Bytes())
|
|
e := base64.RawURLEncoding.EncodeToString([]byte{1, 0, 1})
|
|
|
|
return map[string]interface{}{
|
|
"keys": []map[string]interface{}{
|
|
{
|
|
"kty": "RSA",
|
|
"use": "sig",
|
|
"kid": "test-key-id",
|
|
"n": n,
|
|
"e": e,
|
|
"alg": "RS256",
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// CreateRequest creates a test HTTP request
|
|
func (tf *TestFramework) CreateRequest(method, path string, body ...string) *http.Request {
|
|
var bodyReader *strings.Reader
|
|
if len(body) > 0 {
|
|
bodyReader = strings.NewReader(body[0])
|
|
} else {
|
|
bodyReader = strings.NewReader("")
|
|
}
|
|
|
|
req := httptest.NewRequest(method, path, bodyReader)
|
|
req.Header.Set("User-Agent", "test-agent")
|
|
return req
|
|
}
|
|
|
|
// CreateAuthenticatedRequest creates a request with session cookies
|
|
func (tf *TestFramework) CreateAuthenticatedRequest(method, path string) (*http.Request, *httptest.ResponseRecorder) {
|
|
req := tf.CreateRequest(method, path)
|
|
rw := httptest.NewRecorder()
|
|
|
|
// Create session
|
|
sessionManager, err := NewSessionManager(
|
|
tf.fixtures.EncryptionKey,
|
|
false,
|
|
"",
|
|
"",
|
|
0,
|
|
tf.oidc.logger,
|
|
)
|
|
if err != nil {
|
|
tf.t.Fatalf("Error: %v", err)
|
|
}
|
|
|
|
session, err := sessionManager.GetSession(req)
|
|
if err != nil {
|
|
tf.t.Fatalf("Error: %v", err)
|
|
}
|
|
|
|
session.SetAuthenticated(true)
|
|
session.SetUserIdentifier(tf.fixtures.UserEmail)
|
|
session.SetAccessToken(tf.fixtures.AccessToken)
|
|
session.SetRefreshToken(tf.fixtures.RefreshToken)
|
|
session.SetIDToken(tf.GenerateJWT(tf.fixtures.Claims))
|
|
|
|
err = session.Save(req, rw)
|
|
if err != nil {
|
|
tf.t.Fatalf("Error: %v", err)
|
|
}
|
|
|
|
// Copy cookies to request
|
|
for _, cookie := range rw.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
return req, httptest.NewRecorder()
|
|
}
|
|
|
|
// CreateCallbackRequest creates an OAuth callback request
|
|
func (tf *TestFramework) CreateCallbackRequest() *http.Request {
|
|
values := url.Values{
|
|
"code": {tf.fixtures.AuthCode},
|
|
"state": {tf.fixtures.State},
|
|
}
|
|
|
|
req := tf.CreateRequest("GET", tf.fixtures.CallbackURL+"?"+values.Encode())
|
|
|
|
// Add session with state
|
|
sessionManager, _ := NewSessionManager(
|
|
tf.fixtures.EncryptionKey,
|
|
false,
|
|
"",
|
|
"",
|
|
0,
|
|
tf.oidc.logger,
|
|
)
|
|
|
|
session, _ := sessionManager.GetSession(req)
|
|
session.SetCSRF(tf.fixtures.State)
|
|
session.SetNonce(tf.fixtures.Nonce)
|
|
|
|
rw := httptest.NewRecorder()
|
|
session.Save(req, rw)
|
|
|
|
for _, cookie := range rw.Result().Cookies() {
|
|
req.AddCookie(cookie)
|
|
}
|
|
|
|
return req
|
|
}
|
|
|
|
// AssertResponse validates HTTP response
|
|
func (tf *TestFramework) AssertResponse(rw *httptest.ResponseRecorder, expectedStatus int, contains ...string) {
|
|
if rw.Code != expectedStatus {
|
|
tf.t.Errorf("Unexpected status code: got %d, want %d", rw.Code, expectedStatus)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
for _, expected := range contains {
|
|
if !strings.Contains(body, expected) {
|
|
tf.t.Errorf("Response body missing expected content: %s", expected)
|
|
}
|
|
}
|
|
}
|
|
|
|
// AssertRedirect validates redirect response
|
|
func (tf *TestFramework) AssertRedirect(rw *httptest.ResponseRecorder, expectedLocation string) {
|
|
if rw.Code != http.StatusFound {
|
|
tf.t.Errorf("Expected redirect status, got %d", rw.Code)
|
|
}
|
|
location := rw.Header().Get("Location")
|
|
if strings.HasPrefix(expectedLocation, "http") {
|
|
if location != expectedLocation {
|
|
tf.t.Errorf("Expected location %s, got %s", expectedLocation, location)
|
|
}
|
|
} else {
|
|
if !strings.Contains(location, expectedLocation) {
|
|
tf.t.Errorf("Location should contain %s, got %s", expectedLocation, location)
|
|
}
|
|
}
|
|
}
|
|
|
|
// AssertCookie validates response cookies
|
|
func (tf *TestFramework) AssertCookie(rw *httptest.ResponseRecorder, name string, exists bool) {
|
|
cookies := rw.Result().Cookies()
|
|
found := false
|
|
for _, cookie := range cookies {
|
|
if cookie.Name == name {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if exists {
|
|
if !found {
|
|
tf.t.Errorf("Cookie %s not found", name)
|
|
}
|
|
} else {
|
|
if found {
|
|
tf.t.Errorf("Cookie %s should not exist", name)
|
|
}
|
|
}
|
|
}
|
|
|
|
// AddCleanup registers a cleanup function
|
|
func (tf *TestFramework) AddCleanup(fn func()) {
|
|
tf.mu.Lock()
|
|
defer tf.mu.Unlock()
|
|
tf.cleanup = append(tf.cleanup, fn)
|
|
}
|
|
|
|
// Cleanup runs all registered cleanup functions
|
|
func (tf *TestFramework) Cleanup() {
|
|
tf.mu.Lock()
|
|
defer tf.mu.Unlock()
|
|
|
|
for i := len(tf.cleanup) - 1; i >= 0; i-- {
|
|
if tf.cleanup[i] != nil {
|
|
tf.cleanup[i]()
|
|
}
|
|
}
|
|
|
|
tf.cleanup = nil
|
|
}
|
|
|
|
// RunSubtest runs a subtest with the framework
|
|
func (tf *TestFramework) RunSubtest(name string, fn func()) {
|
|
tf.t.Run(name, func(t *testing.T) {
|
|
// Create sub-framework with shared resources
|
|
subTF := &TestFramework{
|
|
t: t,
|
|
server: tf.server,
|
|
oidc: tf.oidc,
|
|
config: tf.config,
|
|
mocks: tf.mocks,
|
|
fixtures: tf.fixtures,
|
|
privateKey: tf.privateKey,
|
|
publicKey: tf.publicKey,
|
|
cleanup: make([]func(), 0),
|
|
}
|
|
|
|
defer subTF.Cleanup()
|
|
|
|
// Set the current test framework for the function
|
|
currentTestFramework = subTF
|
|
fn()
|
|
currentTestFramework = nil
|
|
})
|
|
}
|
|
|
|
var currentTestFramework *TestFramework
|
|
|
|
// GetTestFramework returns the current test framework (for use in test functions)
|
|
func GetTestFramework() *TestFramework {
|
|
return currentTestFramework
|
|
}
|
|
|
|
// Mock implementations are defined in main_test.go and other test files
|
|
// The test framework uses the existing mock types
|
|
|
|
// TestScenarios provides common test scenarios
|
|
|
|
// TestScenario represents a test scenario
|
|
type TestScenario struct {
|
|
Setup func(*TestFramework)
|
|
Request func(*TestFramework) *http.Request
|
|
Validate func(*TestFramework, *httptest.ResponseRecorder)
|
|
Name string
|
|
ExpectedBody string
|
|
ExpectedStatus int
|
|
}
|
|
|
|
// RunScenarios executes a set of test scenarios
|
|
func (tf *TestFramework) RunScenarios(scenarios []TestScenario) {
|
|
for _, scenario := range scenarios {
|
|
tf.RunSubtest(scenario.Name, func() {
|
|
// Setup
|
|
if scenario.Setup != nil {
|
|
scenario.Setup(tf)
|
|
}
|
|
|
|
// Create request
|
|
req := scenario.Request(tf)
|
|
rw := httptest.NewRecorder()
|
|
|
|
// Execute
|
|
tf.oidc.ServeHTTP(rw, req)
|
|
|
|
// Validate
|
|
if scenario.ExpectedStatus > 0 {
|
|
tf.AssertResponse(rw, scenario.ExpectedStatus)
|
|
}
|
|
|
|
if scenario.ExpectedBody != "" {
|
|
tf.AssertResponse(rw, rw.Code, scenario.ExpectedBody)
|
|
}
|
|
|
|
if scenario.Validate != nil {
|
|
scenario.Validate(tf, rw)
|
|
}
|
|
})
|
|
}
|
|
}
|