Files
traefikoidc/providers/provider_consolidated_test.go
T
lukaszraczylo c3f23cb99b Release 0.7.5 (#70)
* Resolve issue with opaque tokens not being parsed correctly

* Increase test coverage

* Further improvements to test coverage and code quality

* Add new providers.

* fixup! Add new providers.

* Cleanup.

* fixup! Cleanup.

* fixup! fixup! Cleanup.

* fixup! fixup! fixup! Cleanup.

* fixup! fixup! fixup! fixup! Cleanup.

* Memory management optimisation

24 bytes per Put < 256-4096 bytes per buffer allocation avoided (10-170x difference)

* Pooling cleanup.
2025-10-01 12:13:10 +01:00

1110 lines
31 KiB
Go

package providers
import (
"errors"
"fmt"
"net/url"
"runtime"
"strings"
"sync"
"testing"
"time"
internalproviders "github.com/lukaszraczylo/traefikoidc/internal/providers"
)
// ============================================================================
// Mock Implementations
// ============================================================================
// mockSession implements the Session interface for testing
type mockSession struct {
idToken string
accessToken string
refreshToken string
authenticated bool
}
func (m *mockSession) GetIDToken() string { return m.idToken }
func (m *mockSession) GetAccessToken() string { return m.accessToken }
func (m *mockSession) GetRefreshToken() string { return m.refreshToken }
func (m *mockSession) GetAuthenticated() bool { return m.authenticated }
// mockTokenVerifier implements TokenVerifier for testing
type mockTokenVerifier struct {
shouldFail bool
expiredTokens map[string]bool
}
func (m *mockTokenVerifier) VerifyToken(token string) error {
if m.shouldFail {
return errors.New("token verification failed")
}
if m.expiredTokens != nil && m.expiredTokens[token] {
return errors.New("token has expired")
}
return nil
}
// mockTokenCache implements TokenCache for testing
type mockTokenCache struct {
data map[string]map[string]interface{}
}
func (m *mockTokenCache) Get(key string) (map[string]interface{}, bool) {
if m.data == nil {
return nil, false
}
claims, exists := m.data[key]
return claims, exists
}
// mockLegacySettings implements LegacySettings for testing
//
//lint:ignore U1000 Used in tests but staticcheck can't detect the interface implementation
type mockLegacySettings struct {
issuerURL string
authURL string
scopes []string
pkceEnabled bool
clientID string
refreshGracePeriod time.Duration
overrideScopes bool
}
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) GetIssuerURL() string { return m.issuerURL }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) GetAuthURL() string { return m.authURL }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) GetScopes() []string { return m.scopes }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) IsPKCEEnabled() bool { return m.pkceEnabled }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) GetClientID() string { return m.clientID }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) GetRefreshGracePeriod() time.Duration { return m.refreshGracePeriod }
//lint:ignore U1000 Interface method for LegacySettings
func (m *mockLegacySettings) IsOverrideScopes() bool { return m.overrideScopes }
// ============================================================================
// Azure Provider Tests
// ============================================================================
func TestAzureProvider(t *testing.T) {
t.Run("NewAzureProvider", func(t *testing.T) {
provider := internalproviders.NewAzureProvider()
if provider == nil {
t.Fatal("expected non-nil Azure provider")
}
if provider.BaseProvider == nil {
t.Fatal("expected non-nil BaseProvider")
}
})
t.Run("GetType", func(t *testing.T) {
provider := internalproviders.NewAzureProvider()
if got := provider.GetType(); got != internalproviders.ProviderTypeAzure {
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeAzure, got)
}
})
t.Run("GetCapabilities", func(t *testing.T) {
provider := internalproviders.NewAzureProvider()
capabilities := provider.GetCapabilities()
tests := []struct {
name string
field string
expected interface{}
got interface{}
}{
{"SupportsRefreshTokens", "SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
{"RequiresOfflineAccessScope", "RequiresOfflineAccessScope", true, capabilities.RequiresOfflineAccessScope},
{"PreferredTokenValidation", "PreferredTokenValidation", "access", capabilities.PreferredTokenValidation},
}
for _, tt := range tests {
if tt.expected != tt.got {
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
}
}
})
t.Run("BuildAuthParams", func(t *testing.T) {
provider := internalproviders.NewAzureProvider()
tests := []struct {
name string
baseParams url.Values
scopes []string
expectOfflineAccess bool
}{
{
name: "with offline_access scope",
baseParams: url.Values{"client_id": []string{"test-client"}},
scopes: []string{"openid", "offline_access", "email"},
expectOfflineAccess: true,
},
{
name: "without offline_access scope",
baseParams: url.Values{"client_id": []string{"test-client"}},
scopes: []string{"openid", "email"},
expectOfflineAccess: true, // Should be added automatically
},
{
name: "empty scopes",
baseParams: url.Values{},
scopes: []string{},
expectOfflineAccess: true, // Should be added automatically
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if authParams == nil {
t.Fatal("expected non-nil auth params")
}
// Check offline_access scope
if tt.expectOfflineAccess {
hasOfflineAccess := false
for _, scope := range authParams.Scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if !hasOfflineAccess {
t.Error("expected offline_access scope to be present")
}
}
})
}
})
}
// ============================================================================
// Google Provider Tests
// ============================================================================
func TestGoogleProvider(t *testing.T) {
t.Run("internalproviders.NewGoogleProvider", func(t *testing.T) {
provider := internalproviders.NewGoogleProvider()
if provider == nil {
t.Fatal("expected non-nil Google provider")
}
if provider.BaseProvider == nil {
t.Fatal("expected non-nil BaseProvider")
}
})
t.Run("GetType", func(t *testing.T) {
provider := internalproviders.NewGoogleProvider()
if got := provider.GetType(); got != internalproviders.ProviderTypeGoogle {
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeGoogle, got)
}
})
t.Run("GetCapabilities", func(t *testing.T) {
provider := internalproviders.NewGoogleProvider()
capabilities := provider.GetCapabilities()
tests := []struct {
name string
field string
expected interface{}
got interface{}
}{
{"SupportsRefreshTokens", "SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
{"RequiresOfflineAccessScope", "RequiresOfflineAccessScope", false, capabilities.RequiresOfflineAccessScope},
{"RequiresPromptConsent", "RequiresPromptConsent", true, capabilities.RequiresPromptConsent},
{"PreferredTokenValidation", "PreferredTokenValidation", "id", capabilities.PreferredTokenValidation},
}
for _, tt := range tests {
if tt.expected != tt.got {
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
}
}
})
t.Run("BuildAuthParams", func(t *testing.T) {
provider := internalproviders.NewGoogleProvider()
tests := []struct {
name string
baseParams url.Values
scopes []string
expectAccessTypeOffline bool
expectPromptConsent bool
expectOfflineAccessRemoved bool
}{
{
name: "basic params with offline_access scope",
baseParams: url.Values{"client_id": []string{"test-client"}},
scopes: []string{"openid", "offline_access", "email"},
expectAccessTypeOffline: true,
expectPromptConsent: true,
expectOfflineAccessRemoved: true,
},
{
name: "basic params without offline_access scope",
baseParams: url.Values{"client_id": []string{"test-client"}},
scopes: []string{"openid", "email"},
expectAccessTypeOffline: true,
expectPromptConsent: true,
expectOfflineAccessRemoved: false,
},
{
name: "empty scopes",
baseParams: url.Values{},
scopes: []string{},
expectAccessTypeOffline: true,
expectPromptConsent: true,
expectOfflineAccessRemoved: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if authParams == nil {
t.Fatal("expected non-nil auth params")
}
// Check access_type parameter
if tt.expectAccessTypeOffline {
if authParams.URLValues.Get("access_type") != "offline" {
t.Error("expected access_type to be 'offline'")
}
}
// Check prompt parameter
if tt.expectPromptConsent {
if authParams.URLValues.Get("prompt") != "consent" {
t.Error("expected prompt to be 'consent'")
}
}
// Check offline_access scope removal
hasOfflineAccess := false
for _, scope := range authParams.Scopes {
if scope == "offline_access" {
hasOfflineAccess = true
break
}
}
if tt.expectOfflineAccessRemoved && hasOfflineAccess {
t.Error("expected offline_access scope to be removed")
}
if !tt.expectOfflineAccessRemoved && !hasOfflineAccess && containsString(tt.scopes, "offline_access") {
t.Error("expected offline_access scope to be preserved")
}
})
}
})
}
// ============================================================================
// Base Provider Tests
// ============================================================================
func TestBaseProvider(t *testing.T) {
t.Run("GetType", func(t *testing.T) {
provider := internalproviders.NewGenericProvider()
if got := provider.GetType(); got != internalproviders.ProviderTypeGeneric {
t.Errorf("expected provider type %d, got %d", internalproviders.ProviderTypeGeneric, got)
}
})
t.Run("GetCapabilities", func(t *testing.T) {
provider := internalproviders.NewGenericProvider()
capabilities := provider.GetCapabilities()
tests := []struct {
name string
expected interface{}
got interface{}
}{
{"SupportsRefreshTokens", true, capabilities.SupportsRefreshTokens},
{"RequiresOfflineAccessScope", true, capabilities.RequiresOfflineAccessScope},
{"PreferredTokenValidation", "id", capabilities.PreferredTokenValidation},
}
for _, tt := range tests {
if tt.expected != tt.got {
t.Errorf("%s: expected %v, got %v", tt.name, tt.expected, tt.got)
}
}
})
t.Run("ValidateTokenExpiry", func(t *testing.T) {
provider := internalproviders.NewGenericProvider()
tests := []struct {
name string
token string
session *mockSession
cache *mockTokenCache
expectedResult *internalproviders.ValidationResult
}{
{
name: "token not in cache with refresh token",
token: "missing-token",
session: &mockSession{
refreshToken: "refresh-token",
},
cache: &mockTokenCache{
data: map[string]map[string]interface{}{},
},
expectedResult: &internalproviders.ValidationResult{
Authenticated: false,
NeedsRefresh: true,
},
},
{
name: "token not in cache without refresh token",
token: "missing-token",
session: &mockSession{
refreshToken: "",
},
cache: &mockTokenCache{
data: map[string]map[string]interface{}{},
},
expectedResult: &internalproviders.ValidationResult{
Authenticated: false,
NeedsRefresh: false,
},
},
{
name: "valid token in cache",
token: "valid-token",
session: &mockSession{
refreshToken: "refresh-token",
},
cache: &mockTokenCache{
data: map[string]map[string]interface{}{
"valid-token": {
"exp": float64(time.Now().Add(2 * time.Hour).Unix()),
},
},
},
expectedResult: &internalproviders.ValidationResult{
Authenticated: true,
NeedsRefresh: false,
},
},
{
name: "expired token with refresh token",
token: "expired-token",
session: &mockSession{
refreshToken: "refresh-token",
},
cache: &mockTokenCache{
data: map[string]map[string]interface{}{
"expired-token": {
"exp": float64(time.Now().Add(-1 * time.Hour).Unix()),
},
},
},
expectedResult: &internalproviders.ValidationResult{
Authenticated: true,
NeedsRefresh: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := provider.ValidateTokenExpiry(tt.session, tt.token, tt.cache, 5*time.Minute)
if err != nil {
t.Fatalf("ValidateTokenExpiry failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Authenticated != tt.expectedResult.Authenticated {
t.Errorf("expected Authenticated %v, got %v", tt.expectedResult.Authenticated, result.Authenticated)
}
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
t.Errorf("expected NeedsRefresh %v, got %v", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
}
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
t.Errorf("expected NeedsRefresh %v, got %v", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
}
})
}
})
}
// ============================================================================
// Provider Factory Tests
// ============================================================================
func TestProviderFactory(t *testing.T) {
t.Run("NewProviderFactory", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
if factory == nil {
t.Fatal("expected non-nil factory")
}
})
t.Run("CreateProvider", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
tests := []struct {
name string
issuerURL string
wantType internalproviders.ProviderType
wantError bool
errorSubstr string
}{
{
name: "Google provider detection",
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
wantType: internalproviders.ProviderTypeGoogle,
wantError: false,
},
{
name: "Azure provider detection - login.microsoftonline.com",
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
wantType: internalproviders.ProviderTypeAzure,
wantError: false,
},
{
name: "Azure provider detection - sts.windows.net",
issuerURL: "https://sts.windows.net/tenant-id/",
wantType: internalproviders.ProviderTypeAzure,
wantError: false,
},
{
name: "Generic provider detection",
issuerURL: "https://auth.example.com/realms/test",
wantType: internalproviders.ProviderTypeGeneric,
wantError: false,
},
{
name: "Empty issuer URL",
issuerURL: "",
wantError: true,
errorSubstr: "issuer URL cannot be empty",
},
{
name: "Invalid URL format",
issuerURL: "not-a-valid-url",
wantError: true,
errorSubstr: "invalid issuer URL format",
},
{
name: "URL with invalid scheme",
issuerURL: "ftp://example.com/auth",
wantType: internalproviders.ProviderTypeGeneric,
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, err := factory.CreateProvider(tt.issuerURL)
if tt.wantError {
if err == nil {
t.Errorf("expected error but got none")
return
}
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if provider == nil {
t.Error("expected non-nil provider")
return
}
if provider.GetType() != tt.wantType {
t.Errorf("expected provider type %d, got %d", tt.wantType, provider.GetType())
}
})
}
})
t.Run("ConcurrentProviderCreation", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
urls := []string{
"https://accounts.google.com/.well-known/openid_configuration",
"https://login.microsoftonline.com/tenant-id/v2.0",
"https://auth.example.com/realms/test",
}
var wg sync.WaitGroup
errors := make(chan error, len(urls)*10)
for i := 0; i < 10; i++ {
for _, url := range urls {
wg.Add(1)
go func(issuerURL string) {
defer wg.Done()
provider, err := factory.CreateProvider(issuerURL)
if err != nil {
errors <- err
return
}
if provider == nil {
errors <- fmt.Errorf("got nil provider for %s", issuerURL)
}
}(url)
}
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("concurrent creation error: %v", err)
}
})
}
// ============================================================================
// Provider Registry Tests
// ============================================================================
func TestProviderRegistry(t *testing.T) {
t.Run("NewProviderRegistry", func(t *testing.T) {
registry := internalproviders.NewProviderRegistry()
if registry == nil {
t.Fatal("expected non-nil registry")
}
})
t.Run("RegisterAndGet", func(t *testing.T) {
registry := internalproviders.NewProviderRegistry()
// Register providers
googleProvider := internalproviders.NewGoogleProvider()
azureProvider := internalproviders.NewAzureProvider()
registry.RegisterProvider(googleProvider)
registry.RegisterProvider(azureProvider)
// Test getting registered providers
tests := []struct {
name string
providerType internalproviders.ProviderType
shouldExist bool
}{
{"Get Google provider", internalproviders.ProviderTypeGoogle, true},
{"Get Azure provider", internalproviders.ProviderTypeAzure, true},
{"Get unregistered provider", internalproviders.ProviderType(999), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := registry.GetProviderByType(tt.providerType)
if tt.shouldExist {
if provider == nil {
t.Error("expected non-nil provider")
}
} else {
if provider != nil {
t.Error("expected nil provider")
}
}
if tt.shouldExist && provider != nil && provider.GetType() != tt.providerType {
t.Errorf("expected provider type %d, got %d", tt.providerType, provider.GetType())
}
})
}
})
t.Run("Detectinternalproviders.ProviderType", func(t *testing.T) {
registry := internalproviders.NewProviderRegistry()
// Register providers needed for detection
registry.RegisterProvider(internalproviders.NewGoogleProvider())
registry.RegisterProvider(internalproviders.NewAzureProvider())
registry.RegisterProvider(internalproviders.NewGenericProvider())
tests := []struct {
name string
issuerURL string
expectedType internalproviders.ProviderType
}{
{"Google URL", "https://accounts.google.com/.well-known/openid_configuration", internalproviders.ProviderTypeGoogle},
{"Azure login.microsoftonline.com", "https://login.microsoftonline.com/tenant/v2.0", internalproviders.ProviderTypeAzure},
{"Azure sts.windows.net", "https://sts.windows.net/tenant/", internalproviders.ProviderTypeAzure},
{"Generic provider", "https://auth.example.com/realms/test", internalproviders.ProviderTypeGeneric},
// Empty URL should return nil, not a provider
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider := registry.DetectProvider(tt.issuerURL)
if provider == nil {
t.Fatalf("DetectProvider returned nil for URL: %s", tt.issuerURL)
}
providerType := provider.GetType()
if providerType != tt.expectedType {
t.Errorf("expected provider type %d, got %d", tt.expectedType, providerType)
}
})
}
// Test empty URL separately - it should return nil
t.Run("Empty URL", func(t *testing.T) {
provider := registry.DetectProvider("")
if provider != nil {
t.Errorf("expected nil provider for empty URL, got %v", provider)
}
})
})
t.Run("ConcurrentAccess", func(t *testing.T) {
registry := internalproviders.NewProviderRegistry()
// Register initial provider
registry.RegisterProvider(internalproviders.NewGoogleProvider())
var wg sync.WaitGroup
errors := make(chan error, 100)
// Concurrent reads
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
provider := registry.GetProviderByType(internalproviders.ProviderTypeGoogle)
if provider == nil {
errors <- fmt.Errorf("provider not found")
return
}
if provider == nil {
errors <- fmt.Errorf("got nil provider")
}
}()
}
// Concurrent writes
for i := 0; i < 50; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
registry.RegisterProvider(internalproviders.NewGenericProvider())
}(i)
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("concurrent access error: %v", err)
}
})
}
// ============================================================================
// Provider Adapter Tests
// ============================================================================
// NOTE: Adapter tests commented out due to API mismatch - actual NewAdapter requires
// (provider, settings, verifier, cache) parameters, not factory
/*
func TestProviderAdapter(t *testing.T) {
t.Run("internalproviders.NewAdapter", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
adapter := internalproviders.NewAdapter(factory)
if adapter == nil {
t.Fatal("expected non-nil adapter")
}
if adapter.factory == nil {
t.Fatal("expected non-nil factory in adapter")
}
})
t.Run("AdaptLegacySettings", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
adapter := internalproviders.NewAdapter(factory)
tests := []struct {
name string
settings *mockLegacySettings
expectedType internalproviders.ProviderType
expectedScopes []string
expectError bool
}{
{
name: "Google provider settings",
settings: &mockLegacySettings{
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
scopes: []string{"openid", "email", "profile"},
pkceEnabled: true,
clientID: "google-client-id",
overrideScopes: false,
},
expectedType: internalproviders.ProviderTypeGoogle,
expectedScopes: []string{"openid", "email", "profile"},
expectError: false,
},
{
name: "Azure provider settings",
settings: &mockLegacySettings{
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
authURL: "https://login.microsoftonline.com/tenant-id/oauth2/v2.0/authorize",
scopes: []string{"openid", "offline_access"},
pkceEnabled: false,
clientID: "azure-client-id",
overrideScopes: false,
},
expectedType: internalproviders.ProviderTypeAzure,
expectedScopes: []string{"openid", "offline_access"},
expectError: false,
},
{
name: "Generic provider settings",
settings: &mockLegacySettings{
issuerURL: "https://auth.example.com/realms/test",
authURL: "https://auth.example.com/realms/test/protocol/openid-connect/auth",
scopes: []string{"openid"},
pkceEnabled: true,
clientID: "generic-client-id",
overrideScopes: true,
},
expectedType: internalproviders.ProviderTypeGeneric,
expectedScopes: []string{"openid"},
expectError: false,
},
{
name: "Empty issuer URL",
settings: &mockLegacySettings{
issuerURL: "",
authURL: "https://auth.example.com/auth",
scopes: []string{"openid"},
clientID: "client-id",
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
provider, authParams, err := adapter.AdaptLegacySettings(tt.settings)
if tt.expectError {
if err == nil {
t.Error("expected error but got none")
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
return
}
if provider == nil {
t.Fatal("expected non-nil provider")
}
if provider.GetType() != tt.expectedType {
t.Errorf("expected provider type %d, got %d", tt.expectedType, provider.GetType())
}
if authParams == nil {
t.Fatal("expected non-nil auth params")
}
// Verify scopes handling
if !tt.settings.overrideScopes {
// When not overriding, provider may modify scopes
if len(authParams.Scopes) == 0 {
t.Error("expected non-empty scopes")
}
} else {
// When overriding, original scopes should be preserved
if !equalStringSlices(authParams.Scopes, tt.expectedScopes) {
t.Errorf("expected scopes %v, got %v", tt.expectedScopes, authParams.Scopes)
}
}
})
}
})
t.Run("ConcurrentAdaptation", func(t *testing.T) {
factory := internalproviders.NewProviderFactory()
adapter := internalproviders.NewAdapter(factory)
settings := []*mockLegacySettings{
{
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
authURL: "https://accounts.google.com/o/oauth2/v2/auth",
scopes: []string{"openid", "email"},
clientID: "google-client",
},
{
issuerURL: "https://login.microsoftonline.com/tenant/v2.0",
authURL: "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
scopes: []string{"openid", "offline_access"},
clientID: "azure-client",
},
}
var wg sync.WaitGroup
errors := make(chan error, len(settings)*10)
for i := 0; i < 10; i++ {
for _, s := range settings {
wg.Add(1)
go func(setting *mockLegacySettings) {
defer wg.Done()
provider, authParams, err := adapter.AdaptLegacySettings(setting)
if err != nil {
errors <- err
return
}
if provider == nil {
errors <- fmt.Errorf("got nil provider")
return
}
if authParams == nil {
errors <- fmt.Errorf("got nil auth params")
}
}(s)
}
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("concurrent adaptation error: %v", err)
}
})
}
*/
// ============================================================================
// Validation Tests
// ============================================================================
func TestTokenValidation(t *testing.T) {
t.Run("ValidateWithVerifier", func(t *testing.T) {
tests := []struct {
name string
token string
verifier *mockTokenVerifier
expectValid bool
}{
{
name: "valid token",
token: "valid-token",
verifier: &mockTokenVerifier{
shouldFail: false,
},
expectValid: true,
},
{
name: "invalid token",
token: "invalid-token",
verifier: &mockTokenVerifier{
shouldFail: true,
},
expectValid: false,
},
{
name: "expired token",
token: "expired-token",
verifier: &mockTokenVerifier{
expiredTokens: map[string]bool{
"expired-token": true,
},
},
expectValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.verifier.VerifyToken(tt.token)
isValid := err == nil
if isValid != tt.expectValid {
t.Errorf("expected valid=%v, got %v (err: %v)", tt.expectValid, isValid, err)
}
})
}
})
t.Run("ConcurrentValidation", func(t *testing.T) {
verifier := &mockTokenVerifier{
shouldFail: false,
expiredTokens: map[string]bool{
"expired-1": true,
"expired-2": true,
},
}
tokens := []string{"valid-1", "valid-2", "expired-1", "expired-2", "valid-3"}
var wg sync.WaitGroup
results := make(chan bool, len(tokens)*10)
for i := 0; i < 10; i++ {
for _, token := range tokens {
wg.Add(1)
go func(t string) {
defer wg.Done()
err := verifier.VerifyToken(t)
results <- (err == nil)
}(token)
}
}
wg.Wait()
close(results)
validCount := 0
invalidCount := 0
for isValid := range results {
if isValid {
validCount++
} else {
invalidCount++
}
}
expectedValid := 30 // 3 valid tokens * 10 iterations
expectedInvalid := 20 // 2 expired tokens * 10 iterations
if validCount != expectedValid {
t.Errorf("expected %d valid results, got %d", expectedValid, validCount)
}
if invalidCount != expectedInvalid {
t.Errorf("expected %d invalid results, got %d", expectedInvalid, invalidCount)
}
})
}
// ============================================================================
// Memory Management Tests
// ============================================================================
func TestProviderMemoryManagement(t *testing.T) {
t.Run("FactoryMemoryUsage", func(t *testing.T) {
if testing.Short() {
t.Skip("skipping memory test in short mode")
}
var m runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&m)
initialAlloc := m.Alloc
factory := internalproviders.NewProviderFactory()
// Create many providers
providers := make([]internalproviders.OIDCProvider, 0, 1000)
for i := 0; i < 1000; i++ {
var provider internalproviders.OIDCProvider
var err error
switch i % 3 {
case 0:
provider, err = factory.CreateProvider("https://accounts.google.com/.well-known/openid_configuration")
case 1:
provider, err = factory.CreateProvider("https://login.microsoftonline.com/tenant/v2.0")
default:
provider, err = factory.CreateProvider("https://auth.example.com/realms/test")
}
if err != nil {
t.Fatalf("failed to create provider: %v", err)
}
providers = append(providers, provider) // keeping references to prevent GC
}
runtime.GC()
runtime.ReadMemStats(&m)
finalAlloc := m.Alloc
var memUsed, memPerProvider uint64
if finalAlloc > initialAlloc {
memUsed = finalAlloc - initialAlloc
memPerProvider = memUsed / 1000
}
// Each provider should use less than 10KB on average
if memPerProvider > 10*1024 {
t.Errorf("excessive memory usage: %d bytes per provider", memPerProvider)
}
// Use providers to satisfy staticcheck
_ = providers
// Clear references to allow GC
providers = nil
runtime.GC()
runtime.ReadMemStats(&m)
// Memory should be mostly freed
afterGC := m.Alloc
if afterGC > initialAlloc+1024*1024 { // Allow 1MB overhead
t.Errorf("memory not properly freed after GC: %d bytes still allocated", afterGC-initialAlloc)
}
})
}
// ============================================================================
// Helper Functions
// ============================================================================
func containsString(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
//lint:ignore U1000 Used in tests
func equalStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}