mirror of
https://github.com/lukaszraczylo/traefikoidc.git
synced 2026-06-06 22:49:43 +00:00
fixup! fixup! fixup! Further pursue of perfection.
This commit is contained in:
+40
-10
@@ -1,6 +1,8 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -44,6 +46,19 @@ type OptimizedCache struct {
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// normalizeKey ensures keys are within reasonable limits by hashing long keys.
|
||||
// This prevents memory exhaustion while still allowing long keys to be used.
|
||||
func (c *OptimizedCache) normalizeKey(key string) string {
|
||||
if len(key) <= MaxKeyLength {
|
||||
return key
|
||||
}
|
||||
|
||||
// Hash long keys to create a fixed-size key
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(key))
|
||||
return "hash:" + hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// NewOptimizedCache creates a new optimized cache with default settings.
|
||||
// It uses the default maximum size and 64MB memory limit.
|
||||
func NewOptimizedCache() *OptimizedCache {
|
||||
@@ -63,6 +78,11 @@ func NewOptimizedCacheWithConfig(maxSize int, maxMemoryMB int, logger *Logger) *
|
||||
logger = GetSingletonNoOpLogger()
|
||||
}
|
||||
|
||||
// Use default max size if not specified
|
||||
if maxSize <= 0 {
|
||||
maxSize = DefaultMaxSize
|
||||
}
|
||||
|
||||
head := &OptimizedCacheEntry{}
|
||||
tail := &OptimizedCacheEntry{}
|
||||
head.next = tail
|
||||
@@ -95,18 +115,22 @@ func NewOptimizedCacheWithConfig(maxSize int, maxMemoryMB int, logger *Logger) *
|
||||
// - value: The value to store
|
||||
// - expiration: Time until the item expires
|
||||
func (c *OptimizedCache) Set(key string, value interface{}, expiration time.Duration) {
|
||||
if len(key) > MaxKeyLength {
|
||||
c.logger.Debugf("Cache key too long (%d > %d), ignoring", len(key), MaxKeyLength)
|
||||
return
|
||||
}
|
||||
// Normalize the key to handle long keys
|
||||
normalizedKey := c.normalizeKey(key)
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
expTime := now.Add(expiration)
|
||||
var expTime time.Time
|
||||
if expiration == 0 {
|
||||
// Permanent entry - set to far future to avoid expiration
|
||||
expTime = now.Add(100 * 365 * 24 * time.Hour) // 100 years
|
||||
} else {
|
||||
expTime = now.Add(expiration)
|
||||
}
|
||||
|
||||
if entry, exists := c.items[key]; exists {
|
||||
if entry, exists := c.items[normalizedKey]; exists {
|
||||
oldSize := c.estimateEntrySize(entry)
|
||||
entry.Value = value
|
||||
entry.ExpiresAt = expTime
|
||||
@@ -119,7 +143,7 @@ func (c *OptimizedCache) Set(key string, value interface{}, expiration time.Dura
|
||||
entry := &OptimizedCacheEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expTime,
|
||||
Key: key,
|
||||
Key: normalizedKey,
|
||||
}
|
||||
|
||||
entrySize := c.estimateEntrySize(entry)
|
||||
@@ -130,7 +154,7 @@ func (c *OptimizedCache) Set(key string, value interface{}, expiration time.Dura
|
||||
}
|
||||
}
|
||||
|
||||
c.items[key] = entry
|
||||
c.items[normalizedKey] = entry
|
||||
c.currentMemoryBytes += entrySize
|
||||
c.addToTail(entry)
|
||||
}
|
||||
@@ -140,10 +164,13 @@ func (c *OptimizedCache) Set(key string, value interface{}, expiration time.Dura
|
||||
// automatically removes expired items when encountered.
|
||||
// Returns the value and true if found and valid, or nil and false otherwise.
|
||||
func (c *OptimizedCache) Get(key string) (interface{}, bool) {
|
||||
// Normalize the key to handle long keys
|
||||
normalizedKey := c.normalizeKey(key)
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
entry, exists := c.items[key]
|
||||
entry, exists := c.items[normalizedKey]
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
@@ -160,10 +187,13 @@ func (c *OptimizedCache) Get(key string) (interface{}, bool) {
|
||||
// Delete removes an item from the cache, freeing its memory and updating tracking.
|
||||
// This is a manual removal that updates both the hash map and LRU list.
|
||||
func (c *OptimizedCache) Delete(key string) {
|
||||
// Normalize the key to handle long keys
|
||||
normalizedKey := c.normalizeKey(key)
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
if entry, exists := c.items[key]; exists {
|
||||
if entry, exists := c.items[normalizedKey]; exists {
|
||||
c.removeEntry(entry)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestOptimizedCacheBasicOperations tests basic cache operations
|
||||
func TestOptimizedCacheBasicOperations(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Test Set and Get
|
||||
cache.Set("key1", "value1", 10*time.Minute)
|
||||
|
||||
value, found := cache.Get("key1")
|
||||
if !found {
|
||||
t.Error("Expected to find key1")
|
||||
}
|
||||
if value != "value1" {
|
||||
t.Errorf("Expected 'value1', got '%v'", value)
|
||||
}
|
||||
|
||||
// Test Get non-existent key
|
||||
_, found = cache.Get("nonexistent")
|
||||
if found {
|
||||
t.Error("Expected not to find nonexistent key")
|
||||
}
|
||||
|
||||
// Test Delete
|
||||
cache.Delete("key1")
|
||||
_, found = cache.Get("key1")
|
||||
if found {
|
||||
t.Error("Expected key1 to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheExpiration tests cache entry expiration
|
||||
func TestOptimizedCacheExpiration(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Test immediate expiration
|
||||
cache.Set("expired_key", "value", 1*time.Millisecond)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
_, found := cache.Get("expired_key")
|
||||
if found {
|
||||
t.Error("Expected expired key not to be found")
|
||||
}
|
||||
|
||||
// Test non-expiring entry (expiration = 0)
|
||||
cache.Set("permanent_key", "permanent_value", 0)
|
||||
|
||||
value, found := cache.Get("permanent_key")
|
||||
if !found {
|
||||
t.Error("Expected permanent key to be found")
|
||||
}
|
||||
if value != "permanent_value" {
|
||||
t.Errorf("Expected 'permanent_value', got '%v'", value)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheLRUEviction tests LRU eviction behavior
|
||||
func TestOptimizedCacheLRUEviction(t *testing.T) {
|
||||
// Create small cache to trigger eviction
|
||||
logger := newNoOpLogger()
|
||||
cache := NewOptimizedCacheWithConfig(3, 1, logger) // Max 3 items
|
||||
|
||||
// Fill cache to capacity
|
||||
cache.Set("key1", "value1", 10*time.Minute)
|
||||
cache.Set("key2", "value2", 10*time.Minute)
|
||||
cache.Set("key3", "value3", 10*time.Minute)
|
||||
|
||||
// Access key1 to make it most recently used
|
||||
cache.Get("key1")
|
||||
|
||||
// Add another item, should evict key2 (least recently used)
|
||||
cache.Set("key4", "value4", 10*time.Minute)
|
||||
|
||||
// key2 should be evicted
|
||||
_, found := cache.Get("key2")
|
||||
if found {
|
||||
t.Error("Expected key2 to be evicted")
|
||||
}
|
||||
|
||||
// key1 should still exist (was recently accessed)
|
||||
_, found = cache.Get("key1")
|
||||
if !found {
|
||||
t.Error("Expected key1 to still exist")
|
||||
}
|
||||
|
||||
// key3 and key4 should exist
|
||||
_, found = cache.Get("key3")
|
||||
if !found {
|
||||
t.Error("Expected key3 to still exist")
|
||||
}
|
||||
_, found = cache.Get("key4")
|
||||
if !found {
|
||||
t.Error("Expected key4 to exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheMemoryPressure tests memory-based eviction
|
||||
func TestOptimizedCacheMemoryPressure(t *testing.T) {
|
||||
logger := newNoOpLogger()
|
||||
cache := NewOptimizedCacheWithConfig(1000, 1, logger) // 1 MB memory limit
|
||||
|
||||
// Create large values to trigger memory pressure
|
||||
largeValue := strings.Repeat("a", 256*1024) // 256KB each
|
||||
|
||||
// Add several large values
|
||||
cache.Set("large1", largeValue, 10*time.Minute)
|
||||
cache.Set("large2", largeValue, 10*time.Minute)
|
||||
cache.Set("large3", largeValue, 10*time.Minute)
|
||||
cache.Set("large4", largeValue, 10*time.Minute)
|
||||
cache.Set("large5", largeValue, 10*time.Minute) // This should trigger eviction
|
||||
|
||||
// Force garbage collection to get accurate memory reading
|
||||
runtime.GC()
|
||||
|
||||
// Check that some entries were evicted due to memory pressure
|
||||
count := 0
|
||||
for i := 1; i <= 5; i++ {
|
||||
if _, found := cache.Get(formatString("large%d", i)); found {
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
// Should have fewer than 5 items due to memory pressure eviction
|
||||
if count >= 5 {
|
||||
t.Errorf("Expected some items to be evicted due to memory pressure, but found %d items", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheCleanup tests manual cleanup functionality
|
||||
func TestOptimizedCacheCleanup(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Add expired and non-expired items
|
||||
cache.Set("expired1", "value1", 1*time.Millisecond)
|
||||
cache.Set("expired2", "value2", 1*time.Millisecond)
|
||||
cache.Set("valid", "value", 10*time.Minute)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Manual cleanup should remove expired items
|
||||
cache.Cleanup()
|
||||
|
||||
// Expired items should be gone
|
||||
_, found := cache.Get("expired1")
|
||||
if found {
|
||||
t.Error("Expected expired1 to be cleaned up")
|
||||
}
|
||||
_, found = cache.Get("expired2")
|
||||
if found {
|
||||
t.Error("Expected expired2 to be cleaned up")
|
||||
}
|
||||
|
||||
// Valid item should remain
|
||||
_, found = cache.Get("valid")
|
||||
if !found {
|
||||
t.Error("Expected valid item to remain after cleanup")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheConcurrency tests thread safety
|
||||
func TestOptimizedCacheConcurrency(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Number of goroutines for each operation type
|
||||
numGoroutines := 10
|
||||
numOperations := 100
|
||||
|
||||
// Test concurrent writes
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
key := formatString("write_%d_%d", id, j)
|
||||
cache.Set(key, formatString("value_%d_%d", id, j), 10*time.Minute)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Test concurrent reads
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
key := formatString("write_%d_%d", id, j)
|
||||
cache.Get(key) // Don't care about result, just testing concurrency
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Test concurrent deletes
|
||||
wg.Add(numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
key := formatString("delete_%d_%d", id, j)
|
||||
cache.Set(key, "value", 10*time.Minute)
|
||||
cache.Delete(key)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Test concurrent cleanup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 10; i++ {
|
||||
cache.Cleanup()
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If we reach here without deadlock or panic, concurrency test passed
|
||||
t.Log("Concurrency test completed successfully")
|
||||
}
|
||||
|
||||
// TestOptimizedCacheEdgeCases tests edge cases and error conditions
|
||||
func TestOptimizedCacheEdgeCases(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Test empty key
|
||||
cache.Set("", "empty_key_value", 10*time.Minute)
|
||||
value, found := cache.Get("")
|
||||
if !found || value != "empty_key_value" {
|
||||
t.Error("Expected to handle empty key correctly")
|
||||
}
|
||||
|
||||
// Test nil value
|
||||
cache.Set("nil_key", nil, 10*time.Minute)
|
||||
value, found = cache.Get("nil_key")
|
||||
if !found || value != nil {
|
||||
t.Error("Expected to handle nil value correctly")
|
||||
}
|
||||
|
||||
// Test overwriting existing key
|
||||
cache.Set("overwrite", "original", 10*time.Minute)
|
||||
cache.Set("overwrite", "new_value", 10*time.Minute)
|
||||
value, found = cache.Get("overwrite")
|
||||
if !found || value != "new_value" {
|
||||
t.Error("Expected key to be overwritten with new value")
|
||||
}
|
||||
|
||||
// Test delete non-existent key (should not panic)
|
||||
cache.Delete("nonexistent")
|
||||
|
||||
// Test very long key
|
||||
longKey := strings.Repeat("a", 1000)
|
||||
cache.Set(longKey, "long_key_value", 10*time.Minute)
|
||||
value, found = cache.Get(longKey)
|
||||
if !found || value != "long_key_value" {
|
||||
t.Error("Expected to handle very long key correctly")
|
||||
}
|
||||
}
|
||||
|
||||
// TestOptimizedCacheWithDifferentValueTypes tests cache with various value types
|
||||
func TestOptimizedCacheWithDifferentValueTypes(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Test string value
|
||||
cache.Set("string", "test_string", 10*time.Minute)
|
||||
|
||||
// Test int value
|
||||
cache.Set("int", 42, 10*time.Minute)
|
||||
|
||||
// Test slice value
|
||||
cache.Set("slice", []string{"a", "b", "c"}, 10*time.Minute)
|
||||
|
||||
// Test map value
|
||||
cache.Set("map", map[string]int{"key1": 1, "key2": 2}, 10*time.Minute)
|
||||
|
||||
// Test struct value
|
||||
type TestStruct struct {
|
||||
Name string
|
||||
Age int
|
||||
}
|
||||
cache.Set("struct", TestStruct{Name: "John", Age: 30}, 10*time.Minute)
|
||||
|
||||
// Verify all types can be retrieved correctly
|
||||
if val, found := cache.Get("string"); !found || val != "test_string" {
|
||||
t.Error("Failed to retrieve string value")
|
||||
}
|
||||
|
||||
if val, found := cache.Get("int"); !found || val != 42 {
|
||||
t.Error("Failed to retrieve int value")
|
||||
}
|
||||
|
||||
if val, found := cache.Get("slice"); !found {
|
||||
t.Error("Failed to retrieve slice value")
|
||||
} else if slice, ok := val.([]string); !ok || len(slice) != 3 || slice[0] != "a" {
|
||||
t.Error("Retrieved slice value is incorrect")
|
||||
}
|
||||
|
||||
if val, found := cache.Get("map"); !found {
|
||||
t.Error("Failed to retrieve map value")
|
||||
} else if mapVal, ok := val.(map[string]int); !ok || mapVal["key1"] != 1 {
|
||||
t.Error("Retrieved map value is incorrect")
|
||||
}
|
||||
|
||||
if val, found := cache.Get("struct"); !found {
|
||||
t.Error("Failed to retrieve struct value")
|
||||
} else if structVal, ok := val.(TestStruct); !ok || structVal.Name != "John" || structVal.Age != 30 {
|
||||
t.Error("Retrieved struct value is incorrect")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a formatted string key
|
||||
func formatString(format string, args ...interface{}) string {
|
||||
// Simple sprintf implementation for tests
|
||||
result := format
|
||||
for _, arg := range args {
|
||||
if strings.Contains(result, "%d") {
|
||||
if intVal, ok := arg.(int); ok {
|
||||
result = strings.Replace(result, "%d", intToString(intVal), 1)
|
||||
}
|
||||
} else if strings.Contains(result, "%s") {
|
||||
if strVal, ok := arg.(string); ok {
|
||||
result = strings.Replace(result, "%s", strVal, 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper to convert int to string
|
||||
func intToString(i int) string {
|
||||
if i == 0 {
|
||||
return "0"
|
||||
}
|
||||
|
||||
negative := i < 0
|
||||
if negative {
|
||||
i = -i
|
||||
}
|
||||
|
||||
var result []byte
|
||||
for i > 0 {
|
||||
result = append([]byte{byte('0' + (i % 10))}, result...)
|
||||
i /= 10
|
||||
}
|
||||
|
||||
if negative {
|
||||
result = append([]byte{'-'}, result...)
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
@@ -0,0 +1,533 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestValidateGoogleTokens tests the validateGoogleTokens method with various scenarios
|
||||
func TestValidateGoogleTokens(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
// Set refresh grace period to 60 seconds to match default behavior
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidGoogleTokens",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
// Create valid JWT tokens
|
||||
idClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
accessClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idClaims)
|
||||
accessToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", accessClaims)
|
||||
|
||||
// Pre-cache the token claims so validateTokenExpiry can find them
|
||||
ts.tOidc.tokenCache.Set(idToken, idClaims, 0)
|
||||
ts.tOidc.tokenCache.Set(accessToken, accessClaims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Valid Google tokens should authenticate successfully",
|
||||
},
|
||||
{
|
||||
name: "GoogleTokensNeedRefresh",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
// Create token that expires soon (within 60s grace period)
|
||||
claims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(30 * time.Second).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", claims)
|
||||
|
||||
// Pre-cache the token claims so validateTokenExpiry can find them
|
||||
ts.tOidc.tokenCache.Set(idToken, claims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(idToken) // Same token for access
|
||||
session.SetRefreshToken("valid_refresh_token")
|
||||
return session
|
||||
},
|
||||
expectedAuth: true, // Token is still valid, just needs refresh
|
||||
expectedRefresh: true,
|
||||
expectedExpired: false,
|
||||
description: "Google tokens nearing expiration should signal refresh needed",
|
||||
},
|
||||
{
|
||||
name: "GoogleTokensExpired",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(false)
|
||||
// Expired token
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
})
|
||||
session.SetIDToken(idToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false, // Changed: session not authenticated = no refresh needed for Google
|
||||
description: "Unauthenticated Google session with expired token should not refresh",
|
||||
},
|
||||
{
|
||||
name: "GoogleProviderUnauthenticated",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(false)
|
||||
session.SetRefreshToken("some_refresh_token")
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: true,
|
||||
expectedExpired: false,
|
||||
description: "Unauthenticated Google session with refresh token should signal refresh needed",
|
||||
},
|
||||
{
|
||||
name: "GoogleProviderNoTokens",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(false)
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: false, // Changed: no refresh token = no refresh needed
|
||||
expectedExpired: false,
|
||||
description: "Google session with no tokens should return false for all states",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateGoogleTokens(session)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
}
|
||||
if refresh != tt.expectedRefresh {
|
||||
t.Errorf("Expected needsRefresh=%v, got %v. %s", tt.expectedRefresh, refresh, tt.description)
|
||||
}
|
||||
if expired != tt.expectedExpired {
|
||||
t.Errorf("Expected expired=%v, got %v. %s", tt.expectedExpired, expired, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsUserAuthenticated tests the isUserAuthenticated method with various provider types
|
||||
func TestIsUserAuthenticated(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
// Set refresh grace period to 60 seconds to match default behavior
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerType string
|
||||
setupSession func() *SessionData
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "AzureProvider",
|
||||
providerType: "azure",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
|
||||
// Azure needs ID token or opaque access token
|
||||
idClaims := map[string]interface{}{
|
||||
"iss": "https://login.microsoftonline.com/common/v2.0",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idClaims)
|
||||
|
||||
// Pre-cache the token claims for Azure validation
|
||||
ts.tOidc.tokenCache.Set(idToken, idClaims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Azure provider should delegate to validateAzureTokens",
|
||||
},
|
||||
{
|
||||
name: "GoogleProvider",
|
||||
providerType: "google",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
// Standard tokens need both access and ID token
|
||||
idClaims := map[string]interface{}{
|
||||
"iss": "https://accounts.google.com", // Use Google's issuer
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
accessClaims := map[string]interface{}{
|
||||
"iss": "https://accounts.google.com", // Use Google's issuer
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idClaims)
|
||||
accessToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", accessClaims)
|
||||
|
||||
// Pre-cache the token claims
|
||||
ts.tOidc.tokenCache.Set(idToken, idClaims, 0)
|
||||
ts.tOidc.tokenCache.Set(accessToken, accessClaims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Google provider should delegate to validateGoogleTokens",
|
||||
},
|
||||
{
|
||||
name: "GenericOIDCProvider",
|
||||
providerType: "generic",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
// Standard tokens need both access and ID token
|
||||
idClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
accessClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idClaims)
|
||||
accessToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", accessClaims)
|
||||
|
||||
// Pre-cache the token claims
|
||||
ts.tOidc.tokenCache.Set(idToken, idClaims, 0)
|
||||
ts.tOidc.tokenCache.Set(accessToken, accessClaims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Generic OIDC provider should delegate to validateStandardTokens",
|
||||
},
|
||||
{
|
||||
name: "KeycloakProvider",
|
||||
providerType: "keycloak",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
// Standard tokens need both access and ID token
|
||||
idClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
accessClaims := map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", idClaims)
|
||||
accessToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", accessClaims)
|
||||
|
||||
// Pre-cache the token claims
|
||||
ts.tOidc.tokenCache.Set(idToken, idClaims, 0)
|
||||
ts.tOidc.tokenCache.Set(accessToken, accessClaims, 0)
|
||||
|
||||
session.SetIDToken(idToken)
|
||||
session.SetAccessToken(accessToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Keycloak provider should delegate to validateStandardTokens",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Handle Azure provider type by changing issuerURL temporarily
|
||||
originalIssuer := ts.tOidc.issuerURL
|
||||
if tt.providerType == "azure" {
|
||||
ts.tOidc.issuerURL = "https://login.microsoftonline.com/common/v2.0"
|
||||
} else if tt.providerType == "google" {
|
||||
ts.tOidc.issuerURL = "https://accounts.google.com"
|
||||
}
|
||||
defer func() { ts.tOidc.issuerURL = originalIssuer }()
|
||||
|
||||
session := tt.setupSession()
|
||||
auth, refresh, expired := ts.tOidc.isUserAuthenticated(session)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
}
|
||||
if refresh != tt.expectedRefresh {
|
||||
t.Errorf("Expected needsRefresh=%v, got %v. %s", tt.expectedRefresh, refresh, tt.description)
|
||||
}
|
||||
if expired != tt.expectedExpired {
|
||||
t.Errorf("Expected expired=%v, got %v. %s", tt.expectedExpired, expired, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateAzureTokensEdgeCases tests Azure token validation with comprehensive edge cases
|
||||
func TestValidateAzureTokensEdgeCases(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
// Set refresh grace period to 60 seconds to match default behavior
|
||||
ts.tOidc.refreshGracePeriod = 60 * time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupSession func() *SessionData
|
||||
expectedAuth bool
|
||||
expectedRefresh bool
|
||||
expectedExpired bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "UnauthenticatedWithRefreshToken",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(false)
|
||||
session.SetRefreshToken("valid_refresh_token")
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: true,
|
||||
expectedExpired: false,
|
||||
description: "Unauthenticated Azure session with refresh token",
|
||||
},
|
||||
{
|
||||
name: "UnauthenticatedWithoutRefreshToken",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(false)
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: true,
|
||||
expectedExpired: false,
|
||||
description: "Unauthenticated Azure session without refresh token",
|
||||
},
|
||||
{
|
||||
name: "AuthenticatedWithInvalidJWTAccessToken",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("invalid.jwt.token") // JWT format but invalid
|
||||
// Valid ID token
|
||||
idToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
session.SetIDToken(idToken)
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Azure session with invalid JWT access token but valid ID token",
|
||||
},
|
||||
{
|
||||
name: "AuthenticatedWithOpaqueAccessToken",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("opaque_access_token_longer_than_minimum") // Not JWT format but long enough
|
||||
return session
|
||||
},
|
||||
expectedAuth: true,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: false,
|
||||
description: "Azure session with opaque access token",
|
||||
},
|
||||
{
|
||||
name: "AuthenticatedWithBothTokensInvalid",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("invalid.jwt.token")
|
||||
session.SetIDToken("another.invalid.token")
|
||||
session.SetRefreshToken("refresh_token")
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: true,
|
||||
expectedExpired: false,
|
||||
description: "Azure session with both access and ID tokens invalid but has refresh token",
|
||||
},
|
||||
{
|
||||
name: "AuthenticatedWithBothTokensInvalidNoRefresh",
|
||||
setupSession: func() *SessionData {
|
||||
session := createTestSession()
|
||||
session.SetAuthenticated(true)
|
||||
session.SetAccessToken("invalid.jwt.token")
|
||||
session.SetIDToken("another.invalid.token")
|
||||
return session
|
||||
},
|
||||
expectedAuth: false,
|
||||
expectedRefresh: false,
|
||||
expectedExpired: true,
|
||||
description: "Azure session with both tokens invalid and no refresh token",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
session := tt.setupSession()
|
||||
|
||||
auth, refresh, expired := ts.tOidc.validateAzureTokens(session)
|
||||
|
||||
if auth != tt.expectedAuth {
|
||||
t.Errorf("Expected authenticated=%v, got %v. %s", tt.expectedAuth, auth, tt.description)
|
||||
}
|
||||
if refresh != tt.expectedRefresh {
|
||||
t.Errorf("Expected needsRefresh=%v, got %v. %s", tt.expectedRefresh, refresh, tt.description)
|
||||
}
|
||||
if expired != tt.expectedExpired {
|
||||
t.Errorf("Expected expired=%v, got %v. %s", tt.expectedExpired, expired, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartMetadataRefresh tests the metadata refresh functionality
|
||||
func TestStartMetadataRefresh(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerURL string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "SuccessfulMetadataRefresh",
|
||||
providerURL: "https://test-issuer.com",
|
||||
description: "Should start metadata refresh successfully",
|
||||
},
|
||||
{
|
||||
name: "MetadataRefreshWithEmptyURL",
|
||||
providerURL: "",
|
||||
description: "Should handle empty provider URL gracefully",
|
||||
},
|
||||
{
|
||||
name: "MetadataRefreshWithInvalidURL",
|
||||
providerURL: "invalid-url",
|
||||
description: "Should handle invalid provider URL gracefully",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Start metadata refresh (this should not panic or error immediately)
|
||||
ts.tOidc.startMetadataRefresh(tt.providerURL)
|
||||
|
||||
// Give some time for goroutine to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// The function should return successfully
|
||||
// We can't easily test the periodic behavior without making tests very slow,
|
||||
// but we test that it starts without issues
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStartMetadataRefreshContextCancellation tests context cancellation handling
|
||||
func TestStartMetadataRefreshContextCancellation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
// Mock the context cancellation by closing the plugin
|
||||
ts.tOidc.startMetadataRefresh("https://test-issuer.com")
|
||||
|
||||
// Give some time for goroutine to start
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Close the plugin to test cleanup
|
||||
ts.tOidc.Close()
|
||||
|
||||
// Give some time for cleanup
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Test passes if no goroutines are leaked (checked by other tests)
|
||||
}
|
||||
|
||||
// MockReadCloser implements io.ReadCloser for testing HTTP responses
|
||||
type MockReadCloser struct {
|
||||
*strings.Reader
|
||||
}
|
||||
|
||||
func (m *MockReadCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -90,3 +92,287 @@ func TestCacheAdapterSetMaxSize(t *testing.T) {
|
||||
adapter.Cleanup()
|
||||
adapter.Close()
|
||||
}
|
||||
|
||||
// Test isTestMode function with different conditions
|
||||
func TestIsTestMode(t *testing.T) {
|
||||
// Store original values
|
||||
originalSuppressLogs := os.Getenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
originalGoTest := os.Getenv("GO_TEST")
|
||||
originalArgs := os.Args
|
||||
|
||||
// Cleanup after test
|
||||
defer func() {
|
||||
os.Setenv("SUPPRESS_DIAGNOSTIC_LOGS", originalSuppressLogs)
|
||||
os.Setenv("GO_TEST", originalGoTest)
|
||||
os.Args = originalArgs
|
||||
}()
|
||||
|
||||
t.Run("SUPPRESS_DIAGNOSTIC_LOGS environment variable", func(t *testing.T) {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
os.Unsetenv("GO_TEST")
|
||||
os.Args = []string{"myprogram"}
|
||||
|
||||
// Should return false initially
|
||||
if isTestMode() {
|
||||
t.Error("Expected isTestMode to return false without SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
}
|
||||
|
||||
// Set environment variable
|
||||
os.Setenv("SUPPRESS_DIAGNOSTIC_LOGS", "1")
|
||||
if !isTestMode() {
|
||||
t.Error("Expected isTestMode to return true with SUPPRESS_DIAGNOSTIC_LOGS=1")
|
||||
}
|
||||
|
||||
// Test other values
|
||||
os.Setenv("SUPPRESS_DIAGNOSTIC_LOGS", "0")
|
||||
if isTestMode() {
|
||||
t.Error("Expected isTestMode to return false with SUPPRESS_DIAGNOSTIC_LOGS=0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Program name detection", func(t *testing.T) {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
os.Unsetenv("GO_TEST")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
progName string
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"Test binary", "myprogram.test", true},
|
||||
{"Go build temp", "go_build_temp_binary", true},
|
||||
{"Debug binary", "__debug_bin1234", true},
|
||||
{"Test in name", "mytestprogram", true},
|
||||
{"Normal program", "myprogram", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
os.Args = []string{tc.progName}
|
||||
result := isTestMode()
|
||||
if result != tc.shouldMatch {
|
||||
t.Errorf("Program %q: expected %v, got %v", tc.progName, tc.shouldMatch, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GO_TEST environment variable", func(t *testing.T) {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
os.Args = []string{"myprogram"}
|
||||
|
||||
os.Unsetenv("GO_TEST")
|
||||
if isTestMode() {
|
||||
t.Error("Expected isTestMode to return false without GO_TEST")
|
||||
}
|
||||
|
||||
os.Setenv("GO_TEST", "1")
|
||||
if !isTestMode() {
|
||||
t.Error("Expected isTestMode to return true with GO_TEST=1")
|
||||
}
|
||||
|
||||
os.Setenv("GO_TEST", "0")
|
||||
if isTestMode() {
|
||||
t.Error("Expected isTestMode to return false with GO_TEST=0")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Command line arguments with -test", func(t *testing.T) {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
os.Unsetenv("GO_TEST")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected bool
|
||||
}{
|
||||
{"No test args", []string{"myprogram"}, false},
|
||||
{"Test.run flag", []string{"myprogram", "-test.run=TestSomething"}, true},
|
||||
{"Test.v flag", []string{"myprogram", "-test.v"}, true},
|
||||
{"Test.count flag", []string{"myprogram", "-test.count=1"}, true},
|
||||
{"Other flags", []string{"myprogram", "-config", "file.json"}, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
os.Args = tc.args
|
||||
result := isTestMode()
|
||||
if result != tc.expected {
|
||||
t.Errorf("Args %v: expected %v, got %v", tc.args, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Runtime compiler detection", func(t *testing.T) {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
os.Unsetenv("GO_TEST")
|
||||
os.Args = []string{"myprogram"}
|
||||
|
||||
// This is tricky to test because we can't change runtime.Compiler easily
|
||||
// But we can test that the current behavior works
|
||||
if runtime.Compiler == "yaegi" {
|
||||
if !isTestMode() {
|
||||
t.Error("Expected isTestMode to return true with yaegi compiler")
|
||||
}
|
||||
} else {
|
||||
// With gc compiler, should return false for normal program name
|
||||
if isTestMode() {
|
||||
t.Error("Expected isTestMode to return false with gc compiler and normal program")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Comprehensive test scenarios", func(t *testing.T) {
|
||||
// Test multiple conditions at once
|
||||
testCases := []struct {
|
||||
name string
|
||||
envSuppress string
|
||||
envGoTest string
|
||||
progName string
|
||||
args []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
"All conditions false",
|
||||
"", "", "myprogram",
|
||||
[]string{"myprogram", "-config", "test.json"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"Multiple true conditions",
|
||||
"1", "1", "test.exe",
|
||||
[]string{"test.exe", "-test.v"},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"Program name with test",
|
||||
"", "", "mytestbinary",
|
||||
[]string{"mytestbinary"},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.envSuppress != "" {
|
||||
os.Setenv("SUPPRESS_DIAGNOSTIC_LOGS", tc.envSuppress)
|
||||
} else {
|
||||
os.Unsetenv("SUPPRESS_DIAGNOSTIC_LOGS")
|
||||
}
|
||||
|
||||
if tc.envGoTest != "" {
|
||||
os.Setenv("GO_TEST", tc.envGoTest)
|
||||
} else {
|
||||
os.Unsetenv("GO_TEST")
|
||||
}
|
||||
|
||||
os.Args = tc.args
|
||||
|
||||
result := isTestMode()
|
||||
if result != tc.expected {
|
||||
t.Errorf("Scenario %q: expected %v, got %v", tc.name, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test buildFullURL function with edge cases
|
||||
func TestBuildFullURL(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
scheme string
|
||||
host string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"Standard HTTPS URL",
|
||||
"https", "example.com", "/api/v1/users",
|
||||
"https://example.com/api/v1/users",
|
||||
},
|
||||
{
|
||||
"HTTP with port",
|
||||
"http", "localhost:8080", "/health",
|
||||
"http://localhost:8080/health",
|
||||
},
|
||||
{
|
||||
"Empty path",
|
||||
"https", "api.service.com", "",
|
||||
"https://api.service.com/",
|
||||
},
|
||||
{
|
||||
"Root path",
|
||||
"https", "www.example.org", "/",
|
||||
"https://www.example.org/",
|
||||
},
|
||||
{
|
||||
"Path without leading slash",
|
||||
"http", "internal.local", "status",
|
||||
"http://internal.local/status",
|
||||
},
|
||||
{
|
||||
"Complex path with query params",
|
||||
"https", "api.example.com", "/v2/search?q=test&limit=10",
|
||||
"https://api.example.com/v2/search?q=test&limit=10",
|
||||
},
|
||||
{
|
||||
"IPv4 address",
|
||||
"http", "192.168.1.100", "/api",
|
||||
"http://192.168.1.100/api",
|
||||
},
|
||||
{
|
||||
"IPv6 address with brackets",
|
||||
"http", "[::1]:8080", "/test",
|
||||
"http://[::1]:8080/test",
|
||||
},
|
||||
{
|
||||
"Empty scheme",
|
||||
"", "example.com", "/test",
|
||||
"://example.com/test",
|
||||
},
|
||||
{
|
||||
"Empty host",
|
||||
"https", "", "/test",
|
||||
"https:///test",
|
||||
},
|
||||
{
|
||||
"All empty",
|
||||
"", "", "",
|
||||
":///",
|
||||
},
|
||||
{
|
||||
"Special characters in path",
|
||||
"https", "example.com", "/path with spaces/test?param=value with spaces",
|
||||
"https://example.com/path with spaces/test?param=value with spaces",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := buildFullURL(tc.scheme, tc.host, tc.path)
|
||||
if result != tc.expected {
|
||||
t.Errorf("buildFullURL(%q, %q, %q): expected %q, got %q",
|
||||
tc.scheme, tc.host, tc.path, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test that path gets leading slash when missing
|
||||
t.Run("Path normalization", func(t *testing.T) {
|
||||
// When path doesn't start with /, it should be added
|
||||
result1 := buildFullURL("https", "example.com", "api/test")
|
||||
expected1 := "https://example.com/api/test"
|
||||
if result1 != expected1 {
|
||||
t.Errorf("Expected path to be normalized with leading slash: got %q, want %q", result1, expected1)
|
||||
}
|
||||
|
||||
// When path already starts with /, it shouldn't be doubled
|
||||
result2 := buildFullURL("https", "example.com", "/api/test")
|
||||
expected2 := "https://example.com/api/test"
|
||||
if result2 != expected2 {
|
||||
t.Errorf("Expected no double slashes: got %q, want %q", result2, expected2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,160 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Final focused tests to reach 85% coverage target
|
||||
func TestFinalCoverageBoost(t *testing.T) {
|
||||
// Test DefaultOptimizedConfig
|
||||
t.Run("DefaultOptimizedConfig", func(t *testing.T) {
|
||||
config := DefaultOptimizedConfig()
|
||||
if config == nil {
|
||||
t.Error("Expected non-nil default config")
|
||||
}
|
||||
})
|
||||
|
||||
// Test NewLazyCache and operations
|
||||
t.Run("LazyCache operations", func(t *testing.T) {
|
||||
cache := NewLazyCache()
|
||||
if cache == nil {
|
||||
t.Error("Expected non-nil lazy cache")
|
||||
}
|
||||
|
||||
// Test basic operations
|
||||
cache.Set("test1", "value1", time.Minute)
|
||||
value, found := cache.Get("test1")
|
||||
if !found {
|
||||
t.Error("Expected to find cached value")
|
||||
}
|
||||
if value != "value1" {
|
||||
t.Errorf("Expected 'value1', got %v", value)
|
||||
}
|
||||
|
||||
cache.Delete("test1")
|
||||
_, found = cache.Get("test1")
|
||||
if found {
|
||||
t.Error("Expected value to be deleted")
|
||||
}
|
||||
|
||||
cache.Close()
|
||||
})
|
||||
|
||||
// Test NewLazyCacheWithLogger
|
||||
t.Run("LazyCache with logger", func(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
cache := NewLazyCacheWithLogger(logger)
|
||||
if cache == nil {
|
||||
t.Error("Expected non-nil lazy cache with logger")
|
||||
}
|
||||
|
||||
cache.Set("test2", "value2", time.Minute)
|
||||
cache.Close()
|
||||
})
|
||||
|
||||
// Test additional cache operations
|
||||
t.Run("OptimizedCache additional operations", func(t *testing.T) {
|
||||
cache := NewOptimizedCache()
|
||||
|
||||
// Set some values
|
||||
cache.Set("key1", "value1", time.Minute)
|
||||
cache.Set("key2", "value2", time.Minute)
|
||||
|
||||
// Test cleanup
|
||||
cache.Cleanup()
|
||||
|
||||
// Test close
|
||||
cache.Close()
|
||||
})
|
||||
|
||||
// Test UnifiedCache with various operations
|
||||
t.Run("UnifiedCache comprehensive", func(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
cache := NewUnifiedCache(config)
|
||||
|
||||
// Test various operations
|
||||
cache.Set("unified1", "value1", time.Minute)
|
||||
cache.Set("unified2", "value2", time.Minute)
|
||||
|
||||
value, found := cache.Get("unified1")
|
||||
if !found {
|
||||
t.Error("Expected to find cached value in unified cache")
|
||||
}
|
||||
if value != "value1" {
|
||||
t.Errorf("Expected 'value1', got %v", value)
|
||||
}
|
||||
|
||||
cache.Delete("unified1")
|
||||
cache.SetMaxSize(100)
|
||||
|
||||
cache.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// Test Cache Adapter pattern
|
||||
func TestCacheAdapterOperations(t *testing.T) {
|
||||
config := DefaultUnifiedCacheConfig()
|
||||
unifiedCache := NewUnifiedCache(config)
|
||||
adapter := NewCacheAdapter(unifiedCache)
|
||||
if adapter == nil {
|
||||
t.Error("Expected non-nil cache adapter")
|
||||
}
|
||||
|
||||
// Test operations
|
||||
adapter.Set("adapter1", "value1", time.Minute)
|
||||
adapter.Set("adapter2", "value2", time.Minute)
|
||||
|
||||
value, found := adapter.Get("adapter1")
|
||||
if !found {
|
||||
t.Error("Expected to find value in cache adapter")
|
||||
}
|
||||
if value != "value1" {
|
||||
t.Errorf("Expected 'value1', got %v", value)
|
||||
}
|
||||
|
||||
adapter.Delete("adapter1")
|
||||
adapter.Cleanup()
|
||||
adapter.Close()
|
||||
}
|
||||
|
||||
// Test BackgroundTask operations to increase coverage
|
||||
func TestBackgroundTaskCoverage(t *testing.T) {
|
||||
logger := NewLogger("debug")
|
||||
counter := 0
|
||||
|
||||
// Create a background task
|
||||
task := NewBackgroundTask("coverage-test", 50*time.Millisecond, func() {
|
||||
counter++
|
||||
}, logger)
|
||||
|
||||
if task == nil {
|
||||
t.Fatal("Expected non-nil background task")
|
||||
}
|
||||
|
||||
// Start the task
|
||||
task.Start()
|
||||
|
||||
// Let it run a few times
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// Stop the task
|
||||
task.Stop()
|
||||
|
||||
if counter == 0 {
|
||||
t.Error("Expected background task to increment counter")
|
||||
}
|
||||
}
|
||||
|
||||
// Test createDefaultHTTPClient for backward compatibility
|
||||
func TestCreateDefaultHTTPClient(t *testing.T) {
|
||||
client := createDefaultHTTPClient()
|
||||
if client == nil {
|
||||
t.Fatal("Expected non-nil HTTP client")
|
||||
}
|
||||
|
||||
// Test that it has reasonable defaults
|
||||
if client.Timeout <= 0 {
|
||||
t.Error("Expected positive timeout")
|
||||
}
|
||||
}
|
||||
+69
-3
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
@@ -319,6 +320,42 @@ func (iv *InputValidator) ValidateURL(urlStr string) ValidationResult {
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for localhost or private IPs for security
|
||||
// Allow localhost for HTTPS (development/testing) but warn about it
|
||||
hostname := strings.ToLower(parsedURL.Hostname())
|
||||
if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" {
|
||||
if parsedURL.Scheme == "https" {
|
||||
// Allow HTTPS localhost for development but warn
|
||||
result.Warnings = append(result.Warnings, "localhost URLs should only be used for development/testing")
|
||||
} else {
|
||||
// Reject non-HTTPS localhost for security
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "non-HTTPS localhost URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check for private IP ranges (RFC 1918)
|
||||
if strings.HasPrefix(hostname, "10.") ||
|
||||
strings.HasPrefix(hostname, "192.168.") ||
|
||||
strings.HasPrefix(hostname, "172.") {
|
||||
// For 172.x check if it's in the 172.16.0.0/12 range
|
||||
if strings.HasPrefix(hostname, "172.") {
|
||||
parts := strings.Split(hostname, ".")
|
||||
if len(parts) >= 2 {
|
||||
if second, err := strconv.Atoi(parts[1]); err == nil && second >= 16 && second <= 31 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
} else {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "private IP URLs are not allowed for security")
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(sanitized); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
@@ -407,7 +444,9 @@ func (iv *InputValidator) ValidateClaim(claimName, claimValue string) Validation
|
||||
}
|
||||
|
||||
if iv.containsControlCharacters(claimValue) {
|
||||
result.Warnings = append(result.Warnings, "claim value contains control characters")
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "claim value contains control characters")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
@@ -420,7 +459,25 @@ func (iv *InputValidator) ValidateClaim(claimName, claimValue string) Validation
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(claimValue); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for excessive unicode (emojis and special characters)
|
||||
unicodeCount := 0
|
||||
runeCount := 0
|
||||
for _, r := range claimValue {
|
||||
runeCount++
|
||||
if r > 127 { // Non-ASCII character
|
||||
unicodeCount++
|
||||
}
|
||||
}
|
||||
// If more than 50% of the characters are unicode, consider it suspicious
|
||||
if runeCount > 0 && unicodeCount > runeCount/2 {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "claim value contains excessive unicode characters")
|
||||
return result
|
||||
}
|
||||
|
||||
// Specific validations based on claim name
|
||||
@@ -505,6 +562,13 @@ func (iv *InputValidator) ValidateHeader(headerName, headerValue string) Validat
|
||||
return result
|
||||
}
|
||||
|
||||
// Check for control characters in header value
|
||||
if iv.containsControlCharacters(headerValue) {
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, "header value contains control characters")
|
||||
return result
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
if !utf8.ValidString(headerValue) {
|
||||
result.IsValid = false
|
||||
@@ -515,7 +579,9 @@ func (iv *InputValidator) ValidateHeader(headerName, headerValue string) Validat
|
||||
// Check for suspicious patterns
|
||||
if risk := iv.detectSecurityRisk(headerValue); risk != "" {
|
||||
result.SecurityRisk = risk
|
||||
result.Warnings = append(result.Warnings, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
result.IsValid = false
|
||||
result.Errors = append(result.Errors, fmt.Sprintf("potential security risk detected: %s", risk))
|
||||
return result
|
||||
}
|
||||
|
||||
result.SanitizedValue = strings.TrimSpace(headerValue)
|
||||
|
||||
@@ -0,0 +1,480 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestInputValidatorValidateToken tests comprehensive token validation
|
||||
func TestInputValidatorValidateToken(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidJWTToken",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiZXhwIjoxNTE2MjM5MDIyLCJpYXQiOjE1MTYyMzkwMjJ9.signature",
|
||||
expectValid: true,
|
||||
description: "Valid JWT token should pass validation",
|
||||
},
|
||||
{
|
||||
name: "InvalidOpaqueToken",
|
||||
token: "opaque_access_token_that_is_long_enough_to_pass",
|
||||
expectValid: false,
|
||||
description: "Opaque token (non-JWT) should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyToken",
|
||||
token: "",
|
||||
expectValid: false,
|
||||
description: "Empty token should fail validation",
|
||||
},
|
||||
{
|
||||
name: "TokenWithNullBytes",
|
||||
token: "token_with_null\x00byte",
|
||||
expectValid: false,
|
||||
description: "Token with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "TokenTooLong",
|
||||
token: strings.Repeat("a", config.MaxTokenLength+1),
|
||||
expectValid: false,
|
||||
description: "Token exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "TokenWithControlCharacters",
|
||||
token: "token_with_control\x01character",
|
||||
expectValid: false,
|
||||
description: "Token with control characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "TokenWithHighUnicode",
|
||||
token: "token_with_unicode_\uffff",
|
||||
expectValid: false,
|
||||
description: "Token with high unicode characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "MaliciousJWTWithExtraData",
|
||||
token: "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.sig.malicious_extra",
|
||||
expectValid: false,
|
||||
description: "JWT with extra malicious data should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateToken(tt.token)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputValidatorValidateEmail tests email validation edge cases
|
||||
func TestInputValidatorValidateEmail(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidEmail",
|
||||
email: "user@example.com",
|
||||
expectValid: true,
|
||||
description: "Valid email should pass validation",
|
||||
},
|
||||
{
|
||||
name: "ValidEmailWithSubdomain",
|
||||
email: "user@mail.example.com",
|
||||
expectValid: true,
|
||||
description: "Valid email with subdomain should pass validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyEmail",
|
||||
email: "",
|
||||
expectValid: false,
|
||||
description: "Empty email should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmailWithoutAtSign",
|
||||
email: "userexample.com",
|
||||
expectValid: false,
|
||||
description: "Email without @ sign should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmailWithNullBytes",
|
||||
email: "user@example\x00.com",
|
||||
expectValid: false,
|
||||
description: "Email with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmailTooLong",
|
||||
email: strings.Repeat("a", config.MaxEmailLength-10) + "@example.com",
|
||||
expectValid: false,
|
||||
description: "Email exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmailWithControlCharacters",
|
||||
email: "user\x01@example.com",
|
||||
expectValid: false,
|
||||
description: "Email with control characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "MaliciousEmailWithScriptTag",
|
||||
email: "user<script>@example.com",
|
||||
expectValid: false,
|
||||
description: "Email with script tag should fail validation",
|
||||
},
|
||||
{
|
||||
name: "EmailWithUnicodeCharacters",
|
||||
email: "üser@éxample.com",
|
||||
expectValid: false,
|
||||
description: "Email with unicode should fail basic validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateEmail(tt.email)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputValidatorValidateURL tests URL validation with security focus
|
||||
func TestInputValidatorValidateURL(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidHTTPSURL",
|
||||
url: "https://example.com/path",
|
||||
expectValid: true,
|
||||
description: "Valid HTTPS URL should pass validation",
|
||||
},
|
||||
{
|
||||
name: "ValidHTTPURL",
|
||||
url: "http://example.com/path",
|
||||
expectValid: true,
|
||||
description: "Valid HTTP URL should pass validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyURL",
|
||||
url: "",
|
||||
expectValid: false,
|
||||
description: "Empty URL should fail validation",
|
||||
},
|
||||
{
|
||||
name: "InvalidScheme",
|
||||
url: "ftp://example.com",
|
||||
expectValid: false,
|
||||
description: "URL with invalid scheme should fail validation",
|
||||
},
|
||||
{
|
||||
name: "URLWithNullBytes",
|
||||
url: "https://example\x00.com",
|
||||
expectValid: false,
|
||||
description: "URL with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "URLTooLong",
|
||||
url: "https://" + strings.Repeat("a", config.MaxURLLength) + ".com",
|
||||
expectValid: false,
|
||||
description: "URL exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "MalformedURL",
|
||||
url: "https://",
|
||||
expectValid: false,
|
||||
description: "Malformed URL should fail validation",
|
||||
},
|
||||
{
|
||||
name: "HTTPSLocalhostURL",
|
||||
url: "https://localhost:8080/path",
|
||||
expectValid: true,
|
||||
description: "HTTPS localhost URL should be allowed for development",
|
||||
},
|
||||
{
|
||||
name: "HTTPLocalhostURL",
|
||||
url: "http://localhost:8080/path",
|
||||
expectValid: false,
|
||||
description: "HTTP localhost URL should fail validation for security",
|
||||
},
|
||||
{
|
||||
name: "PrivateIPURL",
|
||||
url: "https://192.168.1.1/path",
|
||||
expectValid: false,
|
||||
description: "Private IP URL should fail validation for security",
|
||||
},
|
||||
{
|
||||
name: "JavaScriptURL",
|
||||
url: "javascript:alert(1)",
|
||||
expectValid: false,
|
||||
description: "JavaScript URL should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateURL(tt.url)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputValidatorValidateClaim tests claim validation with security focus
|
||||
func TestInputValidatorValidateClaim(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claimName string
|
||||
claimValue string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidStringClaim",
|
||||
claimName: "email",
|
||||
claimValue: "user@example.com",
|
||||
expectValid: true,
|
||||
description: "Valid string claim should pass validation",
|
||||
},
|
||||
{
|
||||
name: "ValidNumberClaim",
|
||||
claimName: "exp",
|
||||
claimValue: "1516239022",
|
||||
expectValid: true,
|
||||
description: "Valid number claim should pass validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyClaimName",
|
||||
claimName: "",
|
||||
claimValue: "value",
|
||||
expectValid: false,
|
||||
description: "Empty claim name should fail validation",
|
||||
},
|
||||
{
|
||||
name: "ClaimWithNullBytes",
|
||||
claimName: "test",
|
||||
claimValue: "value\x00with_null",
|
||||
expectValid: false,
|
||||
description: "Claim with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "ClaimValueTooLong",
|
||||
claimName: "test",
|
||||
claimValue: strings.Repeat("a", config.MaxClaimLength+1),
|
||||
expectValid: false,
|
||||
description: "Claim value exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "ClaimWithControlCharacters",
|
||||
claimName: "test",
|
||||
claimValue: "value\x01with_control",
|
||||
expectValid: false,
|
||||
description: "Claim with control characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "MaliciousClaimWithHTML",
|
||||
claimName: "test",
|
||||
claimValue: "<script>alert('xss')</script>",
|
||||
expectValid: false,
|
||||
description: "Claim with HTML/script should fail validation",
|
||||
},
|
||||
{
|
||||
name: "ClaimWithExcessiveUnicode",
|
||||
claimName: "test",
|
||||
claimValue: strings.Repeat("🚀", 100), // Many unicode chars
|
||||
expectValid: false,
|
||||
description: "Claim with excessive unicode should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateClaim(tt.claimName, tt.claimValue)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputValidatorValidateHeader tests HTTP header validation
|
||||
func TestInputValidatorValidateHeader(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headerName string
|
||||
headerValue string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidHeader",
|
||||
headerName: "Authorization",
|
||||
headerValue: "Bearer token123",
|
||||
expectValid: true,
|
||||
description: "Valid header should pass validation",
|
||||
},
|
||||
{
|
||||
name: "ValidContentType",
|
||||
headerName: "Content-Type",
|
||||
headerValue: "application/json",
|
||||
expectValid: true,
|
||||
description: "Valid content type header should pass validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyHeaderName",
|
||||
headerName: "",
|
||||
headerValue: "value",
|
||||
expectValid: false,
|
||||
description: "Empty header name should fail validation",
|
||||
},
|
||||
{
|
||||
name: "HeaderWithNullBytes",
|
||||
headerName: "test",
|
||||
headerValue: "value\x00with_null",
|
||||
expectValid: false,
|
||||
description: "Header with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "HeaderValueTooLong",
|
||||
headerName: "test",
|
||||
headerValue: strings.Repeat("a", config.MaxHeaderLength+1),
|
||||
expectValid: false,
|
||||
description: "Header value exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "HeaderWithCRLF",
|
||||
headerName: "test",
|
||||
headerValue: "value\r\nMalicious: header",
|
||||
expectValid: false,
|
||||
description: "Header with CRLF should fail validation to prevent injection",
|
||||
},
|
||||
{
|
||||
name: "HeaderWithControlCharacters",
|
||||
headerName: "test",
|
||||
headerValue: "value\x01with_control",
|
||||
expectValid: false,
|
||||
description: "Header with control characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "MaliciousHeaderWithHTML",
|
||||
headerName: "test",
|
||||
headerValue: "<script>alert('xss')</script>",
|
||||
expectValid: false,
|
||||
description: "Header with HTML/script should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateHeader(tt.headerName, tt.headerValue)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestInputValidatorValidateUsername tests username validation
|
||||
func TestInputValidatorValidateUsername(t *testing.T) {
|
||||
config := DefaultInputValidationConfig()
|
||||
validator, _ := NewInputValidator(config, newNoOpLogger())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
expectValid bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "ValidUsername",
|
||||
username: "john_doe",
|
||||
expectValid: true,
|
||||
description: "Valid username should pass validation",
|
||||
},
|
||||
{
|
||||
name: "ValidUsernameWithNumbers",
|
||||
username: "user123",
|
||||
expectValid: true,
|
||||
description: "Valid username with numbers should pass validation",
|
||||
},
|
||||
{
|
||||
name: "EmptyUsername",
|
||||
username: "",
|
||||
expectValid: false,
|
||||
description: "Empty username should fail validation",
|
||||
},
|
||||
{
|
||||
name: "UsernameWithNullBytes",
|
||||
username: "user\x00name",
|
||||
expectValid: false,
|
||||
description: "Username with null bytes should fail validation",
|
||||
},
|
||||
{
|
||||
name: "UsernameTooLong",
|
||||
username: strings.Repeat("a", config.MaxUsernameLength+1),
|
||||
expectValid: false,
|
||||
description: "Username exceeding max length should fail validation",
|
||||
},
|
||||
{
|
||||
name: "UsernameWithSpecialChars",
|
||||
username: "user@name",
|
||||
expectValid: false,
|
||||
description: "Username with special characters should fail validation",
|
||||
},
|
||||
{
|
||||
name: "UsernameWithSpaces",
|
||||
username: "user name",
|
||||
expectValid: false,
|
||||
description: "Username with spaces should fail validation",
|
||||
},
|
||||
{
|
||||
name: "UsernameWithControlCharacters",
|
||||
username: "user\x01name",
|
||||
expectValid: false,
|
||||
description: "Username with control characters should fail validation",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.ValidateUsername(tt.username)
|
||||
|
||||
if result.IsValid != tt.expectValid {
|
||||
t.Errorf("Expected valid=%v, got %v. %s", tt.expectValid, result.IsValid, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,736 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockLegacySettings implements LegacySettings for testing
|
||||
type mockLegacySettings struct {
|
||||
issuerURL string
|
||||
authURL string
|
||||
scopes []string
|
||||
pkceEnabled bool
|
||||
clientID string
|
||||
refreshGracePeriod time.Duration
|
||||
overrideScopes bool
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) GetIssuerURL() string {
|
||||
return m.issuerURL
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) GetAuthURL() string {
|
||||
return m.authURL
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) GetScopes() []string {
|
||||
return m.scopes
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) IsPKCEEnabled() bool {
|
||||
return m.pkceEnabled
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) GetClientID() string {
|
||||
return m.clientID
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) GetRefreshGracePeriod() time.Duration {
|
||||
return m.refreshGracePeriod
|
||||
}
|
||||
|
||||
func (m *mockLegacySettings) IsOverrideScopes() bool {
|
||||
return m.overrideScopes
|
||||
}
|
||||
|
||||
func TestNewAdapter(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://accounts.google.com",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
scopes: []string{"openid", "email"},
|
||||
pkceEnabled: true,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
if adapter == nil {
|
||||
t.Fatal("expected non-nil adapter")
|
||||
}
|
||||
|
||||
if adapter.provider != provider {
|
||||
t.Error("expected provider to be set correctly")
|
||||
}
|
||||
|
||||
if adapter.legacySettings != settings {
|
||||
t.Error("expected legacy settings to be set correctly")
|
||||
}
|
||||
|
||||
if adapter.tokenVerifier != verifier {
|
||||
t.Error("expected token verifier to be set correctly")
|
||||
}
|
||||
|
||||
if adapter.tokenCache != cache {
|
||||
t.Error("expected token cache to be set correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_GetType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider OIDCProvider
|
||||
expectedType ProviderType
|
||||
}{
|
||||
{
|
||||
name: "Google provider",
|
||||
provider: NewGoogleProvider(),
|
||||
expectedType: ProviderTypeGoogle,
|
||||
},
|
||||
{
|
||||
name: "Azure provider",
|
||||
provider: NewAzureProvider(),
|
||||
expectedType: ProviderTypeAzure,
|
||||
},
|
||||
{
|
||||
name: "Generic provider",
|
||||
provider: NewGenericProvider(),
|
||||
expectedType: ProviderTypeGeneric,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
settings := &mockLegacySettings{clientID: "test-client"}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
adapter := NewAdapter(tt.provider, settings, verifier, cache)
|
||||
providerType := adapter.GetType()
|
||||
|
||||
if providerType != tt.expectedType {
|
||||
t.Errorf("expected provider type %d, got %d", tt.expectedType, providerType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_ValidateTokens(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{
|
||||
refreshGracePeriod: time.Minute * 5,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"valid-token": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *mockSession
|
||||
expectedResult *ValidationResult
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid authenticated session",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
idToken: "valid-token",
|
||||
accessToken: "access-token",
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthenticated with refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthenticated without refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
},
|
||||
expectedResult: &ValidationResult{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := adapter.ValidateTokens(tt.session)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Authenticated != tt.expectedResult.Authenticated {
|
||||
t.Errorf("expected Authenticated %t, got %t", tt.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
||||
t.Errorf("expected NeedsRefresh %t, got %t", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
||||
}
|
||||
|
||||
if result.IsExpired != tt.expectedResult.IsExpired {
|
||||
t.Errorf("expected IsExpired %t, got %t", tt.expectedResult.IsExpired, result.IsExpired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildAuthURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
provider OIDCProvider
|
||||
settings *mockLegacySettings
|
||||
redirectURL string
|
||||
state string
|
||||
nonce string
|
||||
codeChallenge string
|
||||
expectedSubstrs []string
|
||||
unexpectedSubstrs []string
|
||||
}{
|
||||
{
|
||||
name: "Google provider without override scopes",
|
||||
provider: NewGoogleProvider(),
|
||||
settings: &mockLegacySettings{
|
||||
clientID: "google-client-id",
|
||||
issuerURL: "https://accounts.google.com",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
scopes: []string{"openid", "email", "profile"},
|
||||
pkceEnabled: true,
|
||||
overrideScopes: false,
|
||||
},
|
||||
redirectURL: "https://example.com/callback",
|
||||
state: "random-state",
|
||||
nonce: "random-nonce",
|
||||
codeChallenge: "code-challenge",
|
||||
expectedSubstrs: []string{
|
||||
"client_id=google-client-id",
|
||||
"response_type=code",
|
||||
"redirect_uri=https%3A%2F%2Fexample.com%2Fcallback",
|
||||
"state=random-state",
|
||||
"nonce=random-nonce",
|
||||
"code_challenge=code-challenge",
|
||||
"code_challenge_method=S256",
|
||||
"access_type=offline",
|
||||
"prompt=consent",
|
||||
"scope=openid+email+profile",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure provider with override scopes",
|
||||
provider: NewAzureProvider(),
|
||||
settings: &mockLegacySettings{
|
||||
clientID: "azure-client-id",
|
||||
issuerURL: "https://login.microsoftonline.com/tenant",
|
||||
authURL: "https://login.microsoftonline.com/tenant/oauth2/v2.0/authorize",
|
||||
scopes: []string{"openid", "offline_access"},
|
||||
pkceEnabled: false,
|
||||
overrideScopes: true,
|
||||
},
|
||||
redirectURL: "https://example.com/azure-callback",
|
||||
state: "azure-state",
|
||||
nonce: "azure-nonce",
|
||||
codeChallenge: "",
|
||||
expectedSubstrs: []string{
|
||||
"client_id=azure-client-id",
|
||||
"response_type=code",
|
||||
"redirect_uri=https%3A%2F%2Fexample.com%2Fazure-callback",
|
||||
"state=azure-state",
|
||||
"nonce=azure-nonce",
|
||||
"response_mode=query",
|
||||
"scope=openid+offline_access",
|
||||
},
|
||||
unexpectedSubstrs: []string{
|
||||
"code_challenge",
|
||||
"access_type",
|
||||
"prompt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Generic provider with relative auth URL",
|
||||
provider: NewGenericProvider(),
|
||||
settings: &mockLegacySettings{
|
||||
clientID: "generic-client-id",
|
||||
issuerURL: "https://keycloak.example.com/auth/realms/master",
|
||||
authURL: "/auth/realms/master/protocol/openid-connect/auth",
|
||||
scopes: []string{"openid", "email"},
|
||||
pkceEnabled: false,
|
||||
overrideScopes: false,
|
||||
},
|
||||
redirectURL: "https://example.com/generic-callback",
|
||||
state: "generic-state",
|
||||
nonce: "generic-nonce",
|
||||
codeChallenge: "",
|
||||
expectedSubstrs: []string{
|
||||
"keycloak.example.com",
|
||||
"client_id=generic-client-id",
|
||||
"response_type=code",
|
||||
"scope=openid+email+offline_access", // Generic provider adds offline_access
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PKCE disabled",
|
||||
provider: NewGoogleProvider(),
|
||||
settings: &mockLegacySettings{
|
||||
clientID: "google-client-id-no-pkce",
|
||||
issuerURL: "https://accounts.google.com",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
scopes: []string{"openid", "email"},
|
||||
pkceEnabled: false,
|
||||
overrideScopes: false,
|
||||
},
|
||||
redirectURL: "https://example.com/callback",
|
||||
state: "state-no-pkce",
|
||||
nonce: "nonce-no-pkce",
|
||||
codeChallenge: "should-be-ignored",
|
||||
expectedSubstrs: []string{
|
||||
"client_id=google-client-id-no-pkce",
|
||||
"state=state-no-pkce",
|
||||
"nonce=nonce-no-pkce",
|
||||
},
|
||||
unexpectedSubstrs: []string{
|
||||
"code_challenge",
|
||||
"code_challenge_method",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
adapter := NewAdapter(tt.provider, tt.settings, verifier, cache)
|
||||
|
||||
authURL := adapter.BuildAuthURL(tt.redirectURL, tt.state, tt.nonce, tt.codeChallenge)
|
||||
|
||||
if authURL == "" {
|
||||
t.Fatal("expected non-empty auth URL")
|
||||
}
|
||||
|
||||
for _, expectedSubstr := range tt.expectedSubstrs {
|
||||
if !strings.Contains(authURL, expectedSubstr) {
|
||||
t.Errorf("expected auth URL to contain %q, got %q", expectedSubstr, authURL)
|
||||
}
|
||||
}
|
||||
|
||||
for _, unexpectedSubstr := range tt.unexpectedSubstrs {
|
||||
if strings.Contains(authURL, unexpectedSubstr) {
|
||||
t.Errorf("expected auth URL to NOT contain %q, got %q", unexpectedSubstr, authURL)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_BuildAuthURL_ErrorCases(t *testing.T) {
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
t.Run("invalid issuer URL", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "://invalid-url",
|
||||
authURL: "/relative/path",
|
||||
scopes: []string{"openid"},
|
||||
overrideScopes: false,
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
|
||||
if authURL != "" {
|
||||
t.Errorf("expected empty auth URL for invalid issuer URL, got %q", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid auth URL", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://example.com",
|
||||
authURL: "://invalid-auth-url",
|
||||
scopes: []string{"openid"},
|
||||
overrideScopes: false,
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
|
||||
if authURL != "" {
|
||||
t.Errorf("expected empty auth URL for invalid auth URL, got %q", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid absolute auth URL", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://example.com",
|
||||
authURL: "://invalid-absolute-url",
|
||||
scopes: []string{"openid"},
|
||||
overrideScopes: false,
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
|
||||
if authURL != "" {
|
||||
t.Errorf("expected empty auth URL for invalid absolute auth URL, got %q", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("provider BuildAuthParams error", func(t *testing.T) {
|
||||
// Create a mock provider that returns an error
|
||||
mockProvider := &mockProviderWithError{}
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://example.com",
|
||||
authURL: "https://example.com/auth",
|
||||
scopes: []string{"openid"},
|
||||
overrideScopes: false,
|
||||
}
|
||||
|
||||
adapter := NewAdapter(mockProvider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL("https://example.com/callback", "state", "nonce", "")
|
||||
|
||||
if authURL != "" {
|
||||
t.Errorf("expected empty auth URL when provider returns error, got %q", authURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockProviderWithError is a test helper that returns errors from BuildAuthParams
|
||||
type mockProviderWithError struct {
|
||||
*BaseProvider
|
||||
}
|
||||
|
||||
func (m *mockProviderWithError) GetType() ProviderType {
|
||||
return ProviderTypeGeneric
|
||||
}
|
||||
|
||||
func (m *mockProviderWithError) BuildAuthParams(baseParams url.Values, scopes []string) (*AuthParams, error) {
|
||||
return nil, fmt.Errorf("mock error from BuildAuthParams")
|
||||
}
|
||||
|
||||
func TestAdapter_ConcurrentAccess(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://accounts.google.com",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
scopes: []string{"openid", "email"},
|
||||
pkceEnabled: true,
|
||||
refreshGracePeriod: time.Minute,
|
||||
overrideScopes: false,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"valid-token": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Test concurrent access to adapter methods
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
// Test GetType
|
||||
providerType := adapter.GetType()
|
||||
if providerType != ProviderTypeGoogle {
|
||||
t.Errorf("worker %d: expected Google provider type", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test BuildAuthURL
|
||||
authURL := adapter.BuildAuthURL(
|
||||
fmt.Sprintf("https://example.com/callback-%d", workerID),
|
||||
fmt.Sprintf("state-%d-%d", workerID, j),
|
||||
fmt.Sprintf("nonce-%d-%d", workerID, j),
|
||||
fmt.Sprintf("challenge-%d-%d", workerID, j),
|
||||
)
|
||||
if authURL == "" {
|
||||
t.Errorf("worker %d: expected non-empty auth URL", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateTokens
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
idToken: "valid-token",
|
||||
accessToken: fmt.Sprintf("access-token-%d", workerID),
|
||||
refreshToken: fmt.Sprintf("refresh-token-%d", workerID),
|
||||
}
|
||||
|
||||
result, err := adapter.ValidateTokens(session)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in ValidateTokens: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if !result.Authenticated {
|
||||
t.Errorf("worker %d: expected authenticated result", workerID)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_MemorySafety(t *testing.T) {
|
||||
const numIterations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: fmt.Sprintf("client-%d", i),
|
||||
issuerURL: fmt.Sprintf("https://example%d.com", i),
|
||||
authURL: fmt.Sprintf("https://example%d.com/auth", i),
|
||||
scopes: []string{"openid", "email"},
|
||||
pkceEnabled: i%2 == 0, // Alternate PKCE setting
|
||||
refreshGracePeriod: time.Minute,
|
||||
overrideScopes: i%3 == 0, // Alternate override setting
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
fmt.Sprintf("token-%d", i): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
// Exercise all adapter methods
|
||||
_ = adapter.GetType()
|
||||
_ = adapter.BuildAuthURL("https://example.com/callback", "state", "nonce", "challenge")
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
idToken: fmt.Sprintf("token-%d", i),
|
||||
accessToken: fmt.Sprintf("access-%d", i),
|
||||
refreshToken: fmt.Sprintf("refresh-%d", i),
|
||||
}
|
||||
_, _ = adapter.ValidateTokens(session)
|
||||
}
|
||||
|
||||
// Force garbage collection
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter_EdgeCases(t *testing.T) {
|
||||
t.Run("empty parameters", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "",
|
||||
issuerURL: "",
|
||||
authURL: "",
|
||||
scopes: []string{},
|
||||
pkceEnabled: false,
|
||||
overrideScopes: false,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL("", "", "", "")
|
||||
|
||||
// Should not crash, but may return empty or invalid URL
|
||||
if authURL != "" {
|
||||
t.Logf("Got auth URL with empty parameters: %s", authURL)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("very long parameters", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
longString := strings.Repeat("a", 5000)
|
||||
settings := &mockLegacySettings{
|
||||
clientID: longString,
|
||||
issuerURL: "https://example.com/" + longString,
|
||||
authURL: "https://example.com/" + longString + "/auth",
|
||||
scopes: []string{"openid", longString},
|
||||
pkceEnabled: true,
|
||||
overrideScopes: false,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL(
|
||||
"https://example.com/callback",
|
||||
longString,
|
||||
longString,
|
||||
longString,
|
||||
)
|
||||
|
||||
// Should not crash
|
||||
if authURL == "" {
|
||||
t.Log("Long parameters resulted in empty auth URL")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special characters in parameters", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "client@example.com",
|
||||
issuerURL: "https://example.com/auth?param=value&other=test",
|
||||
authURL: "https://example.com/auth/endpoint?default=param",
|
||||
scopes: []string{"openid", "email+special", "profile/test"},
|
||||
pkceEnabled: true,
|
||||
overrideScopes: false,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
authURL := adapter.BuildAuthURL(
|
||||
"https://example.com/callback?return=url",
|
||||
"state+with/special=chars&more",
|
||||
"nonce_with_underscores",
|
||||
"challenge-with-dashes",
|
||||
)
|
||||
|
||||
if authURL == "" {
|
||||
t.Error("expected non-empty auth URL with special characters")
|
||||
}
|
||||
|
||||
// Verify URL is properly encoded
|
||||
if !strings.Contains(authURL, "%") {
|
||||
t.Error("expected auth URL to contain URL encoding")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkAdapter_GetType(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{clientID: "test-client"}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
adapter.GetType()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdapter_BuildAuthURL(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{
|
||||
clientID: "test-client",
|
||||
issuerURL: "https://accounts.google.com",
|
||||
authURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
scopes: []string{"openid", "email", "profile"},
|
||||
pkceEnabled: true,
|
||||
overrideScopes: false,
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
adapter.BuildAuthURL(
|
||||
"https://example.com/callback",
|
||||
"test-state",
|
||||
"test-nonce",
|
||||
"test-challenge",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAdapter_ValidateTokens(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
settings := &mockLegacySettings{refreshGracePeriod: time.Minute}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"test-token": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
adapter := NewAdapter(provider, settings, verifier, cache)
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
idToken: "test-token",
|
||||
accessToken: "access-token",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := adapter.ValidateTokens(session)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,695 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockSession implements the Session interface for testing
|
||||
type mockSession struct {
|
||||
idToken string
|
||||
accessToken string
|
||||
refreshToken string
|
||||
authenticated bool
|
||||
}
|
||||
|
||||
func (m *mockSession) GetIDToken() string {
|
||||
return m.idToken
|
||||
}
|
||||
|
||||
func (m *mockSession) GetAccessToken() string {
|
||||
return m.accessToken
|
||||
}
|
||||
|
||||
func (m *mockSession) GetRefreshToken() string {
|
||||
return m.refreshToken
|
||||
}
|
||||
|
||||
func (m *mockSession) GetAuthenticated() bool {
|
||||
return m.authenticated
|
||||
}
|
||||
|
||||
// mockTokenVerifier implements TokenVerifier for testing
|
||||
type mockTokenVerifier struct {
|
||||
shouldFail bool
|
||||
expiredTokens map[string]bool
|
||||
}
|
||||
|
||||
func (m *mockTokenVerifier) VerifyToken(token string) error {
|
||||
if m.shouldFail {
|
||||
return errors.New("token verification failed")
|
||||
}
|
||||
if m.expiredTokens != nil && m.expiredTokens[token] {
|
||||
return errors.New("token has expired")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockTokenCache implements TokenCache for testing
|
||||
type mockTokenCache struct {
|
||||
data map[string]map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockTokenCache) Get(key string) (map[string]interface{}, bool) {
|
||||
if m.data == nil {
|
||||
return nil, false
|
||||
}
|
||||
claims, exists := m.data[key]
|
||||
return claims, exists
|
||||
}
|
||||
|
||||
func TestNewAzureProvider(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
if provider == nil {
|
||||
t.Fatal("expected non-nil Azure provider")
|
||||
}
|
||||
|
||||
if provider.BaseProvider == nil {
|
||||
t.Fatal("expected non-nil BaseProvider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_GetType(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
providerType := provider.GetType()
|
||||
|
||||
if providerType != ProviderTypeAzure {
|
||||
t.Errorf("expected provider type %d, got %d", ProviderTypeAzure, providerType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_GetCapabilities(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
capabilities := provider.GetCapabilities()
|
||||
|
||||
expectedCapabilities := ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: true,
|
||||
PreferredTokenValidation: "access",
|
||||
}
|
||||
|
||||
if capabilities.SupportsRefreshTokens != expectedCapabilities.SupportsRefreshTokens {
|
||||
t.Errorf("expected SupportsRefreshTokens %t, got %t", expectedCapabilities.SupportsRefreshTokens, capabilities.SupportsRefreshTokens)
|
||||
}
|
||||
|
||||
if capabilities.RequiresOfflineAccessScope != expectedCapabilities.RequiresOfflineAccessScope {
|
||||
t.Errorf("expected RequiresOfflineAccessScope %t, got %t", expectedCapabilities.RequiresOfflineAccessScope, capabilities.RequiresOfflineAccessScope)
|
||||
}
|
||||
|
||||
if capabilities.PreferredTokenValidation != expectedCapabilities.PreferredTokenValidation {
|
||||
t.Errorf("expected PreferredTokenValidation %q, got %q", expectedCapabilities.PreferredTokenValidation, capabilities.PreferredTokenValidation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_BuildAuthParams(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseParams url.Values
|
||||
scopes []string
|
||||
expectOfflineAccess bool
|
||||
expectResponseMode bool
|
||||
}{
|
||||
{
|
||||
name: "basic params with offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "offline_access", "email"},
|
||||
expectOfflineAccess: true,
|
||||
expectResponseMode: true,
|
||||
},
|
||||
{
|
||||
name: "basic params without offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "email"},
|
||||
expectOfflineAccess: true, // Should be added automatically
|
||||
expectResponseMode: true,
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
baseParams: url.Values{},
|
||||
scopes: []string{},
|
||||
expectOfflineAccess: true, // Should be added automatically
|
||||
expectResponseMode: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Fatal("expected non-nil auth params")
|
||||
}
|
||||
|
||||
// Check response_mode is set
|
||||
if tt.expectResponseMode {
|
||||
responseMode := authParams.URLValues.Get("response_mode")
|
||||
if responseMode != "query" {
|
||||
t.Errorf("expected response_mode 'query', got %q", responseMode)
|
||||
}
|
||||
}
|
||||
|
||||
// Check offline_access scope
|
||||
if tt.expectOfflineAccess {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range authParams.Scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
t.Error("expected offline_access scope to be present")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify other parameters are preserved
|
||||
for key, values := range tt.baseParams {
|
||||
if key == "response_mode" {
|
||||
continue // This gets overridden
|
||||
}
|
||||
paramValues := authParams.URLValues[key]
|
||||
if len(paramValues) != len(values) {
|
||||
t.Errorf("expected %d values for param %s, got %d", len(values), key, len(paramValues))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_ValidateTokens(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *mockSession
|
||||
verifier *mockTokenVerifier
|
||||
cache *mockTokenCache
|
||||
refreshGracePeriod time.Duration
|
||||
expectedResult *ValidationResult
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated with refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthenticated without refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
IsExpired: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with valid JWT access token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "header.payload.signature",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"header.payload.signature": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with invalid access token but valid ID token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "header.payload.signature",
|
||||
idToken: "id.token.here",
|
||||
},
|
||||
verifier: &mockTokenVerifier{shouldFail: true},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"id.token.here": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with opaque access token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "opaque-token-no-dots",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with ID token only",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
idToken: "id.token.here",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"id.token.here": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "expired ID token with refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
idToken: "expired.token.here",
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{
|
||||
expiredTokens: map[string]bool{
|
||||
"expired.token.here": true,
|
||||
},
|
||||
},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated but no tokens",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := provider.ValidateTokens(tt.session, tt.verifier, tt.cache, tt.refreshGracePeriod)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Authenticated != tt.expectedResult.Authenticated {
|
||||
t.Errorf("expected Authenticated %t, got %t", tt.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
||||
t.Errorf("expected NeedsRefresh %t, got %t", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
||||
}
|
||||
|
||||
if result.IsExpired != tt.expectedResult.IsExpired {
|
||||
t.Errorf("expected IsExpired %t, got %t", tt.expectedResult.IsExpired, result.IsExpired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_ValidateConfig(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
// Azure provider uses BaseProvider's ValidateConfig which always returns nil
|
||||
err := provider.ValidateConfig()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from ValidateConfig: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_HandleTokenRefresh(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
// Test that HandleTokenRefresh doesn't fail
|
||||
tokenData := &TokenResult{
|
||||
IDToken: "id-token",
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
}
|
||||
|
||||
err := provider.HandleTokenRefresh(tokenData)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from HandleTokenRefresh: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_ConcurrentAccess(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Test concurrent access to provider methods
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: fmt.Sprintf("access-token-%d", workerID),
|
||||
idToken: fmt.Sprintf("id-token-%d", workerID),
|
||||
refreshToken: fmt.Sprintf("refresh-token-%d", workerID),
|
||||
}
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
fmt.Sprintf("access-token-%d", workerID): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
fmt.Sprintf("id-token-%d", workerID): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
// Test GetType
|
||||
if provider.GetType() != ProviderTypeAzure {
|
||||
t.Errorf("worker %d: expected Azure provider type", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test GetCapabilities
|
||||
capabilities := provider.GetCapabilities()
|
||||
if !capabilities.SupportsRefreshTokens {
|
||||
t.Errorf("worker %d: expected refresh token support", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateTokens
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in ValidateTokens: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if !result.Authenticated {
|
||||
t.Errorf("worker %d: expected authenticated result", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test BuildAuthParams
|
||||
baseParams := url.Values{"client_id": []string{fmt.Sprintf("client-%d", workerID)}}
|
||||
scopes := []string{"openid", "email"}
|
||||
authParams, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in BuildAuthParams: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if authParams == nil {
|
||||
t.Errorf("worker %d: expected non-nil auth params", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateConfig
|
||||
err = provider.ValidateConfig()
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in ValidateConfig: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_MemorySafety(t *testing.T) {
|
||||
const numIterations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: fmt.Sprintf("access-token-%d.payload.signature", i),
|
||||
idToken: fmt.Sprintf("id-token-%d.payload.signature", i),
|
||||
refreshToken: fmt.Sprintf("refresh-token-%d", i),
|
||||
}
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
fmt.Sprintf("access-token-%d.payload.signature", i): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Exercise all provider methods
|
||||
_ = provider.GetType()
|
||||
_ = provider.GetCapabilities()
|
||||
_, _ = provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
_, _ = provider.BuildAuthParams(url.Values{}, []string{"openid"})
|
||||
_ = provider.ValidateConfig()
|
||||
_ = provider.HandleTokenRefresh(&TokenResult{})
|
||||
}
|
||||
|
||||
// Force garbage collection
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAzureProvider_EdgeCases(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
t.Run("nil session", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Expected behavior for nil session
|
||||
t.Logf("Recovered from expected panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
_, err := provider.ValidateTokens(nil, verifier, cache, time.Minute)
|
||||
if err == nil {
|
||||
t.Error("expected error with nil session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil verifier", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
session := &mockSession{authenticated: true, idToken: "test.token.here"}
|
||||
cache := &mockTokenCache{}
|
||||
_, err := provider.ValidateTokens(session, nil, cache, time.Minute)
|
||||
if err == nil {
|
||||
t.Error("expected error with nil verifier")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil cache", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic with nil cache: %v", r)
|
||||
}
|
||||
}()
|
||||
session := &mockSession{authenticated: true, accessToken: "test.token.here"}
|
||||
verifier := &mockTokenVerifier{}
|
||||
_, err := provider.ValidateTokens(session, verifier, nil, time.Minute)
|
||||
if err == nil {
|
||||
t.Error("expected error with nil cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty tokens", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "",
|
||||
idToken: "",
|
||||
refreshToken: "",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with empty tokens: %v", err)
|
||||
}
|
||||
if !result.IsExpired {
|
||||
t.Error("expected IsExpired=true for empty tokens without refresh token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("malformed JWT tokens", func(t *testing.T) {
|
||||
malformedTokens := []string{
|
||||
"not.enough.parts",
|
||||
"too.many.parts.in.this.token",
|
||||
"",
|
||||
"single-part-token",
|
||||
}
|
||||
|
||||
for _, token := range malformedTokens {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: token,
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with malformed token %q: %v", token, err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Errorf("expected non-nil result for malformed token %q", token)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("very long tokens", func(t *testing.T) {
|
||||
longToken := strings.Repeat("a", 10000) + "." + strings.Repeat("b", 10000) + "." + strings.Repeat("c", 10000)
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: longToken,
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
longToken: {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with very long token: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Error("expected non-nil result with very long token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkAzureProvider_GetType(b *testing.B) {
|
||||
provider := NewAzureProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetType()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAzureProvider_GetCapabilities(b *testing.B) {
|
||||
provider := NewAzureProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetCapabilities()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAzureProvider_BuildAuthParams(b *testing.B) {
|
||||
provider := NewAzureProvider()
|
||||
baseParams := url.Values{"client_id": []string{"test-client"}}
|
||||
scopes := []string{"openid", "email", "profile"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAzureProvider_ValidateTokens(b *testing.B) {
|
||||
provider := NewAzureProvider()
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "header.payload.signature",
|
||||
idToken: "id.token.here",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"header.payload.signature": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,650 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBaseProvider_GetType(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
providerType := provider.GetType()
|
||||
|
||||
if providerType != ProviderTypeGeneric {
|
||||
t.Errorf("expected provider type %d, got %d", ProviderTypeGeneric, providerType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_GetCapabilities(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
capabilities := provider.GetCapabilities()
|
||||
|
||||
expectedCapabilities := ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: true,
|
||||
PreferredTokenValidation: "id",
|
||||
}
|
||||
|
||||
if capabilities.SupportsRefreshTokens != expectedCapabilities.SupportsRefreshTokens {
|
||||
t.Errorf("expected SupportsRefreshTokens %t, got %t", expectedCapabilities.SupportsRefreshTokens, capabilities.SupportsRefreshTokens)
|
||||
}
|
||||
|
||||
if capabilities.RequiresOfflineAccessScope != expectedCapabilities.RequiresOfflineAccessScope {
|
||||
t.Errorf("expected RequiresOfflineAccessScope %t, got %t", expectedCapabilities.RequiresOfflineAccessScope, capabilities.RequiresOfflineAccessScope)
|
||||
}
|
||||
|
||||
if capabilities.PreferredTokenValidation != expectedCapabilities.PreferredTokenValidation {
|
||||
t.Errorf("expected PreferredTokenValidation %q, got %q", expectedCapabilities.PreferredTokenValidation, capabilities.PreferredTokenValidation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_BuildAuthParams(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseParams url.Values
|
||||
scopes []string
|
||||
expectOfflineAccess bool
|
||||
}{
|
||||
{
|
||||
name: "params with offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "offline_access", "email"},
|
||||
expectOfflineAccess: true,
|
||||
},
|
||||
{
|
||||
name: "params without offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "email"},
|
||||
expectOfflineAccess: true, // Should be added automatically
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
baseParams: url.Values{},
|
||||
scopes: []string{},
|
||||
expectOfflineAccess: true, // Should be added automatically
|
||||
},
|
||||
{
|
||||
name: "multiple offline_access scopes",
|
||||
baseParams: url.Values{},
|
||||
scopes: []string{"openid", "offline_access", "email", "offline_access"},
|
||||
expectOfflineAccess: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Fatal("expected non-nil auth params")
|
||||
}
|
||||
|
||||
// Check offline_access scope
|
||||
if tt.expectOfflineAccess {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range authParams.Scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
t.Error("expected offline_access scope to be present")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify other parameters are preserved
|
||||
for key, values := range tt.baseParams {
|
||||
paramValues := authParams.URLValues[key]
|
||||
if len(paramValues) != len(values) {
|
||||
t.Errorf("expected %d values for param %s, got %d", len(values), key, len(paramValues))
|
||||
}
|
||||
for i, expectedValue := range values {
|
||||
if i < len(paramValues) && paramValues[i] != expectedValue {
|
||||
t.Errorf("expected param %s[%d] to be %q, got %q", key, i, expectedValue, paramValues[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_ValidateTokenExpiry(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
session *mockSession
|
||||
cache *mockTokenCache
|
||||
refreshGracePeriod time.Duration
|
||||
expectedResult *ValidationResult
|
||||
}{
|
||||
{
|
||||
name: "token not in cache with refresh token",
|
||||
token: "missing-token",
|
||||
session: &mockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token not in cache without refresh token",
|
||||
token: "missing-token",
|
||||
session: &mockSession{
|
||||
refreshToken: "",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
IsExpired: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with missing exp claim with refresh token",
|
||||
token: "token-without-exp",
|
||||
session: &mockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"token-without-exp": {
|
||||
"sub": "user123",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with invalid exp claim type with refresh token",
|
||||
token: "token-with-invalid-exp",
|
||||
session: &mockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"token-with-invalid-exp": {
|
||||
"exp": "not-a-number",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token with invalid exp claim type without refresh token",
|
||||
token: "token-with-invalid-exp",
|
||||
session: &mockSession{
|
||||
refreshToken: "",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"token-with-invalid-exp": {
|
||||
"exp": "not-a-number",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
IsExpired: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token expired within grace period with refresh token",
|
||||
token: "expiring-token",
|
||||
refreshGracePeriod: time.Minute * 10,
|
||||
session: &mockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"expiring-token": {
|
||||
"exp": float64(time.Now().Add(time.Minute * 5).Unix()), // Expires in 5 minutes, within grace period
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token expired within grace period without refresh token",
|
||||
token: "expiring-token-no-refresh",
|
||||
refreshGracePeriod: time.Minute * 10,
|
||||
session: &mockSession{
|
||||
refreshToken: "",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"expiring-token-no-refresh": {
|
||||
"exp": float64(time.Now().Add(time.Minute * 5).Unix()), // Expires in 5 minutes, within grace period
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "token valid outside grace period",
|
||||
token: "valid-token",
|
||||
refreshGracePeriod: time.Minute * 5,
|
||||
session: &mockSession{
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"valid-token": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()), // Expires in 1 hour, outside grace period
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := provider.ValidateTokenExpiry(tt.session, tt.token, tt.cache, tt.refreshGracePeriod)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Authenticated != tt.expectedResult.Authenticated {
|
||||
t.Errorf("expected Authenticated %t, got %t", tt.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
||||
t.Errorf("expected NeedsRefresh %t, got %t", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
||||
}
|
||||
|
||||
if result.IsExpired != tt.expectedResult.IsExpired {
|
||||
t.Errorf("expected IsExpired %t, got %t", tt.expectedResult.IsExpired, result.IsExpired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_ValidateTokens_AdditionalCases(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
t.Run("authenticated with access token but no ID token and refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Authenticated {
|
||||
t.Error("expected Authenticated to be true")
|
||||
}
|
||||
if !result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be true when no ID token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated with access token but no ID token and no refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "",
|
||||
refreshToken: "",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Authenticated {
|
||||
t.Error("expected Authenticated to be true")
|
||||
}
|
||||
if result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be false when no refresh token available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("authenticated with no access token but has refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "",
|
||||
idToken: "",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Authenticated {
|
||||
t.Error("expected Authenticated to be false when no access token")
|
||||
}
|
||||
if !result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be true when refresh token available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token verification error containing 'token has expired'", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "expired-id-token",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{
|
||||
expiredTokens: map[string]bool{
|
||||
"expired-id-token": true,
|
||||
},
|
||||
}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Authenticated {
|
||||
t.Error("expected Authenticated to be false for expired token")
|
||||
}
|
||||
if !result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be true when token expired and refresh token available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token verification error containing 'token has expired' without refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "expired-id-token",
|
||||
refreshToken: "",
|
||||
}
|
||||
verifier := &mockTokenVerifier{
|
||||
expiredTokens: map[string]bool{
|
||||
"expired-id-token": true,
|
||||
},
|
||||
}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Authenticated {
|
||||
t.Error("expected Authenticated to be false for expired token")
|
||||
}
|
||||
if result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be false when no refresh token")
|
||||
}
|
||||
if !result.IsExpired {
|
||||
t.Error("expected IsExpired to be true for expired token without refresh")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token verification other error with refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "invalid-token",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{shouldFail: true}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Authenticated {
|
||||
t.Error("expected Authenticated to be false for invalid token")
|
||||
}
|
||||
if !result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be true when verification fails and refresh token available")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("token verification other error without refresh token", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "invalid-token",
|
||||
refreshToken: "",
|
||||
}
|
||||
verifier := &mockTokenVerifier{shouldFail: true}
|
||||
cache := &mockTokenCache{}
|
||||
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Authenticated {
|
||||
t.Error("expected Authenticated to be false for invalid token")
|
||||
}
|
||||
if result.NeedsRefresh {
|
||||
t.Error("expected NeedsRefresh to be false when no refresh token")
|
||||
}
|
||||
if !result.IsExpired {
|
||||
t.Error("expected IsExpired to be true for invalid token without refresh")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBaseProvider_HandleTokenRefresh(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
// Test that HandleTokenRefresh doesn't fail
|
||||
tokenData := &TokenResult{
|
||||
IDToken: "id-token",
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
}
|
||||
|
||||
err := provider.HandleTokenRefresh(tokenData)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from HandleTokenRefresh: %v", err)
|
||||
}
|
||||
|
||||
// Test with nil token data
|
||||
err = provider.HandleTokenRefresh(nil)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from HandleTokenRefresh with nil data: %v", err)
|
||||
}
|
||||
|
||||
// Test with empty token data
|
||||
emptyTokenData := &TokenResult{}
|
||||
err = provider.HandleTokenRefresh(emptyTokenData)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from HandleTokenRefresh with empty data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_ValidateConfig(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
// Base provider ValidateConfig should always return nil
|
||||
err := provider.ValidateConfig()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from ValidateConfig: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBaseProvider_EdgeCases(t *testing.T) {
|
||||
provider := NewBaseProvider()
|
||||
|
||||
t.Run("BuildAuthParams with nil scopes", func(t *testing.T) {
|
||||
baseParams := url.Values{"client_id": []string{"test-client"}}
|
||||
authParams, err := provider.BuildAuthParams(baseParams, nil)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Fatal("expected non-nil auth params")
|
||||
}
|
||||
|
||||
// Should still add offline_access
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range authParams.Scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOfflineAccess {
|
||||
t.Error("expected offline_access scope to be added even with nil input scopes")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BuildAuthParams with nil baseParams", func(t *testing.T) {
|
||||
scopes := []string{"openid", "email"}
|
||||
authParams, err := provider.BuildAuthParams(nil, scopes)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Fatal("expected non-nil auth params")
|
||||
}
|
||||
|
||||
// Note: nil baseParams results in nil URLValues, which is handled by the calling code
|
||||
if authParams.URLValues != nil {
|
||||
t.Logf("Got non-nil URLValues: %v", authParams.URLValues)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ValidateTokenExpiry with nil cache", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic with nil cache: %v", r)
|
||||
}
|
||||
}()
|
||||
session := &mockSession{refreshToken: "refresh-token"}
|
||||
_, err := provider.ValidateTokenExpiry(session, "test-token", nil, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Got expected error with nil cache: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ValidateTokenExpiry with nil session", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic with nil session: %v", r)
|
||||
}
|
||||
}()
|
||||
cache := &mockTokenCache{}
|
||||
_, err := provider.ValidateTokenExpiry(nil, "test-token", cache, time.Minute)
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Got expected error with nil session: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkBaseProvider_GetType(b *testing.B) {
|
||||
provider := NewBaseProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetType()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseProvider_GetCapabilities(b *testing.B) {
|
||||
provider := NewBaseProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetCapabilities()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseProvider_BuildAuthParams(b *testing.B) {
|
||||
provider := NewBaseProvider()
|
||||
baseParams := url.Values{"client_id": []string{"test-client"}}
|
||||
scopes := []string{"openid", "email", "profile"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBaseProvider_ValidateTokenExpiry(b *testing.B) {
|
||||
provider := NewBaseProvider()
|
||||
session := &mockSession{refreshToken: "refresh-token"}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"test-token": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ValidateTokenExpiry(session, "test-token", cache, time.Minute)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,508 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewProviderFactory(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
if factory == nil {
|
||||
t.Fatal("expected non-nil factory")
|
||||
}
|
||||
|
||||
if factory.registry == nil {
|
||||
t.Fatal("expected non-nil registry in factory")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_CreateProvider(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
wantType ProviderType
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "Google provider detection",
|
||||
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
||||
wantType: ProviderTypeGoogle,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure provider detection - login.microsoftonline.com",
|
||||
issuerURL: "https://login.microsoftonline.com/tenant-id/v2.0",
|
||||
wantType: ProviderTypeAzure,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure provider detection - sts.windows.net",
|
||||
issuerURL: "https://sts.windows.net/tenant-id/",
|
||||
wantType: ProviderTypeAzure,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Generic provider detection",
|
||||
issuerURL: "https://auth.example.com/realms/test",
|
||||
wantType: ProviderTypeGeneric,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty issuer URL",
|
||||
issuerURL: "",
|
||||
wantError: true,
|
||||
errorSubstr: "issuer URL cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "Invalid URL format",
|
||||
issuerURL: "not-a-valid-url",
|
||||
wantType: ProviderTypeGeneric,
|
||||
wantError: false, // url.Parse accepts this as a valid URL
|
||||
},
|
||||
{
|
||||
name: "URL with invalid scheme",
|
||||
issuerURL: "ftp://example.com/auth",
|
||||
wantType: ProviderTypeGeneric,
|
||||
wantError: false, // Should create generic provider for non-standard schemes
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := factory.CreateProvider(tt.issuerURL)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && err.Error() != tt.errorSubstr {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if provider == nil {
|
||||
t.Error("expected non-nil provider")
|
||||
return
|
||||
}
|
||||
|
||||
if provider.GetType() != tt.wantType {
|
||||
t.Errorf("expected provider type %d, got %d", tt.wantType, provider.GetType())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_CreateProviderByType(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "Generic provider",
|
||||
providerType: ProviderTypeGeneric,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Google provider",
|
||||
providerType: ProviderTypeGoogle,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure provider",
|
||||
providerType: ProviderTypeAzure,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid provider type",
|
||||
providerType: ProviderType(999),
|
||||
wantError: true,
|
||||
errorSubstr: "unsupported provider type: 999",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider, err := factory.CreateProviderByType(tt.providerType)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && err.Error() != tt.errorSubstr {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if provider == nil {
|
||||
t.Error("expected non-nil provider")
|
||||
return
|
||||
}
|
||||
|
||||
if provider.GetType() != tt.providerType {
|
||||
t.Errorf("expected provider type %d, got %d", tt.providerType, provider.GetType())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_GetSupportedProviders(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
supported := factory.GetSupportedProviders()
|
||||
|
||||
expectedProviders := map[ProviderType][]string{
|
||||
ProviderTypeGeneric: {"*"},
|
||||
ProviderTypeGoogle: {"accounts.google.com"},
|
||||
ProviderTypeAzure: {"login.microsoftonline.com", "sts.windows.net"},
|
||||
}
|
||||
|
||||
if len(supported) != len(expectedProviders) {
|
||||
t.Errorf("expected %d supported providers, got %d", len(expectedProviders), len(supported))
|
||||
}
|
||||
|
||||
for expectedType, expectedPatterns := range expectedProviders {
|
||||
patterns, exists := supported[expectedType]
|
||||
if !exists {
|
||||
t.Errorf("expected provider type %d to be supported", expectedType)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(patterns) != len(expectedPatterns) {
|
||||
t.Errorf("expected %d patterns for provider type %d, got %d", len(expectedPatterns), expectedType, len(patterns))
|
||||
continue
|
||||
}
|
||||
|
||||
for i, expectedPattern := range expectedPatterns {
|
||||
if patterns[i] != expectedPattern {
|
||||
t.Errorf("expected pattern %q for provider type %d, got %q", expectedPattern, expectedType, patterns[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_DetectProviderType(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
wantType ProviderType
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "Google detection",
|
||||
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
||||
wantType: ProviderTypeGoogle,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Azure detection",
|
||||
issuerURL: "https://login.microsoftonline.com/tenant/v2.0",
|
||||
wantType: ProviderTypeAzure,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Generic detection",
|
||||
issuerURL: "https://keycloak.example.com/auth/realms/master",
|
||||
wantType: ProviderTypeGeneric,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL",
|
||||
issuerURL: "",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
providerType, err := factory.DetectProviderType(tt.issuerURL)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if providerType != tt.wantType {
|
||||
t.Errorf("expected provider type %d, got %d", tt.wantType, providerType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_IsProviderSupported(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
wantSupport bool
|
||||
}{
|
||||
{
|
||||
name: "Google URL",
|
||||
issuerURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
wantSupport: true,
|
||||
},
|
||||
{
|
||||
name: "Azure URL - login.microsoftonline.com",
|
||||
issuerURL: "https://login.microsoftonline.com/tenant/v2.0",
|
||||
wantSupport: true,
|
||||
},
|
||||
{
|
||||
name: "Azure URL - sts.windows.net",
|
||||
issuerURL: "https://sts.windows.net/tenant/",
|
||||
wantSupport: true,
|
||||
},
|
||||
{
|
||||
name: "Generic URL",
|
||||
issuerURL: "https://auth.example.com/realms/master",
|
||||
wantSupport: true,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
issuerURL: "",
|
||||
wantSupport: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid URL",
|
||||
issuerURL: "not-a-valid-url",
|
||||
wantSupport: true, // Generic provider supports all URLs
|
||||
},
|
||||
{
|
||||
name: "Valid but generic URL",
|
||||
issuerURL: "https://keycloak.example.com/auth",
|
||||
wantSupport: true, // Should be supported as generic
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
supported := factory.IsProviderSupported(tt.issuerURL)
|
||||
if supported != tt.wantSupport {
|
||||
t.Errorf("expected support %v for URL %q, got %v", tt.wantSupport, tt.issuerURL, supported)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_ConcurrentAccess(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 100
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://auth.example.com/realms/master",
|
||||
}
|
||||
|
||||
// Test concurrent provider creation
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
url := urls[j%len(urls)]
|
||||
|
||||
// Test CreateProvider
|
||||
provider, err := factory.CreateProvider(url)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error creating provider: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if provider == nil {
|
||||
t.Errorf("worker %d: expected non-nil provider", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test IsProviderSupported
|
||||
supported := factory.IsProviderSupported(url)
|
||||
if !supported {
|
||||
t.Errorf("worker %d: expected URL %s to be supported", workerID, url)
|
||||
return
|
||||
}
|
||||
|
||||
// Test DetectProviderType
|
||||
_, err = factory.DetectProviderType(url)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error detecting provider type: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_MemorySafety(t *testing.T) {
|
||||
// Test that creating many providers doesn't cause memory leaks
|
||||
const numCreations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numCreations; i++ {
|
||||
factory := NewProviderFactory()
|
||||
_, err := factory.CreateProvider("https://accounts.google.com/.well-known/openid_configuration")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error creating provider: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Force garbage collection to cleanup any lingering resources
|
||||
runtime.GC()
|
||||
runtime.GC() // Call twice to ensure cleanup
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderFactory_EdgeCases(t *testing.T) {
|
||||
factory := NewProviderFactory()
|
||||
|
||||
t.Run("nil factory registry handling", func(t *testing.T) {
|
||||
// This test ensures we handle edge cases properly
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Expected behavior - nil registry should cause panic or be handled
|
||||
t.Logf("Recovered from panic as expected: %v", r)
|
||||
}
|
||||
}()
|
||||
brokenFactory := &ProviderFactory{registry: nil}
|
||||
_, err := brokenFactory.CreateProvider("https://accounts.google.com")
|
||||
if err == nil {
|
||||
t.Error("expected error with nil registry")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("malformed URLs", func(t *testing.T) {
|
||||
malformedURLs := []string{
|
||||
"://missing-scheme.com",
|
||||
"https://",
|
||||
"https:///missing-host",
|
||||
"https://example.com:port-not-number",
|
||||
}
|
||||
|
||||
for _, url := range malformedURLs {
|
||||
_, err := factory.CreateProvider(url)
|
||||
// Some malformed URLs might still be accepted by url.Parse,
|
||||
// but we ensure the system doesn't crash
|
||||
if err == nil {
|
||||
// Still check that we get a valid provider
|
||||
provider, err := factory.CreateProvider(url)
|
||||
if err == nil && provider == nil {
|
||||
t.Errorf("got nil provider without error for URL: %s", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("very long URLs", func(t *testing.T) {
|
||||
longURL := "https://accounts.google.com/" + string(make([]byte, 10000))
|
||||
for i := range longURL[len("https://accounts.google.com/"):] {
|
||||
longURL = longURL[:len("https://accounts.google.com/")+i] + "a" + longURL[len("https://accounts.google.com/")+i+1:]
|
||||
}
|
||||
|
||||
provider, err := factory.CreateProvider(longURL)
|
||||
if err != nil {
|
||||
t.Logf("long URL rejected as expected: %v", err)
|
||||
} else if provider == nil {
|
||||
t.Error("got nil provider without error for very long URL")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkProviderFactory_CreateProvider(b *testing.B) {
|
||||
factory := NewProviderFactory()
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://auth.example.com/realms/master",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
url := urls[i%len(urls)]
|
||||
_, err := factory.CreateProvider(url)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProviderFactory_DetectProviderType(b *testing.B) {
|
||||
factory := NewProviderFactory()
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://auth.example.com/realms/master",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
url := urls[i%len(urls)]
|
||||
_, err := factory.DetectProviderType(url)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProviderFactory_IsProviderSupported(b *testing.B) {
|
||||
factory := NewProviderFactory()
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://auth.example.com/realms/master",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
url := urls[i%len(urls)]
|
||||
factory.IsProviderSupported(url)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,724 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewGoogleProvider(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
if provider == nil {
|
||||
t.Fatal("expected non-nil Google provider")
|
||||
}
|
||||
|
||||
if provider.BaseProvider == nil {
|
||||
t.Fatal("expected non-nil BaseProvider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_GetType(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
providerType := provider.GetType()
|
||||
|
||||
if providerType != ProviderTypeGoogle {
|
||||
t.Errorf("expected provider type %d, got %d", ProviderTypeGoogle, providerType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_GetCapabilities(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
capabilities := provider.GetCapabilities()
|
||||
|
||||
expectedCapabilities := ProviderCapabilities{
|
||||
SupportsRefreshTokens: true,
|
||||
RequiresOfflineAccessScope: false,
|
||||
RequiresPromptConsent: true,
|
||||
PreferredTokenValidation: "id",
|
||||
}
|
||||
|
||||
if capabilities.SupportsRefreshTokens != expectedCapabilities.SupportsRefreshTokens {
|
||||
t.Errorf("expected SupportsRefreshTokens %t, got %t", expectedCapabilities.SupportsRefreshTokens, capabilities.SupportsRefreshTokens)
|
||||
}
|
||||
|
||||
if capabilities.RequiresOfflineAccessScope != expectedCapabilities.RequiresOfflineAccessScope {
|
||||
t.Errorf("expected RequiresOfflineAccessScope %t, got %t", expectedCapabilities.RequiresOfflineAccessScope, capabilities.RequiresOfflineAccessScope)
|
||||
}
|
||||
|
||||
if capabilities.RequiresPromptConsent != expectedCapabilities.RequiresPromptConsent {
|
||||
t.Errorf("expected RequiresPromptConsent %t, got %t", expectedCapabilities.RequiresPromptConsent, capabilities.RequiresPromptConsent)
|
||||
}
|
||||
|
||||
if capabilities.PreferredTokenValidation != expectedCapabilities.PreferredTokenValidation {
|
||||
t.Errorf("expected PreferredTokenValidation %q, got %q", expectedCapabilities.PreferredTokenValidation, capabilities.PreferredTokenValidation)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_BuildAuthParams(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
baseParams url.Values
|
||||
scopes []string
|
||||
expectAccessTypeOffline bool
|
||||
expectPromptConsent bool
|
||||
expectOfflineAccessRemoved bool
|
||||
}{
|
||||
{
|
||||
name: "basic params with offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "offline_access", "email"},
|
||||
expectAccessTypeOffline: true,
|
||||
expectPromptConsent: true,
|
||||
expectOfflineAccessRemoved: true,
|
||||
},
|
||||
{
|
||||
name: "basic params without offline_access scope",
|
||||
baseParams: url.Values{"client_id": []string{"test-client"}},
|
||||
scopes: []string{"openid", "email"},
|
||||
expectAccessTypeOffline: true,
|
||||
expectPromptConsent: true,
|
||||
expectOfflineAccessRemoved: false,
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
baseParams: url.Values{},
|
||||
scopes: []string{},
|
||||
expectAccessTypeOffline: true,
|
||||
expectPromptConsent: true,
|
||||
expectOfflineAccessRemoved: false,
|
||||
},
|
||||
{
|
||||
name: "multiple offline_access scopes",
|
||||
baseParams: url.Values{},
|
||||
scopes: []string{"openid", "offline_access", "email", "offline_access"},
|
||||
expectAccessTypeOffline: true,
|
||||
expectPromptConsent: true,
|
||||
expectOfflineAccessRemoved: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authParams, err := provider.BuildAuthParams(tt.baseParams, tt.scopes)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Fatal("expected non-nil auth params")
|
||||
}
|
||||
|
||||
// Check access_type is set to offline
|
||||
if tt.expectAccessTypeOffline {
|
||||
accessType := authParams.URLValues.Get("access_type")
|
||||
if accessType != "offline" {
|
||||
t.Errorf("expected access_type 'offline', got %q", accessType)
|
||||
}
|
||||
}
|
||||
|
||||
// Check prompt is set to consent
|
||||
if tt.expectPromptConsent {
|
||||
prompt := authParams.URLValues.Get("prompt")
|
||||
if prompt != "consent" {
|
||||
t.Errorf("expected prompt 'consent', got %q", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
// Check offline_access scope is filtered out
|
||||
if tt.expectOfflineAccessRemoved {
|
||||
hasOfflineAccess := false
|
||||
for _, scope := range authParams.Scopes {
|
||||
if scope == "offline_access" {
|
||||
hasOfflineAccess = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasOfflineAccess {
|
||||
t.Error("expected offline_access scope to be removed for Google provider")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify other scopes are preserved (except offline_access)
|
||||
expectedScopes := make([]string, 0, len(tt.scopes))
|
||||
for _, scope := range tt.scopes {
|
||||
if scope != "offline_access" {
|
||||
expectedScopes = append(expectedScopes, scope)
|
||||
}
|
||||
}
|
||||
|
||||
if len(authParams.Scopes) != len(expectedScopes) {
|
||||
t.Errorf("expected %d scopes after filtering, got %d", len(expectedScopes), len(authParams.Scopes))
|
||||
}
|
||||
|
||||
// Verify other parameters are preserved
|
||||
for key, values := range tt.baseParams {
|
||||
if key == "access_type" || key == "prompt" {
|
||||
continue // These get overridden
|
||||
}
|
||||
paramValues := authParams.URLValues[key]
|
||||
if len(paramValues) != len(values) {
|
||||
t.Errorf("expected %d values for param %s, got %d", len(values), key, len(paramValues))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_ValidateTokens(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
session *mockSession
|
||||
verifier *mockTokenVerifier
|
||||
cache *mockTokenCache
|
||||
refreshGracePeriod time.Duration
|
||||
expectedResult *ValidationResult
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "unauthenticated with refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unauthenticated without refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: false,
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{},
|
||||
},
|
||||
{
|
||||
name: "authenticated with valid ID token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "id.token.here",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"id.token.here": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with expired ID token and refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "expired.token.here",
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{
|
||||
expiredTokens: map[string]bool{
|
||||
"expired.token.here": true,
|
||||
},
|
||||
},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with no ID token but has access token and refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
Authenticated: true,
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with no tokens but has refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
refreshToken: "refresh-token",
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
NeedsRefresh: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "authenticated with no tokens and no refresh token",
|
||||
session: &mockSession{
|
||||
authenticated: true,
|
||||
},
|
||||
verifier: &mockTokenVerifier{},
|
||||
cache: &mockTokenCache{},
|
||||
expectedResult: &ValidationResult{
|
||||
IsExpired: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := provider.ValidateTokens(tt.session, tt.verifier, tt.cache, tt.refreshGracePeriod)
|
||||
|
||||
if tt.expectError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
|
||||
if result.Authenticated != tt.expectedResult.Authenticated {
|
||||
t.Errorf("expected Authenticated %t, got %t", tt.expectedResult.Authenticated, result.Authenticated)
|
||||
}
|
||||
|
||||
if result.NeedsRefresh != tt.expectedResult.NeedsRefresh {
|
||||
t.Errorf("expected NeedsRefresh %t, got %t", tt.expectedResult.NeedsRefresh, result.NeedsRefresh)
|
||||
}
|
||||
|
||||
if result.IsExpired != tt.expectedResult.IsExpired {
|
||||
t.Errorf("expected IsExpired %t, got %t", tt.expectedResult.IsExpired, result.IsExpired)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_ValidateConfig(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
// Google provider uses BaseProvider's ValidateConfig which always returns nil
|
||||
err := provider.ValidateConfig()
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from ValidateConfig: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_HandleTokenRefresh(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
// Test that HandleTokenRefresh doesn't fail
|
||||
tokenData := &TokenResult{
|
||||
IDToken: "id-token",
|
||||
AccessToken: "access-token",
|
||||
RefreshToken: "refresh-token",
|
||||
}
|
||||
|
||||
err := provider.HandleTokenRefresh(tokenData)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error from HandleTokenRefresh: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_ConcurrentAccess(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
// Test concurrent access to provider methods
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: fmt.Sprintf("access-token-%d", workerID),
|
||||
idToken: fmt.Sprintf("id-token-%d", workerID),
|
||||
refreshToken: fmt.Sprintf("refresh-token-%d", workerID),
|
||||
}
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
fmt.Sprintf("id-token-%d", workerID): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
// Test GetType
|
||||
if provider.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("worker %d: expected Google provider type", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test GetCapabilities
|
||||
capabilities := provider.GetCapabilities()
|
||||
if !capabilities.SupportsRefreshTokens {
|
||||
t.Errorf("worker %d: expected refresh token support", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateTokens
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in ValidateTokens: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if !result.Authenticated {
|
||||
t.Errorf("worker %d: expected authenticated result", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test BuildAuthParams
|
||||
baseParams := url.Values{"client_id": []string{fmt.Sprintf("client-%d", workerID)}}
|
||||
scopes := []string{"openid", "email"}
|
||||
authParams, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in BuildAuthParams: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
if authParams == nil {
|
||||
t.Errorf("worker %d: expected non-nil auth params", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateConfig
|
||||
err = provider.ValidateConfig()
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error in ValidateConfig: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_MemorySafety(t *testing.T) {
|
||||
const numIterations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: fmt.Sprintf("access-token-%d", i),
|
||||
idToken: fmt.Sprintf("id-token-%d", i),
|
||||
refreshToken: fmt.Sprintf("refresh-token-%d", i),
|
||||
}
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
fmt.Sprintf("id-token-%d", i): {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Exercise all provider methods
|
||||
_ = provider.GetType()
|
||||
_ = provider.GetCapabilities()
|
||||
_, _ = provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
_, _ = provider.BuildAuthParams(url.Values{}, []string{"openid"})
|
||||
_ = provider.ValidateConfig()
|
||||
_ = provider.HandleTokenRefresh(&TokenResult{})
|
||||
}
|
||||
|
||||
// Force garbage collection
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleProvider_EdgeCases(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
t.Run("nil session", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Expected behavior for nil session
|
||||
t.Logf("Recovered from expected panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
_, err := provider.ValidateTokens(nil, verifier, cache, time.Minute)
|
||||
if err == nil {
|
||||
t.Error("expected error with nil session")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil verifier", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
session := &mockSession{authenticated: true, idToken: "test.token.here"}
|
||||
cache := &mockTokenCache{}
|
||||
result, err := provider.ValidateTokens(session, nil, cache, time.Minute)
|
||||
// Google provider uses BaseProvider which handles nil verifier gracefully
|
||||
if err != nil {
|
||||
t.Logf("Got expected error with nil verifier: %v", err)
|
||||
} else if result != nil && result.NeedsRefresh {
|
||||
t.Logf("Provider handled nil verifier gracefully by requesting refresh")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil cache", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic with nil cache: %v", r)
|
||||
}
|
||||
}()
|
||||
session := &mockSession{authenticated: true, accessToken: "test-token"}
|
||||
verifier := &mockTokenVerifier{}
|
||||
result, err := provider.ValidateTokens(session, verifier, nil, time.Minute)
|
||||
// Google provider uses BaseProvider which handles nil cache gracefully
|
||||
if err != nil {
|
||||
t.Logf("Got expected error with nil cache: %v", err)
|
||||
} else if result != nil && result.NeedsRefresh {
|
||||
t.Logf("Provider handled nil cache gracefully by requesting refresh")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty tokens", func(t *testing.T) {
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "",
|
||||
idToken: "",
|
||||
refreshToken: "",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{}
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with empty tokens: %v", err)
|
||||
}
|
||||
if !result.IsExpired {
|
||||
t.Error("expected IsExpired=true for empty tokens without refresh token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("offline_access scope filtering", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputScopes []string
|
||||
expectScopes []string
|
||||
}{
|
||||
{
|
||||
name: "single offline_access",
|
||||
inputScopes: []string{"offline_access"},
|
||||
expectScopes: []string{},
|
||||
},
|
||||
{
|
||||
name: "offline_access with others",
|
||||
inputScopes: []string{"openid", "offline_access", "email"},
|
||||
expectScopes: []string{"openid", "email"},
|
||||
},
|
||||
{
|
||||
name: "multiple offline_access",
|
||||
inputScopes: []string{"offline_access", "openid", "offline_access", "profile"},
|
||||
expectScopes: []string{"openid", "profile"},
|
||||
},
|
||||
{
|
||||
name: "no offline_access",
|
||||
inputScopes: []string{"openid", "email", "profile"},
|
||||
expectScopes: []string{"openid", "email", "profile"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
authParams, err := provider.BuildAuthParams(url.Values{}, tt.inputScopes)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(authParams.Scopes) != len(tt.expectScopes) {
|
||||
t.Errorf("expected %d scopes, got %d", len(tt.expectScopes), len(authParams.Scopes))
|
||||
}
|
||||
|
||||
for i, expectedScope := range tt.expectScopes {
|
||||
if i >= len(authParams.Scopes) || authParams.Scopes[i] != expectedScope {
|
||||
t.Errorf("expected scope %q at position %d, got %q", expectedScope, i, authParams.Scopes[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("very long tokens", func(t *testing.T) {
|
||||
longToken := strings.Repeat("a", 5000)
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
idToken: longToken,
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
longToken: {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
result, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with very long token: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Error("expected non-nil result with very long token")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special characters in parameters", func(t *testing.T) {
|
||||
specialParams := url.Values{
|
||||
"client_id": []string{"client@example.com"},
|
||||
"redirect_uri": []string{"https://example.com/callback?param=value&other=test"},
|
||||
"state": []string{"state+with/special=chars&more"},
|
||||
}
|
||||
scopes := []string{"openid", "email+special", "profile/test"}
|
||||
|
||||
authParams, err := provider.BuildAuthParams(specialParams, scopes)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error with special characters: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if authParams == nil {
|
||||
t.Error("expected non-nil auth params with special characters")
|
||||
}
|
||||
|
||||
// Verify all special parameter values are preserved
|
||||
for key, expectedValues := range specialParams {
|
||||
if key == "access_type" || key == "prompt" {
|
||||
continue // These get overridden
|
||||
}
|
||||
actualValues := authParams.URLValues[key]
|
||||
if len(actualValues) != len(expectedValues) {
|
||||
t.Errorf("parameter %s: expected %d values, got %d", key, len(expectedValues), len(actualValues))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkGoogleProvider_GetType(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetType()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGoogleProvider_GetCapabilities(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
provider.GetCapabilities()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGoogleProvider_BuildAuthParams(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
baseParams := url.Values{"client_id": []string{"test-client"}}
|
||||
scopes := []string{"openid", "email", "profile", "offline_access"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGoogleProvider_ValidateTokens(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
session := &mockSession{
|
||||
authenticated: true,
|
||||
accessToken: "access-token",
|
||||
idToken: "id.token.here",
|
||||
refreshToken: "refresh-token",
|
||||
}
|
||||
verifier := &mockTokenVerifier{}
|
||||
cache := &mockTokenCache{
|
||||
data: map[string]map[string]interface{}{
|
||||
"id.token.here": {
|
||||
"exp": float64(time.Now().Add(time.Hour).Unix()),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.ValidateTokens(session, verifier, cache, time.Minute)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGoogleProvider_OfflineAccessFiltering(b *testing.B) {
|
||||
provider := NewGoogleProvider()
|
||||
baseParams := url.Values{"client_id": []string{"test-client"}}
|
||||
scopes := []string{"openid", "offline_access", "email", "offline_access", "profile", "offline_access"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := provider.BuildAuthParams(baseParams, scopes)
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProviderRegistry_GetProviderByType(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
googleProvider := NewGoogleProvider()
|
||||
azureProvider := NewAzureProvider()
|
||||
genericProvider := NewGenericProvider()
|
||||
|
||||
registry.RegisterProvider(googleProvider)
|
||||
registry.RegisterProvider(azureProvider)
|
||||
registry.RegisterProvider(genericProvider)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
expectNil bool
|
||||
}{
|
||||
{
|
||||
name: "Google provider",
|
||||
providerType: ProviderTypeGoogle,
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Azure provider",
|
||||
providerType: ProviderTypeAzure,
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Generic provider",
|
||||
providerType: ProviderTypeGeneric,
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid provider type",
|
||||
providerType: ProviderType(999),
|
||||
expectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := registry.GetProviderByType(tt.providerType)
|
||||
|
||||
if tt.expectNil {
|
||||
if provider != nil {
|
||||
t.Errorf("expected nil provider for type %d, got %v", tt.providerType, provider)
|
||||
}
|
||||
} else {
|
||||
if provider == nil {
|
||||
t.Errorf("expected non-nil provider for type %d", tt.providerType)
|
||||
} else if provider.GetType() != tt.providerType {
|
||||
t.Errorf("expected provider type %d, got %d", tt.providerType, provider.GetType())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_GetRegisteredProviders(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Initially should be empty
|
||||
providers := registry.GetRegisteredProviders()
|
||||
if len(providers) != 0 {
|
||||
t.Errorf("expected 0 registered providers initially, got %d", len(providers))
|
||||
}
|
||||
|
||||
// Register providers one by one
|
||||
expectedTypes := []ProviderType{
|
||||
ProviderTypeGoogle,
|
||||
ProviderTypeAzure,
|
||||
ProviderTypeGeneric,
|
||||
}
|
||||
|
||||
for i, providerType := range expectedTypes {
|
||||
switch providerType {
|
||||
case ProviderTypeGoogle:
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
case ProviderTypeAzure:
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
case ProviderTypeGeneric:
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
}
|
||||
|
||||
providers := registry.GetRegisteredProviders()
|
||||
if len(providers) != i+1 {
|
||||
t.Errorf("expected %d registered providers after registering %d, got %d", i+1, i+1, len(providers))
|
||||
}
|
||||
|
||||
// Check that the newly registered type is present
|
||||
found := false
|
||||
for _, registeredType := range providers {
|
||||
if registeredType == providerType {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected provider type %d to be in registered providers list", providerType)
|
||||
}
|
||||
}
|
||||
|
||||
// Final check - all providers should be registered
|
||||
finalProviders := registry.GetRegisteredProviders()
|
||||
if len(finalProviders) != len(expectedTypes) {
|
||||
t.Errorf("expected %d final registered providers, got %d", len(expectedTypes), len(finalProviders))
|
||||
}
|
||||
|
||||
// Check all expected types are present
|
||||
for _, expectedType := range expectedTypes {
|
||||
found := false
|
||||
for _, registeredType := range finalProviders {
|
||||
if registeredType == expectedType {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("expected provider type %d to be in final registered providers list", expectedType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_ClearCache(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
// Populate cache by detecting providers
|
||||
googleURL := "https://accounts.google.com/.well-known/openid_configuration"
|
||||
azureURL := "https://login.microsoftonline.com/tenant/v2.0"
|
||||
genericURL := "https://keycloak.example.com/auth/realms/master"
|
||||
|
||||
// These calls should populate the cache
|
||||
provider1 := registry.DetectProvider(googleURL)
|
||||
provider2 := registry.DetectProvider(azureURL)
|
||||
provider3 := registry.DetectProvider(genericURL)
|
||||
|
||||
// Verify providers were detected correctly
|
||||
if provider1.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider, got %d", provider1.GetType())
|
||||
}
|
||||
if provider2.GetType() != ProviderTypeAzure {
|
||||
t.Errorf("expected Azure provider, got %d", provider2.GetType())
|
||||
}
|
||||
if provider3.GetType() != ProviderTypeGeneric {
|
||||
t.Errorf("expected Generic provider, got %d", provider3.GetType())
|
||||
}
|
||||
|
||||
// Clear cache
|
||||
registry.ClearCache()
|
||||
|
||||
// Detect again - should work but might create new instances internally
|
||||
provider1After := registry.DetectProvider(googleURL)
|
||||
provider2After := registry.DetectProvider(azureURL)
|
||||
provider3After := registry.DetectProvider(genericURL)
|
||||
|
||||
// Verify detection still works correctly after cache clear
|
||||
if provider1After.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider after cache clear, got %d", provider1After.GetType())
|
||||
}
|
||||
if provider2After.GetType() != ProviderTypeAzure {
|
||||
t.Errorf("expected Azure provider after cache clear, got %d", provider2After.GetType())
|
||||
}
|
||||
if provider3After.GetType() != ProviderTypeGeneric {
|
||||
t.Errorf("expected Generic provider after cache clear, got %d", provider3After.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_DetectProvider_EdgeCases(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
expectedType ProviderType
|
||||
expectNil bool
|
||||
}{
|
||||
{
|
||||
name: "Google URL with subdomain",
|
||||
issuerURL: "https://accounts.google.com.evil.com/",
|
||||
expectedType: ProviderTypeGoogle, // Contains "accounts.google.com"
|
||||
},
|
||||
{
|
||||
name: "Azure URL with subdomain",
|
||||
issuerURL: "https://login.microsoftonline.com.evil.com/",
|
||||
expectedType: ProviderTypeAzure, // Contains "login.microsoftonline.com"
|
||||
},
|
||||
{
|
||||
name: "Google URL case insensitive",
|
||||
issuerURL: "https://ACCOUNTS.GOOGLE.COM/auth",
|
||||
expectedType: ProviderTypeGeneric, // Case sensitive matching
|
||||
},
|
||||
{
|
||||
name: "Azure login URL case insensitive",
|
||||
issuerURL: "https://LOGIN.MICROSOFTONLINE.COM/tenant",
|
||||
expectedType: ProviderTypeGeneric, // Case sensitive matching
|
||||
},
|
||||
{
|
||||
name: "Azure STS URL case insensitive",
|
||||
issuerURL: "https://STS.WINDOWS.NET/tenant",
|
||||
expectedType: ProviderTypeGeneric, // Case sensitive matching
|
||||
},
|
||||
{
|
||||
name: "Invalid URL",
|
||||
issuerURL: "://invalid-url",
|
||||
expectNil: true,
|
||||
},
|
||||
{
|
||||
name: "Empty URL",
|
||||
issuerURL: "",
|
||||
expectedType: ProviderTypeGeneric, // Falls back to generic
|
||||
},
|
||||
{
|
||||
name: "URL without host",
|
||||
issuerURL: "/path/only",
|
||||
expectedType: ProviderTypeGeneric, // Falls back to generic
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
provider := registry.DetectProvider(tt.issuerURL)
|
||||
|
||||
if tt.expectNil {
|
||||
if provider != nil {
|
||||
t.Errorf("expected nil provider for URL %q, got %v", tt.issuerURL, provider)
|
||||
}
|
||||
} else {
|
||||
if provider == nil {
|
||||
t.Errorf("expected non-nil provider for URL %q", tt.issuerURL)
|
||||
} else if provider.GetType() != tt.expectedType {
|
||||
t.Errorf("expected provider type %d for URL %q, got %d", tt.expectedType, tt.issuerURL, provider.GetType())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_ConcurrentAccess(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 100
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://sts.windows.net/tenant/",
|
||||
"https://keycloak.example.com/auth/realms/master",
|
||||
}
|
||||
|
||||
// Test concurrent access to all registry methods
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
url := urls[j%len(urls)]
|
||||
|
||||
// Test DetectProvider
|
||||
provider := registry.DetectProvider(url)
|
||||
if provider == nil {
|
||||
t.Errorf("worker %d: expected non-nil provider for URL %s", workerID, url)
|
||||
return
|
||||
}
|
||||
|
||||
// Test GetProviderByType
|
||||
providerByType := registry.GetProviderByType(provider.GetType())
|
||||
if providerByType == nil {
|
||||
t.Errorf("worker %d: expected non-nil provider for type %d", workerID, provider.GetType())
|
||||
return
|
||||
}
|
||||
|
||||
// Test GetRegisteredProviders
|
||||
providers := registry.GetRegisteredProviders()
|
||||
if len(providers) == 0 {
|
||||
t.Errorf("worker %d: expected non-empty providers list", workerID)
|
||||
return
|
||||
}
|
||||
|
||||
// Occasionally clear cache to test concurrent cache operations
|
||||
if workerID%10 == 0 && j == 0 {
|
||||
registry.ClearCache()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_MemorySafety(t *testing.T) {
|
||||
const numIterations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
// Exercise all registry methods
|
||||
urls := []string{
|
||||
"https://accounts.google.com/config",
|
||||
"https://login.microsoftonline.com/tenant/config",
|
||||
"https://keycloak.example.com/auth",
|
||||
}
|
||||
|
||||
for _, url := range urls {
|
||||
provider := registry.DetectProvider(url)
|
||||
if provider != nil {
|
||||
_ = registry.GetProviderByType(provider.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
_ = registry.GetRegisteredProviders()
|
||||
registry.ClearCache()
|
||||
|
||||
// Create many entries to test cache memory management
|
||||
for j := 0; j < 10; j++ {
|
||||
testURL := fmt.Sprintf("https://example%d.com/auth", j)
|
||||
registry.DetectProvider(testURL)
|
||||
}
|
||||
}
|
||||
|
||||
// Force garbage collection
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderRegistry_CacheConsistency(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register providers
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
testURL := "https://accounts.google.com/.well-known/openid_configuration"
|
||||
|
||||
// First detection should populate cache
|
||||
provider1 := registry.DetectProvider(testURL)
|
||||
if provider1 == nil {
|
||||
t.Fatal("expected non-nil provider from first detection")
|
||||
}
|
||||
if provider1.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider, got %d", provider1.GetType())
|
||||
}
|
||||
|
||||
// Second detection should use cache (same result)
|
||||
provider2 := registry.DetectProvider(testURL)
|
||||
if provider2 == nil {
|
||||
t.Fatal("expected non-nil provider from second detection")
|
||||
}
|
||||
if provider2.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider from cache, got %d", provider2.GetType())
|
||||
}
|
||||
|
||||
// Clear cache and detect again
|
||||
registry.ClearCache()
|
||||
provider3 := registry.DetectProvider(testURL)
|
||||
if provider3 == nil {
|
||||
t.Fatal("expected non-nil provider after cache clear")
|
||||
}
|
||||
if provider3.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider after cache clear, got %d", provider3.GetType())
|
||||
}
|
||||
}
|
||||
|
||||
// Test registry without any providers registered
|
||||
func TestProviderRegistry_EmptyRegistry(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Test with empty registry
|
||||
provider := registry.DetectProvider("https://accounts.google.com/auth")
|
||||
if provider != nil {
|
||||
t.Errorf("expected nil provider from empty registry, got %v", provider)
|
||||
}
|
||||
|
||||
providerByType := registry.GetProviderByType(ProviderTypeGoogle)
|
||||
if providerByType != nil {
|
||||
t.Errorf("expected nil provider by type from empty registry, got %v", providerByType)
|
||||
}
|
||||
|
||||
providers := registry.GetRegisteredProviders()
|
||||
if len(providers) != 0 {
|
||||
t.Errorf("expected empty providers list from empty registry, got %d providers", len(providers))
|
||||
}
|
||||
|
||||
// Clear cache should not panic on empty registry
|
||||
registry.ClearCache()
|
||||
}
|
||||
|
||||
// Test multiple providers of same type (edge case)
|
||||
func TestProviderRegistry_DuplicateProviderTypes(t *testing.T) {
|
||||
registry := NewProviderRegistry()
|
||||
|
||||
// Register same provider type multiple times
|
||||
provider1 := NewGoogleProvider()
|
||||
provider2 := NewGoogleProvider()
|
||||
|
||||
registry.RegisterProvider(provider1)
|
||||
registry.RegisterProvider(provider2)
|
||||
|
||||
// GetProviderByType should return one of them (likely the latest)
|
||||
retrieved := registry.GetProviderByType(ProviderTypeGoogle)
|
||||
if retrieved == nil {
|
||||
t.Error("expected non-nil provider")
|
||||
}
|
||||
if retrieved.GetType() != ProviderTypeGoogle {
|
||||
t.Errorf("expected Google provider type, got %d", retrieved.GetType())
|
||||
}
|
||||
|
||||
// Both should be in the providers list
|
||||
allProviders := registry.GetRegisteredProviders()
|
||||
googleCount := 0
|
||||
for _, providerType := range allProviders {
|
||||
if providerType == ProviderTypeGoogle {
|
||||
googleCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Note: This behavior depends on implementation - the registry might deduplicate or keep all
|
||||
if googleCount == 0 {
|
||||
t.Error("expected at least one Google provider in registered list")
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkProviderRegistry_DetectProvider(b *testing.B) {
|
||||
registry := NewProviderRegistry()
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://keycloak.example.com/auth/realms/master",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
url := urls[i%len(urls)]
|
||||
registry.DetectProvider(url)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProviderRegistry_GetProviderByType(b *testing.B) {
|
||||
registry := NewProviderRegistry()
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
types := []ProviderType{
|
||||
ProviderTypeGoogle,
|
||||
ProviderTypeAzure,
|
||||
ProviderTypeGeneric,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
providerType := types[i%len(types)]
|
||||
registry.GetProviderByType(providerType)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkProviderRegistry_GetRegisteredProviders(b *testing.B) {
|
||||
registry := NewProviderRegistry()
|
||||
registry.RegisterProvider(NewGoogleProvider())
|
||||
registry.RegisterProvider(NewAzureProvider())
|
||||
registry.RegisterProvider(NewGenericProvider())
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
registry.GetRegisteredProviders()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,761 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewConfigValidator(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("expected non-nil config validator")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_ValidateIssuerURL(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid HTTPS URL",
|
||||
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid HTTP URL",
|
||||
issuerURL: "http://localhost:8080/auth/realms/master",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "empty URL",
|
||||
issuerURL: "",
|
||||
wantError: true,
|
||||
errorSubstr: "issuer URL cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "invalid URL format",
|
||||
issuerURL: "://invalid-url",
|
||||
wantError: true,
|
||||
errorSubstr: "invalid issuer URL format",
|
||||
},
|
||||
{
|
||||
name: "URL without scheme",
|
||||
issuerURL: "example.com/auth",
|
||||
wantError: true,
|
||||
errorSubstr: "issuer URL must include scheme",
|
||||
},
|
||||
{
|
||||
name: "URL with invalid scheme",
|
||||
issuerURL: "ftp://example.com/auth",
|
||||
wantError: true,
|
||||
errorSubstr: "issuer URL scheme must be http or https",
|
||||
},
|
||||
{
|
||||
name: "URL without host",
|
||||
issuerURL: "https:///path/only",
|
||||
wantError: true,
|
||||
errorSubstr: "issuer URL must include host",
|
||||
},
|
||||
{
|
||||
name: "URL with port",
|
||||
issuerURL: "https://example.com:8443/auth",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "URL with path and query",
|
||||
issuerURL: "https://example.com/auth/realms/master?param=value",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateIssuerURL(tt.issuerURL)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_ValidateClientID(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
clientID string
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid client ID",
|
||||
clientID: "valid-client-id",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "long client ID",
|
||||
clientID: "very-long-client-id-with-many-characters-12345678901234567890",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "empty client ID",
|
||||
clientID: "",
|
||||
wantError: true,
|
||||
errorSubstr: "client ID cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "very short client ID",
|
||||
clientID: "ab",
|
||||
wantError: true,
|
||||
errorSubstr: "client ID appears to be too short",
|
||||
},
|
||||
{
|
||||
name: "minimum valid length",
|
||||
clientID: "abc",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "client ID with special characters",
|
||||
clientID: "client@example.com",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "client ID with numbers",
|
||||
clientID: "client-123-456",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateClientID(tt.clientID)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_ValidateScopes(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []string
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid scopes with openid",
|
||||
scopes: []string{"openid", "email", "profile"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "openid scope only",
|
||||
scopes: []string{"openid"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
scopes: []string{},
|
||||
wantError: true,
|
||||
errorSubstr: "at least one scope must be provided",
|
||||
},
|
||||
{
|
||||
name: "nil scopes",
|
||||
scopes: nil,
|
||||
wantError: true,
|
||||
errorSubstr: "at least one scope must be provided",
|
||||
},
|
||||
{
|
||||
name: "scopes without openid",
|
||||
scopes: []string{"email", "profile"},
|
||||
wantError: true,
|
||||
errorSubstr: "'openid' scope is required for OIDC authentication",
|
||||
},
|
||||
{
|
||||
name: "scopes with openid and whitespace",
|
||||
scopes: []string{" openid ", "email", "profile"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "scopes with mixed case openid",
|
||||
scopes: []string{"OpenID", "email"}, // This should fail as it's case sensitive
|
||||
wantError: true,
|
||||
errorSubstr: "'openid' scope is required for OIDC authentication",
|
||||
},
|
||||
{
|
||||
name: "scopes with offline_access",
|
||||
scopes: []string{"openid", "offline_access", "email"},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "many scopes",
|
||||
scopes: []string{"openid", "email", "profile", "address", "phone", "offline_access", "custom_scope"},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateScopes(tt.scopes)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_ValidateRedirectURL(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
redirectURL string
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid HTTPS redirect URL",
|
||||
redirectURL: "https://example.com/callback",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid HTTP redirect URL",
|
||||
redirectURL: "http://localhost:8080/callback",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "empty redirect URL",
|
||||
redirectURL: "",
|
||||
wantError: true,
|
||||
errorSubstr: "redirect URL cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "invalid redirect URL format",
|
||||
redirectURL: "://invalid-url",
|
||||
wantError: true,
|
||||
errorSubstr: "invalid redirect URL format",
|
||||
},
|
||||
{
|
||||
name: "redirect URL without scheme",
|
||||
redirectURL: "example.com/callback",
|
||||
wantError: true,
|
||||
errorSubstr: "redirect URL must include scheme",
|
||||
},
|
||||
{
|
||||
name: "redirect URL with custom scheme",
|
||||
redirectURL: "myapp://callback",
|
||||
wantError: false, // Custom schemes are allowed for mobile apps
|
||||
},
|
||||
{
|
||||
name: "redirect URL with query parameters",
|
||||
redirectURL: "https://example.com/callback?state=123&code=456",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "redirect URL with fragment",
|
||||
redirectURL: "https://example.com/callback#section",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "localhost redirect URL",
|
||||
redirectURL: "http://localhost/callback",
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "IP address redirect URL",
|
||||
redirectURL: "http://192.168.1.1:8080/callback",
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateRedirectURL(tt.redirectURL)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_ValidateProviderSpecificConfig(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
t.Run("Google provider config", func(t *testing.T) {
|
||||
provider := NewGoogleProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]interface{}
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid Google issuer URL",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid Google issuer URL",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://example.com/auth",
|
||||
},
|
||||
wantError: true,
|
||||
errorSubstr: "google provider requires issuer URL to contain accounts.google.com",
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
config: map[string]interface{}{},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "non-string issuer URL",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": 12345,
|
||||
},
|
||||
wantError: false, // Type assertion fails, but doesn't error
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateProviderSpecificConfig(provider, tt.config)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Azure provider config", func(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config map[string]interface{}
|
||||
wantError bool
|
||||
errorSubstr string
|
||||
}{
|
||||
{
|
||||
name: "valid Azure issuer URL - login.microsoftonline.com",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://login.microsoftonline.com/12345678-1234-1234-1234-123456789012/v2.0",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid Azure issuer URL - sts.windows.net",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://sts.windows.net/12345678-1234-1234-1234-123456789012/",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "valid Azure issuer URL with proper tenant ID",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://login.microsoftonline.com/12345678-1234-1234-1234-123456789012/v2.0",
|
||||
},
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid Azure issuer URL",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://example.com/auth",
|
||||
},
|
||||
wantError: true,
|
||||
errorSubstr: "azure provider requires issuer URL to contain login.microsoftonline.com or sts.windows.net",
|
||||
},
|
||||
{
|
||||
name: "Azure issuer URL without tenant ID",
|
||||
config: map[string]interface{}{
|
||||
"issuer_url": "https://login.microsoftonline.com/v2.0",
|
||||
},
|
||||
wantError: true,
|
||||
errorSubstr: "azure issuer URL should include tenant ID",
|
||||
},
|
||||
{
|
||||
name: "empty config",
|
||||
config: map[string]interface{}{},
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateProviderSpecificConfig(provider, tt.config)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Error("expected error but got none")
|
||||
return
|
||||
}
|
||||
if tt.errorSubstr != "" && !strings.Contains(err.Error(), tt.errorSubstr) {
|
||||
t.Errorf("expected error to contain %q, got %q", tt.errorSubstr, err.Error())
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Generic provider config", func(t *testing.T) {
|
||||
provider := NewGenericProvider()
|
||||
|
||||
config := map[string]interface{}{
|
||||
"issuer_url": "https://example.com/auth",
|
||||
"any_key": "any_value",
|
||||
}
|
||||
|
||||
err := validator.ValidateProviderSpecificConfig(provider, config)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error for generic provider: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown provider type", func(t *testing.T) {
|
||||
// Create a mock provider with invalid type
|
||||
mockProvider := &struct {
|
||||
OIDCProvider
|
||||
}{}
|
||||
mockProvider.OIDCProvider = NewGenericProvider()
|
||||
|
||||
// Override GetType to return invalid type
|
||||
provider := &mockProviderWithInvalidType{mockProvider.OIDCProvider}
|
||||
|
||||
err := validator.ValidateProviderSpecificConfig(provider, map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown provider type")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unknown provider type") {
|
||||
t.Errorf("expected error about unknown provider type, got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockProviderWithInvalidType is a test helper that returns an invalid provider type
|
||||
type mockProviderWithInvalidType struct {
|
||||
OIDCProvider
|
||||
}
|
||||
|
||||
func (m *mockProviderWithInvalidType) GetType() ProviderType {
|
||||
return ProviderType(999) // Invalid provider type
|
||||
}
|
||||
|
||||
func TestConfigValidator_ConcurrentAccess(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
// Track initial goroutine count for memory safety
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
const numGoroutines = 50
|
||||
const numOperationsPerGoroutine = 10
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numGoroutines)
|
||||
|
||||
testData := []struct {
|
||||
issuerURL string
|
||||
clientID string
|
||||
scopes []string
|
||||
redirectURL string
|
||||
}{
|
||||
{
|
||||
issuerURL: "https://accounts.google.com/.well-known/openid_configuration",
|
||||
clientID: "google-client-id",
|
||||
scopes: []string{"openid", "email", "profile"},
|
||||
redirectURL: "https://example.com/callback",
|
||||
},
|
||||
{
|
||||
issuerURL: "https://login.microsoftonline.com/tenant/v2.0",
|
||||
clientID: "azure-client-id",
|
||||
scopes: []string{"openid", "offline_access"},
|
||||
redirectURL: "https://example.com/azure-callback",
|
||||
},
|
||||
{
|
||||
issuerURL: "https://keycloak.example.com/auth/realms/master",
|
||||
clientID: "generic-client-id",
|
||||
scopes: []string{"openid", "email"},
|
||||
redirectURL: "http://localhost:8080/callback",
|
||||
},
|
||||
}
|
||||
|
||||
// Test concurrent validation operations
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func(workerID int) {
|
||||
defer wg.Done()
|
||||
|
||||
data := testData[workerID%len(testData)]
|
||||
|
||||
for j := 0; j < numOperationsPerGoroutine; j++ {
|
||||
// Test ValidateIssuerURL
|
||||
err := validator.ValidateIssuerURL(data.issuerURL)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error validating issuer URL: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateClientID
|
||||
err = validator.ValidateClientID(data.clientID)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error validating client ID: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateScopes
|
||||
err = validator.ValidateScopes(data.scopes)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error validating scopes: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Test ValidateRedirectURL
|
||||
err = validator.ValidateRedirectURL(data.redirectURL)
|
||||
if err != nil {
|
||||
t.Errorf("worker %d: unexpected error validating redirect URL: %v", workerID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Check for potential goroutine leaks - allow some tolerance for test framework overhead
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_MemorySafety(t *testing.T) {
|
||||
const numIterations = 1000
|
||||
|
||||
initialGoroutines := runtime.NumGoroutine()
|
||||
|
||||
for i := 0; i < numIterations; i++ {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
// Exercise all validation methods
|
||||
_ = validator.ValidateIssuerURL(fmt.Sprintf("https://example%d.com/auth", i))
|
||||
_ = validator.ValidateClientID(fmt.Sprintf("client-id-%d", i))
|
||||
_ = validator.ValidateScopes([]string{"openid", fmt.Sprintf("scope-%d", i)})
|
||||
_ = validator.ValidateRedirectURL(fmt.Sprintf("https://example%d.com/callback", i))
|
||||
|
||||
// Test provider-specific validation
|
||||
provider := NewGoogleProvider()
|
||||
config := map[string]interface{}{
|
||||
"issuer_url": fmt.Sprintf("https://accounts.google.com/config-%d", i),
|
||||
}
|
||||
_ = validator.ValidateProviderSpecificConfig(provider, config)
|
||||
}
|
||||
|
||||
// Force garbage collection
|
||||
runtime.GC()
|
||||
runtime.GC()
|
||||
|
||||
finalGoroutines := runtime.NumGoroutine()
|
||||
if finalGoroutines > initialGoroutines+5 {
|
||||
t.Errorf("potential goroutine leak: started with %d goroutines, ended with %d", initialGoroutines, finalGoroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigValidator_EdgeCases(t *testing.T) {
|
||||
validator := NewConfigValidator()
|
||||
|
||||
t.Run("very long URLs and strings", func(t *testing.T) {
|
||||
longString := strings.Repeat("a", 10000)
|
||||
longURL := "https://" + longString + ".com/auth"
|
||||
|
||||
// These should not crash
|
||||
err := validator.ValidateIssuerURL(longURL)
|
||||
if err != nil {
|
||||
t.Logf("Long URL validation failed as expected: %v", err)
|
||||
}
|
||||
|
||||
err = validator.ValidateClientID(longString)
|
||||
if err != nil {
|
||||
t.Logf("Long client ID validation failed as expected: %v", err)
|
||||
}
|
||||
|
||||
err = validator.ValidateRedirectURL(longURL)
|
||||
if err != nil {
|
||||
t.Logf("Long redirect URL validation failed as expected: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special characters and encoding", func(t *testing.T) {
|
||||
specialURLs := []string{
|
||||
"https://example.com/auth?param=value%20with%20spaces",
|
||||
"https://example.com/auth#fragment",
|
||||
"https://example.com/auth/path with spaces",
|
||||
"https://example.com/auth?param=特殊字符",
|
||||
"https://xn--e1afmkfd.xn--p1ai/auth", // Punycode domain
|
||||
}
|
||||
|
||||
for _, url := range specialURLs {
|
||||
err := validator.ValidateIssuerURL(url)
|
||||
// Some may fail, but should not crash
|
||||
if err != nil {
|
||||
t.Logf("Special URL %q validation failed: %v", url, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil and empty inputs", func(t *testing.T) {
|
||||
// Test nil scopes
|
||||
err := validator.ValidateScopes(nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for nil scopes")
|
||||
}
|
||||
|
||||
// Test empty scopes
|
||||
err = validator.ValidateScopes([]string{})
|
||||
if err == nil {
|
||||
t.Error("expected error for empty scopes")
|
||||
}
|
||||
|
||||
// Test nil provider
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Logf("Recovered from expected panic with nil provider: %v", r)
|
||||
}
|
||||
}()
|
||||
err = validator.ValidateProviderSpecificConfig(nil, map[string]interface{}{})
|
||||
if err == nil {
|
||||
t.Error("expected error with nil provider")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("malformed tenant IDs", func(t *testing.T) {
|
||||
provider := NewAzureProvider()
|
||||
malformedTenantConfigs := []map[string]interface{}{
|
||||
{
|
||||
"issuer_url": "https://login.microsoftonline.com/not-a-guid/v2.0",
|
||||
},
|
||||
{
|
||||
"issuer_url": "https://login.microsoftonline.com/12345678-1234-1234-123456789012/v2.0", // Wrong length
|
||||
},
|
||||
{
|
||||
"issuer_url": "https://login.microsoftonline.com/12345678-1234-1234-1234-12345678901/v2.0", // Wrong format
|
||||
},
|
||||
}
|
||||
|
||||
for i, config := range malformedTenantConfigs {
|
||||
err := validator.ValidateProviderSpecificConfig(provider, config)
|
||||
if err == nil {
|
||||
t.Errorf("expected error for malformed tenant ID in config %d", i)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark tests for performance validation
|
||||
func BenchmarkConfigValidator_ValidateIssuerURL(b *testing.B) {
|
||||
validator := NewConfigValidator()
|
||||
urls := []string{
|
||||
"https://accounts.google.com/.well-known/openid_configuration",
|
||||
"https://login.microsoftonline.com/tenant/v2.0",
|
||||
"https://keycloak.example.com/auth/realms/master",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
url := urls[i%len(urls)]
|
||||
validator.ValidateIssuerURL(url)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConfigValidator_ValidateScopes(b *testing.B) {
|
||||
validator := NewConfigValidator()
|
||||
scopes := []string{"openid", "email", "profile", "offline_access"}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.ValidateScopes(scopes)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkConfigValidator_ValidateProviderSpecificConfig(b *testing.B) {
|
||||
validator := NewConfigValidator()
|
||||
provider := NewGoogleProvider()
|
||||
config := map[string]interface{}{
|
||||
"issuer_url": "https://accounts.google.com/.well-known/openid_configuration",
|
||||
"client_id": "test-client-id",
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
validator.ValidateProviderSpecificConfig(provider, config)
|
||||
}
|
||||
}
|
||||
@@ -370,6 +370,16 @@ func (cm *ChunkManager) validateJWTFormat(token string, tokenType string) error
|
||||
// Returns:
|
||||
// - An error if the opaque token format is invalid, nil if valid.
|
||||
func (cm *ChunkManager) validateOpaqueToken(token string, tokenType string) error {
|
||||
// Check for empty token
|
||||
if token == "" {
|
||||
return fmt.Errorf("%s opaque token cannot be empty", tokenType)
|
||||
}
|
||||
|
||||
// Check minimum length
|
||||
if len(token) < 20 {
|
||||
return fmt.Errorf("%s opaque token too short (length: %d, minimum: 20)", tokenType, len(token))
|
||||
}
|
||||
|
||||
if strings.Contains(token, " ") {
|
||||
err := fmt.Errorf("%s opaque token contains spaces", tokenType)
|
||||
return err
|
||||
@@ -533,6 +543,14 @@ func (cm *ChunkManager) validateTokenSanitization(token string, config TokenConf
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for control characters (ASCII 0-31 and 127)
|
||||
for i, char := range token {
|
||||
if char < 32 || char == 127 {
|
||||
err := fmt.Errorf("%s token contains control character at position %d", config.Type, i)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
suspiciousPatterns := []string{
|
||||
"\\x", "\\u", "\\n", "\\r", "\\t", "\\0",
|
||||
"<script", "</script", "javascript:", "data:",
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
package traefikoidc
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestChunkManagerValidateJWT tests JWT validation in chunk manager
|
||||
func TestChunkManagerValidateJWT(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
cm := NewChunkManager(ts.tOidc.logger)
|
||||
|
||||
// Test valid JWT format (using base64url encoded parts that are long enough)
|
||||
validJWT := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
|
||||
err := cm.validateJWTFormat(validJWT, "test")
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid JWT to pass, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test invalid JWT format - too few parts
|
||||
invalidJWT := "header.payload"
|
||||
err = cm.validateJWTFormat(invalidJWT, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected invalid JWT to fail validation")
|
||||
}
|
||||
|
||||
// Test invalid JWT format - too many parts
|
||||
invalidJWT2 := "header.payload.signature.extra"
|
||||
err = cm.validateJWTFormat(invalidJWT2, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected invalid JWT with extra parts to fail validation")
|
||||
}
|
||||
|
||||
// Test empty JWT
|
||||
err = cm.validateJWTFormat("", "test")
|
||||
if err == nil {
|
||||
t.Error("Expected empty JWT to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkManagerValidateOpaqueToken tests opaque token validation
|
||||
func TestChunkManagerValidateOpaqueToken(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
cm := NewChunkManager(ts.tOidc.logger)
|
||||
|
||||
// Test valid opaque token
|
||||
validOpaque := "valid_opaque_token_that_is_long_enough"
|
||||
err := cm.validateOpaqueToken(validOpaque, "test")
|
||||
if err != nil {
|
||||
t.Errorf("Expected valid opaque token to pass, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test too short opaque token
|
||||
shortOpaque := "short"
|
||||
err = cm.validateOpaqueToken(shortOpaque, "test")
|
||||
if err == nil {
|
||||
t.Error("Expected short opaque token to fail validation")
|
||||
}
|
||||
|
||||
// Test empty opaque token
|
||||
err = cm.validateOpaqueToken("", "test")
|
||||
if err == nil {
|
||||
t.Error("Expected empty opaque token to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkManagerValidateTokenSize tests token size validation
|
||||
func TestChunkManagerValidateTokenSize(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
cm := NewChunkManager(ts.tOidc.logger)
|
||||
|
||||
// Test normal token size
|
||||
normalToken := strings.Repeat("a", 1000)
|
||||
err := cm.validateTokenSize(normalToken, AccessTokenConfig)
|
||||
if err != nil {
|
||||
t.Errorf("Expected normal token to pass size validation, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test oversized token
|
||||
oversizedToken := strings.Repeat("a", AccessTokenConfig.MaxLength+1)
|
||||
err = cm.validateTokenSize(oversizedToken, AccessTokenConfig)
|
||||
if err == nil {
|
||||
t.Error("Expected oversized token to fail validation")
|
||||
}
|
||||
|
||||
// Test undersized token
|
||||
undersizedToken := "ab"
|
||||
err = cm.validateTokenSize(undersizedToken, AccessTokenConfig)
|
||||
if err == nil {
|
||||
t.Error("Expected undersized token to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkManagerValidateTokenContent tests token content validation
|
||||
func TestChunkManagerValidateTokenContent(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
cm := NewChunkManager(ts.tOidc.logger)
|
||||
|
||||
// Test normal token content
|
||||
normalToken := "normal_token_content_without_issues"
|
||||
err := cm.validateTokenContent(normalToken, AccessTokenConfig)
|
||||
if err != nil {
|
||||
t.Errorf("Expected normal token to pass content validation, got error: %v", err)
|
||||
}
|
||||
|
||||
// Test token with null bytes
|
||||
nullByteToken := "token_with\x00null_byte"
|
||||
err = cm.validateTokenContent(nullByteToken, AccessTokenConfig)
|
||||
if err == nil {
|
||||
t.Error("Expected token with null bytes to fail validation")
|
||||
}
|
||||
|
||||
// Test token with control characters
|
||||
controlCharToken := "token_with\x01control"
|
||||
err = cm.validateTokenContent(controlCharToken, AccessTokenConfig)
|
||||
if err == nil {
|
||||
t.Error("Expected token with control characters to fail validation")
|
||||
}
|
||||
}
|
||||
|
||||
// TestChunkManagerSingleTokenValidation tests single token validation path
|
||||
func TestChunkManagerSingleTokenValidation(t *testing.T) {
|
||||
ts := NewTestSuite(t)
|
||||
ts.Setup()
|
||||
|
||||
cm := NewChunkManager(ts.tOidc.logger)
|
||||
|
||||
// Create a valid JWT token
|
||||
validToken, _ := createTestJWT(ts.rsaPrivateKey, "RS256", "test-key-id", map[string]interface{}{
|
||||
"iss": "https://test-issuer.com",
|
||||
"aud": "test-client-id",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
|
||||
// Test valid token processing
|
||||
result := cm.processSingleToken(validToken, false, AccessTokenConfig)
|
||||
if result.Error != nil {
|
||||
t.Errorf("Expected valid token to process successfully, got error: %v", result.Error)
|
||||
}
|
||||
if result.Token != validToken {
|
||||
t.Error("Expected token to be returned unchanged")
|
||||
}
|
||||
|
||||
// Test invalid token processing
|
||||
invalidToken := "invalid.token"
|
||||
result = cm.processSingleToken(invalidToken, false, IDTokenConfig) // ID tokens require JWT format
|
||||
if result.Error == nil {
|
||||
t.Error("Expected invalid token to fail processing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTokenConfigValidation tests different token configurations
|
||||
func TestTokenConfigValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config TokenConfig
|
||||
}{
|
||||
{
|
||||
name: "AccessTokenConfig",
|
||||
config: AccessTokenConfig,
|
||||
},
|
||||
{
|
||||
name: "RefreshTokenConfig",
|
||||
config: RefreshTokenConfig,
|
||||
},
|
||||
{
|
||||
name: "IDTokenConfig",
|
||||
config: IDTokenConfig,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Verify config has expected fields
|
||||
if tt.config.Type == "" {
|
||||
t.Error("Expected config to have Type set")
|
||||
}
|
||||
if tt.config.MaxLength <= 0 {
|
||||
t.Error("Expected config to have positive MaxLength")
|
||||
}
|
||||
if tt.config.MinLength <= 0 {
|
||||
t.Error("Expected config to have positive MinLength")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user