Improve test coverage

This commit is contained in:
2025-01-06 11:01:20 +00:00
parent a8d65688c4
commit eecb7dfc92
3 changed files with 672 additions and 2 deletions
+306
View File
@@ -0,0 +1,306 @@
package traefikoidc
import (
"reflect"
"testing"
"time"
)
func TestCache(t *testing.T) {
t.Run("Basic Set and Get", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Test Set
cache.Set(key, value, expiration)
// Test Get
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value {
t.Errorf("Expected value %v, got %v", value, got)
}
})
t.Run("Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 10 * time.Millisecond
// Set with short expiration
cache.Set(key, value, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be expired")
}
})
t.Run("Delete", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
expiration := 1 * time.Second
// Set and then delete
cache.Set(key, value, expiration)
cache.Delete(key)
// Should not find deleted key
_, found := cache.Get(key)
if found {
t.Error("Expected key to be deleted")
}
})
t.Run("Cleanup", func(t *testing.T) {
cache := NewCache()
// Add multiple items with different expirations
cache.Set("expired1", "value1", 10*time.Millisecond)
cache.Set("expired2", "value2", 10*time.Millisecond)
cache.Set("valid", "value3", 1*time.Second)
// Wait for some items to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
cache.Cleanup()
// Check expired items are removed
_, found1 := cache.Get("expired1")
_, found2 := cache.Get("expired2")
_, found3 := cache.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid item to remain in cache")
}
})
t.Run("Concurrent Access", func(t *testing.T) {
cache := NewCache()
done := make(chan bool)
// Start multiple goroutines to access cache concurrently
for i := 0; i < 10; i++ {
go func(id int) {
key := "key"
value := "value"
expiration := 1 * time.Second
// Perform multiple operations
cache.Set(key, value, expiration)
cache.Get(key)
cache.Delete(key)
cache.Cleanup()
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
})
t.Run("Zero Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with zero expiration
cache.Set(key, value, 0)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with zero expiration to be immediately expired")
}
})
t.Run("Negative Expiration", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value := "test-value"
// Set with negative expiration
cache.Set(key, value, -1*time.Second)
// Should not find the key
_, found := cache.Get(key)
if found {
t.Error("Expected key with negative expiration to be immediately expired")
}
})
t.Run("Update Existing Key", func(t *testing.T) {
cache := NewCache()
key := "test-key"
value1 := "value1"
value2 := "value2"
expiration := 1 * time.Second
// Set initial value
cache.Set(key, value1, expiration)
// Update value
cache.Set(key, value2, expiration)
// Check updated value
got, found := cache.Get(key)
if !found {
t.Error("Expected to find key in cache")
}
if got != value2 {
t.Errorf("Expected updated value %v, got %v", value2, got)
}
})
t.Run("Different Value Types", func(t *testing.T) {
cache := NewCache()
expiration := 1 * time.Second
// Test with different value types
testCases := []struct {
key string
value interface{}
}{
{"string", "test"},
{"int", 42},
{"float", 3.14},
{"bool", true},
{"slice", []string{"a", "b", "c"}},
{"map", map[string]int{"a": 1, "b": 2}},
{"struct", struct{ Name string }{"test"}},
}
for _, tc := range testCases {
t.Run(tc.key, func(t *testing.T) {
cache.Set(tc.key, tc.value, expiration)
got, found := cache.Get(tc.key)
if !found {
t.Error("Expected to find key in cache")
}
// Use reflect.DeepEqual for comparing complex types like slices and maps
if !reflect.DeepEqual(got, tc.value) {
t.Errorf("Expected value %v, got %v", tc.value, got)
}
})
}
})
}
func TestTokenCache(t *testing.T) {
t.Run("Basic Operations", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"admin": true,
}
expiration := 1 * time.Second
// Test Set and Get
tc.Set(token, claims, expiration)
gotClaims, found := tc.Get(token)
if !found {
t.Error("Expected to find token in cache")
}
if len(gotClaims) != len(claims) {
t.Errorf("Expected %d claims, got %d", len(claims), len(gotClaims))
}
for k, v := range claims {
if gotClaims[k] != v {
t.Errorf("Expected claim %s to be %v, got %v", k, v, gotClaims[k])
}
}
// Test Delete
tc.Delete(token)
_, found = tc.Get(token)
if found {
t.Error("Expected token to be deleted")
}
})
t.Run("Expiration", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 10 * time.Millisecond
// Set with short expiration
tc.Set(token, claims, expiration)
// Wait for expiration
time.Sleep(20 * time.Millisecond)
// Should not find expired token
_, found := tc.Get(token)
if found {
t.Error("Expected token to be expired")
}
})
t.Run("Cleanup", func(t *testing.T) {
tc := NewTokenCache()
// Add multiple tokens with different expirations
tc.Set("expired1", map[string]interface{}{"sub": "1"}, 10*time.Millisecond)
tc.Set("expired2", map[string]interface{}{"sub": "2"}, 10*time.Millisecond)
tc.Set("valid", map[string]interface{}{"sub": "3"}, 1*time.Second)
// Wait for some tokens to expire
time.Sleep(20 * time.Millisecond)
// Run cleanup
tc.Cleanup()
// Check expired tokens are removed
_, found1 := tc.Get("expired1")
_, found2 := tc.Get("expired2")
_, found3 := tc.Get("valid")
if found1 {
t.Error("Expected expired1 to be cleaned up")
}
if found2 {
t.Error("Expected expired2 to be cleaned up")
}
if !found3 {
t.Error("Expected valid token to remain in cache")
}
})
t.Run("Token Prefix", func(t *testing.T) {
tc := NewTokenCache()
token := "test-token"
claims := map[string]interface{}{"sub": "1234567890"}
expiration := 1 * time.Second
// Set token
tc.Set(token, claims, expiration)
// Verify internal storage uses prefix
_, found := tc.cache.Get("t-" + token)
if !found {
t.Error("Expected to find prefixed token in underlying cache")
}
})
}
+4 -2
View File
@@ -187,8 +187,10 @@ func NewLogger(logLevel string) *Logger {
logDebug := log.New(io.Discard, "DEBUG: TraefikOidcPlugin: ", log.Ldate|log.Ltime)
logError.SetOutput(os.Stderr)
logInfo.SetOutput(os.Stdout)
if logLevel == "debug" || logLevel == "info" {
logInfo.SetOutput(os.Stdout)
}
if logLevel == "debug" {
logDebug.SetOutput(os.Stdout)
}
+362
View File
@@ -0,0 +1,362 @@
package traefikoidc
import (
"bytes"
"log"
"net/http"
"testing"
)
func TestCreateConfig(t *testing.T) {
t.Run("Default Values", func(t *testing.T) {
config := CreateConfig()
// Check default scopes
expectedScopes := []string{"openid", "profile", "email"}
if len(config.Scopes) != len(expectedScopes) {
t.Errorf("Expected %d default scopes, got %d", len(expectedScopes), len(config.Scopes))
}
for i, scope := range expectedScopes {
if config.Scopes[i] != scope {
t.Errorf("Expected scope %s at position %d, got %s", scope, i, config.Scopes[i])
}
}
// Check default log level
if config.LogLevel != "info" {
t.Errorf("Expected default log level 'info', got '%s'", config.LogLevel)
}
// Check default rate limit
if config.RateLimit != 100 {
t.Errorf("Expected default rate limit 100, got %d", config.RateLimit)
}
})
t.Run("Custom Values Preserved", func(t *testing.T) {
config := CreateConfig()
config.Scopes = []string{"custom_scope"}
config.LogLevel = "debug"
config.RateLimit = 50
// Verify custom values are not overwritten
if len(config.Scopes) != 1 || config.Scopes[0] != "custom_scope" {
t.Error("Custom scopes were overwritten")
}
if config.LogLevel != "debug" {
t.Error("Custom log level was overwritten")
}
if config.RateLimit != 50 {
t.Error("Custom rate limit was overwritten")
}
})
}
func TestConfigValidate(t *testing.T) {
tests := []struct {
name string
config *Config
expectedError string
}{
{
name: "Empty Config",
config: &Config{},
expectedError: "providerURL is required",
},
{
name: "Missing CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
},
expectedError: "callbackURL is required",
},
{
name: "Missing ClientID",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
},
expectedError: "clientID is required",
},
{
name: "Missing ClientSecret",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
},
expectedError: "clientSecret is required",
},
{
name: "Missing SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
},
expectedError: "sessionEncryptionKey is required",
},
{
name: "Invalid ProviderURL",
config: &Config{
ProviderURL: "not-a-url",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "providerURL must be a valid URL",
},
{
name: "Invalid CallbackURL",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "callback", // Missing leading slash
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "encryption-key",
},
expectedError: "callbackURL must start with /",
},
{
name: "Short SessionEncryptionKey",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "short",
},
expectedError: "sessionEncryptionKey must be at least 32 characters long",
},
{
name: "Negative RateLimit",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
RateLimit: -1,
},
expectedError: "rateLimit must be non-negative",
},
{
name: "Invalid LogLevel",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "invalid",
},
expectedError: "logLevel must be one of: debug, info, error",
},
{
name: "Valid Config",
config: &Config{
ProviderURL: "https://provider.com",
CallbackURL: "/callback",
ClientID: "client-id",
ClientSecret: "client-secret",
SessionEncryptionKey: "this-is-a-long-enough-encryption-key",
LogLevel: "debug",
RateLimit: 100,
},
expectedError: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := tc.config.Validate()
if tc.expectedError == "" {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
} else {
if err == nil {
t.Errorf("Expected error containing '%s', got nil", tc.expectedError)
} else if err.Error() != tc.expectedError {
t.Errorf("Expected error '%s', got '%s'", tc.expectedError, err.Error())
}
}
})
}
}
func TestLogger(t *testing.T) {
// Capture log output
var debugBuf, infoBuf, errorBuf bytes.Buffer
tests := []struct {
name string
logLevel string
testFunc func(*Logger)
checkFunc func(t *testing.T, debugOut, infoOut, errorOut string)
}{
{
name: "Debug Level",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut == "" {
t.Error("Expected debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Info Level",
logLevel: "info",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut == "" {
t.Error("Expected info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Error Level",
logLevel: "error",
testFunc: func(l *Logger) {
l.Debug("debug message")
l.Info("info message")
l.Error("error message")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if debugOut != "" {
t.Error("Did not expect debug message in output")
}
if infoOut != "" {
t.Error("Did not expect info message in output")
}
if errorOut == "" {
t.Error("Expected error message in output")
}
},
},
{
name: "Printf Methods",
logLevel: "debug",
testFunc: func(l *Logger) {
l.Debugf("debug %s", "formatted")
l.Infof("info %s", "formatted")
l.Errorf("error %s", "formatted")
},
checkFunc: func(t *testing.T, debugOut, infoOut, errorOut string) {
if !bytes.Contains([]byte(debugOut), []byte("debug formatted")) {
t.Error("Expected formatted debug message")
}
if !bytes.Contains([]byte(infoOut), []byte("info formatted")) {
t.Error("Expected formatted info message")
}
if !bytes.Contains([]byte(errorOut), []byte("error formatted")) {
t.Error("Expected formatted error message")
}
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Reset buffers
debugBuf.Reset()
infoBuf.Reset()
errorBuf.Reset()
// Create logger with test buffers
logger := NewLogger(tc.logLevel)
logger.logError.SetOutput(&errorBuf)
if tc.logLevel == "debug" || tc.logLevel == "info" {
logger.logInfo.SetOutput(&infoBuf)
}
if tc.logLevel == "debug" {
logger.logDebug.SetOutput(&debugBuf)
}
// Run test
tc.testFunc(logger)
// Check results
tc.checkFunc(t, debugBuf.String(), infoBuf.String(), errorBuf.String())
})
}
}
func TestHandleError(t *testing.T) {
// Create a test logger with captured output
var errorBuf bytes.Buffer
logger := &Logger{
logError: log.New(&errorBuf, "ERROR: ", log.Ldate|log.Ltime),
}
logger.logError.SetOutput(&errorBuf)
// Create a test response recorder
rr := &testResponseRecorder{
headers: make(map[string][]string),
}
// Test error handling
message := "test error message"
code := 400
handleError(rr, message, code, logger)
// Check response code
if rr.statusCode != code {
t.Errorf("Expected status code %d, got %d", code, rr.statusCode)
}
// Check response body
expectedBody := message + "\n"
if rr.body != expectedBody {
t.Errorf("Expected body %q, got %q", expectedBody, rr.body)
}
// Check error was logged
if !bytes.Contains(errorBuf.Bytes(), []byte(message)) {
t.Error("Error message was not logged")
}
}
// Test helper types
type testResponseRecorder struct {
statusCode int
body string
headers map[string][]string
}
func (r *testResponseRecorder) Header() http.Header {
return r.headers
}
func (r *testResponseRecorder) Write(b []byte) (int, error) {
r.body = string(b)
return len(b), nil
}
func (r *testResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}