Files
traefikoidc/middleware/middleware_test.go
T
lukaszraczylo 1b49e133da Complete rebuild of the plugin
* Fix bug affecting Azure OIDC authentication ( and most likely others )

* Fixes issue #51

* Ensure that appended roles are unique. Update the documentation.

* Improvements targetting possible memory usage spikes.

* Additional fixes and cleanup

* Refactoring code to fix the issues identified by the users.

* Modernize run

* Fieldalignment

* Multiple changes to improve performance and reduce complexity.
- Optimise the errors and recovery.
- Deduplicate code in metadata cache.
- Remove unused performance monitoring code.
- Simplify session management and settings handling.

* Fix claims issue.

* Add ability to overwrite the default scopes in the settings file

* Well.. that escalated quickly.

Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ).

* Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work.

* Abstract the provider logic into a separate package.

* Additional micro fixes and cleanups.

* Simplify all the things.

* fixup! Simplify all the things.

* fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! Simplify all the things.

* fixup! fixup! fixup! fixup! Simplify all the things.

* ...

* Cleanup tests.

* fixup! Cleanup tests.

* fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! Cleanup tests.

* fixup! fixup! fixup! fixup! fixup! Cleanup tests.

* Issue #53: Fix CSRF token handling in reverse proxy

1.  HTTPS Detection Fixed (session.go:723)
- Now uses X-Forwarded-Proto header instead of r.URL.Scheme
- Properly detects HTTPS in reverse proxy environments
2.  SameSite Cookie Attribute Fixed
- Removed automatic SameSiteStrictMode for HTTPS (would break OAuth)
- Keeps SameSiteLaxMode to allow OAuth callbacks from external domains
- Only uses Strict for AJAX requests which don't involve OAuth redirects
3.  Cookie Domain Handling Fixed
- Now respects X-Forwarded-Host header for cookie domain
- Ensures cookies are set for the public domain, not internal proxy domain
4.  EnhanceSessionSecurity Properly Integrated
- Function is now actually called during session save
- Applies security enhancements without breaking OAuth flow

Why Issue #53 Failed Before:

1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back)
2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail
3. Cookie domain might have been wrong (internal vs public domain)

Why It Works Now:

1. Cookies are properly marked Secure for HTTPS
2. Uses SameSite=Lax to allow OAuth provider callbacks
3. Cookie domain uses public domain from X-Forwarded-Host
4. CSRF token persists through the entire OAuth flow

* Next set of enhancements together with memory usage improvements.

* Memory leak fixes and optimisations.

* CSRF and Cookie Domain fixes

* fixup! CSRF and Cookie Domain fixes

* Metadata cache leak fix + profiling

* fixup! Metadata cache leak fix + profiling

* Memory leaks hunting, part 1337.

* Further pursue of perfection.

* fixup! Further pursue of perfection.

* fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection.

* Clear race conditions

* fixup! Clear race conditions

* Weekend fun with memory leaks

* Splitting code into multiple files with reasonable testing coverage.

