mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
1b49e133da
* Fix bug affecting Azure OIDC authentication ( and most likely others ) * Fixes issue #51 * Ensure that appended roles are unique. Update the documentation. * Improvements targetting possible memory usage spikes. * Additional fixes and cleanup * Refactoring code to fix the issues identified by the users. * Modernize run * Fieldalignment * Multiple changes to improve performance and reduce complexity. - Optimise the errors and recovery. - Deduplicate code in metadata cache. - Remove unused performance monitoring code. - Simplify session management and settings handling. * Fix claims issue. * Add ability to overwrite the default scopes in the settings file * Well.. that escalated quickly. Completely forgot that Traefik uses outdated Yaegi and requires compatibility with 1.20 ( pre-generic Go code ). * Bugfix #51: Ensures that user provided scopes overrides work. * fixup! Bugfix #51: Ensures that user provided scopes overrides work. * fixup! fixup! Bugfix #51: Ensures that user provided scopes overrides work. * Abstract the provider logic into a separate package. * Additional micro fixes and cleanups. * Simplify all the things. * fixup! Simplify all the things. * fixup! fixup! Simplify all the things. * fixup! fixup! fixup! Simplify all the things. * fixup! fixup! fixup! fixup! Simplify all the things. * ... * Cleanup tests. * fixup! Cleanup tests. * fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! Cleanup tests. * fixup! fixup! fixup! fixup! fixup! Cleanup tests. * Issue #53: Fix CSRF token handling in reverse proxy 1. ✅ HTTPS Detection Fixed (session.go:723) - Now uses X-Forwarded-Proto header instead of r.URL.Scheme - Properly detects HTTPS in reverse proxy environments 2. ✅ SameSite Cookie Attribute Fixed - Removed automatic SameSiteStrictMode for HTTPS (would break OAuth) - Keeps SameSiteLaxMode to allow OAuth callbacks from external domains - Only uses Strict for AJAX requests which don't involve OAuth redirects 3. ✅ Cookie Domain Handling Fixed - Now respects X-Forwarded-Host header for cookie domain - Ensures cookies are set for the public domain, not internal proxy domain 4. ✅ EnhanceSessionSecurity Properly Integrated - Function is now actually called during session save - Applies security enhancements without breaking OAuth flow Why Issue #53 Failed Before: 1. Cookies were not marked Secure in HTTPS environments (browser wouldn't send them back) 2. If they had been Secure with SameSite=Strict, Azure callbacks would still fail 3. Cookie domain might have been wrong (internal vs public domain) Why It Works Now: 1. Cookies are properly marked Secure for HTTPS 2. Uses SameSite=Lax to allow OAuth provider callbacks 3. Cookie domain uses public domain from X-Forwarded-Host 4. CSRF token persists through the entire OAuth flow * Next set of enhancements together with memory usage improvements. * Memory leak fixes and optimisations. * CSRF and Cookie Domain fixes * fixup! CSRF and Cookie Domain fixes * Metadata cache leak fix + profiling * fixup! Metadata cache leak fix + profiling * Memory leaks hunting, part 1337. * Further pursue of perfection. * fixup! Further pursue of perfection. * fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! Further pursue of perfection. * Clear race conditions * fixup! Clear race conditions * Weekend fun with memory leaks * Splitting code into multiple files with reasonable testing coverage. ``` ok github.com/lukaszraczylo/traefikoidc 117.017s coverage: 72.6% of statements ok github.com/lukaszraczylo/traefikoidc/auth 0.505s coverage: 87.1% of statements ok github.com/lukaszraczylo/traefikoidc/circuit_breaker 0.283s coverage: 99.0% of statements github.com/lukaszraczylo/traefikoidc/config coverage: 0.0% of statements ok github.com/lukaszraczylo/traefikoidc/handlers 0.349s coverage: 98.2% of statements ok github.com/lukaszraczylo/traefikoidc/internal/providers (cached) coverage: 94.3% of statements ok github.com/lukaszraczylo/traefikoidc/middleware 0.808s coverage: 78.0% of statements ok github.com/lukaszraczylo/traefikoidc/recovery 0.653s coverage: 100.0% of statements ok github.com/lukaszraczylo/traefikoidc/session/chunking (cached) coverage: 87.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/core (cached) coverage: 85.6% of statements ok github.com/lukaszraczylo/traefikoidc/session/crypto (cached) coverage: 81.8% of statements ok github.com/lukaszraczylo/traefikoidc/session/storage (cached) coverage: 93.5% of statements ok github.com/lukaszraczylo/traefikoidc/session/validators (cached) coverage: 98.8% of statements ```` * fixup! Splitting code into multiple files with reasonable testing coverage. * fixup! fixup! Splitting code into multiple files with reasonable testing coverage. * Weekend fun with further optimisations. * fixup! Weekend fun with further optimisations. * fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * fixup! fixup! fixup! fixup! fixup! Weekend fun with further optimisations. * Pre-release cleanup. * Enhance test coverage. * fixup! Enhance test coverage. * fixup! fixup! Enhance test coverage. * fixup! fixup! fixup! Enhance test coverage.
805 lines
24 KiB
Go
805 lines
24 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
)
|
|
|
|
// TestUncoveredMiddlewareFunctions tests the functions with 0% coverage in middleware package
|
|
func TestUncoveredMiddlewareFunctions(t *testing.T) {
|
|
t.Run("generateNonce", func(t *testing.T) {
|
|
// This function currently returns an error in the stub implementation
|
|
nonce, err := generateNonce()
|
|
if err == nil {
|
|
t.Errorf("Expected generateNonce to return an error in stub implementation")
|
|
}
|
|
if nonce != "" {
|
|
t.Errorf("Expected generateNonce to return empty string, got %s", nonce)
|
|
}
|
|
// Verify the error message
|
|
expectedError := "generateNonce not implemented"
|
|
if err.Error() != expectedError {
|
|
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("generateCodeVerifier", func(t *testing.T) {
|
|
// This function currently returns an error in the stub implementation
|
|
verifier, err := generateCodeVerifier()
|
|
if err == nil {
|
|
t.Errorf("Expected generateCodeVerifier to return an error in stub implementation")
|
|
}
|
|
if verifier != "" {
|
|
t.Errorf("Expected generateCodeVerifier to return empty string, got %s", verifier)
|
|
}
|
|
// Verify the error message
|
|
expectedError := "generateCodeVerifier not implemented"
|
|
if err.Error() != expectedError {
|
|
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
|
|
}
|
|
})
|
|
|
|
t.Run("deriveCodeChallenge", func(t *testing.T) {
|
|
// This function currently returns an error in the stub implementation
|
|
challenge, err := deriveCodeChallenge()
|
|
if err == nil {
|
|
t.Errorf("Expected deriveCodeChallenge to return an error in stub implementation")
|
|
}
|
|
if challenge != "" {
|
|
t.Errorf("Expected deriveCodeChallenge to return empty string, got %s", challenge)
|
|
}
|
|
// Verify the error message
|
|
expectedError := "deriveCodeChallenge not implemented"
|
|
if err.Error() != expectedError {
|
|
t.Errorf("Expected error message '%s', got '%s'", expectedError, err.Error())
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestBuildFullURLFunction tests the buildFullURL function that already has 100% coverage
|
|
// but this ensures we maintain that coverage and test edge cases
|
|
func TestBuildFullURLFunction(t *testing.T) {
|
|
t.Run("buildFullURL", func(t *testing.T) {
|
|
// Test basic URL building
|
|
scheme := "https"
|
|
host := "example.com"
|
|
path := "/callback"
|
|
|
|
url := buildFullURL(scheme, host, path)
|
|
expected := "https://example.com/callback"
|
|
|
|
if url != expected {
|
|
t.Errorf("Expected URL %s, got %s", expected, url)
|
|
}
|
|
|
|
// Test with path that doesn't start with / (function just concatenates)
|
|
url2 := buildFullURL(scheme, host, "callback")
|
|
expected2 := "https://example.comcallback"
|
|
|
|
if url2 != expected2 {
|
|
t.Errorf("Expected URL %s, got %s", expected2, url2)
|
|
}
|
|
|
|
// Test with empty path
|
|
url3 := buildFullURL(scheme, host, "")
|
|
expected3 := "https://example.com"
|
|
|
|
if url3 != expected3 {
|
|
t.Errorf("Expected URL %s, got %s", expected3, url3)
|
|
}
|
|
|
|
// Test with different schemes
|
|
url4 := buildFullURL("http", "localhost:8080", "/test")
|
|
expected4 := "http://localhost:8080/test"
|
|
|
|
if url4 != expected4 {
|
|
t.Errorf("Expected URL %s, got %s", expected4, url4)
|
|
}
|
|
|
|
// Test with special characters
|
|
url5 := buildFullURL("https", "api.example.com", "/v1/auth?redirect=true")
|
|
expected5 := "https://api.example.com/v1/auth?redirect=true"
|
|
|
|
if url5 != expected5 {
|
|
t.Errorf("Expected URL %s, got %s", expected5, url5)
|
|
}
|
|
|
|
// Test with empty components
|
|
url6 := buildFullURL("", "", "")
|
|
expected6 := "://"
|
|
|
|
if url6 != expected6 {
|
|
t.Errorf("Expected URL %s, got %s", expected6, url6)
|
|
}
|
|
|
|
// Test with port numbers
|
|
url7 := buildFullURL("http", "localhost:3000", "/admin")
|
|
expected7 := "http://localhost:3000/admin"
|
|
|
|
if url7 != expected7 {
|
|
t.Errorf("Expected URL %s, got %s", expected7, url7)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Mock types for testing
|
|
type mockLogger struct {
|
|
logs []string
|
|
mu sync.Mutex
|
|
}
|
|
|
|
func (m *mockLogger) Debug(msg string) { m.log("DEBUG: " + msg) }
|
|
func (m *mockLogger) Debugf(format string, args ...interface{}) { m.log("DEBUG: " + format) }
|
|
func (m *mockLogger) Error(msg string) { m.log("ERROR: " + msg) }
|
|
func (m *mockLogger) Errorf(format string, args ...interface{}) { m.log("ERROR: " + format) }
|
|
func (m *mockLogger) Info(msg string) { m.log("INFO: " + msg) }
|
|
func (m *mockLogger) Infof(format string, args ...interface{}) { m.log("INFO: " + format) }
|
|
func (m *mockLogger) log(msg string) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.logs = append(m.logs, msg)
|
|
}
|
|
|
|
type mockSessionManager struct {
|
|
getSessionFunc func(req *http.Request) (SessionData, error)
|
|
cleanupOldCookiesFunc func(rw http.ResponseWriter, req *http.Request)
|
|
}
|
|
|
|
func (m *mockSessionManager) CleanupOldCookies(rw http.ResponseWriter, req *http.Request) {
|
|
if m.cleanupOldCookiesFunc != nil {
|
|
m.cleanupOldCookiesFunc(rw, req)
|
|
}
|
|
}
|
|
|
|
func (m *mockSessionManager) GetSession(req *http.Request) (SessionData, error) {
|
|
if m.getSessionFunc != nil {
|
|
return m.getSessionFunc(req)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
type mockSessionData struct {
|
|
email string
|
|
accessToken string
|
|
idToken string
|
|
refreshToken string
|
|
clearFunc func(req *http.Request, rw http.ResponseWriter) error
|
|
resetRedirectCountFunc func()
|
|
}
|
|
|
|
func (m *mockSessionData) GetEmail() string { return m.email }
|
|
func (m *mockSessionData) GetAccessToken() string { return m.accessToken }
|
|
func (m *mockSessionData) GetIDToken() string { return m.idToken }
|
|
func (m *mockSessionData) GetRefreshToken() string { return m.refreshToken }
|
|
func (m *mockSessionData) Clear(req *http.Request, rw http.ResponseWriter) error {
|
|
if m.clearFunc != nil {
|
|
return m.clearFunc(req, rw)
|
|
}
|
|
return nil
|
|
}
|
|
func (m *mockSessionData) ResetRedirectCount() {
|
|
if m.resetRedirectCountFunc != nil {
|
|
m.resetRedirectCountFunc()
|
|
}
|
|
}
|
|
func (m *mockSessionData) returnToPoolSafely() {}
|
|
|
|
type mockAuthHandler struct {
|
|
initiateAuthFunc func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error))
|
|
}
|
|
|
|
func (m *mockAuthHandler) InitiateAuthentication(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
|
if m.initiateAuthFunc != nil {
|
|
m.initiateAuthFunc(rw, req, session, redirectURL, generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
}
|
|
}
|
|
|
|
type mockURLHelper struct {
|
|
determineExcludedFunc func(currentRequest string, excludedURLs map[string]struct{}) bool
|
|
determineSchemeFunc func(req *http.Request) string
|
|
determineHostFunc func(req *http.Request) string
|
|
}
|
|
|
|
func (m *mockURLHelper) DetermineExcludedURL(currentRequest string, excludedURLs map[string]struct{}) bool {
|
|
if m.determineExcludedFunc != nil {
|
|
return m.determineExcludedFunc(currentRequest, excludedURLs)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *mockURLHelper) DetermineScheme(req *http.Request) string {
|
|
if m.determineSchemeFunc != nil {
|
|
return m.determineSchemeFunc(req)
|
|
}
|
|
return "https"
|
|
}
|
|
|
|
func (m *mockURLHelper) DetermineHost(req *http.Request) string {
|
|
if m.determineHostFunc != nil {
|
|
return m.determineHostFunc(req)
|
|
}
|
|
return "example.com"
|
|
}
|
|
|
|
type mockTokenVerifier struct {
|
|
verifyFunc func(token string) error
|
|
}
|
|
|
|
func (m *mockTokenVerifier) VerifyToken(token string) error {
|
|
if m.verifyFunc != nil {
|
|
return m.verifyFunc(token)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TestStubFunctionsErrorBehavior tests error behaviors more thoroughly
|
|
func TestStubFunctionsErrorBehavior(t *testing.T) {
|
|
t.Run("generateNonce_multiple_calls", func(t *testing.T) {
|
|
// Test multiple calls to ensure consistent behavior
|
|
for i := 0; i < 3; i++ {
|
|
nonce, err := generateNonce()
|
|
if err == nil {
|
|
t.Errorf("Call %d: Expected generateNonce to return an error", i)
|
|
}
|
|
if nonce != "" {
|
|
t.Errorf("Call %d: Expected empty nonce, got %s", i, nonce)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("generateCodeVerifier_multiple_calls", func(t *testing.T) {
|
|
// Test multiple calls to ensure consistent behavior
|
|
for i := 0; i < 3; i++ {
|
|
verifier, err := generateCodeVerifier()
|
|
if err == nil {
|
|
t.Errorf("Call %d: Expected generateCodeVerifier to return an error", i)
|
|
}
|
|
if verifier != "" {
|
|
t.Errorf("Call %d: Expected empty verifier, got %s", i, verifier)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("deriveCodeChallenge_multiple_calls", func(t *testing.T) {
|
|
// Test multiple calls to ensure consistent behavior
|
|
for i := 0; i < 3; i++ {
|
|
challenge, err := deriveCodeChallenge()
|
|
if err == nil {
|
|
t.Errorf("Call %d: Expected deriveCodeChallenge to return an error", i)
|
|
}
|
|
if challenge != "" {
|
|
t.Errorf("Call %d: Expected empty challenge, got %s", i, challenge)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestHandleTokenRefresh tests the handleTokenRefresh method with various scenarios
|
|
func TestHandleTokenRefresh(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
needsRefresh bool
|
|
authenticated bool
|
|
isAjaxRequest bool
|
|
refreshSuccess bool
|
|
allowedDomain bool
|
|
expectErrorResponse bool
|
|
expectProcessAuthorized bool
|
|
expectInitAuth bool
|
|
}{
|
|
{
|
|
name: "successful_refresh_authenticated",
|
|
needsRefresh: true,
|
|
authenticated: true,
|
|
isAjaxRequest: false,
|
|
refreshSuccess: true,
|
|
allowedDomain: true,
|
|
expectProcessAuthorized: true,
|
|
},
|
|
{
|
|
name: "successful_refresh_not_authenticated",
|
|
needsRefresh: true,
|
|
authenticated: false,
|
|
isAjaxRequest: false,
|
|
refreshSuccess: true,
|
|
allowedDomain: true,
|
|
expectProcessAuthorized: true,
|
|
},
|
|
{
|
|
name: "successful_refresh_disallowed_domain",
|
|
needsRefresh: true,
|
|
authenticated: true,
|
|
isAjaxRequest: false,
|
|
refreshSuccess: true,
|
|
allowedDomain: false,
|
|
expectErrorResponse: true,
|
|
},
|
|
{
|
|
name: "failed_refresh_browser_request",
|
|
needsRefresh: true,
|
|
authenticated: true,
|
|
isAjaxRequest: false,
|
|
refreshSuccess: false,
|
|
expectInitAuth: true,
|
|
},
|
|
{
|
|
name: "failed_refresh_ajax_request",
|
|
needsRefresh: true,
|
|
authenticated: true,
|
|
isAjaxRequest: true,
|
|
refreshSuccess: false,
|
|
expectErrorResponse: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Setup mocks
|
|
logger := &mockLogger{}
|
|
nextHandlerCalled := false
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextHandlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
session := &mockSessionData{
|
|
email: "test@example.com",
|
|
accessToken: "access_token",
|
|
idToken: "id_token",
|
|
refreshToken: "refresh_token",
|
|
}
|
|
|
|
initAuthCalled := false
|
|
errorResponseSent := false
|
|
|
|
m := &AuthMiddleware{
|
|
logger: logger,
|
|
next: nextHandler,
|
|
logoutURLPath: "/logout",
|
|
refreshTokenFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData) bool {
|
|
return tt.refreshSuccess
|
|
},
|
|
isAllowedDomainFunc: func(email string) bool {
|
|
return tt.allowedDomain
|
|
},
|
|
isAjaxRequestFunc: func(req *http.Request) bool {
|
|
return tt.isAjaxRequest
|
|
},
|
|
sendErrorResponseFunc: func(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
|
errorResponseSent = true
|
|
rw.WriteHeader(code)
|
|
},
|
|
authHandler: &mockAuthHandler{
|
|
initiateAuthFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
|
initAuthCalled = true
|
|
},
|
|
},
|
|
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
|
|
return nil, nil, nil
|
|
},
|
|
}
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
// Call the method under test
|
|
m.handleTokenRefresh(rw, req, session, "https://example.com/callback",
|
|
tt.needsRefresh, tt.authenticated, tt.isAjaxRequest)
|
|
|
|
// Verify expectations - processAuthorizedRequest will call the next handler if successful
|
|
if tt.expectProcessAuthorized && !nextHandlerCalled {
|
|
t.Error("Expected processAuthorizedRequest to complete (next handler called)")
|
|
}
|
|
if tt.expectInitAuth && !initAuthCalled {
|
|
t.Error("Expected InitiateAuthentication to be called")
|
|
}
|
|
if tt.expectErrorResponse && !errorResponseSent {
|
|
t.Error("Expected error response to be sent")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestProcessAuthorizedRequest tests the processAuthorizedRequest method
|
|
func TestProcessAuthorizedRequest(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
email string
|
|
idToken string
|
|
accessToken string
|
|
allowedRoles map[string]struct{}
|
|
userGroups []string
|
|
userRoles []string
|
|
extractError error
|
|
expectHeaders bool
|
|
expectForbidden bool
|
|
expectReauth bool
|
|
}{
|
|
{
|
|
name: "no_email_triggers_reauth",
|
|
email: "",
|
|
idToken: "token",
|
|
expectReauth: true,
|
|
},
|
|
{
|
|
name: "successful_with_id_token",
|
|
email: "user@example.com",
|
|
idToken: "id_token",
|
|
accessToken: "access_token",
|
|
expectHeaders: true,
|
|
},
|
|
{
|
|
name: "successful_with_access_token_only",
|
|
email: "user@example.com",
|
|
idToken: "",
|
|
accessToken: "access_token",
|
|
expectHeaders: true,
|
|
},
|
|
{
|
|
name: "no_token_with_role_requirements",
|
|
email: "user@example.com",
|
|
idToken: "",
|
|
accessToken: "",
|
|
allowedRoles: map[string]struct{}{"admin": {}},
|
|
expectReauth: true,
|
|
},
|
|
{
|
|
name: "user_has_allowed_role",
|
|
email: "user@example.com",
|
|
idToken: "token",
|
|
allowedRoles: map[string]struct{}{"admin": {}},
|
|
userRoles: []string{"admin", "user"},
|
|
expectHeaders: true,
|
|
},
|
|
{
|
|
name: "user_has_allowed_group",
|
|
email: "user@example.com",
|
|
idToken: "token",
|
|
allowedRoles: map[string]struct{}{"developers": {}},
|
|
userGroups: []string{"developers", "testers"},
|
|
expectHeaders: true,
|
|
},
|
|
{
|
|
name: "user_lacks_required_roles",
|
|
email: "user@example.com",
|
|
idToken: "token",
|
|
allowedRoles: map[string]struct{}{"admin": {}},
|
|
userRoles: []string{"user"},
|
|
expectForbidden: true,
|
|
},
|
|
{
|
|
name: "extract_error_with_role_requirements",
|
|
email: "user@example.com",
|
|
idToken: "token",
|
|
allowedRoles: map[string]struct{}{"admin": {}},
|
|
extractError: errors.New("extraction failed"),
|
|
expectReauth: true,
|
|
},
|
|
{
|
|
name: "extract_error_without_role_requirements",
|
|
email: "user@example.com",
|
|
idToken: "token",
|
|
extractError: errors.New("extraction failed"),
|
|
expectHeaders: true, // Should continue without roles/groups
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Setup mocks
|
|
logger := &mockLogger{}
|
|
nextHandlerCalled := false
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextHandlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
session := &mockSessionData{
|
|
email: tt.email,
|
|
accessToken: tt.accessToken,
|
|
idToken: tt.idToken,
|
|
}
|
|
|
|
initAuthCalled := false
|
|
errorResponseSent := false
|
|
var errorCode int
|
|
|
|
m := &AuthMiddleware{
|
|
logger: logger,
|
|
next: nextHandler,
|
|
allowedRolesAndGroups: tt.allowedRoles,
|
|
logoutURLPath: "/logout",
|
|
extractGroupsAndRolesFunc: func(tokenString string) ([]string, []string, error) {
|
|
if tt.extractError != nil {
|
|
return nil, nil, tt.extractError
|
|
}
|
|
return tt.userGroups, tt.userRoles, nil
|
|
},
|
|
sendErrorResponseFunc: func(rw http.ResponseWriter, req *http.Request, message string, code int) {
|
|
errorResponseSent = true
|
|
errorCode = code
|
|
rw.WriteHeader(code)
|
|
},
|
|
authHandler: &mockAuthHandler{
|
|
initiateAuthFunc: func(rw http.ResponseWriter, req *http.Request, session SessionData, redirectURL string,
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge func() (string, error)) {
|
|
initAuthCalled = true
|
|
// Ensure ResetRedirectCount was called
|
|
if mockSession, ok := session.(*mockSessionData); ok {
|
|
if mockSession.resetRedirectCountFunc != nil {
|
|
mockSession.resetRedirectCountFunc()
|
|
}
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
// Track ResetRedirectCount calls
|
|
resetCountCalled := false
|
|
session.resetRedirectCountFunc = func() {
|
|
resetCountCalled = true
|
|
}
|
|
|
|
// Create request and response recorder
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
// Call the method under test
|
|
m.processAuthorizedRequest(rw, req, session, "https://example.com/callback")
|
|
|
|
// Verify expectations
|
|
if tt.expectHeaders && !nextHandlerCalled {
|
|
t.Error("Expected next handler to be called")
|
|
}
|
|
|
|
if tt.expectHeaders {
|
|
if req.Header.Get("X-Forwarded-User") != tt.email {
|
|
t.Errorf("Expected X-Forwarded-User header to be %s, got %s",
|
|
tt.email, req.Header.Get("X-Forwarded-User"))
|
|
}
|
|
if req.Header.Get("X-Auth-Request-User") != tt.email {
|
|
t.Errorf("Expected X-Auth-Request-User header to be %s, got %s",
|
|
tt.email, req.Header.Get("X-Auth-Request-User"))
|
|
}
|
|
if tt.idToken != "" && req.Header.Get("X-Auth-Request-Token") != tt.idToken {
|
|
t.Errorf("Expected X-Auth-Request-Token header to be %s, got %s",
|
|
tt.idToken, req.Header.Get("X-Auth-Request-Token"))
|
|
}
|
|
if len(tt.userGroups) > 0 && req.Header.Get("X-User-Groups") == "" {
|
|
t.Error("Expected X-User-Groups header to be set")
|
|
}
|
|
if len(tt.userRoles) > 0 && req.Header.Get("X-User-Roles") == "" {
|
|
t.Error("Expected X-User-Roles header to be set")
|
|
}
|
|
}
|
|
|
|
if tt.expectForbidden && (!errorResponseSent || errorCode != http.StatusForbidden) {
|
|
t.Error("Expected forbidden response")
|
|
}
|
|
|
|
if tt.expectReauth {
|
|
if !initAuthCalled {
|
|
t.Error("Expected InitiateAuthentication to be called")
|
|
}
|
|
if !resetCountCalled {
|
|
t.Error("Expected ResetRedirectCount to be called before reauth")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeHTTP_AdditionalCoverage tests additional ServeHTTP scenarios for better coverage
|
|
func TestServeHTTP_AdditionalCoverage(t *testing.T) {
|
|
t.Run("first_request_starts_background_tasks", func(t *testing.T) {
|
|
// Setup mocks
|
|
logger := &mockLogger{}
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
tokenCleanupStarted := false
|
|
metadataRefreshStarted := false
|
|
|
|
initComplete := make(chan struct{})
|
|
close(initComplete) // Already initialized
|
|
|
|
wg := &sync.WaitGroup{}
|
|
|
|
m := &AuthMiddleware{
|
|
logger: logger,
|
|
next: nextHandler,
|
|
issuerURL: "https://issuer.example.com",
|
|
providerURL: "https://provider.example.com",
|
|
initComplete: initComplete,
|
|
goroutineWG: wg,
|
|
sessionManager: &mockSessionManager{
|
|
getSessionFunc: func(req *http.Request) (SessionData, error) {
|
|
return &mockSessionData{
|
|
email: "user@example.com",
|
|
accessToken: "token",
|
|
}, nil
|
|
},
|
|
},
|
|
urlHelper: &mockURLHelper{
|
|
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
|
|
return false
|
|
},
|
|
},
|
|
isUserAuthenticatedFunc: func(session SessionData) (bool, bool, bool) {
|
|
return true, false, false
|
|
},
|
|
isAllowedDomainFunc: func(email string) bool {
|
|
return true
|
|
},
|
|
tokenVerifier: &mockTokenVerifier{},
|
|
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
|
|
return nil, nil, nil
|
|
},
|
|
startTokenCleanupFunc: func() {
|
|
tokenCleanupStarted = true
|
|
},
|
|
startMetadataRefreshFunc: func(url string) {
|
|
metadataRefreshStarted = true
|
|
},
|
|
}
|
|
|
|
// First request should start background tasks
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rw, req)
|
|
|
|
if !tokenCleanupStarted {
|
|
t.Error("Expected token cleanup to be started on first request")
|
|
}
|
|
if !metadataRefreshStarted {
|
|
t.Error("Expected metadata refresh to be started on first request")
|
|
}
|
|
if !m.firstRequestReceived {
|
|
t.Error("Expected firstRequestReceived to be set")
|
|
}
|
|
|
|
// Second request should not start tasks again
|
|
tokenCleanupStarted = false
|
|
metadataRefreshStarted = false
|
|
|
|
req2 := httptest.NewRequest("GET", "/api/test2", nil)
|
|
rw2 := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rw2, req2)
|
|
|
|
if tokenCleanupStarted {
|
|
t.Error("Token cleanup should not be started again")
|
|
}
|
|
if metadataRefreshStarted {
|
|
t.Error("Metadata refresh should not be started again")
|
|
}
|
|
})
|
|
|
|
t.Run("health_endpoint_skips_first_request_logic", func(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
tokenCleanupStarted := false
|
|
metadataRefreshStarted := false
|
|
|
|
initComplete := make(chan struct{})
|
|
close(initComplete)
|
|
|
|
m := &AuthMiddleware{
|
|
logger: logger,
|
|
next: nextHandler,
|
|
issuerURL: "https://issuer.example.com",
|
|
initComplete: initComplete,
|
|
excludedURLs: map[string]struct{}{"/health": {}},
|
|
sessionManager: &mockSessionManager{
|
|
getSessionFunc: func(req *http.Request) (SessionData, error) {
|
|
return &mockSessionData{}, nil
|
|
},
|
|
},
|
|
urlHelper: &mockURLHelper{
|
|
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
|
|
_, ok := urls[path]
|
|
return ok
|
|
},
|
|
},
|
|
startTokenCleanupFunc: func() {
|
|
tokenCleanupStarted = true
|
|
},
|
|
startMetadataRefreshFunc: func(url string) {
|
|
metadataRefreshStarted = true
|
|
},
|
|
}
|
|
|
|
// Health request should not trigger background tasks
|
|
req := httptest.NewRequest("GET", "/health", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rw, req)
|
|
|
|
if tokenCleanupStarted {
|
|
t.Error("Token cleanup should not be started for health endpoint")
|
|
}
|
|
if metadataRefreshStarted {
|
|
t.Error("Metadata refresh should not be started for health endpoint")
|
|
}
|
|
if m.firstRequestReceived {
|
|
t.Error("firstRequestReceived should not be set for health endpoint")
|
|
}
|
|
})
|
|
|
|
t.Run("opaque_access_token_skips_jwt_verification", func(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
nextHandlerCalled := false
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
nextHandlerCalled = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
initComplete := make(chan struct{})
|
|
close(initComplete)
|
|
|
|
verifyTokenCalled := false
|
|
|
|
m := &AuthMiddleware{
|
|
logger: logger,
|
|
next: nextHandler,
|
|
issuerURL: "https://issuer.example.com",
|
|
initComplete: initComplete,
|
|
firstRequestReceived: true, // Skip first request logic
|
|
sessionManager: &mockSessionManager{
|
|
getSessionFunc: func(req *http.Request) (SessionData, error) {
|
|
return &mockSessionData{
|
|
email: "user@example.com",
|
|
accessToken: "opaque_token_without_dots", // Opaque token
|
|
}, nil
|
|
},
|
|
},
|
|
urlHelper: &mockURLHelper{
|
|
determineExcludedFunc: func(path string, urls map[string]struct{}) bool {
|
|
return false
|
|
},
|
|
},
|
|
isUserAuthenticatedFunc: func(session SessionData) (bool, bool, bool) {
|
|
return true, false, false // Authenticated, no refresh needed
|
|
},
|
|
isAllowedDomainFunc: func(email string) bool {
|
|
return true
|
|
},
|
|
tokenVerifier: &mockTokenVerifier{
|
|
verifyFunc: func(token string) error {
|
|
verifyTokenCalled = true
|
|
return nil
|
|
},
|
|
},
|
|
extractGroupsAndRolesFunc: func(token string) ([]string, []string, error) {
|
|
return nil, nil, nil
|
|
},
|
|
startTokenCleanupFunc: func() {},
|
|
startMetadataRefreshFunc: func(url string) {},
|
|
}
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
m.ServeHTTP(rw, req)
|
|
|
|
if verifyTokenCalled {
|
|
t.Error("JWT verification should be skipped for opaque tokens")
|
|
}
|
|
if !nextHandlerCalled {
|
|
t.Error("Next handler should be called for valid opaque token")
|
|
}
|
|
})
|
|
}
|