diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..2b6bfa4 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,306 @@ +package traefikoidc + +import ( + "reflect" + "testing" + "time" +) + +func TestCache(t *testing.T) { + t.Run("Basic Set and Get", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value := "test-value" + expiration := 1 * time.Second + + // Test Set + cache.Set(key, value, expiration) + + // Test Get + got, found := cache.Get(key) + if !found { + t.Error("Expected to find key in cache") + } + if got != value { + t.Errorf("Expected value %v, got %v", value, got) + } + }) + + t.Run("Expiration", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value := "test-value" + expiration := 10 * time.Millisecond + + // Set with short expiration + cache.Set(key, value, expiration) + + // Wait for expiration + time.Sleep(20 * time.Millisecond) + + // Should not find expired key + _, found := cache.Get(key) + if found { + t.Error("Expected key to be expired") + } + }) + + t.Run("Delete", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value := "test-value" + expiration := 1 * time.Second + + // Set and then delete + cache.Set(key, value, expiration) + cache.Delete(key) + + // Should not find deleted key + _, found := cache.Get(key) + if found { + t.Error("Expected key to be deleted") + } + }) + + t.Run("Cleanup", func(t *testing.T) { + cache := NewCache() + // Add multiple items with different expirations + cache.Set("expired1", "value1", 10*time.Millisecond) + cache.Set("expired2", "value2", 10*time.Millisecond) + cache.Set("valid", "value3", 1*time.Second) + + // Wait for some items to expire + time.Sleep(20 * time.Millisecond) + + // Run cleanup + cache.Cleanup() + + // Check expired items are removed + _, found1 := cache.Get("expired1") + _, found2 := cache.Get("expired2") + _, found3 := cache.Get("valid") + + if found1 { + t.Error("Expected expired1 to be cleaned up") + } + if found2 { + t.Error("Expected expired2 to be cleaned up") + } + if !found3 { + t.Error("Expected valid item to remain in cache") + } + }) + + t.Run("Concurrent Access", func(t *testing.T) { + cache := NewCache() + done := make(chan bool) + + // Start multiple goroutines to access cache concurrently + for i := 0; i < 10; i++ { + go func(id int) { + key := "key" + value := "value" + expiration := 1 * time.Second + + // Perform multiple operations + cache.Set(key, value, expiration) + cache.Get(key) + cache.Delete(key) + cache.Cleanup() + + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + }) + + t.Run("Zero Expiration", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value := "test-value" + + // Set with zero expiration + cache.Set(key, value, 0) + + // Should not find the key + _, found := cache.Get(key) + if found { + t.Error("Expected key with zero expiration to be immediately expired") + } + }) + + t.Run("Negative Expiration", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value := "test-value" + + // Set with negative expiration + cache.Set(key, value, -1*time.Second) + + // Should not find the key + _, found := cache.Get(key) + if found { + t.Error("Expected key with negative expiration to be immediately expired") + } + }) + + t.Run("Update Existing Key", func(t *testing.T) { + cache := NewCache() + key := "test-key" + value1 := "value1" + value2 := "value2" + expiration := 1 * time.Second + + // Set initial value + cache.Set(key, value1, expiration) + + // Update value + cache.Set(key, value2, expiration) + + // Check updated value + got, found := cache.Get(key) + if !found { + t.Error("Expected to find key in cache") + } + if got != value2 { + t.Errorf("Expected updated value %v, got %v", value2, got) + } + }) + + t.Run("Different Value Types", func(t *testing.T) { + cache := NewCache() + expiration := 1 * time.Second + + // Test with different value types + testCases := []struct { + key string + value interface{} + }{ + {"string", "test"}, + {"int", 42}, + {"float", 3.14}, + {"bool", true}, + {"slice", []string{"a", "b", "c"}}, + {"map", map[string]int{"a": 1, "b": 2}}, + {"struct", struct{ Name string }{"test"}}, + } + + for _, tc := range testCases { + t.Run(tc.key, func(t *testing.T) { + cache.Set(tc.key, tc.value, expiration) + got, found := cache.Get(tc.key) + if !found { + t.Error("Expected to find key in cache") + } + // Use reflect.DeepEqual for comparing complex types like slices and maps + if !reflect.DeepEqual(got, tc.value) { + t.Errorf("Expected value %v, got %v", tc.value, got) + } + }) + } + }) +} + +func TestTokenCache(t *testing.T) { + t.Run("Basic Operations", func(t *testing.T) { + tc := NewTokenCache() + token := "test-token" + claims := map[string]interface{}{ + "sub": "1234567890", + "name": "John Doe", + "admin": true, + } + expiration := 1 * time.Second + + // Test Set and Get + tc.Set(token, claims, expiration) + gotClaims, found := tc.Get(token) + if !found { + t.Error("Expected to find token in cache") + } + if len(gotClaims) != len(claims) { + t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims)) + } + for k, v := range claims { + if gotClaims[k] != v { + t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k]) + } + } + + // Test Delete + tc.Delete(token) + _, found = tc.Get(token) + if found { + t.Error("Expected token to be deleted") + } + }) + + t.Run("Expiration", func(t *testing.T) { + tc := NewTokenCache() + token := "test-token" + claims := map[string]interface{}{"sub": "1234567890"} + expiration := 10 * time.Millisecond + + // Set with short expiration + tc.Set(token, claims, expiration) + + // Wait for expiration + time.Sleep(20 * time.Millisecond) + + // Should not find expired token + _, found := tc.Get(token) + if found { + t.Error("Expected token to be expired") + } + }) + + t.Run("Cleanup", func(t *testing.T) { + tc := NewTokenCache() + + // Add multiple tokens with different expirations + tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond) + tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond) + tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second) + + // Wait for some tokens to expire + time.Sleep(20 * time.Millisecond) + + // Run cleanup + tc.Cleanup() + + // Check expired tokens are removed + _, found1 := tc.Get("expired1") + _, found2 := tc.Get("expired2") + _, found3 := tc.Get("valid") + + if found1 { + t.Error("Expected expired1 to be cleaned up") + } + if found2 { + t.Error("Expected expired2 to be cleaned up") + } + if !found3 { + t.Error("Expected valid token to remain in cache") + } + }) + + t.Run("Token Prefix", func(t *testing.T) { + tc := NewTokenCache() + token := "test-token" + claims := map[string]interface{}{"sub": "1234567890"} + expiration := 1 * time.Second + + // Set token + tc.Set(token, claims, expiration) + + // Verify internal storage uses prefix + _, found := tc.cache.Get("t-" + token) + if !found { + t.Error("Expected to find prefixed token in underlying cache") + } + }) +} diff --git a/settings.go b/settings.go index 95cbaaf..2ae5209 100644 --- a/settings.go +++ b/settings.go @@ -187,8 +187,10 @@ func NewLogger(logLevel string) *Logger { logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime) logError.SetOutput(os.Stderr) - logInfo.SetOutput(os.Stdout) - + + if logLevel == "debug" || logLevel == "info" { + logInfo.SetOutput(os.Stdout) + } if logLevel == "debug" { logDebug.SetOutput(os.Stdout) } diff --git a/settings_test.go b/settings_test.go new file mode 100644 index 0000000..0662c97 --- /dev/null +++ b/settings_test.go @@ -0,0 +1,362 @@ +package traefikoidc + +import ( + "bytes" + "log" + "net/http" + "testing" +) + +func TestCreateConfig(t *testing.T) { + t.Run("Default Values", func(t *testing.T) { + config := CreateConfig() + + // Check default scopes + expectedScopes := []string{"openid", "profile", "email"} + if len(config.Scopes) != len(expectedScopes) { + t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes)) + } + for i, scope := range expectedScopes { + if config.Scopes[i] != scope { + t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i]) + } + } + + // Check default log level + if config.LogLevel != "info" { + t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel) + } + + // Check default rate limit + if config.RateLimit != 100 { + t.Errorf("Expected default rate limit 100, got %d", config.RateLimit) + } + }) + + t.Run("Custom Values Preserved", func(t *testing.T) { + config := CreateConfig() + config.Scopes = []string{"custom_scope"} + config.LogLevel = "debug" + config.RateLimit = 50 + + // Verify custom values are not overwritten + if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" { + t.Error("Custom scopes were overwritten") + } + if config.LogLevel != "debug" { + t.Error("Custom log level was overwritten") + } + if config.RateLimit != 50 { + t.Error("Custom rate limit was overwritten") + } + }) +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config *Config + expectedError string + }{ + { + name: "Empty Config", + config: &Config{}, + expectedError: "providerURL is required", + }, + { + name: "Missing CallbackURL", + config: &Config{ + ProviderURL: "https://provider.com", + }, + expectedError: "callbackURL is required", + }, + { + name: "Missing ClientID", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + }, + expectedError: "clientID is required", + }, + { + name: "Missing ClientSecret", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + }, + expectedError: "clientSecret is required", + }, + { + name: "Missing SessionEncryptionKey", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + }, + expectedError: "sessionEncryptionKey is required", + }, + { + name: "Invalid ProviderURL", + config: &Config{ + ProviderURL: "not-a-url", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "encryption-key", + }, + expectedError: "providerURL must be a valid URL", + }, + { + name: "Invalid CallbackURL", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "callback", // Missing leading slash + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "encryption-key", + }, + expectedError: "callbackURL must start with /", + }, + { + name: "Short SessionEncryptionKey", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "short", + }, + expectedError: "sessionEncryptionKey must be at least 32 characters long", + }, + { + name: "Negative RateLimit", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + RateLimit: -1, + }, + expectedError: "rateLimit must be non-negative", + }, + { + name: "Invalid LogLevel", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + LogLevel: "invalid", + }, + expectedError: "logLevel must be one of: debug, info, error", + }, + { + name: "Valid Config", + config: &Config{ + ProviderURL: "https://provider.com", + CallbackURL: "/callback", + ClientID: "client-id", + ClientSecret: "client-secret", + SessionEncryptionKey: "this-is-a-long-enough-encryption-key", + LogLevel: "debug", + RateLimit: 100, + }, + expectedError: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := tc.config.Validate() + if tc.expectedError == "" { + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected error containing '%s', got nil", tc.expectedError) + } else if err.Error() != tc.expectedError { + t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error()) + } + } + }) + } +} + +func TestLogger(t *testing.T) { + // Capture log output + var debugBuf, infoBuf, errorBuf bytes.Buffer + + tests := []struct { + name string + logLevel string + testFunc func(*Logger) + checkFunc func(t *testing.T, debugOut, infoOut, errorOut string) + }{ + { + name: "Debug Level", + logLevel: "debug", + testFunc: func(l *Logger) { + l.Debug("debug message") + l.Info("info message") + l.Error("error message") + }, + checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) { + if debugOut == "" { + t.Error("Expected debug message in output") + } + if infoOut == "" { + t.Error("Expected info message in output") + } + if errorOut == "" { + t.Error("Expected error message in output") + } + }, + }, + { + name: "Info Level", + logLevel: "info", + testFunc: func(l *Logger) { + l.Debug("debug message") + l.Info("info message") + l.Error("error message") + }, + checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) { + if debugOut != "" { + t.Error("Did not expect debug message in output") + } + if infoOut == "" { + t.Error("Expected info message in output") + } + if errorOut == "" { + t.Error("Expected error message in output") + } + }, + }, + { + name: "Error Level", + logLevel: "error", + testFunc: func(l *Logger) { + l.Debug("debug message") + l.Info("info message") + l.Error("error message") + }, + checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) { + if debugOut != "" { + t.Error("Did not expect debug message in output") + } + if infoOut != "" { + t.Error("Did not expect info message in output") + } + if errorOut == "" { + t.Error("Expected error message in output") + } + }, + }, + { + name: "Printf Methods", + logLevel: "debug", + testFunc: func(l *Logger) { + l.Debugf("debug %s", "formatted") + l.Infof("info %s", "formatted") + l.Errorf("error %s", "formatted") + }, + checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) { + if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) { + t.Error("Expected formatted debug message") + } + if !bytes.Contains([]byte(infoOut), []byte("info formatted")) { + t.Error("Expected formatted info message") + } + if !bytes.Contains([]byte(errorOut), []byte("error formatted")) { + t.Error("Expected formatted error message") + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Reset buffers + debugBuf.Reset() + infoBuf.Reset() + errorBuf.Reset() + + // Create logger with test buffers + logger := NewLogger(tc.logLevel) + logger.logError.SetOutput(&errorBuf) + + if tc.logLevel == "debug" || tc.logLevel == "info" { + logger.logInfo.SetOutput(&infoBuf) + } + if tc.logLevel == "debug" { + logger.logDebug.SetOutput(&debugBuf) + } + + // Run test + tc.testFunc(logger) + + // Check results + tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String()) + }) + } +} + +func TestHandleError(t *testing.T) { + // Create a test logger with captured output + var errorBuf bytes.Buffer + logger := &Logger{ + logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime), + } + logger.logError.SetOutput(&errorBuf) + + // Create a test response recorder + rr := &testResponseRecorder{ + headers: make(map[string][]string), + } + + // Test error handling + message := "test error message" + code := 400 + handleError(rr, message, code, logger) + + // Check response code + if rr.statusCode != code { + t.Errorf("Expected status code %d, got %d", code, rr.statusCode) + } + + // Check response body + expectedBody := message + "\n" + if rr.body != expectedBody { + t.Errorf("Expected body %q, got %q", expectedBody, rr.body) + } + + // Check error was logged + if !bytes.Contains(errorBuf.Bytes(), []byte(message)) { + t.Error("Error message was not logged") + } +} + +// Test helper types +type testResponseRecorder struct { + statusCode int + body string + headers map[string][]string +} + +func (r *testResponseRecorder) Header() http.Header { + return r.headers +} + +func (r *testResponseRecorder) Write(b []byte) (int, error) { + r.body = string(b) + return len(b), nil +} + +func (r *testResponseRecorder) WriteHeader(code int) { + r.statusCode = code +}