mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
c3f23cb99b
* Resolve issue with opaque tokens not being parsed correctly * Increase test coverage * Further improvements to test coverage and code quality * Add new providers. * fixup! Add new providers. * Cleanup. * fixup! Cleanup. * fixup! fixup! Cleanup. * fixup! fixup! fixup! Cleanup. * fixup! fixup! fixup! fixup! Cleanup. * Memory management optimisation 24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference) * Pooling cleanup.
600 lines
20 KiB
Go
600 lines
20 KiB
Go
package auth
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
// Test mocks
|
|
type mockLogger struct {
|
|
debugMessages []string
|
|
errorMessages []string
|
|
}
|
|
|
|
func (l *mockLogger) Debugf(format string, args ...interface{}) {
|
|
l.debugMessages = append(l.debugMessages, format)
|
|
}
|
|
|
|
func (l *mockLogger) Errorf(format string, args ...interface{}) {
|
|
l.errorMessages = append(l.errorMessages, format)
|
|
}
|
|
|
|
type mockSessionData struct {
|
|
authenticated bool
|
|
email string
|
|
accessToken string
|
|
refreshToken string
|
|
idToken string
|
|
csrf string
|
|
nonce string
|
|
codeVerifier string
|
|
incomingPath string
|
|
redirectCount int
|
|
saveError error
|
|
dirty bool
|
|
}
|
|
|
|
func (s *mockSessionData) GetRedirectCount() int { return s.redirectCount }
|
|
func (s *mockSessionData) ResetRedirectCount() { s.redirectCount = 0 }
|
|
func (s *mockSessionData) IncrementRedirectCount() { s.redirectCount++ }
|
|
func (s *mockSessionData) SetAuthenticated(auth bool) { s.authenticated = auth }
|
|
func (s *mockSessionData) SetEmail(email string) { s.email = email }
|
|
func (s *mockSessionData) SetAccessToken(token string) { s.accessToken = token }
|
|
func (s *mockSessionData) SetRefreshToken(token string) { s.refreshToken = token }
|
|
func (s *mockSessionData) SetIDToken(token string) { s.idToken = token }
|
|
func (s *mockSessionData) SetNonce(nonce string) { s.nonce = nonce }
|
|
func (s *mockSessionData) SetCodeVerifier(verifier string) { s.codeVerifier = verifier }
|
|
func (s *mockSessionData) SetCSRF(csrf string) { s.csrf = csrf }
|
|
func (s *mockSessionData) SetIncomingPath(path string) { s.incomingPath = path }
|
|
func (s *mockSessionData) MarkDirty() { s.dirty = true }
|
|
|
|
func (s *mockSessionData) Save(req *http.Request, rw http.ResponseWriter) error {
|
|
return s.saveError
|
|
}
|
|
|
|
// TestAuthHandler_NewAuthHandler tests the constructor
|
|
func TestAuthHandler_NewAuthHandler(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
isGoogleProv := func() bool { return false }
|
|
isAzureProv := func() bool { return true }
|
|
scopes := []string{"openid", "profile", "email"}
|
|
|
|
handler := NewAuthHandler(logger, true, isGoogleProv, isAzureProv,
|
|
"test-client-id", "https://example.com/auth", "https://example.com",
|
|
scopes, false)
|
|
|
|
if handler == nil {
|
|
t.Fatal("Expected handler to be created, got nil")
|
|
}
|
|
|
|
if handler.logger != logger {
|
|
t.Error("Logger not set correctly")
|
|
}
|
|
|
|
if !handler.enablePKCE {
|
|
t.Error("PKCE should be enabled")
|
|
}
|
|
|
|
if handler.clientID != "test-client-id" {
|
|
t.Errorf("Expected clientID 'test-client-id', got '%s'", handler.clientID)
|
|
}
|
|
|
|
if handler.authURL != "https://example.com/auth" {
|
|
t.Errorf("Expected authURL 'https://example.com/auth', got '%s'", handler.authURL)
|
|
}
|
|
|
|
if handler.issuerURL != "https://example.com" {
|
|
t.Errorf("Expected issuerURL 'https://example.com', got '%s'", handler.issuerURL)
|
|
}
|
|
|
|
if len(handler.scopes) != 3 {
|
|
t.Errorf("Expected 3 scopes, got %d", len(handler.scopes))
|
|
}
|
|
|
|
if handler.overrideScopes {
|
|
t.Error("overrideScopes should be false")
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_MaxRedirects tests redirect limit enforcement
|
|
func TestAuthHandler_InitiateAuthentication_MaxRedirects(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
|
|
|
session := &mockSessionData{redirectCount: 5} // At the limit
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "test-nonce", nil }
|
|
generateCodeVerifier := func() (string, error) { return "", nil }
|
|
deriveCodeChallenge := func() (string, error) { return "", nil }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
if rw.Code != http.StatusLoopDetected {
|
|
t.Errorf("Expected status %d, got %d", http.StatusLoopDetected, rw.Code)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
if !strings.Contains(body, "Too many redirects") {
|
|
t.Errorf("Expected 'Too many redirects' in response body, got '%s'", body)
|
|
}
|
|
|
|
if session.redirectCount != 0 {
|
|
t.Errorf("Expected redirect count to be reset, got %d", session.redirectCount)
|
|
}
|
|
|
|
if len(logger.errorMessages) == 0 {
|
|
t.Error("Expected error to be logged")
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_NonceGenerationError tests nonce generation failure
|
|
func TestAuthHandler_InitiateAuthentication_NonceGenerationError(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
|
|
|
session := &mockSessionData{}
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "", &testError{"nonce generation failed"} }
|
|
generateCodeVerifier := func() (string, error) { return "", nil }
|
|
deriveCodeChallenge := func() (string, error) { return "", nil }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
if rw.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
if !strings.Contains(body, "Failed to generate nonce") {
|
|
t.Errorf("Expected 'Failed to generate nonce' in response body, got '%s'", body)
|
|
}
|
|
|
|
if len(logger.errorMessages) == 0 {
|
|
t.Error("Expected error to be logged")
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError tests PKCE code verifier generation failure
|
|
func TestAuthHandler_InitiateAuthentication_PKCECodeVerifierError(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
|
|
|
session := &mockSessionData{}
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "test-nonce", nil }
|
|
generateCodeVerifier := func() (string, error) { return "", &testError{"code verifier generation failed"} }
|
|
deriveCodeChallenge := func() (string, error) { return "", nil }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
if rw.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
if !strings.Contains(body, "Failed to generate code verifier") {
|
|
t.Errorf("Expected 'Failed to generate code verifier' in response body, got '%s'", body)
|
|
}
|
|
|
|
if len(logger.errorMessages) == 0 {
|
|
t.Error("Expected error to be logged")
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError tests PKCE code challenge derivation failure
|
|
func TestAuthHandler_InitiateAuthentication_PKCECodeChallengeError(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
|
|
|
session := &mockSessionData{}
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "test-nonce", nil }
|
|
generateCodeVerifier := func() (string, error) { return "test-verifier", nil }
|
|
deriveCodeChallenge := func() (string, error) { return "", &testError{"code challenge derivation failed"} }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
if rw.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
if !strings.Contains(body, "Failed to generate code challenge") {
|
|
t.Errorf("Expected 'Failed to generate code challenge' in response body, got '%s'", body)
|
|
}
|
|
|
|
if len(logger.errorMessages) == 0 {
|
|
t.Error("Expected error to be logged")
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_SessionSaveError tests session save failure
|
|
func TestAuthHandler_InitiateAuthentication_SessionSaveError(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{}, false)
|
|
|
|
session := &mockSessionData{saveError: &testError{"save failed"}}
|
|
req := httptest.NewRequest("GET", "/test?param=value", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "test-nonce", nil }
|
|
generateCodeVerifier := func() (string, error) { return "", nil }
|
|
deriveCodeChallenge := func() (string, error) { return "", nil }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
if rw.Code != http.StatusInternalServerError {
|
|
t.Errorf("Expected status %d, got %d", http.StatusInternalServerError, rw.Code)
|
|
}
|
|
|
|
body := rw.Body.String()
|
|
if !strings.Contains(body, "Failed to save session") {
|
|
t.Errorf("Expected 'Failed to save session' in response body, got '%s'", body)
|
|
}
|
|
|
|
if len(logger.errorMessages) == 0 {
|
|
t.Error("Expected error to be logged")
|
|
}
|
|
|
|
// Verify session was prepared correctly before the save failure
|
|
if session.incomingPath != "/test?param=value" {
|
|
t.Errorf("Expected incoming path '/test?param=value', got '%s'", session.incomingPath)
|
|
}
|
|
|
|
if session.nonce != "test-nonce" {
|
|
t.Errorf("Expected nonce 'test-nonce', got '%s'", session.nonce)
|
|
}
|
|
|
|
if session.redirectCount != 1 {
|
|
t.Errorf("Expected redirect count 1, got %d", session.redirectCount)
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_InitiateAuthentication_Success tests successful authentication initiation
|
|
func TestAuthHandler_InitiateAuthentication_Success(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
|
"test-client", "https://example.com/auth", "https://example.com", []string{"openid", "email"}, false)
|
|
|
|
session := &mockSessionData{}
|
|
req := httptest.NewRequest("GET", "/protected/resource", nil)
|
|
rw := httptest.NewRecorder()
|
|
|
|
generateNonce := func() (string, error) { return "generated-nonce", nil }
|
|
generateCodeVerifier := func() (string, error) { return "generated-verifier", nil }
|
|
deriveCodeChallenge := func() (string, error) { return "generated-challenge", nil }
|
|
|
|
handler.InitiateAuthentication(rw, req, session, "https://example.com/callback",
|
|
generateNonce, generateCodeVerifier, deriveCodeChallenge)
|
|
|
|
// Should redirect
|
|
if rw.Code != http.StatusFound {
|
|
t.Errorf("Expected status %d, got %d", http.StatusFound, rw.Code)
|
|
}
|
|
|
|
location := rw.Header().Get("Location")
|
|
if location == "" {
|
|
t.Error("Expected Location header to be set")
|
|
}
|
|
|
|
// Parse the redirect URL to verify parameters
|
|
parsedURL, err := url.Parse(location)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse redirect URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
|
|
// Verify required parameters
|
|
if query.Get("client_id") != "test-client" {
|
|
t.Errorf("Expected client_id 'test-client', got '%s'", query.Get("client_id"))
|
|
}
|
|
|
|
if query.Get("response_type") != "code" {
|
|
t.Errorf("Expected response_type 'code', got '%s'", query.Get("response_type"))
|
|
}
|
|
|
|
if query.Get("redirect_uri") != "https://example.com/callback" {
|
|
t.Errorf("Expected redirect_uri 'https://example.com/callback', got '%s'", query.Get("redirect_uri"))
|
|
}
|
|
|
|
if query.Get("nonce") != "generated-nonce" {
|
|
t.Errorf("Expected nonce 'generated-nonce', got '%s'", query.Get("nonce"))
|
|
}
|
|
|
|
// Verify PKCE parameters
|
|
if query.Get("code_challenge") != "generated-challenge" {
|
|
t.Errorf("Expected code_challenge 'generated-challenge', got '%s'", query.Get("code_challenge"))
|
|
}
|
|
|
|
if query.Get("code_challenge_method") != "S256" {
|
|
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
|
|
}
|
|
|
|
// Verify scope
|
|
scope := query.Get("scope")
|
|
if !strings.Contains(scope, "openid") || !strings.Contains(scope, "email") {
|
|
t.Errorf("Expected scope to contain 'openid' and 'email', got '%s'", scope)
|
|
}
|
|
|
|
// Verify session was updated correctly
|
|
if !session.dirty {
|
|
t.Error("Expected session to be marked dirty")
|
|
}
|
|
|
|
if session.incomingPath != "/protected/resource" {
|
|
t.Errorf("Expected incoming path '/protected/resource', got '%s'", session.incomingPath)
|
|
}
|
|
|
|
if session.nonce != "generated-nonce" {
|
|
t.Errorf("Expected session nonce 'generated-nonce', got '%s'", session.nonce)
|
|
}
|
|
|
|
if session.codeVerifier != "generated-verifier" {
|
|
t.Errorf("Expected session code verifier 'generated-verifier', got '%s'", session.codeVerifier)
|
|
}
|
|
|
|
// Verify session data was cleared
|
|
if session.authenticated {
|
|
t.Error("Expected session to not be authenticated")
|
|
}
|
|
|
|
if session.email != "" {
|
|
t.Errorf("Expected email to be cleared, got '%s'", session.email)
|
|
}
|
|
|
|
if session.accessToken != "" {
|
|
t.Errorf("Expected access token to be cleared, got '%s'", session.accessToken)
|
|
}
|
|
|
|
if session.idToken != "" {
|
|
t.Errorf("Expected ID token to be cleared, got '%s'", session.idToken)
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_BuildAuthURL_GoogleProvider tests Google-specific URL building
|
|
func TestAuthHandler_BuildAuthURL_GoogleProvider(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return true }, func() bool { return false },
|
|
"google-client", "https://accounts.google.com/oauth2/auth", "https://accounts.google.com",
|
|
[]string{"openid", "profile", "email"}, false)
|
|
|
|
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
|
|
|
parsedURL, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
|
|
// Google-specific parameters
|
|
if query.Get("access_type") != "offline" {
|
|
t.Errorf("Expected access_type 'offline' for Google, got '%s'", query.Get("access_type"))
|
|
}
|
|
|
|
if query.Get("prompt") != "consent" {
|
|
t.Errorf("Expected prompt 'consent' for Google, got '%s'", query.Get("prompt"))
|
|
}
|
|
|
|
// Standard parameters should still be present
|
|
if query.Get("client_id") != "google-client" {
|
|
t.Errorf("Expected client_id 'google-client', got '%s'", query.Get("client_id"))
|
|
}
|
|
|
|
if query.Get("state") != "test-state" {
|
|
t.Errorf("Expected state 'test-state', got '%s'", query.Get("state"))
|
|
}
|
|
|
|
if query.Get("nonce") != "test-nonce" {
|
|
t.Errorf("Expected nonce 'test-nonce', got '%s'", query.Get("nonce"))
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_BuildAuthURL_AzureProvider tests Azure-specific URL building
|
|
func TestAuthHandler_BuildAuthURL_AzureProvider(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return true },
|
|
"azure-client", "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
|
"https://login.microsoftonline.com/tenant/v2.0",
|
|
[]string{"openid", "profile", "email"}, false)
|
|
|
|
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
|
|
|
parsedURL, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
|
|
// Azure-specific parameters
|
|
if query.Get("response_mode") != "query" {
|
|
t.Errorf("Expected response_mode 'query' for Azure, got '%s'", query.Get("response_mode"))
|
|
}
|
|
|
|
// Azure should add offline_access scope automatically
|
|
scope := query.Get("scope")
|
|
if !strings.Contains(scope, "offline_access") {
|
|
t.Errorf("Expected scope to contain 'offline_access' for Azure, got '%s'", scope)
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_BuildAuthURL_PKCEEnabled tests PKCE parameter inclusion
|
|
func TestAuthHandler_BuildAuthURL_PKCEEnabled(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, true, func() bool { return false }, func() bool { return false },
|
|
"pkce-client", "https://example.com/auth", "https://example.com",
|
|
[]string{"openid"}, false)
|
|
|
|
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
|
|
|
parsedURL, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
|
|
if query.Get("code_challenge") != "test-challenge" {
|
|
t.Errorf("Expected code_challenge 'test-challenge', got '%s'", query.Get("code_challenge"))
|
|
}
|
|
|
|
if query.Get("code_challenge_method") != "S256" {
|
|
t.Errorf("Expected code_challenge_method 'S256', got '%s'", query.Get("code_challenge_method"))
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_BuildAuthURL_PKCEDisabled tests when PKCE is disabled
|
|
func TestAuthHandler_BuildAuthURL_PKCEDisabled(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return false },
|
|
"no-pkce-client", "https://example.com/auth", "https://example.com",
|
|
[]string{"openid"}, false)
|
|
|
|
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "test-challenge")
|
|
|
|
parsedURL, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
|
|
query := parsedURL.Query()
|
|
|
|
// PKCE parameters should not be included
|
|
if query.Get("code_challenge") != "" {
|
|
t.Errorf("Expected no code_challenge when PKCE disabled, got '%s'", query.Get("code_challenge"))
|
|
}
|
|
|
|
if query.Get("code_challenge_method") != "" {
|
|
t.Errorf("Expected no code_challenge_method when PKCE disabled, got '%s'", query.Get("code_challenge_method"))
|
|
}
|
|
}
|
|
|
|
// TestAuthHandler_BuildAuthURL_ScopeHandling tests various scope configurations
|
|
func TestAuthHandler_BuildAuthURL_ScopeHandling(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
scopes []string
|
|
overrideScopes bool
|
|
isAzure bool
|
|
expectedScopes []string
|
|
}{
|
|
{
|
|
name: "Basic scopes",
|
|
scopes: []string{"openid", "profile", "email"},
|
|
overrideScopes: false,
|
|
isAzure: false,
|
|
expectedScopes: []string{"openid", "profile", "email", "offline_access"},
|
|
},
|
|
{
|
|
name: "Azure with offline_access already present",
|
|
scopes: []string{"openid", "profile", "offline_access"},
|
|
overrideScopes: false,
|
|
isAzure: true,
|
|
expectedScopes: []string{"openid", "profile", "offline_access"},
|
|
},
|
|
{
|
|
name: "Azure auto-add offline_access",
|
|
scopes: []string{"openid", "profile"},
|
|
overrideScopes: false,
|
|
isAzure: true,
|
|
expectedScopes: []string{"openid", "profile", "offline_access"},
|
|
},
|
|
{
|
|
name: "Override scopes with empty array",
|
|
scopes: []string{},
|
|
overrideScopes: true,
|
|
isAzure: true,
|
|
expectedScopes: []string{"offline_access"},
|
|
},
|
|
{
|
|
name: "Override scopes prevents auto-add",
|
|
scopes: []string{"openid", "custom_scope"},
|
|
overrideScopes: true,
|
|
isAzure: true,
|
|
expectedScopes: []string{"openid", "custom_scope"},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
logger := &mockLogger{}
|
|
handler := NewAuthHandler(logger, false, func() bool { return false }, func() bool { return tt.isAzure },
|
|
"test-client", "https://example.com/auth", "https://example.com",
|
|
tt.scopes, tt.overrideScopes)
|
|
|
|
authURL := handler.BuildAuthURL("https://example.com/callback", "test-state", "test-nonce", "")
|
|
|
|
parsedURL, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse auth URL: %v", err)
|
|
}
|
|
|
|
actualScope := parsedURL.Query().Get("scope")
|
|
actualScopes := strings.Split(actualScope, " ")
|
|
|
|
// Check each expected scope is present
|
|
for _, expectedScope := range tt.expectedScopes {
|
|
found := false
|
|
for _, actualScope := range actualScopes {
|
|
if actualScope == expectedScope {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Expected scope '%s' not found in '%s'", expectedScope, actualScope)
|
|
}
|
|
}
|
|
|
|
// Check no unexpected scopes are present
|
|
for _, actualScope := range actualScopes {
|
|
if actualScope == "" {
|
|
continue // Skip empty strings from split
|
|
}
|
|
found := false
|
|
for _, expectedScope := range tt.expectedScopes {
|
|
if actualScope == expectedScope {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
t.Errorf("Unexpected scope '%s' found in '%s'", actualScope, parsedURL.Query().Get("scope"))
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test helper type for errors
|
|
type testError struct {
|
|
message string
|
|
}
|
|
|
|
func (e *testError) Error() string {
|
|
return e.message
|
|
}
|