```
ok      github.com/lukaszraczylo/traefikoidc    117.017s        coverage: 72.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/auth       0.505s  coverage: 87.1% of statements
ok      github.com/lukaszraczylo/traefikoidc/circuit_breaker    0.283s  coverage: 99.0% of statements
        github.com/lukaszraczylo/traefikoidc/config             coverage: 0.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/handlers   0.349s  coverage: 98.2% of statements
ok      github.com/lukaszraczylo/traefikoidc/internal/providers (cached)        coverage: 94.3% of statements
ok      github.com/lukaszraczylo/traefikoidc/middleware 0.808s  coverage: 78.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/recovery   0.653s  coverage: 100.0% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/chunking   (cached)        coverage: 87.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/core       (cached)        coverage: 85.6% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/crypto     (cached)        coverage: 81.8% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/storage    (cached)        coverage: 93.5% of statements
ok      github.com/lukaszraczylo/traefikoidc/session/validators (cached)        coverage: 98.8% of statements
````

* fixup! Splitting code into multiple files with reasonable testing coverage.

* fixup! fixup! Splitting code into multiple files with reasonable testing coverage.

* Weekend fun with further optimisations.

* fixup! Weekend fun with further optimisations.

* fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations.

* Pre-release cleanup.

* Enhance test coverage.

* fixup! Enhance test coverage.

* fixup! fixup! Enhance test coverage.

* fixup! fixup! fixup! Enhance test coverage.
2025-09-18 11:01:30 +01:00

805 lines
24 KiB
Go

package middleware
import (
"errors"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
// TestUncoveredMiddlewareFunctions tests the functions with 0% coverage in middleware package
func TestUncoveredMiddlewareFunctions(t *testing.T) {
t.Run("generateNonce", func(t *testing.T) {
// This function currently returns an error in the stub implementation
nonce, err := generateNonce()
if err == nil {
t.Errorf("Expected generateNonce to return an error in stub implementation")
}
if nonce != "" {
t.Errorf("Expected generateNonce to return empty string, got %s", nonce)
}
// Verify the error message
expectedError := "generateNonce not implemented"
if err.Error() != expectedError {
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
}
})
t.Run("generateCodeVerifier", func(t *testing.T) {
// This function currently returns an error in the stub implementation
verifier, err := generateCodeVerifier()
if err == nil {
t.Errorf("Expected generateCodeVerifier to return an error in stub implementation")
}
if verifier != "" {
t.Errorf("Expected generateCodeVerifier to return empty string, got %s", verifier)
}
// Verify the error message
expectedError := "generateCodeVerifier not implemented"
if err.Error() != expectedError {
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
}
})
t.Run("deriveCodeChallenge", func(t *testing.T) {
// This function currently returns an error in the stub implementation
challenge, err := deriveCodeChallenge()
if err == nil {
t.Errorf("Expected deriveCodeChallenge to return an error in stub implementation")
}
if challenge != "" {
t.Errorf("Expected deriveCodeChallenge to return empty string, got %s", challenge)
}
// Verify the error message
expectedError := "deriveCodeChallenge not implemented"
if err.Error() != expectedError {
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
}
})
}
// TestBuildFullURLFunction tests the buildFullURL function that already has 100% coverage
// but this ensures we maintain that coverage and test edge cases
func TestBuildFullURLFunction(t *testing.T) {
t.Run("buildFullURL", func(t *testing.T) {
// Test basic URL building
scheme := "https"
host := "example.com"
path := "/callback"
url := buildFullURL(scheme, host, path)
expected := "https://example.com/callback"
if url != expected {
t.Errorf("Expected URL %s, got %s", expected, url)
}
// Test with path that doesn't start with / (function just concatenates)
url2 := buildFullURL(scheme, host, "callback")
expected2 := "https://example.comcallback"
if url2 != expected2 {
t.Errorf("Expected URL %s, got %s", expected2, url2)
}
// Test with empty path
url3 := buildFullURL(scheme, host, "")
expected3 := "https://example.com"
if url3 != expected3 {
t.Errorf("Expected URL %s, got %s", expected3, url3)
}
// Test with different schemes
url4 := buildFullURL("http", "localhost:8080", "/test")
expected4 := "http://localhost:8080/test"
if url4 != expected4 {
t.Errorf("Expected URL %s, got %s", expected4, url4)
}
// Test with special characters
url5 := buildFullURL("https", "api.example.com", "/v1/auth?redirect=true")
expected5 := "https://api.example.com/v1/auth?redirect=true"
if url5 != expected5 {
t.Errorf("Expected URL %s, got %s", expected5, url5)
}
// Test with empty components
url6 := buildFullURL("", "", "")
expected6 := "://"
if url6 != expected6 {
t.Errorf("Expected URL %s, got %s", expected6, url6)
}
// Test with port numbers
url7 := buildFullURL("http", "localhost:3000", "/admin")
expected7 := "http://localhost:3000/admin"
if url7 != expected7 {
t.Errorf("Expected URL %s, got %s", expected7, url7)
}
})
}
// Mock types for testing
type mockLogger struct {
logs []string
mu sync.Mutex
}
func (m *mockLogger) Debug(msg string) { m.log("DEBUG: " + msg) }
func (m *mockLogger) Debugf(format string, args ...interface{}) { m.log("DEBUG: " + format) }
func (m *mockLogger) Error(msg string) { m.log("ERROR: " + msg) }
func (m *mockLogger) Errorf(format string, args ...interface{}) { m.log("ERROR: " + format) }
func (m *mockLogger) Info(msg string) { m.log("INFO: " + msg) }
func (m *mockLogger) Infof(format string, args ...interface{}) { m.log("INFO: " + format) }
func (m *mockLogger) log(msg string) {
m.mu.Lock()
defer m.mu.Unlock()
m.logs = append(m.logs, msg)
}
type mockSessionManager struct {
getSessionFunc func(req *http.Request) (SessionData, error)
cleanupOldCookiesFunc func(rw http.ResponseWriter, req *http.Request)
}
func (m *mockSessionManager) CleanupOldCookies(rw http.ResponseWriter, req *http.Request) {
if m.cleanupOldCookiesFunc != nil {
m.cleanupOldCookiesFunc(rw, req)
}
}
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
if m.getSessionFunc != nil {
return m.getSessionFunc(req)
}
return nil, nil
}
type mockSessionData struct {
email string
accessToken string
idToken string
refreshToken string
clearFunc func(req *http.Request, rw http.ResponseWriter) error
resetRedirectCountFunc func()
}
func (m *mockSessionData) GetEmail() string { return m.email }
func (m *mockSessionData) GetAccessToken() string { return m.accessToken }
func (m *mockSessionData) GetIDToken() string { return m.idToken }
func (m *mockSessionData) GetRefreshToken() string { return m.refreshToken }
func (m *mockSessionData) Clear(req *http.Request, rw http.ResponseWriter) error {
if m.clearFunc != nil {
return m.clearFunc(req, rw)
}
return nil
}
func (m *mockSessionData) ResetRedirectCount() {
if m.resetRedirectCountFunc != nil {
m.resetRedirectCountFunc()
}
}
func (m *mockSessionData) returnToPoolSafely() {}
type mockAuthHandler struct {
initiateAuthFunc func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error))
}
func (m *mockAuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
if m.initiateAuthFunc != nil {
m.initiateAuthFunc(rw, req, session, redirectURL, generateNonce, generateCodeVerifier, deriveCodeChallenge)
}
}
type mockURLHelper struct {
determineExcludedFunc func(currentRequest string, excludedURLs map[string]struct{}) bool
determineSchemeFunc func(req *http.Request) string
determineHostFunc func(req *http.Request) string
}
func (m *mockURLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
if m.determineExcludedFunc != nil {
return m.determineExcludedFunc(currentRequest, excludedURLs)
}
return false
}
func (m *mockURLHelper) DetermineScheme(req *http.Request) string {
if m.determineSchemeFunc != nil {
return m.determineSchemeFunc(req)
}
return "https"
}
func (m *mockURLHelper) DetermineHost(req *http.Request) string {
if m.determineHostFunc != nil {
return m.determineHostFunc(req)
}
return "example.com"
}
type mockTokenVerifier struct {
verifyFunc func(token string) error
}
func (m *mockTokenVerifier) VerifyToken(token string) error {
if m.verifyFunc != nil {
return m.verifyFunc(token)
}
return nil
}
// TestStubFunctionsErrorBehavior tests error behaviors more thoroughly
func TestStubFunctionsErrorBehavior(t *testing.T) {
t.Run("generateNonce_multiple_calls", func(t *testing.T) {
// Test multiple calls to ensure consistent behavior
for i := 0; i < 3; i++ {
nonce, err := generateNonce()
if err == nil {
t.Errorf("Call %d: Expected generateNonce to return an error", i)
}
if nonce != "" {
t.Errorf("Call %d: Expected empty nonce, got %s", i, nonce)
}
}
})
t.Run("generateCodeVerifier_multiple_calls", func(t *testing.T) {
// Test multiple calls to ensure consistent behavior
for i := 0; i < 3; i++ {
verifier, err := generateCodeVerifier()
if err == nil {
t.Errorf("Call %d: Expected generateCodeVerifier to return an error", i)
}
if verifier != "" {
t.Errorf("Call %d: Expected empty verifier, got %s", i, verifier)
}
}
})
t.Run("deriveCodeChallenge_multiple_calls", func(t *testing.T) {
// Test multiple calls to ensure consistent behavior
for i := 0; i < 3; i++ {
challenge, err := deriveCodeChallenge()
if err == nil {
t.Errorf("Call %d: Expected deriveCodeChallenge to return an error", i)
}
if challenge != "" {
t.Errorf("Call %d: Expected empty challenge, got %s", i, challenge)
}
}
})
}
// TestHandleTokenRefresh tests the handleTokenRefresh method with various scenarios
func TestHandleTokenRefresh(t *testing.T) {
tests := []struct {
name string
needsRefresh bool
authenticated bool
isAjaxRequest bool
refreshSuccess bool
allowedDomain bool
expectErrorResponse bool
expectProcessAuthorized bool
expectInitAuth bool
}{
{
name: "successful_refresh_authenticated",
needsRefresh: true,
authenticated: true,
isAjaxRequest: false,
refreshSuccess: true,
allowedDomain: true,
expectProcessAuthorized: true,
},
{
name: "successful_refresh_not_authenticated",
needsRefresh: true,
authenticated: false,
isAjaxRequest: false,
refreshSuccess: true,
allowedDomain: true,
expectProcessAuthorized: true,
},
{
name: "successful_refresh_disallowed_domain",
needsRefresh: true,
authenticated: true,
isAjaxRequest: false,
refreshSuccess: true,
allowedDomain: false,
expectErrorResponse: true,
},
{
name: "failed_refresh_browser_request",
needsRefresh: true,
authenticated: true,
isAjaxRequest: false,
refreshSuccess: false,
expectInitAuth: true,
},
{
name: "failed_refresh_ajax_request",
needsRefresh: true,
authenticated: true,
isAjaxRequest: true,
refreshSuccess: false,
expectErrorResponse: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mocks
logger := &mockLogger{}
nextHandlerCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextHandlerCalled = true
w.WriteHeader(http.StatusOK)
})
session := &mockSessionData{
email: "test@example.com",
accessToken: "access_token",
idToken: "id_token",
refreshToken: "refresh_token",
}
initAuthCalled := false
errorResponseSent := false
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
logoutURLPath: "/logout",
refreshTokenFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData) bool {
return tt.refreshSuccess
},
isAllowedDomainFunc: func(email string) bool {
return tt.allowedDomain
},
isAjaxRequestFunc: func(req *http.Request) bool {
return tt.isAjaxRequest
},
sendErrorResponseFunc: func(rw http.ResponseWriter, req *http.Request, message string, code int) {
errorResponseSent = true
rw.WriteHeader(code)
},
authHandler: &mockAuthHandler{
initiateAuthFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
initAuthCalled = true
},
},
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
return nil, nil, nil
},
}
// Create request and response recorder
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Call the method under test
m.handleTokenRefresh(rw, req, session, "https://example.com/callback",
tt.needsRefresh, tt.authenticated, tt.isAjaxRequest)
// Verify expectations - processAuthorizedRequest will call the next handler if successful
if tt.expectProcessAuthorized && !nextHandlerCalled {
t.Error("Expected processAuthorizedRequest to complete (next handler called)")
}
if tt.expectInitAuth && !initAuthCalled {
t.Error("Expected InitiateAuthentication to be called")
}
if tt.expectErrorResponse && !errorResponseSent {
t.Error("Expected error response to be sent")
}
})
}
}
// TestProcessAuthorizedRequest tests the processAuthorizedRequest method
func TestProcessAuthorizedRequest(t *testing.T) {
tests := []struct {
name string
email string
idToken string
accessToken string
allowedRoles map[string]struct{}
userGroups []string
userRoles []string
extractError error
expectHeaders bool
expectForbidden bool
expectReauth bool
}{
{
name: "no_email_triggers_reauth",
email: "",
idToken: "token",
expectReauth: true,
},
{
name: "successful_with_id_token",
email: "user@example.com",
idToken: "id_token",
accessToken: "access_token",
expectHeaders: true,
},
{
name: "successful_with_access_token_only",
email: "user@example.com",
idToken: "",
accessToken: "access_token",
expectHeaders: true,
},
{
name: "no_token_with_role_requirements",
email: "user@example.com",
idToken: "",
accessToken: "",
allowedRoles: map[string]struct{}{"admin": {}},
expectReauth: true,
},
{
name: "user_has_allowed_role",
email: "user@example.com",
idToken: "token",
allowedRoles: map[string]struct{}{"admin": {}},
userRoles: []string{"admin", "user"},
expectHeaders: true,
},
{
name: "user_has_allowed_group",
email: "user@example.com",
idToken: "token",
allowedRoles: map[string]struct{}{"developers": {}},
userGroups: []string{"developers", "testers"},
expectHeaders: true,
},
{
name: "user_lacks_required_roles",
email: "user@example.com",
idToken: "token",
allowedRoles: map[string]struct{}{"admin": {}},
userRoles: []string{"user"},
expectForbidden: true,
},
{
name: "extract_error_with_role_requirements",
email: "user@example.com",
idToken: "token",
allowedRoles: map[string]struct{}{"admin": {}},
extractError: errors.New("extraction failed"),
expectReauth: true,
},
{
name: "extract_error_without_role_requirements",
email: "user@example.com",
idToken: "token",
extractError: errors.New("extraction failed"),
expectHeaders: true, // Should continue without roles/groups
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mocks
logger := &mockLogger{}
nextHandlerCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextHandlerCalled = true
w.WriteHeader(http.StatusOK)
})
session := &mockSessionData{
email: tt.email,
accessToken: tt.accessToken,
idToken: tt.idToken,
}
initAuthCalled := false
errorResponseSent := false
var errorCode int
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
allowedRolesAndGroups: tt.allowedRoles,
logoutURLPath: "/logout",
extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) {
if tt.extractError != nil {
return nil, nil, tt.extractError
}
return tt.userGroups, tt.userRoles, nil
},
sendErrorResponseFunc: func(rw http.ResponseWriter, req *http.Request, message string, code int) {
errorResponseSent = true
errorCode = code
rw.WriteHeader(code)
},
authHandler: &mockAuthHandler{
initiateAuthFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
initAuthCalled = true
// Ensure ResetRedirectCount was called
if mockSession, ok := session.(*mockSessionData); ok {
if mockSession.resetRedirectCountFunc != nil {
mockSession.resetRedirectCountFunc()
}
}
},
},
}
// Track ResetRedirectCount calls
resetCountCalled := false
session.resetRedirectCountFunc = func() {
resetCountCalled = true
}
// Create request and response recorder
req := httptest.NewRequest("GET", "/test", nil)
rw := httptest.NewRecorder()
// Call the method under test
m.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
// Verify expectations
if tt.expectHeaders && !nextHandlerCalled {
t.Error("Expected next handler to be called")
}
if tt.expectHeaders {
if req.Header.Get("X-Forwarded-User") != tt.email {
t.Errorf("Expected X-Forwarded-User header to be %s, got %s",
tt.email, req.Header.Get("X-Forwarded-User"))
}
if req.Header.Get("X-Auth-Request-User") != tt.email {
t.Errorf("Expected X-Auth-Request-User header to be %s, got %s",
tt.email, req.Header.Get("X-Auth-Request-User"))
}
if tt.idToken != "" && req.Header.Get("X-Auth-Request-Token") != tt.idToken {
t.Errorf("Expected X-Auth-Request-Token header to be %s, got %s",
tt.idToken, req.Header.Get("X-Auth-Request-Token"))
}
if len(tt.userGroups) > 0 && req.Header.Get("X-User-Groups") == "" {
t.Error("Expected X-User-Groups header to be set")
}
if len(tt.userRoles) > 0 && req.Header.Get("X-User-Roles") == "" {
t.Error("Expected X-User-Roles header to be set")
}
}
if tt.expectForbidden && (!errorResponseSent || errorCode != http.StatusForbidden) {
t.Error("Expected forbidden response")
}
if tt.expectReauth {
if !initAuthCalled {
t.Error("Expected InitiateAuthentication to be called")
}
if !resetCountCalled {
t.Error("Expected ResetRedirectCount to be called before reauth")
}
}
})
}
}
// TestServeHTTP_AdditionalCoverage tests additional ServeHTTP scenarios for better coverage
func TestServeHTTP_AdditionalCoverage(t *testing.T) {
t.Run("first_request_starts_background_tasks", func(t *testing.T) {
// Setup mocks
logger := &mockLogger{}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
tokenCleanupStarted := false
metadataRefreshStarted := false
initComplete := make(chan struct{})
close(initComplete) // Already initialized
wg := &sync.WaitGroup{}
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
issuerURL: "https://issuer.example.com",
providerURL: "https://provider.example.com",
initComplete: initComplete,
goroutineWG: wg,
sessionManager: &mockSessionManager{
getSessionFunc: func(req *http.Request) (SessionData, error) {
return &mockSessionData{
email: "user@example.com",
accessToken: "token",
}, nil
},
},
urlHelper: &mockURLHelper{
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
return false
},
},
isUserAuthenticatedFunc: func(session SessionData) (bool, bool, bool) {
return true, false, false
},
isAllowedDomainFunc: func(email string) bool {
return true
},
tokenVerifier: &mockTokenVerifier{},
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
return nil, nil, nil
},
startTokenCleanupFunc: func() {
tokenCleanupStarted = true
},
startMetadataRefreshFunc: func(url string) {
metadataRefreshStarted = true
},
}
// First request should start background tasks
req := httptest.NewRequest("GET", "/api/test", nil)
rw := httptest.NewRecorder()
m.ServeHTTP(rw, req)
if !tokenCleanupStarted {
t.Error("Expected token cleanup to be started on first request")
}
if !metadataRefreshStarted {
t.Error("Expected metadata refresh to be started on first request")
}
if !m.firstRequestReceived {
t.Error("Expected firstRequestReceived to be set")
}
// Second request should not start tasks again
tokenCleanupStarted = false
metadataRefreshStarted = false
req2 := httptest.NewRequest("GET", "/api/test2", nil)
rw2 := httptest.NewRecorder()
m.ServeHTTP(rw2, req2)
if tokenCleanupStarted {
t.Error("Token cleanup should not be started again")
}
if metadataRefreshStarted {
t.Error("Metadata refresh should not be started again")
}
})
t.Run("health_endpoint_skips_first_request_logic", func(t *testing.T) {
logger := &mockLogger{}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
tokenCleanupStarted := false
metadataRefreshStarted := false
initComplete := make(chan struct{})
close(initComplete)
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
issuerURL: "https://issuer.example.com",
initComplete: initComplete,
excludedURLs: map[string]struct{}{"/health": {}},
sessionManager: &mockSessionManager{
getSessionFunc: func(req *http.Request) (SessionData, error) {
return &mockSessionData{}, nil
},
},
urlHelper: &mockURLHelper{
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
_, ok := urls[path]
return ok
},
},
startTokenCleanupFunc: func() {
tokenCleanupStarted = true
},
startMetadataRefreshFunc: func(url string) {
metadataRefreshStarted = true
},
}
// Health request should not trigger background tasks
req := httptest.NewRequest("GET", "/health", nil)
rw := httptest.NewRecorder()
m.ServeHTTP(rw, req)
if tokenCleanupStarted {
t.Error("Token cleanup should not be started for health endpoint")
}
if metadataRefreshStarted {
t.Error("Metadata refresh should not be started for health endpoint")
}
if m.firstRequestReceived {
t.Error("firstRequestReceived should not be set for health endpoint")
}
})
t.Run("opaque_access_token_skips_jwt_verification", func(t *testing.T) {
logger := &mockLogger{}
nextHandlerCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextHandlerCalled = true
w.WriteHeader(http.StatusOK)
})
initComplete := make(chan struct{})
close(initComplete)
verifyTokenCalled := false
m := &AuthMiddleware{
logger: logger,
next: nextHandler,
issuerURL: "https://issuer.example.com",
initComplete: initComplete,
firstRequestReceived: true, // Skip first request logic
sessionManager: &mockSessionManager{
getSessionFunc: func(req *http.Request) (SessionData, error) {
return &mockSessionData{
email: "user@example.com",
accessToken: "opaque_token_without_dots", // Opaque token
}, nil
},
},
urlHelper: &mockURLHelper{
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
return false
},
},
isUserAuthenticatedFunc: func(session SessionData) (bool, bool, bool) {
return true, false, false // Authenticated, no refresh needed
},
isAllowedDomainFunc: func(email string) bool {
return true
},
tokenVerifier: &mockTokenVerifier{
verifyFunc: func(token string) error {
verifyTokenCalled = true
return nil
},
},
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
return nil, nil, nil
},
startTokenCleanupFunc: func() {},
startMetadataRefreshFunc: func(url string) {},
}
req := httptest.NewRequest("GET", "/api/test", nil)
rw := httptest.NewRecorder()
m.ServeHTTP(rw, req)
if verifyTokenCalled {
t.Error("JWT verification should be skipped for opaque tokens")
}
if !nextHandlerCalled {
t.Error("Next handler should be called for valid opaque token")
}
})
}