Files
traefikoidc/mocks_test.go
T
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00

528 lines
12 KiB
Go

package traefikoidc
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/sessions"
)
// MockOAuthProvider simulates an OAuth/OIDC provider for testing
type MockOAuthProvider struct {
TokenEndpoint string
AuthEndpoint string
JWKSEndpoint string
RevokeEndpoint string
EndSessionEndpoint string
// Configurable behaviors
TokenExchangeFunc func(grantType, codeOrToken, redirectURL, codeVerifier string) (*TokenResponse, error)
RefreshTokenFunc func(refreshToken string) (*TokenResponse, error)
RevokeTokenFunc func(token, tokenType string) error
JWKSResponseFunc func() ([]byte, error)
// Simulation flags
SimulateTimeout bool
SimulateRateLimit bool
SimulateServerError bool
TimeoutDuration time.Duration
ResponseDelay time.Duration
// Request tracking
RequestCount int32
LastRequest *http.Request
LastRequestBody []byte
RequestHistory []*http.Request
mu sync.Mutex
}
// NewMockOAuthProvider creates a new mock OAuth provider with default endpoints
func NewMockOAuthProvider() *MockOAuthProvider {
return &MockOAuthProvider{
TokenEndpoint: "https://mock-provider.example.com/token",
AuthEndpoint: "https://mock-provider.example.com/auth",
JWKSEndpoint: "https://mock-provider.example.com/.well-known/jwks.json",
RevokeEndpoint: "https://mock-provider.example.com/revoke",
EndSessionEndpoint: "https://mock-provider.example.com/logout",
TimeoutDuration: 30 * time.Second,
}
}
// ServeHTTP handles HTTP requests to the mock provider
func (m *MockOAuthProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
atomic.AddInt32(&m.RequestCount, 1)
m.mu.Lock()
m.LastRequest = r
if r.Body != nil {
body, _ := io.ReadAll(r.Body)
m.LastRequestBody = body
r.Body = io.NopCloser(strings.NewReader(string(body)))
}
m.RequestHistory = append(m.RequestHistory, r)
m.mu.Unlock()
// Simulate delays
if m.ResponseDelay > 0 {
time.Sleep(m.ResponseDelay)
}
// Simulate timeout
if m.SimulateTimeout {
time.Sleep(m.TimeoutDuration)
return
}
// Simulate rate limiting
if m.SimulateRateLimit {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error": "rate_limit_exceeded"}`))
return
}
// Simulate server error
if m.SimulateServerError {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{"error": "internal_server_error"}`))
return
}
// Route to appropriate handler
switch {
case strings.Contains(r.URL.Path, "/token"):
m.handleTokenRequest(w, r)
case strings.Contains(r.URL.Path, "/jwks"):
m.handleJWKSRequest(w, r)
case strings.Contains(r.URL.Path, "/revoke"):
m.handleRevokeRequest(w, r)
default:
w.WriteHeader(http.StatusNotFound)
}
}
func (m *MockOAuthProvider) handleTokenRequest(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
grantType := values.Get("grant_type")
var response *TokenResponse
var err error
if grantType == "authorization_code" {
code := values.Get("code")
redirectURL := values.Get("redirect_uri")
codeVerifier := values.Get("code_verifier")
if m.TokenExchangeFunc != nil {
response, err = m.TokenExchangeFunc(grantType, code, redirectURL, codeVerifier)
} else {
// Default successful response
response = &TokenResponse{
AccessToken: "mock_access_token",
IDToken: "mock_id_token",
RefreshToken: "mock_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
}
}
} else if grantType == "refresh_token" {
refreshToken := values.Get("refresh_token")
if m.RefreshTokenFunc != nil {
response, err = m.RefreshTokenFunc(refreshToken)
} else {
// Default successful refresh response
response = &TokenResponse{
AccessToken: "new_mock_access_token",
IDToken: "new_mock_id_token",
RefreshToken: "new_mock_refresh_token",
TokenType: "Bearer",
ExpiresIn: 3600,
}
}
}
if err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_grant",
"error_description": err.Error(),
})
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(response)
}
func (m *MockOAuthProvider) handleJWKSRequest(w http.ResponseWriter, r *http.Request) {
var response []byte
var err error
if m.JWKSResponseFunc != nil {
response, err = m.JWKSResponseFunc()
} else {
// Default JWKS response
response = []byte(`{
"keys": [
{
"kty": "RSA",
"use": "sig",
"kid": "test-key-1",
"n": "test-modulus",
"e": "AQAB"
}
]
}`)
}
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(response)
}
func (m *MockOAuthProvider) handleRevokeRequest(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
values, _ := url.ParseQuery(string(body))
token := values.Get("token")
tokenType := values.Get("token_type_hint")
if m.RevokeTokenFunc != nil {
if err := m.RevokeTokenFunc(token, tokenType); err != nil {
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_token",
})
return
}
}
w.WriteHeader(http.StatusOK)
}
// GetRequestCount returns the number of requests received
func (m *MockOAuthProvider) GetRequestCount() int {
return int(atomic.LoadInt32(&m.RequestCount))
}
// Reset resets the mock provider state
func (m *MockOAuthProvider) Reset() {
atomic.StoreInt32(&m.RequestCount, 0)
m.mu.Lock()
m.LastRequest = nil
m.LastRequestBody = nil
m.RequestHistory = nil
m.mu.Unlock()
m.SimulateTimeout = false
m.SimulateRateLimit = false
m.SimulateServerError = false
}
// MockSessionManager implements a mock session manager for testing
type MockSessionManager struct {
Sessions map[string]*SessionData
mu sync.RWMutex
// Configurable behaviors
GetSessionFunc func(r *http.Request) (*SessionData, error)
SaveSessionFunc func(r *http.Request, w http.ResponseWriter, session *SessionData) error
DeleteSessionFunc func(r *http.Request, w http.ResponseWriter) error
// Simulation flags
SimulateError bool
SimulateNotFound bool
// Tracking
GetCallCount int32
SaveCallCount int32
DeleteCallCount int32
}
// NewMockSessionManager creates a new mock session manager
func NewMockSessionManager() *MockSessionManager {
return &MockSessionManager{
Sessions: make(map[string]*SessionData),
}
}
// GetSession retrieves a session
func (m *MockSessionManager) GetSession(r *http.Request) (*SessionData, error) {
atomic.AddInt32(&m.GetCallCount, 1)
if m.GetSessionFunc != nil {
return m.GetSessionFunc(r)
}
if m.SimulateError {
return nil, errors.New("session error")
}
if m.SimulateNotFound {
return nil, nil
}
// Default implementation using a simple cookie
cookie, err := r.Cookie("session_id")
if err != nil {
return nil, nil
}
m.mu.RLock()
session, exists := m.Sessions[cookie.Value]
m.mu.RUnlock()
if !exists {
return nil, nil
}
return session, nil
}
// SaveSession saves a session
func (m *MockSessionManager) SaveSession(r *http.Request, w http.ResponseWriter, session *SessionData) error {
atomic.AddInt32(&m.SaveCallCount, 1)
if m.SaveSessionFunc != nil {
return m.SaveSessionFunc(r, w, session)
}
if m.SimulateError {
return errors.New("save error")
}
// Generate session ID
sessionID := fmt.Sprintf("session_%d", time.Now().UnixNano())
m.mu.Lock()
m.Sessions[sessionID] = session
m.mu.Unlock()
// Set cookie
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: sessionID,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
})
return nil
}
// DeleteSession deletes a session
func (m *MockSessionManager) DeleteSession(r *http.Request, w http.ResponseWriter) error {
atomic.AddInt32(&m.DeleteCallCount, 1)
if m.DeleteSessionFunc != nil {
return m.DeleteSessionFunc(r, w)
}
cookie, err := r.Cookie("session_id")
if err != nil {
return nil
}
m.mu.Lock()
delete(m.Sessions, cookie.Value)
m.mu.Unlock()
// Clear cookie
http.SetCookie(w, &http.Cookie{
Name: "session_id",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: true,
})
return nil
}
// Reset resets the mock session manager
func (m *MockSessionManager) Reset() {
m.mu.Lock()
m.Sessions = make(map[string]*SessionData)
m.mu.Unlock()
atomic.StoreInt32(&m.GetCallCount, 0)
atomic.StoreInt32(&m.SaveCallCount, 0)
atomic.StoreInt32(&m.DeleteCallCount, 0)
m.SimulateError = false
m.SimulateNotFound = false
}
// MockHTTPClient implements a mock HTTP client for testing
type MockHTTPClient struct {
// Response configuration
ResponseFunc func(req *http.Request) (*http.Response, error)
// Default response settings
DefaultStatusCode int
DefaultBody string
DefaultHeaders map[string]string
// Simulation flags
SimulateTimeout bool
SimulateError bool
TimeoutDuration time.Duration
// Request tracking
Requests []*http.Request
RequestBodies [][]byte
mu sync.Mutex
}
// NewMockHTTPClient creates a new mock HTTP client
func NewMockHTTPClient() *MockHTTPClient {
return &MockHTTPClient{
DefaultStatusCode: http.StatusOK,
DefaultHeaders: make(map[string]string),
TimeoutDuration: 30 * time.Second,
}
}
// Do executes a mock HTTP request
func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) {
m.mu.Lock()
m.Requests = append(m.Requests, req)
if req.Body != nil {
body, _ := io.ReadAll(req.Body)
m.RequestBodies = append(m.RequestBodies, body)
req.Body = io.NopCloser(strings.NewReader(string(body)))
}
m.mu.Unlock()
// Simulate timeout
if m.SimulateTimeout {
ctx, cancel := context.WithTimeout(req.Context(), m.TimeoutDuration)
defer cancel()
<-ctx.Done()
return nil, context.DeadlineExceeded
}
// Simulate error
if m.SimulateError {
return nil, errors.New("http client error")
}
// Use custom response function if provided
if m.ResponseFunc != nil {
return m.ResponseFunc(req)
}
// Default response
resp := &http.Response{
StatusCode: m.DefaultStatusCode,
Header: make(http.Header),
Request: req,
}
// Set headers
for k, v := range m.DefaultHeaders {
resp.Header.Set(k, v)
}
// Set body
if m.DefaultBody != "" {
resp.Body = io.NopCloser(strings.NewReader(m.DefaultBody))
resp.ContentLength = int64(len(m.DefaultBody))
} else {
resp.Body = io.NopCloser(strings.NewReader(""))
}
return resp, nil
}
// Reset resets the mock HTTP client
func (m *MockHTTPClient) Reset() {
m.mu.Lock()
m.Requests = nil
m.RequestBodies = nil
m.mu.Unlock()
m.SimulateTimeout = false
m.SimulateError = false
}
// GetRequestCount returns the number of requests made
func (m *MockHTTPClient) GetRequestCount() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.Requests)
}
// Note: MockTokenExchanger is already defined in main_test.go
// These mock types are provided for additional testing scenarios
// CreateTestHTTPServer creates a test HTTP server with the given handler
func CreateTestHTTPServer(handler http.Handler) *httptest.Server {
return httptest.NewServer(handler)
}
// CreateTestHTTPSServer creates a test HTTPS server with the given handler
func CreateTestHTTPSServer(handler http.Handler) *httptest.Server {
return httptest.NewTLSServer(handler)
}
// CreateMockSessionData creates a mock SessionData for testing
func CreateMockSessionData() *SessionData {
return &SessionData{
mainSession: nil,
accessSession: nil,
refreshSession: nil,
idTokenSession: nil,
accessTokenChunks: make(map[int]*sessions.Session),
refreshTokenChunks: make(map[int]*sessions.Session),
idTokenChunks: make(map[int]*sessions.Session),
}
}
// MockRoundTripper implements http.RoundTripper for testing
type MockRoundTripper struct {
RoundTripFunc func(req *http.Request) (*http.Response, error)
Requests []*http.Request
mu sync.Mutex
}
// RoundTrip executes a mock HTTP round trip
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
m.mu.Lock()
m.Requests = append(m.Requests, req)
m.mu.Unlock()
if m.RoundTripFunc != nil {
return m.RoundTripFunc(req)
}
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("")),
Header: make(http.Header),
Request: req,
}, nil
}
// Reset resets the mock round tripper
func (m *MockRoundTripper) Reset() {
m.mu.Lock()
m.Requests = nil
m.mu.Unlock()
}