mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-05 22:44:17 +00:00
c3f23cb99b
* Resolve issue with opaque tokens not being parsed correctly * Increase test coverage * Further improvements to test coverage and code quality * Add new providers. * fixup! Add new providers. * Cleanup. * fixup! Cleanup. * fixup! fixup! Cleanup. * fixup! fixup! fixup! Cleanup. * fixup! fixup! fixup! fixup! Cleanup. * Memory management optimisation 24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference) * Pooling cleanup.
1110 lines
31 KiB
Go
1110 lines
31 KiB
Go
package providers
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
internalproviders "github.com/lukaszraczylo/traefikoidc/internal/providers"
|
|
)
|
|
|
|
// ============================================================================
|
|
// Mock Implementations
|
|
// ============================================================================
|
|
|
|
// mockSession implements the Session interface for testing
|
|
type mockSession struct {
|
|
idToken string
|
|
accessToken string
|
|
refreshToken string
|
|
authenticated bool
|
|
}
|
|
|
|
func (m *mockSession) GetIDToken() string { return m.idToken }
|
|
func (m *mockSession) GetAccessToken() string { return m.accessToken }
|
|
func (m *mockSession) GetRefreshToken() string { return m.refreshToken }
|
|
func (m *mockSession) GetAuthenticated() bool { return m.authenticated }
|
|
|
|
// mockTokenVerifier implements TokenVerifier for testing
|
|
type mockTokenVerifier struct {
|
|
shouldFail bool
|
|
expiredTokens map[string]bool
|
|
}
|
|
|
|
func (m *mockTokenVerifier) VerifyToken(token string) error {
|
|
if m.shouldFail {
|
|
return errors.New("token verification failed")
|
|
}
|
|
if m.expiredTokens != nil && m.expiredTokens[token] {
|
|
return errors.New("token has expired")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// mockTokenCache implements TokenCache for testing
|
|
type mockTokenCache struct {
|
|
data map[string]map[string]interface{}
|
|
}
|
|
|
|
func (m *mockTokenCache) Get(key string) (map[string]interface{}, bool) {
|
|
if m.data == nil {
|
|
return nil, false
|
|
}
|
|
claims, exists := m.data[key]
|
|
return claims, exists
|
|
}
|
|
|
|
// mockLegacySettings implements LegacySettings for testing
|
|
//
|
|
//lint:ignore U1000 Used in tests but staticcheck can't detect the interface implementation
|
|
type mockLegacySettings struct {
|
|
issuerURL string
|
|
authURL string
|
|
scopes []string
|
|
pkceEnabled bool
|
|
clientID string
|
|
refreshGracePeriod time.Duration
|
|
overrideScopes bool
|
|
}
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) GetIssuerURL() string { return m.issuerURL }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) GetAuthURL() string { return m.authURL }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) GetScopes() []string { return m.scopes }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) IsPKCEEnabled() bool { return m.pkceEnabled }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) GetClientID() string { return m.clientID }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) GetRefreshGracePeriod() time.Duration { return m.refreshGracePeriod }
|
|
|
|
//lint:ignore U1000 Interface method for LegacySettings
|
|
func (m *mockLegacySettings) IsOverrideScopes() bool { return m.overrideScopes }
|
|
|
|
// ============================================================================
|
|
// Azure Provider Tests
|
|
// ============================================================================
|
|
|
|
func TestAzureProvider(t *testing.T) {
|
|
t.Run("NewAzureProvider", func(t *testing.T) {
|
|
provider := internalproviders.NewAzureProvider()
|
|
if provider == nil {
|
|
t.Fatal("expected non-nil Azure provider")
|
|
}
|
|
if provider.BaseProvider == nil {
|
|
t.Fatal("expected non-nil BaseProvider")
|
|
}
|
|
})
|
|
|
|
t.Run("GetType", func(t *testing.T) {
|
|
provider := internalproviders.NewAzureProvider()
|
|
if got := provider.GetType(); got != internalproviders.ProviderTypeAzure {
|
|
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeAzure, got)
|
|
}
|
|
})
|
|
|
|
t.Run("GetCapabilities", func(t *testing.T) {
|
|
provider := internalproviders.NewAzureProvider()
|
|
capabilities := provider.GetCapabilities()
|
|
|
|
tests := []struct {
|
|
name string
|
|
field string
|
|
expected interface{}
|
|
got interface{}
|
|
}{
|
|
{"SupportsRefreshTokens", "SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
|
|
{"RequiresOfflineAccessScope", "RequiresOfflineAccessScope", true, capabilities.RequiresOfflineAccessScope},
|
|
{"PreferredTokenValidation", "PreferredTokenValidation", "access", capabilities.PreferredTokenValidation},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if tt.expected != tt.got {
|
|
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("BuildAuthParams", func(t *testing.T) {
|
|
provider := internalproviders.NewAzureProvider()
|
|
|
|
tests := []struct {
|
|
name string
|
|
baseParams url.Values
|
|
scopes []string
|
|
expectOfflineAccess bool
|
|
}{
|
|
{
|
|
name: "with offline_access scope",
|
|
baseParams: url.Values{"client_id": []string{"test-client"}},
|
|
scopes: []string{"openid", "offline_access", "email"},
|
|
expectOfflineAccess: true,
|
|
},
|
|
{
|
|
name: "without offline_access scope",
|
|
baseParams: url.Values{"client_id": []string{"test-client"}},
|
|
scopes: []string{"openid", "email"},
|
|
expectOfflineAccess: true, // Should be added automatically
|
|
},
|
|
{
|
|
name: "empty scopes",
|
|
baseParams: url.Values{},
|
|
scopes: []string{},
|
|
expectOfflineAccess: true, // Should be added automatically
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if authParams == nil {
|
|
t.Fatal("expected non-nil auth params")
|
|
}
|
|
|
|
// Check offline_access scope
|
|
if tt.expectOfflineAccess {
|
|
hasOfflineAccess := false
|
|
for _, scope := range authParams.Scopes {
|
|
if scope == "offline_access" {
|
|
hasOfflineAccess = true
|
|
break
|
|
}
|
|
}
|
|
if !hasOfflineAccess {
|
|
t.Error("expected offline_access scope to be present")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Google Provider Tests
|
|
// ============================================================================
|
|
|
|
func TestGoogleProvider(t *testing.T) {
|
|
t.Run("internalproviders.NewGoogleProvider", func(t *testing.T) {
|
|
provider := internalproviders.NewGoogleProvider()
|
|
if provider == nil {
|
|
t.Fatal("expected non-nil Google provider")
|
|
}
|
|
if provider.BaseProvider == nil {
|
|
t.Fatal("expected non-nil BaseProvider")
|
|
}
|
|
})
|
|
|
|
t.Run("GetType", func(t *testing.T) {
|
|
provider := internalproviders.NewGoogleProvider()
|
|
if got := provider.GetType(); got != internalproviders.ProviderTypeGoogle {
|
|
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeGoogle, got)
|
|
}
|
|
})
|
|
|
|
t.Run("GetCapabilities", func(t *testing.T) {
|
|
provider := internalproviders.NewGoogleProvider()
|
|
capabilities := provider.GetCapabilities()
|
|
|
|
tests := []struct {
|
|
name string
|
|
field string
|
|
expected interface{}
|
|
got interface{}
|
|
}{
|
|
{"SupportsRefreshTokens", "SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
|
|
{"RequiresOfflineAccessScope", "RequiresOfflineAccessScope", false, capabilities.RequiresOfflineAccessScope},
|
|
{"RequiresPromptConsent", "RequiresPromptConsent", true, capabilities.RequiresPromptConsent},
|
|
{"PreferredTokenValidation", "PreferredTokenValidation", "id", capabilities.PreferredTokenValidation},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if tt.expected != tt.got {
|
|
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("BuildAuthParams", func(t *testing.T) {
|
|
provider := internalproviders.NewGoogleProvider()
|
|
|
|
tests := []struct {
|
|
name string
|
|
baseParams url.Values
|
|
scopes []string
|
|
expectAccessTypeOffline bool
|
|
expectPromptConsent bool
|
|
expectOfflineAccessRemoved bool
|
|
}{
|
|
{
|
|
name: "basic params with offline_access scope",
|
|
baseParams: url.Values{"client_id": []string{"test-client"}},
|
|
scopes: []string{"openid", "offline_access", "email"},
|
|
expectAccessTypeOffline: true,
|
|
expectPromptConsent: true,
|
|
expectOfflineAccessRemoved: true,
|
|
},
|
|
{
|
|
name: "basic params without offline_access scope",
|
|
baseParams: url.Values{"client_id": []string{"test-client"}},
|
|
scopes: []string{"openid", "email"},
|
|
expectAccessTypeOffline: true,
|
|
expectPromptConsent: true,
|
|
expectOfflineAccessRemoved: false,
|
|
},
|
|
{
|
|
name: "empty scopes",
|
|
baseParams: url.Values{},
|
|
scopes: []string{},
|
|
expectAccessTypeOffline: true,
|
|
expectPromptConsent: true,
|
|
expectOfflineAccessRemoved: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if authParams == nil {
|
|
t.Fatal("expected non-nil auth params")
|
|
}
|
|
|
|
// Check access_type parameter
|
|
if tt.expectAccessTypeOffline {
|
|
if authParams.URLValues.Get("access_type") != "offline" {
|
|
t.Error("expected access_type to be 'offline'")
|
|
}
|
|
}
|
|
|
|
// Check prompt parameter
|
|
if tt.expectPromptConsent {
|
|
if authParams.URLValues.Get("prompt") != "consent" {
|
|
t.Error("expected prompt to be 'consent'")
|
|
}
|
|
}
|
|
|
|
// Check offline_access scope removal
|
|
hasOfflineAccess := false
|
|
for _, scope := range authParams.Scopes {
|
|
if scope == "offline_access" {
|
|
hasOfflineAccess = true
|
|
break
|
|
}
|
|
}
|
|
if tt.expectOfflineAccessRemoved && hasOfflineAccess {
|
|
t.Error("expected offline_access scope to be removed")
|
|
}
|
|
if !tt.expectOfflineAccessRemoved && !hasOfflineAccess && containsString(tt.scopes, "offline_access") {
|
|
t.Error("expected offline_access scope to be preserved")
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Base Provider Tests
|
|
// ============================================================================
|
|
|
|
func TestBaseProvider(t *testing.T) {
|
|
t.Run("GetType", func(t *testing.T) {
|
|
provider := internalproviders.NewGenericProvider()
|
|
if got := provider.GetType(); got != internalproviders.ProviderTypeGeneric {
|
|
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeGeneric, got)
|
|
}
|
|
})
|
|
|
|
t.Run("GetCapabilities", func(t *testing.T) {
|
|
provider := internalproviders.NewGenericProvider()
|
|
capabilities := provider.GetCapabilities()
|
|
|
|
tests := []struct {
|
|
name string
|
|
expected interface{}
|
|
got interface{}
|
|
}{
|
|
{"SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
|
|
{"RequiresOfflineAccessScope", true, capabilities.RequiresOfflineAccessScope},
|
|
{"PreferredTokenValidation", "id", capabilities.PreferredTokenValidation},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
if tt.expected != tt.got {
|
|
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("ValidateTokenExpiry", func(t *testing.T) {
|
|
provider := internalproviders.NewGenericProvider()
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
session *mockSession
|
|
cache *mockTokenCache
|
|
expectedResult *internalproviders.ValidationResult
|
|
}{
|
|
{
|
|
name: "token not in cache with refresh token",
|
|
token: "missing-token",
|
|
session: &mockSession{
|
|
refreshToken: "refresh-token",
|
|
},
|
|
cache: &mockTokenCache{
|
|
data: map[string]map[string]interface{}{},
|
|
},
|
|
expectedResult: &internalproviders.ValidationResult{
|
|
Authenticated: false,
|
|
NeedsRefresh: true,
|
|
},
|
|
},
|
|
{
|
|
name: "token not in cache without refresh token",
|
|
token: "missing-token",
|
|
session: &mockSession{
|
|
refreshToken: "",
|
|
},
|
|
cache: &mockTokenCache{
|
|
data: map[string]map[string]interface{}{},
|
|
},
|
|
expectedResult: &internalproviders.ValidationResult{
|
|
Authenticated: false,
|
|
NeedsRefresh: false,
|
|
},
|
|
},
|
|
{
|
|
name: "valid token in cache",
|
|
token: "valid-token",
|
|
session: &mockSession{
|
|
refreshToken: "refresh-token",
|
|
},
|
|
cache: &mockTokenCache{
|
|
data: map[string]map[string]interface{}{
|
|
"valid-token": {
|
|
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
|
|
},
|
|
},
|
|
},
|
|
expectedResult: &internalproviders.ValidationResult{
|
|
Authenticated: true,
|
|
NeedsRefresh: false,
|
|
},
|
|
},
|
|
{
|
|
name: "expired token with refresh token",
|
|
token: "expired-token",
|
|
session: &mockSession{
|
|
refreshToken: "refresh-token",
|
|
},
|
|
cache: &mockTokenCache{
|
|
data: map[string]map[string]interface{}{
|
|
"expired-token": {
|
|
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()),
|
|
},
|
|
},
|
|
},
|
|
expectedResult: &internalproviders.ValidationResult{
|
|
Authenticated: true,
|
|
NeedsRefresh: true,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result, err := provider.ValidateTokenExpiry(tt.session, tt.token, tt.cache, 5*time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("ValidateTokenExpiry failed: %v", err)
|
|
}
|
|
|
|
if result == nil {
|
|
t.Fatal("expected non-nil result")
|
|
}
|
|
|
|
if result.Authenticated != tt.expectedResult.Authenticated {
|
|
t.Errorf("expected Authenticated %v, got %v", tt.expectedResult.Authenticated, result.Authenticated)
|
|
}
|
|
|
|
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
|
t.Errorf("expected NeedsRefresh %v, got %v", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
|
}
|
|
|
|
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
|
t.Errorf("expected NeedsRefresh %v, got %v", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Factory Tests
|
|
// ============================================================================
|
|
|
|
func TestProviderFactory(t *testing.T) {
|
|
t.Run("NewProviderFactory", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
if factory == nil {
|
|
t.Fatal("expected non-nil factory")
|
|
}
|
|
})
|
|
|
|
t.Run("CreateProvider", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
|
|
tests := []struct {
|
|
name string
|
|
issuerURL string
|
|
wantType internalproviders.ProviderType
|
|
wantError bool
|
|
errorSubstr string
|
|
}{
|
|
{
|
|
name: "Google provider detection",
|
|
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
|
wantType: internalproviders.ProviderTypeGoogle,
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "Azure provider detection - login.microsoftonline.com",
|
|
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
|
wantType: internalproviders.ProviderTypeAzure,
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "Azure provider detection - sts.windows.net",
|
|
issuerURL: "https://sts.windows.net/tenant-id/",
|
|
wantType: internalproviders.ProviderTypeAzure,
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "Generic provider detection",
|
|
issuerURL: "https://auth.example.com/realms/test",
|
|
wantType: internalproviders.ProviderTypeGeneric,
|
|
wantError: false,
|
|
},
|
|
{
|
|
name: "Empty issuer URL",
|
|
issuerURL: "",
|
|
wantError: true,
|
|
errorSubstr: "issuer URL cannot be empty",
|
|
},
|
|
{
|
|
name: "Invalid URL format",
|
|
issuerURL: "not-a-valid-url",
|
|
wantError: true,
|
|
errorSubstr: "invalid issuer URL format",
|
|
},
|
|
{
|
|
name: "URL with invalid scheme",
|
|
issuerURL: "ftp://example.com/auth",
|
|
wantType: internalproviders.ProviderTypeGeneric,
|
|
wantError: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider, err := factory.CreateProvider(tt.issuerURL)
|
|
|
|
if tt.wantError {
|
|
if err == nil {
|
|
t.Errorf("expected error but got none")
|
|
return
|
|
}
|
|
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
|
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if provider == nil {
|
|
t.Error("expected non-nil provider")
|
|
return
|
|
}
|
|
|
|
if provider.GetType() != tt.wantType {
|
|
t.Errorf("expected provider type %d, got %d", tt.wantType, provider.GetType())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("ConcurrentProviderCreation", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
urls := []string{
|
|
"https://accounts.google.com/.well-known/openid_configuration",
|
|
"https://login.microsoftonline.com/tenant-id/v2.0",
|
|
"https://auth.example.com/realms/test",
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, len(urls)*10)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
for _, url := range urls {
|
|
wg.Add(1)
|
|
go func(issuerURL string) {
|
|
defer wg.Done()
|
|
provider, err := factory.CreateProvider(issuerURL)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
if provider == nil {
|
|
errors <- fmt.Errorf("got nil provider for %s", issuerURL)
|
|
}
|
|
}(url)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
t.Errorf("concurrent creation error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Registry Tests
|
|
// ============================================================================
|
|
|
|
func TestProviderRegistry(t *testing.T) {
|
|
t.Run("NewProviderRegistry", func(t *testing.T) {
|
|
registry := internalproviders.NewProviderRegistry()
|
|
if registry == nil {
|
|
t.Fatal("expected non-nil registry")
|
|
}
|
|
})
|
|
|
|
t.Run("RegisterAndGet", func(t *testing.T) {
|
|
registry := internalproviders.NewProviderRegistry()
|
|
|
|
// Register providers
|
|
googleProvider := internalproviders.NewGoogleProvider()
|
|
azureProvider := internalproviders.NewAzureProvider()
|
|
|
|
registry.RegisterProvider(googleProvider)
|
|
registry.RegisterProvider(azureProvider)
|
|
|
|
// Test getting registered providers
|
|
tests := []struct {
|
|
name string
|
|
providerType internalproviders.ProviderType
|
|
shouldExist bool
|
|
}{
|
|
{"Get Google provider", internalproviders.ProviderTypeGoogle, true},
|
|
{"Get Azure provider", internalproviders.ProviderTypeAzure, true},
|
|
{"Get unregistered provider", internalproviders.ProviderType(999), false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider := registry.GetProviderByType(tt.providerType)
|
|
|
|
if tt.shouldExist {
|
|
if provider == nil {
|
|
t.Error("expected non-nil provider")
|
|
}
|
|
} else {
|
|
if provider != nil {
|
|
t.Error("expected nil provider")
|
|
}
|
|
}
|
|
|
|
if tt.shouldExist && provider != nil && provider.GetType() != tt.providerType {
|
|
t.Errorf("expected provider type %d, got %d", tt.providerType, provider.GetType())
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("Detectinternalproviders.ProviderType", func(t *testing.T) {
|
|
registry := internalproviders.NewProviderRegistry()
|
|
|
|
// Register providers needed for detection
|
|
registry.RegisterProvider(internalproviders.NewGoogleProvider())
|
|
registry.RegisterProvider(internalproviders.NewAzureProvider())
|
|
registry.RegisterProvider(internalproviders.NewGenericProvider())
|
|
|
|
tests := []struct {
|
|
name string
|
|
issuerURL string
|
|
expectedType internalproviders.ProviderType
|
|
}{
|
|
{"Google URL", "https://accounts.google.com/.well-known/openid_configuration", internalproviders.ProviderTypeGoogle},
|
|
{"Azure login.microsoftonline.com", "https://login.microsoftonline.com/tenant/v2.0", internalproviders.ProviderTypeAzure},
|
|
{"Azure sts.windows.net", "https://sts.windows.net/tenant/", internalproviders.ProviderTypeAzure},
|
|
{"Generic provider", "https://auth.example.com/realms/test", internalproviders.ProviderTypeGeneric},
|
|
// Empty URL should return nil, not a provider
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider := registry.DetectProvider(tt.issuerURL)
|
|
if provider == nil {
|
|
t.Fatalf("DetectProvider returned nil for URL: %s", tt.issuerURL)
|
|
}
|
|
providerType := provider.GetType()
|
|
if providerType != tt.expectedType {
|
|
t.Errorf("expected provider type %d, got %d", tt.expectedType, providerType)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test empty URL separately - it should return nil
|
|
t.Run("Empty URL", func(t *testing.T) {
|
|
provider := registry.DetectProvider("")
|
|
if provider != nil {
|
|
t.Errorf("expected nil provider for empty URL, got %v", provider)
|
|
}
|
|
})
|
|
})
|
|
|
|
t.Run("ConcurrentAccess", func(t *testing.T) {
|
|
registry := internalproviders.NewProviderRegistry()
|
|
|
|
// Register initial provider
|
|
registry.RegisterProvider(internalproviders.NewGoogleProvider())
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, 100)
|
|
|
|
// Concurrent reads
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
provider := registry.GetProviderByType(internalproviders.ProviderTypeGoogle)
|
|
if provider == nil {
|
|
errors <- fmt.Errorf("provider not found")
|
|
return
|
|
}
|
|
if provider == nil {
|
|
errors <- fmt.Errorf("got nil provider")
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Concurrent writes
|
|
for i := 0; i < 50; i++ {
|
|
wg.Add(1)
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
registry.RegisterProvider(internalproviders.NewGenericProvider())
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
t.Errorf("concurrent access error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Adapter Tests
|
|
// ============================================================================
|
|
// NOTE: Adapter tests commented out due to API mismatch - actual NewAdapter requires
|
|
// (provider, settings, verifier, cache) parameters, not factory
|
|
/*
|
|
func TestProviderAdapter(t *testing.T) {
|
|
t.Run("internalproviders.NewAdapter", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
adapter := internalproviders.NewAdapter(factory)
|
|
|
|
if adapter == nil {
|
|
t.Fatal("expected non-nil adapter")
|
|
}
|
|
if adapter.factory == nil {
|
|
t.Fatal("expected non-nil factory in adapter")
|
|
}
|
|
})
|
|
|
|
t.Run("AdaptLegacySettings", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
adapter := internalproviders.NewAdapter(factory)
|
|
|
|
tests := []struct {
|
|
name string
|
|
settings *mockLegacySettings
|
|
expectedType internalproviders.ProviderType
|
|
expectedScopes []string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "Google provider settings",
|
|
settings: &mockLegacySettings{
|
|
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
|
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
|
scopes: []string{"openid", "email", "profile"},
|
|
pkceEnabled: true,
|
|
clientID: "google-client-id",
|
|
overrideScopes: false,
|
|
},
|
|
expectedType: internalproviders.ProviderTypeGoogle,
|
|
expectedScopes: []string{"openid", "email", "profile"},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Azure provider settings",
|
|
settings: &mockLegacySettings{
|
|
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
|
authURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
|
|
scopes: []string{"openid", "offline_access"},
|
|
pkceEnabled: false,
|
|
clientID: "azure-client-id",
|
|
overrideScopes: false,
|
|
},
|
|
expectedType: internalproviders.ProviderTypeAzure,
|
|
expectedScopes: []string{"openid", "offline_access"},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Generic provider settings",
|
|
settings: &mockLegacySettings{
|
|
issuerURL: "https://auth.example.com/realms/test",
|
|
authURL: "https://auth.example.com/realms/test/protocol/openid-connect/auth",
|
|
scopes: []string{"openid"},
|
|
pkceEnabled: true,
|
|
clientID: "generic-client-id",
|
|
overrideScopes: true,
|
|
},
|
|
expectedType: internalproviders.ProviderTypeGeneric,
|
|
expectedScopes: []string{"openid"},
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "Empty issuer URL",
|
|
settings: &mockLegacySettings{
|
|
issuerURL: "",
|
|
authURL: "https://auth.example.com/auth",
|
|
scopes: []string{"openid"},
|
|
clientID: "client-id",
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
provider, authParams, err := adapter.AdaptLegacySettings(tt.settings)
|
|
|
|
if tt.expectError {
|
|
if err == nil {
|
|
t.Error("expected error but got none")
|
|
}
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
return
|
|
}
|
|
|
|
if provider == nil {
|
|
t.Fatal("expected non-nil provider")
|
|
}
|
|
|
|
if provider.GetType() != tt.expectedType {
|
|
t.Errorf("expected provider type %d, got %d", tt.expectedType, provider.GetType())
|
|
}
|
|
|
|
if authParams == nil {
|
|
t.Fatal("expected non-nil auth params")
|
|
}
|
|
|
|
// Verify scopes handling
|
|
if !tt.settings.overrideScopes {
|
|
// When not overriding, provider may modify scopes
|
|
if len(authParams.Scopes) == 0 {
|
|
t.Error("expected non-empty scopes")
|
|
}
|
|
} else {
|
|
// When overriding, original scopes should be preserved
|
|
if !equalStringSlices(authParams.Scopes, tt.expectedScopes) {
|
|
t.Errorf("expected scopes %v, got %v", tt.expectedScopes, authParams.Scopes)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("ConcurrentAdaptation", func(t *testing.T) {
|
|
factory := internalproviders.NewProviderFactory()
|
|
adapter := internalproviders.NewAdapter(factory)
|
|
|
|
settings := []*mockLegacySettings{
|
|
{
|
|
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
|
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
|
scopes: []string{"openid", "email"},
|
|
clientID: "google-client",
|
|
},
|
|
{
|
|
issuerURL: "https://login.microsoftonline.com/tenant/v2.0",
|
|
authURL: "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
|
scopes: []string{"openid", "offline_access"},
|
|
clientID: "azure-client",
|
|
},
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
errors := make(chan error, len(settings)*10)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
for _, s := range settings {
|
|
wg.Add(1)
|
|
go func(setting *mockLegacySettings) {
|
|
defer wg.Done()
|
|
provider, authParams, err := adapter.AdaptLegacySettings(setting)
|
|
if err != nil {
|
|
errors <- err
|
|
return
|
|
}
|
|
if provider == nil {
|
|
errors <- fmt.Errorf("got nil provider")
|
|
return
|
|
}
|
|
if authParams == nil {
|
|
errors <- fmt.Errorf("got nil auth params")
|
|
}
|
|
}(s)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
close(errors)
|
|
|
|
for err := range errors {
|
|
t.Errorf("concurrent adaptation error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
*/
|
|
|
|
// ============================================================================
|
|
// Validation Tests
|
|
// ============================================================================
|
|
|
|
func TestTokenValidation(t *testing.T) {
|
|
t.Run("ValidateWithVerifier", func(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
verifier *mockTokenVerifier
|
|
expectValid bool
|
|
}{
|
|
{
|
|
name: "valid token",
|
|
token: "valid-token",
|
|
verifier: &mockTokenVerifier{
|
|
shouldFail: false,
|
|
},
|
|
expectValid: true,
|
|
},
|
|
{
|
|
name: "invalid token",
|
|
token: "invalid-token",
|
|
verifier: &mockTokenVerifier{
|
|
shouldFail: true,
|
|
},
|
|
expectValid: false,
|
|
},
|
|
{
|
|
name: "expired token",
|
|
token: "expired-token",
|
|
verifier: &mockTokenVerifier{
|
|
expiredTokens: map[string]bool{
|
|
"expired-token": true,
|
|
},
|
|
},
|
|
expectValid: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := tt.verifier.VerifyToken(tt.token)
|
|
isValid := err == nil
|
|
|
|
if isValid != tt.expectValid {
|
|
t.Errorf("expected valid=%v, got %v (err: %v)", tt.expectValid, isValid, err)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("ConcurrentValidation", func(t *testing.T) {
|
|
verifier := &mockTokenVerifier{
|
|
shouldFail: false,
|
|
expiredTokens: map[string]bool{
|
|
"expired-1": true,
|
|
"expired-2": true,
|
|
},
|
|
}
|
|
|
|
tokens := []string{"valid-1", "valid-2", "expired-1", "expired-2", "valid-3"}
|
|
|
|
var wg sync.WaitGroup
|
|
results := make(chan bool, len(tokens)*10)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
for _, token := range tokens {
|
|
wg.Add(1)
|
|
go func(t string) {
|
|
defer wg.Done()
|
|
err := verifier.VerifyToken(t)
|
|
results <- (err == nil)
|
|
}(token)
|
|
}
|
|
}
|
|
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
validCount := 0
|
|
invalidCount := 0
|
|
for isValid := range results {
|
|
if isValid {
|
|
validCount++
|
|
} else {
|
|
invalidCount++
|
|
}
|
|
}
|
|
|
|
expectedValid := 30 // 3 valid tokens * 10 iterations
|
|
expectedInvalid := 20 // 2 expired tokens * 10 iterations
|
|
|
|
if validCount != expectedValid {
|
|
t.Errorf("expected %d valid results, got %d", expectedValid, validCount)
|
|
}
|
|
if invalidCount != expectedInvalid {
|
|
t.Errorf("expected %d invalid results, got %d", expectedInvalid, invalidCount)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Memory Management Tests
|
|
// ============================================================================
|
|
|
|
func TestProviderMemoryManagement(t *testing.T) {
|
|
t.Run("FactoryMemoryUsage", func(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping memory test in short mode")
|
|
}
|
|
|
|
var m runtime.MemStats
|
|
runtime.GC()
|
|
runtime.ReadMemStats(&m)
|
|
initialAlloc := m.Alloc
|
|
|
|
factory := internalproviders.NewProviderFactory()
|
|
|
|
// Create many providers
|
|
providers := make([]internalproviders.OIDCProvider, 0, 1000)
|
|
for i := 0; i < 1000; i++ {
|
|
var provider internalproviders.OIDCProvider
|
|
var err error
|
|
|
|
switch i % 3 {
|
|
case 0:
|
|
provider, err = factory.CreateProvider("https://accounts.google.com/.well-known/openid_configuration")
|
|
case 1:
|
|
provider, err = factory.CreateProvider("https://login.microsoftonline.com/tenant/v2.0")
|
|
default:
|
|
provider, err = factory.CreateProvider("https://auth.example.com/realms/test")
|
|
}
|
|
|
|
if err != nil {
|
|
t.Fatalf("failed to create provider: %v", err)
|
|
}
|
|
providers = append(providers, provider) // keeping references to prevent GC
|
|
}
|
|
|
|
runtime.GC()
|
|
runtime.ReadMemStats(&m)
|
|
finalAlloc := m.Alloc
|
|
|
|
var memUsed, memPerProvider uint64
|
|
if finalAlloc > initialAlloc {
|
|
memUsed = finalAlloc - initialAlloc
|
|
memPerProvider = memUsed / 1000
|
|
}
|
|
|
|
// Each provider should use less than 10KB on average
|
|
if memPerProvider > 10*1024 {
|
|
t.Errorf("excessive memory usage: %d bytes per provider", memPerProvider)
|
|
}
|
|
|
|
// Use providers to satisfy staticcheck
|
|
_ = providers
|
|
// Clear references to allow GC
|
|
providers = nil
|
|
runtime.GC()
|
|
runtime.ReadMemStats(&m)
|
|
|
|
// Memory should be mostly freed
|
|
afterGC := m.Alloc
|
|
if afterGC > initialAlloc+1024*1024 { // Allow 1MB overhead
|
|
t.Errorf("memory not properly freed after GC: %d bytes still allocated", afterGC-initialAlloc)
|
|
}
|
|
})
|
|
}
|
|
|
|
// ============================================================================
|
|
// Helper Functions
|
|
// ============================================================================
|
|
|
|
func containsString(slice []string, str string) bool {
|
|
for _, s := range slice {
|
|
if s == str {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
//lint:ignore U1000 Used in tests
|
|
func equalStringSlices(a, b []string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
for i := range a {
|
|
if a[i] != b[i] {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|