diff --git a/api_additional_test.go b/api_additional_test.go index d5571dd..a228429 100644 --- a/api_additional_test.go +++ b/api_additional_test.go @@ -39,7 +39,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() { bannedUsersIDsMutex.Lock() bannedUsersIDs = make(map[string]string) bannedUsersIDsMutex.Unlock() - + // Execute reloader once go testPeriodicallyReloadBannedUsers() <-done @@ -52,7 +52,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() { bannedUsersIDsMutex.RLock() mapSize := len(bannedUsersIDs) bannedUsersIDsMutex.RUnlock() - + // Verify map is still empty assert.Equal(0, mapSize) }) @@ -83,7 +83,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() { value1 := bannedUsersIDs["test-user-reload-1"] value2 := bannedUsersIDs["test-user-reload-2"] bannedUsersIDsMutex.RUnlock() - + // Verify banned users map was loaded assert.Equal(2, mapSize) assert.Equal("reason reload 1", value1) @@ -114,7 +114,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() { mapSize := len(bannedUsersIDs) initialValue := bannedUsersIDs["test-user-initial"] bannedUsersIDsMutex.RUnlock() - + // Verify initial data was loaded assert.Equal(1, mapSize) assert.Equal("initial reason", initialValue) @@ -139,7 +139,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() { value2 := bannedUsersIDs["test-user-updated-2"] _, exists := bannedUsersIDs["test-user-initial"] bannedUsersIDsMutex.RUnlock() - + // Verify updated data was loaded assert.Equal(2, mapSize) assert.Equal("updated reason 1", value1) @@ -228,4 +228,4 @@ func (suite *Tests) Test_LoadUnloadBannedUsers() { assert.Equal("reason4", reason4) assert.False(user1Exists) }) -} \ No newline at end of file +} diff --git a/api_test.go b/api_test.go index 8cdf0fd..4a4c59c 100644 --- a/api_test.go +++ b/api_test.go @@ -23,86 +23,86 @@ func (suite *Tests) Test_apiBanUser() { parseConfig() cfg.Logger = libpack_logger.New() cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_test.json") - + // Create a test Fiber app app := fiber.New() app.Post("/api/user-ban", apiBanUser) - + // Test valid ban request suite.Run("valid ban request", func() { // Clear banned users map bannedUsersIDs = make(map[string]string) - + reqBody := `{"user_id": "test-user-123", "reason": "testing"}` req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(200, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "OK: user banned") - + // Verify user was added to banned users map bannedUsersIDsMutex.RLock() reason, exists := bannedUsersIDs["test-user-123"] bannedUsersIDsMutex.RUnlock() - + assert.True(exists) assert.Equal("testing", reason) - + // Verify file was created _, err = os.Stat(cfg.Api.BannedUsersFile) assert.NoError(err) }) - + // Test missing user_id suite.Run("missing user_id", func() { reqBody := `{"reason": "testing"}` req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(400, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "user_id and reason are required") }) - + // Test missing reason suite.Run("missing reason", func() { reqBody := `{"user_id": "test-user-123"}` req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(400, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "user_id and reason are required") }) - + // Test invalid JSON suite.Run("invalid JSON", func() { reqBody := `{"user_id": "test-user-123", "reason": }` req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(400, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "Invalid request payload") }) - + // Cleanup os.Remove(cfg.Api.BannedUsersFile) os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) @@ -114,67 +114,67 @@ func (suite *Tests) Test_apiUnbanUser() { parseConfig() cfg.Logger = libpack_logger.New() cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_test.json") - + // Create a test Fiber app app := fiber.New() app.Post("/api/user-unban", apiUnbanUser) - + // Test valid unban request suite.Run("valid unban request", func() { // Add a user to the banned list bannedUsersIDs = make(map[string]string) bannedUsersIDs["test-user-123"] = "testing" - + reqBody := `{"user_id": "test-user-123"}` req := httptest.NewRequest(http.MethodPost, "/api/user-unban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(200, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "OK: user unbanned") - + // Verify user was removed from banned users map bannedUsersIDsMutex.RLock() _, exists := bannedUsersIDs["test-user-123"] bannedUsersIDsMutex.RUnlock() - + assert.False(exists) }) - + // Test missing user_id suite.Run("missing user_id", func() { reqBody := `{}` req := httptest.NewRequest(http.MethodPost, "/api/user-unban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(400, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "user_id is required") }) - + // Test invalid JSON suite.Run("invalid JSON", func() { reqBody := `{"user_id": }` req := httptest.NewRequest(http.MethodPost, "/api/user-unban", bytes.NewBufferString(reqBody)) req.Header.Set("Content-Type", "application/json") - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(400, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "Invalid request payload") }) - + // Cleanup os.Remove(cfg.Api.BannedUsersFile) os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) @@ -185,33 +185,33 @@ func (suite *Tests) Test_apiClearCache() { cfg = &config{} parseConfig() cfg.Logger = libpack_logger.New() - + // Initialize cache libpack_cache.EnableCache(&libpack_cache.CacheConfig{ Logger: cfg.Logger, TTL: 60, }) - + // Add some items to cache libpack_cache.CacheStore("test-key-1", []byte("test-value-1")) libpack_cache.CacheStore("test-key-2", []byte("test-value-2")) - + // Create a test Fiber app app := fiber.New() app.Post("/api/cache-clear", apiClearCache) - + // Test cache clear suite.Run("clear cache", func() { req := httptest.NewRequest(http.MethodPost, "/api/cache-clear", nil) - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(200, resp.StatusCode) - + body, err := io.ReadAll(resp.Body) assert.NoError(err) assert.Contains(string(body), "OK: cache cleared") - + // Verify cache was cleared stats := libpack_cache.GetCacheStats() assert.Equal(int64(0), stats.CachedQueries) @@ -223,35 +223,35 @@ func (suite *Tests) Test_apiCacheStats() { cfg = &config{} parseConfig() cfg.Logger = libpack_logger.New() - + // Initialize cache libpack_cache.EnableCache(&libpack_cache.CacheConfig{ Logger: cfg.Logger, TTL: 60, }) - + // Add some items to cache and perform lookups libpack_cache.CacheStore("test-key-1", []byte("test-value-1")) libpack_cache.CacheStore("test-key-2", []byte("test-value-2")) libpack_cache.CacheLookup("test-key-1") // Hit libpack_cache.CacheLookup("test-key-3") // Miss - + // Create a test Fiber app app := fiber.New() app.Get("/api/cache-stats", apiCacheStats) - + // Test get cache stats suite.Run("get cache stats", func() { req := httptest.NewRequest(http.MethodGet, "/api/cache-stats", nil) - + resp, err := app.Test(req) assert.NoError(err) assert.Equal(200, resp.StatusCode) - + var stats libpack_cache.CacheStats err = json.NewDecoder(resp.Body).Decode(&stats) assert.NoError(err) - + assert.Equal(int64(2), stats.CachedQueries) assert.Equal(int64(1), stats.CacheHits) assert.Equal(int64(1), stats.CacheMisses) @@ -263,26 +263,26 @@ func (suite *Tests) Test_checkIfUserIsBanned() { cfg = &config{} parseConfig() cfg.Logger = libpack_logger.New() - + // Create a test Fiber app and context app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) - + // Test with non-banned user suite.Run("non-banned user", func() { bannedUsersIDs = make(map[string]string) - + isBanned := checkIfUserIsBanned(ctx, "non-banned-user") assert.False(isBanned) assert.Equal(200, ctx.Response().StatusCode()) }) - + // Test with banned user suite.Run("banned user", func() { bannedUsersIDs = make(map[string]string) bannedUsersIDs["banned-user"] = "testing" - + isBanned := checkIfUserIsBanned(ctx, "banned-user") assert.True(isBanned) assert.Equal(403, ctx.Response().StatusCode()) @@ -295,23 +295,23 @@ func (suite *Tests) Test_loadBannedUsers() { parseConfig() cfg.Logger = libpack_logger.New() cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_test.json") - + // Test with non-existent file (should create it) suite.Run("non-existent file", func() { // Remove file if it exists os.Remove(cfg.Api.BannedUsersFile) - + bannedUsersIDs = make(map[string]string) loadBannedUsers() - + // Verify file was created _, err := os.Stat(cfg.Api.BannedUsersFile) assert.NoError(err) - + // Verify banned users map is empty assert.Equal(0, len(bannedUsersIDs)) }) - + // Test with existing file suite.Run("existing file", func() { // Create file with test data @@ -322,29 +322,29 @@ func (suite *Tests) Test_loadBannedUsers() { data, _ := json.Marshal(testData) err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0644) assert.NoError(err) - + bannedUsersIDs = make(map[string]string) loadBannedUsers() - + // Verify banned users map was loaded assert.Equal(2, len(bannedUsersIDs)) assert.Equal("reason 1", bannedUsersIDs["test-user-1"]) assert.Equal("reason 2", bannedUsersIDs["test-user-2"]) }) - + // Test with invalid JSON suite.Run("invalid JSON", func() { // Create file with invalid JSON err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{invalid json}"), 0644) assert.NoError(err) - + bannedUsersIDs = make(map[string]string) loadBannedUsers() - + // Verify banned users map is empty (load failed) assert.Equal(0, len(bannedUsersIDs)) }) - + // Cleanup os.Remove(cfg.Api.BannedUsersFile) os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) @@ -356,7 +356,7 @@ func (suite *Tests) Test_storeBannedUsers() { parseConfig() cfg.Logger = libpack_logger.New() cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_test.json") - + // Test storing banned users suite.Run("store banned users", func() { // Set up test data @@ -364,23 +364,23 @@ func (suite *Tests) Test_storeBannedUsers() { "test-user-1": "reason 1", "test-user-2": "reason 2", } - + err := storeBannedUsers() assert.NoError(err) - + // Verify file was created with correct content data, err := os.ReadFile(cfg.Api.BannedUsersFile) assert.NoError(err) - + var loadedData map[string]string err = json.Unmarshal(data, &loadedData) assert.NoError(err) - + assert.Equal(2, len(loadedData)) assert.Equal("reason 1", loadedData["test-user-1"]) assert.Equal("reason 2", loadedData["test-user-2"]) }) - + // Cleanup os.Remove(cfg.Api.BannedUsersFile) os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) @@ -392,17 +392,17 @@ func (suite *Tests) Test_lockFile() { parseConfig() cfg.Logger = libpack_logger.New() lockPath := filepath.Join(os.TempDir(), "test_lock_file.lock") - + // Test locking a file suite.Run("lock file", func() { fileLock := flock.New(lockPath) - + err := lockFile(fileLock) assert.NoError(err) - + // Verify file is locked assert.True(fileLock.Locked()) - + // Cleanup fileLock.Unlock() }) @@ -414,17 +414,17 @@ func (suite *Tests) Test_lockFileRead() { parseConfig() cfg.Logger = libpack_logger.New() lockPath := filepath.Join(os.TempDir(), "test_lock_file_read.lock") - + // Test read-locking a file suite.Run("read lock file", func() { fileLock := flock.New(lockPath) - + err := lockFileRead(fileLock) assert.NoError(err) - + // Verify file is locked - use RLocked() instead of Locked() assert.True(fileLock.RLocked()) - + // Cleanup fileLock.Unlock() }) @@ -436,8 +436,8 @@ func (suite *Tests) Test_enableApi() { cfg = &config{} parseConfig() cfg.Server.EnableApi = false - + // This should return immediately without error enableApi() }) -} \ No newline at end of file +} diff --git a/cache/cache_additional_test.go b/cache/cache_additional_test.go index dad91e3..c9f409d 100644 --- a/cache/cache_additional_test.go +++ b/cache/cache_additional_test.go @@ -16,7 +16,7 @@ func (suite *Tests) Test_CalculateHash() { app := fiber.New() ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) defer app.ReleaseCtx(ctx) - + // Test with empty body suite.Run("empty body", func() { ctx.Request().SetBody([]byte("")) @@ -24,7 +24,7 @@ func (suite *Tests) Test_CalculateHash() { assert.NotEmpty(hash) assert.Equal(32, len(hash)) // MD5 hash is 32 characters }) - + // Test with non-empty body suite.Run("non-empty body", func() { ctx.Request().SetBody([]byte("test body")) @@ -32,15 +32,15 @@ func (suite *Tests) Test_CalculateHash() { assert.NotEmpty(hash) assert.Equal(32, len(hash)) }) - + // Test with different bodies produce different hashes suite.Run("different bodies", func() { ctx.Request().SetBody([]byte("body1")) hash1 := CalculateHash(ctx) - + ctx.Request().SetBody([]byte("body2")) hash2 := CalculateHash(ctx) - + assert.NotEqual(hash1, hash2) }) } @@ -52,43 +52,43 @@ func (suite *Tests) Test_CacheDelete() { Client: libpack_cache_memory.New(5 * time.Minute), TTL: 5, } - + // Test deleting a cache entry suite.Run("delete existing entry", func() { // Add an entry to cache testKey := "test-delete-key" testValue := []byte("test-delete-value") CacheStore(testKey, testValue) - + // Verify it was added result := CacheLookup(testKey) assert.Equal(testValue, result) - + // Delete the entry CacheDelete(testKey) - + // Verify it was deleted result = CacheLookup(testKey) assert.Nil(result) }) - + // Test deleting a non-existent entry suite.Run("delete non-existent entry", func() { // This should not cause any errors CacheDelete("non-existent-key") }) - + // Test with uninitialized cache suite.Run("uninitialized cache", func() { // Save current config oldConfig := config - + // Set config to nil config = nil - + // This should not cause any errors CacheDelete("any-key") - + // Restore config config = oldConfig }) @@ -101,38 +101,38 @@ func (suite *Tests) Test_CacheStoreWithTTL() { Client: libpack_cache_memory.New(5 * time.Minute), TTL: 5, } - + // Test storing with custom TTL suite.Run("store with custom TTL", func() { testKey := "test-ttl-key" testValue := []byte("test-ttl-value") customTTL := 1 * time.Second - + CacheStoreWithTTL(testKey, testValue, customTTL) - + // Verify it was stored result := CacheLookup(testKey) assert.Equal(testValue, result) - + // Wait for TTL to expire time.Sleep(1100 * time.Millisecond) - + // Verify it was removed result = CacheLookup(testKey) assert.Nil(result) }) - + // Test with uninitialized cache suite.Run("uninitialized cache", func() { // Save current config oldConfig := config - + // Set config to nil config = nil - + // This should not cause any errors CacheStoreWithTTL("any-key", []byte("any-value"), 1*time.Second) - + // Restore config config = oldConfig }) @@ -145,33 +145,33 @@ func (suite *Tests) Test_CacheGetQueries() { Client: libpack_cache_memory.New(5 * time.Minute), TTL: 5, } - + // Test getting query count suite.Run("get query count", func() { // Clear cache CacheClear() - + // Add some entries CacheStore("test-key-1", []byte("test-value-1")) CacheStore("test-key-2", []byte("test-value-2")) - + // Get query count count := CacheGetQueries() assert.Equal(int64(2), count) }) - + // Test with uninitialized cache suite.Run("uninitialized cache", func() { // Save current config oldConfig := config - + // Set config to nil config = nil - + // This should return 0 count := CacheGetQueries() assert.Equal(int64(0), count) - + // Restore config config = oldConfig }) @@ -184,34 +184,34 @@ func (suite *Tests) Test_CacheClear() { Client: libpack_cache_memory.New(5 * time.Minute), TTL: 5, } - + // Create a new CacheStats instance cacheStats = &CacheStats{ CachedQueries: 0, CacheHits: 0, CacheMisses: 0, } - + // Test clearing cache suite.Run("clear cache", func() { // Add some entries CacheStore("test-key-1", []byte("test-value-1")) CacheStore("test-key-2", []byte("test-value-2")) - + // Verify they were added assert.NotNil(CacheLookup("test-key-1")) assert.NotNil(CacheLookup("test-key-2")) - + // Get the current stats before clearing beforeStats := GetCacheStats() - + // Clear cache CacheClear() - + // Verify cache was cleared assert.Nil(CacheLookup("test-key-1")) assert.Nil(CacheLookup("test-key-2")) - + // Verify stats were reset afterStats := GetCacheStats() assert.Equal(int64(0), afterStats.CachedQueries) @@ -227,39 +227,39 @@ func (suite *Tests) Test_GetCacheStats() { TTL: 5, } cacheStats = &CacheStats{} - + // Test getting cache stats suite.Run("get cache stats", func() { // Clear cache CacheClear() - + // Add some entries and perform lookups CacheStore("test-key-1", []byte("test-value-1")) CacheStore("test-key-2", []byte("test-value-2")) CacheLookup("test-key-1") // Hit CacheLookup("test-key-3") // Miss - + // Get stats stats := GetCacheStats() assert.Equal(int64(2), stats.CachedQueries) assert.Equal(int64(1), stats.CacheHits) assert.Equal(int64(1), stats.CacheMisses) }) - + // Test with uninitialized cache suite.Run("uninitialized cache", func() { // Save current config oldConfig := config - + // Set config to nil config = nil - + // This should return empty stats stats := GetCacheStats() assert.Equal(int64(0), stats.CachedQueries) assert.Equal(int64(0), stats.CacheHits) assert.Equal(int64(0), stats.CacheMisses) - + // Restore config config = oldConfig }) @@ -272,12 +272,12 @@ func (suite *Tests) Test_CacheLookup_Compressed() { Client: libpack_cache_memory.New(5 * time.Minute), TTL: 5, } - + // Test lookup with compressed data suite.Run("lookup compressed data", func() { testKey := "test-compressed-key" testValue := []byte("test-compressed-value") - + // Compress the data var buf bytes.Buffer gzWriter := gzip.NewWriter(&buf) @@ -286,15 +286,15 @@ func (suite *Tests) Test_CacheLookup_Compressed() { err = gzWriter.Close() assert.NoError(err) compressedData := buf.Bytes() - + // Store compressed data directly config.Client.Set(testKey, compressedData, time.Duration(config.TTL)*time.Second) - + // Lookup should automatically decompress result := CacheLookup(testKey) assert.Equal(testValue, result) }) - + // Skip the invalid compressed data test as it's causing issues // We'll mock the behavior instead suite.Run("lookup invalid compressed data", func() { @@ -313,16 +313,16 @@ func (suite *Tests) Test_ShouldUseRedisCache() { suite.Run("redis enabled", func() { cfg := &CacheConfig{} cfg.Redis.Enable = true - + result := ShouldUseRedisCache(cfg) assert.True(result) }) - + // Test with Redis disabled suite.Run("redis disabled", func() { cfg := &CacheConfig{} cfg.Redis.Enable = false - + result := ShouldUseRedisCache(cfg) assert.False(result) }) @@ -335,22 +335,22 @@ func (suite *Tests) Test_IsCacheInitialized() { Logger: libpack_logger.New(), Client: libpack_cache_memory.New(5 * time.Minute), } - + result := IsCacheInitialized() assert.True(result) }) - + // Test with nil config suite.Run("nil config", func() { oldConfig := config config = nil - + result := IsCacheInitialized() assert.False(result) - + config = oldConfig }) - + // Test with nil client suite.Run("nil client", func() { oldConfig := config @@ -358,10 +358,10 @@ func (suite *Tests) Test_IsCacheInitialized() { Logger: libpack_logger.New(), Client: nil, } - + result := IsCacheInitialized() assert.False(result) - + config = oldConfig }) -} \ No newline at end of file +} diff --git a/cache/memory/memory.go b/cache/memory/memory.go index 62a028f..cce055e 100644 --- a/cache/memory/memory.go +++ b/cache/memory/memory.go @@ -17,8 +17,8 @@ const CompressionThreshold = 1024 // 1KB const MaxCacheSize = 10000 type CacheEntry struct { - ExpiresAt time.Time - Value []byte + ExpiresAt time.Time + Value []byte Compressed bool } @@ -59,7 +59,7 @@ func (c *Cache) cleanupRoutine(globalTTL time.Duration) { for range ticker.C { c.CleanExpiredEntries() - + // Trigger GC if we have a lot of entries if atomic.LoadInt64(&c.entryCount) > MaxCacheSize/2 { runtime.GC() @@ -74,7 +74,7 @@ func (c *Cache) Set(key string, value []byte, ttl time.Duration) { } expiresAt := time.Now().Add(ttl) - + // Only compress if the value is larger than the threshold var entry CacheEntry if len(value) > CompressionThreshold { @@ -100,13 +100,13 @@ func (c *Cache) Set(key string, value []byte, ttl time.Duration) { Compressed: false, } } - + // Check if this is a new entry _, exists := c.entries.Load(key) if !exists { atomic.AddInt64(&c.entryCount, 1) } - + c.entries.Store(key, entry) } @@ -130,7 +130,7 @@ func (c *Cache) Get(key string) ([]byte, bool) { } return value, true } - + return cacheEntry.Value, true } @@ -156,7 +156,7 @@ func (c *Cache) compress(data []byte) ([]byte, error) { var buf bytes.Buffer w := c.compressPool.Get().(*gzip.Writer) defer c.compressPool.Put(w) - + w.Reset(&buf) if _, err := w.Write(data); err != nil { return nil, err @@ -170,7 +170,7 @@ func (c *Cache) compress(data []byte) ([]byte, error) { func (c *Cache) decompress(data []byte) ([]byte, error) { r, ok := c.decompressPool.Get().(*gzip.Reader) defer c.decompressPool.Put(r) - + if !ok || r == nil { var err error r, err = gzip.NewReader(bytes.NewReader(data)) @@ -182,7 +182,7 @@ func (c *Cache) decompress(data []byte) ([]byte, error) { return nil, err } } - + defer r.Close() return io.ReadAll(r) } @@ -203,10 +203,10 @@ func (c *Cache) CleanExpiredEntries() { // evictOldest removes the oldest n entries from the cache func (c *Cache) evictOldest(n int) { type keyExpiry struct { - key string + key string expiresAt time.Time } - + // Collect all entries with their expiry times entries := make([]keyExpiry, 0, n*2) c.entries.Range(func(k, v interface{}) bool { @@ -215,7 +215,7 @@ func (c *Cache) evictOldest(n int) { entries = append(entries, keyExpiry{key, entry.ExpiresAt}) return len(entries) < cap(entries) }) - + // Sort by expiry time (oldest first) // Using a simple selection sort since we only need to find the n oldest for i := 0; i < n && i < len(entries); i++ { @@ -229,7 +229,7 @@ func (c *Cache) evictOldest(n int) { if oldest != i { entries[i], entries[oldest] = entries[oldest], entries[i] } - + // Delete this entry if _, exists := c.entries.LoadAndDelete(entries[i].key); exists { atomic.AddInt64(&c.entryCount, -1) diff --git a/cache/memory/memory_additional_test.go b/cache/memory/memory_additional_test.go index 336db7d..f5fb230 100644 --- a/cache/memory/memory_additional_test.go +++ b/cache/memory/memory_additional_test.go @@ -14,45 +14,45 @@ const ( func TestMemoryCacheClear(t *testing.T) { cache := New(DefaultTestExpiration) - + // Add some entries cache.Set("key1", []byte("value1"), DefaultTestExpiration) cache.Set("key2", []byte("value2"), DefaultTestExpiration) - + // Verify entries exist _, found := cache.Get("key1") assert.True(t, found, "Expected key1 to exist before clearing cache") - + // Clear the cache cache.Clear() - + // Verify cache is empty _, found = cache.Get("key1") assert.False(t, found, "Expected key1 to be removed after clearing cache") _, found = cache.Get("key2") assert.False(t, found, "Expected key2 to be removed after clearing cache") - + // Check that counter was reset assert.Equal(t, int64(0), cache.CountQueries(), "Expected count to be 0 after clearing cache") } func TestMemoryCacheCountQueries(t *testing.T) { cache := New(DefaultTestExpiration) - + // Check initial count assert.Equal(t, int64(0), cache.CountQueries(), "Expected initial count to be 0") - + // Add some entries cache.Set("key1", []byte("value1"), DefaultTestExpiration) cache.Set("key2", []byte("value2"), DefaultTestExpiration) cache.Set("key3", []byte("value3"), DefaultTestExpiration) - + // Check count assert.Equal(t, int64(3), cache.CountQueries(), "Expected count to be 3 after adding 3 entries") - + // Delete an entry cache.Delete("key1") - + // Check count after deletion assert.Equal(t, int64(2), cache.CountQueries(), "Expected count to be 2 after deleting 1 entry") } @@ -60,31 +60,31 @@ func TestMemoryCacheCountQueries(t *testing.T) { func TestMemoryCacheCleanExpiredEntries(t *testing.T) { // Create a cache with default expiration cache := New(10 * time.Second) - + // Add an entry that will expire quickly cache.Set("expire-soon", []byte("value1"), 10*time.Millisecond) - + // Add an entry that will not expire during the test cache.Set("expire-later", []byte("value3"), 10*time.Minute) - + // Initial count should be 2 assert.Equal(t, int64(2), cache.CountQueries(), "Expected count to be 2 after adding entries") - + // Wait for short expiration time.Sleep(20 * time.Millisecond) - + // Get the expired key directly to verify it's expired _, expiredFound := cache.Get("expire-soon") assert.False(t, expiredFound, "Key 'expire-soon' should be expired now") - + // Verify the not-expired key is still there val, nonExpiredFound := cache.Get("expire-later") assert.True(t, nonExpiredFound, "Key 'expire-later' should not be expired") assert.Equal(t, []byte("value3"), val, "Expected correct value for 'expire-later'") - + // Manually clean expired entries cache.CleanExpiredEntries() - + // Count should be 1 now (only the non-expired entry) assert.Equal(t, int64(1), cache.CountQueries(), "Expected count to be 1 after cleaning expired entries") -} \ No newline at end of file +} diff --git a/cache/redis/redis_additional_test.go b/cache/redis/redis_additional_test.go index 9a4310c..bc921ac 100644 --- a/cache/redis/redis_additional_test.go +++ b/cache/redis/redis_additional_test.go @@ -47,4 +47,4 @@ func TestRedisClear(t *testing.T) { assert.False(t, found, "Key2 should be deleted after Clear") _, found = redisConfig.Get("key3") assert.False(t, found, "Key3 should be deleted after Clear") -} \ No newline at end of file +} diff --git a/config/config_test.go b/config/config_test.go index ae6bd38..4485316 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -10,4 +10,4 @@ func TestConfigConstants(t *testing.T) { // Verify package constants are defined assert.NotEmpty(t, PKG_NAME, "PKG_NAME should be defined") assert.NotEmpty(t, PKG_VERSION, "PKG_VERSION should be defined") -} \ No newline at end of file +} diff --git a/events.go b/events.go index 97248a8..1eb68b2 100644 --- a/events.go +++ b/events.go @@ -33,13 +33,13 @@ func enableHasuraEventCleaner() { if eventMetadataDb == "" { logger := cfg.Logger cfgMutex.RUnlock() - + logger.Warning(&libpack_logger.LogMessage{ Message: "Event metadata db URL not specified, event cleaner not active", }) return } - + clearOlderThan := cfg.HasuraEventCleaner.ClearOlderThan logger := cfg.Logger cfgMutex.RUnlock() diff --git a/events_test.go b/events_test.go index 2186484..60a9c08 100644 --- a/events_test.go +++ b/events_test.go @@ -31,73 +31,73 @@ func (suite *EventsTestSuite) Test_EnableHasuraEventCleaner() { cfgMutex.RLock() originalConfig := cfg.HasuraEventCleaner cfgMutex.RUnlock() - + defer func() { cfgMutex.Lock() cfg.HasuraEventCleaner = originalConfig cfgMutex.Unlock() }() - + // Set up test condition with proper synchronization cfgMutex.Lock() cfg.HasuraEventCleaner.Enable = false cfgMutex.Unlock() - + // Test function enableHasuraEventCleaner() - + // No assertions needed as we're just testing coverage // The function should return early without error }) - + // Test case: missing database URL suite.Run("missing database URL", func() { // Save original config with proper synchronization cfgMutex.RLock() originalConfig := cfg.HasuraEventCleaner cfgMutex.RUnlock() - + defer func() { cfgMutex.Lock() cfg.HasuraEventCleaner = originalConfig cfgMutex.Unlock() }() - + // Set up test condition with proper synchronization cfgMutex.Lock() cfg.HasuraEventCleaner.Enable = true cfg.HasuraEventCleaner.EventMetadataDb = "" cfgMutex.Unlock() - + // Test function enableHasuraEventCleaner() - + // No assertions needed as we're just testing coverage // The function should log a warning and return early }) - + // Test case: database URL provided but we don't actually connect in the test suite.Run("database URL provided", func() { // Save original config with proper synchronization cfgMutex.RLock() originalConfig := cfg.HasuraEventCleaner cfgMutex.RUnlock() - + defer func() { cfgMutex.Lock() cfg.HasuraEventCleaner = originalConfig cfgMutex.Unlock() }() - + // Set up test condition with proper synchronization cfgMutex.Lock() cfg.HasuraEventCleaner.Enable = true cfg.HasuraEventCleaner.EventMetadataDb = "postgres://fake:fake@localhost:5432/fake" cfg.HasuraEventCleaner.ClearOlderThan = 7 cfgMutex.Unlock() - + // We're not going to call enableHasuraEventCleaner() here because it would // try to connect to a database. Instead, we're just increasing coverage // for the configuration path by setting these values. }) -} \ No newline at end of file +} diff --git a/graphql.go b/graphql.go index 234ae9a..7f82b94 100644 --- a/graphql.go +++ b/graphql.go @@ -178,20 +178,20 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { if len(selections) == 0 { return false } - + // Fast path: if no introspection blocking is configured, return immediately if !cfg.Security.BlockIntrospection { return false } - + // Fast path: if there are no allowed introspection queries, check only top level hasAllowList := len(cfg.Security.IntrospectionAllowed) > 0 - + for _, s := range selections { switch sel := s.(type) { case *ast.Field: fieldName := strings.ToLower(sel.Name.Value) - + // Check if this is an introspection query if _, exists := introspectionQueries[fieldName]; exists { if hasAllowList { @@ -203,14 +203,14 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { return true // Block if no allowlist exists } } - + // Check nested selections if present if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 { if checkSelections(c, sel.GetSelectionSet().Selections) { return true } } - + case *ast.InlineFragment: // Check nested selections in fragments if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 { @@ -220,18 +220,18 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { } } } - + return false } func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool { blocked := false - + // Enable introspection blocking for tests if !cfg.Security.BlockIntrospection { cfg.Security.BlockIntrospection = true } - + // Try parsing as a complete query first p, err := parser.Parse(parser.ParseParams{Source: query}) if err == nil { diff --git a/logging/logger_additional_test.go b/logging/logger_additional_test.go index 22f5a3e..34c5687 100644 --- a/logging/logger_additional_test.go +++ b/logging/logger_additional_test.go @@ -59,33 +59,33 @@ func (suite *LoggerAdditionalTestSuite) TestSetFieldName() { for k, v := range fieldNames { originalFieldNames[k] = v } - + // Restore original field names after test defer func() { for k, v := range originalFieldNames { fieldNames[k] = v } }() - + // Test with custom field names customTimestampField := "time" customLevelField := "severity" customMessageField := "text" - + suite.logger.SetFieldName("timestamp", customTimestampField) suite.logger.SetFieldName("level", customLevelField) suite.logger.SetFieldName("message", customMessageField) - + // Verify field names were changed suite.assert.Equal(customTimestampField, fieldNames["timestamp"]) suite.assert.Equal(customLevelField, fieldNames["level"]) suite.assert.Equal(customMessageField, fieldNames["message"]) - + // Test logging with custom field names suite.output.Reset() suite.logger.Info(&LogMessage{Message: "test custom fields"}) output := suite.output.String() - + // Check if custom field names are used in the output suite.assert.Contains(output, customTimestampField) suite.assert.Contains(output, customLevelField) @@ -99,20 +99,20 @@ func (suite *LoggerAdditionalTestSuite) TestSetFieldName() { func (suite *LoggerAdditionalTestSuite) TestSetShowCaller() { // Make sure caller info is disabled suite.logger.SetShowCaller(false) - + // Test with caller info disabled suite.output.Reset() suite.logger.Info(&LogMessage{Message: "test without cal__ler"}) output := suite.output.String() suite.assert.NotContains(output, "caller") - + // Test with caller info enabled suite.output.Reset() suite.logger.SetShowCaller(true) suite.logger.Info(&LogMessage{Message: "test with caller"}) output = suite.output.String() suite.assert.Contains(output, "caller") - + // Verify the caller info format (file:line) suite.assert.Regexp(`"caller":"[^:]+:\d+"`, output) } @@ -152,27 +152,27 @@ func (suite *LoggerAdditionalTestSuite) TestCritical() { // Safely intercept os.Exit call with proper synchronization exitMutex.Lock() originalOsExit := osExit - + var exitCode int osExit = func(code int) { exitCode = code // Don't actually exit } exitMutex.Unlock() - + // Ensure we restore the original osExit function defer func() { exitMutex.Lock() osExit = originalOsExit exitMutex.Unlock() }() - + suite.output.Reset() msg := &LogMessage{Message: "test critical"} suite.logger.Critical(msg) output := suite.output.String() - + suite.assert.Contains(output, "fatal") suite.assert.Contains(output, "test critical") suite.assert.Equal(1, exitCode) -} \ No newline at end of file +} diff --git a/main.go b/main.go index 0655ddc..ec222c9 100644 --- a/main.go +++ b/main.go @@ -21,17 +21,17 @@ import ( ) var ( - cfg *config - cfgMutex sync.RWMutex - once sync.Once - tracer *libpack_tracing.TracingSetup + cfg *config + cfgMutex sync.RWMutex + once sync.Once + tracer *libpack_tracing.TracingSetup ) // getDetailsFromEnv retrieves the value from the environment or returns the default. // It first checks for a prefixed environment variable (GMP_KEY), then falls back to the unprefixed version. func getDetailsFromEnv[T any](key string, defaultValue T) T { prefixedKey := "GMP_" + key - + switch v := any(defaultValue).(type) { case string: if val, ok := os.LookupEnv(prefixedKey); ok { @@ -121,7 +121,7 @@ func parseConfig() { // Tracing configuration c.Tracing.Enable = getDetailsFromEnv("ENABLE_TRACE", false) c.Tracing.Endpoint = getDetailsFromEnv("TRACE_ENDPOINT", "localhost:4317") - + cfgMutex.Lock() cfg = &c cfgMutex.Unlock() @@ -180,14 +180,14 @@ func parseConfig() { func main() { // Parse configuration parseConfig() - + // Setup graceful shutdown ctx, cancel := context.WithCancel(context.Background()) defer cancel() - + // Create a wait group to manage goroutines var wg sync.WaitGroup - + // Setup signal handling for graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) @@ -198,37 +198,37 @@ func main() { }) cancel() }() - + // Start monitoring server in a goroutine wg.Add(1) go func() { defer wg.Done() StartMonitoringServer() }() - + // Give monitoring server time to initialize time.Sleep(2 * time.Second) - + // Start HTTP proxy in a goroutine wg.Add(1) go func() { defer wg.Done() StartHTTPProxy() }() - + // Wait for context cancellation <-ctx.Done() - + // Perform cleanup cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Shutting down services...", }) - + // Cleanup tracing if tracer != nil { shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) defer shutdownCancel() - + if err := tracer.Shutdown(shutdownCtx); err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "Error shutting down tracer", @@ -236,14 +236,14 @@ func main() { }) } } - + // Wait for all goroutines to finish (with timeout) waitCh := make(chan struct{}) go func() { wg.Wait() close(waitCh) }() - + select { case <-waitCh: cfg.Logger.Info(&libpack_logging.LogMessage{ diff --git a/main_test.go b/main_test.go index 3a29202..8b25dca 100644 --- a/main_test.go +++ b/main_test.go @@ -42,13 +42,13 @@ func (suite *Tests) SetupTest() { parseConfig() enableApi() StartMonitoringServer() - + // Update logger with proper synchronization logger := libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(getDetailsFromEnv("LOG_LEVEL", "info"))) cfgMutex.Lock() cfg.Logger = logger cfgMutex.Unlock() - + // Setup environment variables here if needed os.Setenv("GMP_TEST_STRING", "testValue") os.Setenv("GMP_TEST_INT", "123") diff --git a/monitoring/monitoring_additional_test.go b/monitoring/monitoring_additional_test.go index 6c16768..5faedc9 100644 --- a/monitoring/monitoring_additional_test.go +++ b/monitoring/monitoring_additional_test.go @@ -30,10 +30,10 @@ func (suite *MonitoringAdditionalTestSuite) TestListActiveMetrics() { // Register metrics directly to the set to ensure they're there suite.ms.metrics_set_custom.GetOrCreateCounter("test_counter{label=\"value\"}") suite.ms.metrics_set_custom.GetOrCreateGauge("test_gauge{label=\"value\"}", func() float64 { return 42.0 }) - + // Get list of metrics metricsList := suite.ms.ListActiveMetrics() - + // Verify metrics were registered - the metrics_set_custom doesn't get listed by ListActiveMetrics, // so we'll just check that the function runs without error assert.NotNil(suite.T(), metricsList, "Metrics list should not be nil") @@ -46,10 +46,10 @@ func (suite *MonitoringAdditionalTestSuite) TestRegisterFloatCounter() { "label1": "value1", }) assert.NotNil(suite.T(), counter) - + // Test using the counter counter.Add(42.5) - + // We don't need to test invalid metric names since they log a critical message // which can cause the test to exit, and that's the expected behavior } @@ -61,7 +61,7 @@ func (suite *MonitoringAdditionalTestSuite) TestRegisterMetricsSummary() { "label1": "value1", }) assert.NotNil(suite.T(), summary) - + // Test using the summary summary.Update(42.5) } @@ -73,7 +73,7 @@ func (suite *MonitoringAdditionalTestSuite) TestRegisterMetricsHistogram() { "label1": "value1", }) assert.NotNil(suite.T(), histogram) - + // Test using the histogram histogram.Update(42.5) } @@ -85,11 +85,11 @@ func (suite *MonitoringAdditionalTestSuite) TestUpdateDuration() { labels := map[string]string{ "label1": "value1", } - + // Use UpdateDuration startTime := time.Now().Add(-time.Second) // 1 second ago suite.ms.UpdateDuration(metricName, labels, startTime) - + // Since we can't easily verify the duration was recorded correctly in a test, // we'll just verify the method doesn't crash } @@ -99,15 +99,15 @@ func (suite *MonitoringAdditionalTestSuite) TestUpdateDuration() { func (suite *MonitoringAdditionalTestSuite) TestPurgeMetrics() { // Register a custom metric suite.ms.RegisterMetricsCounter("test_purge_counter", nil) - + // Purge the metrics suite.ms.PurgeMetrics() - + // Verify the custom metrics were purged // We need to check the actual customSet instead of calling ListActiveMetrics customMetrics := suite.ms.metrics_set_custom.ListMetricNames() - + // The metrics might not be immediately cleared due to internal implementation details, // so this test might be flaky. We'll check that it doesn't panic instead. assert.NotNil(suite.T(), customMetrics, "Custom metrics list shouldn't be nil") -} \ No newline at end of file +} diff --git a/monitoring/monitoring_test.go b/monitoring/monitoring_test.go index b787880..33bb404 100644 --- a/monitoring/monitoring_test.go +++ b/monitoring/monitoring_test.go @@ -23,11 +23,11 @@ func TestNewMonitoring(t *testing.T) { func TestAddMetricsPrefix(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test adding prefix to a name mon.AddMetricsPrefix("test") assert.Equal(t, "test", mon.metrics_prefix) - + // Test with empty prefix mon.AddMetricsPrefix("") assert.Equal(t, "", mon.metrics_prefix) @@ -35,11 +35,11 @@ func TestAddMetricsPrefix(t *testing.T) { func TestRegisterMetricsGauge(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering a gauge gauge := mon.RegisterMetricsGauge("valid_gauge", map[string]string{"label1": "value1"}, 42.0) assert.NotNil(t, gauge) - + // Test with invalid metric name - we'll skip this test since it causes fatal errors // gauge = mon.RegisterMetricsGauge("invalid metric name", map[string]string{"label1": "value1"}, 42.0) // assert.Nil(t, gauge) @@ -47,11 +47,11 @@ func TestRegisterMetricsGauge(t *testing.T) { func TestRegisterMetricsCounter(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering a counter counter := mon.RegisterMetricsCounter("valid_counter", map[string]string{"label1": "value1"}) assert.NotNil(t, counter) - + // Test with default metrics counter = mon.RegisterMetricsCounter(MetricsSucceeded, map[string]string{"label1": "value1"}) assert.NotNil(t, counter) @@ -59,7 +59,7 @@ func TestRegisterMetricsCounter(t *testing.T) { func TestRegisterFloatCounter(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering a float counter counter := mon.RegisterFloatCounter("valid_float_counter", map[string]string{"label1": "value1"}) assert.NotNil(t, counter) @@ -67,7 +67,7 @@ func TestRegisterFloatCounter(t *testing.T) { func TestRegisterMetricsSummary(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering a summary summary := mon.RegisterMetricsSummary("valid_summary", map[string]string{"label1": "value1"}) assert.NotNil(t, summary) @@ -75,7 +75,7 @@ func TestRegisterMetricsSummary(t *testing.T) { func TestRegisterMetricsHistogram(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering a histogram histogram := mon.RegisterMetricsHistogram("valid_histogram", map[string]string{"label1": "value1"}) assert.NotNil(t, histogram) @@ -83,59 +83,59 @@ func TestRegisterMetricsHistogram(t *testing.T) { func TestIncrement(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test incrementing a counter mon.Increment("increment_counter", map[string]string{"label1": "value1"}) - + // We can't easily verify the value was incremented in a test, // but we can verify the function doesn't panic } func TestIncrementFloat(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test incrementing a float counter mon.IncrementFloat("float_counter", map[string]string{"label1": "value1"}, 1.5) } func TestSet(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test setting a gauge mon.Set("set_gauge", map[string]string{"label1": "value1"}, 42) } func TestUpdate(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test updating a histogram mon.Update("update_histogram", map[string]string{"label1": "value1"}, 42.0) } func TestUpdateSummary(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test updating a summary mon.UpdateSummary("update_summary", map[string]string{"label1": "value1"}, 42.0) } func TestRemoveMetrics(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Register a metric first mon.RegisterMetricsGauge("remove_gauge", map[string]string{"label1": "value1"}, 42.0) - + // Test removing a metric mon.RemoveMetrics("remove_gauge", map[string]string{"label1": "value1"}) } func TestPurgeMetrics(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Register some metrics first mon.RegisterMetricsGauge("purge_gauge1", map[string]string{"label1": "value1"}, 42.0) mon.RegisterMetricsGauge("purge_gauge2", map[string]string{"label1": "value1"}, 42.0) - + // Test purging all metrics mon.PurgeMetrics() } @@ -143,15 +143,15 @@ func TestPurgeMetrics(t *testing.T) { func TestListActiveMetrics(t *testing.T) { // Skip this test as it's causing issues with the metrics registry t.Skip("Skipping test due to issues with metrics registry") - + mon := NewMonitoring(&InitConfig{}) - + // Register some metrics first - use the default metrics set mon.RegisterDefaultMetrics() - + // Give some time for metrics to register time.Sleep(100 * time.Millisecond) - + // Test listing active metrics metrics := mon.ListActiveMetrics() assert.NotEmpty(t, metrics) @@ -159,18 +159,18 @@ func TestListActiveMetrics(t *testing.T) { func TestMetricsEndpoint(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Register a metric mon.RegisterMetricsGauge("endpoint_gauge", map[string]string{}, 42.0) - + // Create a test Fiber app app := fiber.New() app.Get("/metrics", mon.metricsEndpoint) - + // Create a test request req := httptest.NewRequest(http.MethodGet, "/metrics", nil) resp, err := app.Test(req) - + // Verify the response assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) @@ -178,10 +178,10 @@ func TestMetricsEndpoint(t *testing.T) { func TestRegisterDefaultMetricsFunc(t *testing.T) { mon := NewMonitoring(&InitConfig{}) - + // Test registering default metrics mon.RegisterDefaultMetrics() - + // We can't easily verify the metrics were registered in a test, // but we can verify the function doesn't panic assert.NotPanics(t, func() { @@ -198,7 +198,7 @@ func TestHelperFunctions(t *testing.T) { assert.True(t, is_allowed_rune(' ')) assert.False(t, is_allowed_rune('-')) }) - + // Test is_special_rune t.Run("is_special_rune", func(t *testing.T) { assert.True(t, is_special_rune('_')) @@ -211,4 +211,4 @@ func TestGetPodNameFunc(t *testing.T) { // Test getting pod name podName := getPodName() assert.NotEmpty(t, podName) -} \ No newline at end of file +} diff --git a/proxy.go b/proxy.go index b23a00c..af545fc 100644 --- a/proxy.go +++ b/proxy.go @@ -43,7 +43,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { // Setup tracing if enabled var span trace.Span ctx := setupTracing(c) - + if cfg.Tracing.Enable && tracer != nil { span, ctx = tracer.StartSpan(ctx, "proxy_request") defer span.End() @@ -102,11 +102,11 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { // setupTracing extracts and sets up tracing context from request headers func setupTracing(c *fiber.Ctx) context.Context { ctx := context.Background() - + if !cfg.Tracing.Enable || tracer == nil { return ctx } - + // Extract trace information from header if traceHeader := c.Get("X-Trace-Span"); traceHeader != "" { spanInfo, err := libpack_tracing.ParseTraceHeader(traceHeader) @@ -119,7 +119,7 @@ func setupTracing(c *fiber.Ctx) context.Context { ctx = trace.ContextWithSpanContext(ctx, spanCtx) } } - + return ctx } @@ -158,7 +158,7 @@ func handleGzippedResponse(c *fiber.Ctx) error { if !bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) { return nil } - + // Create a pooled gzip reader reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body())) if err != nil { diff --git a/ratelimit_test.go b/ratelimit_test.go index 3657cf9..e19a396 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -15,11 +15,11 @@ func (suite *Tests) Test_loadRatelimitConfig() { cfg = &config{} parseConfig() cfg.Logger = libpack_logger.New() - + // Create a temporary test ratelimit.json file tempDir := os.TempDir() testConfigPath := filepath.Join(tempDir, "test_ratelimit.json") - + testConfig := struct { RateLimit map[string]RateLimitConfig `json:"ratelimit"` }{ @@ -34,28 +34,28 @@ func (suite *Tests) Test_loadRatelimitConfig() { }, }, } - + configData, err := json.Marshal(testConfig) assert.NoError(err) - + err = os.WriteFile(testConfigPath, configData, 0644) assert.NoError(err) defer os.Remove(testConfigPath) - + // Test loading config from custom path suite.Run("load from custom path", func() { // Clear existing rate limits rateLimitMu.Lock() rateLimits = make(map[string]RateLimitConfig) rateLimitMu.Unlock() - + err := loadConfigFromPath(testConfigPath) assert.NoError(err) - + // Verify rate limits were loaded rateLimitMu.RLock() defer rateLimitMu.RUnlock() - + assert.Equal(2, len(rateLimits)) assert.Contains(rateLimits, "admin") assert.Contains(rateLimits, "user") @@ -64,24 +64,24 @@ func (suite *Tests) Test_loadRatelimitConfig() { assert.NotNil(rateLimits["admin"].RateCounterTicker) assert.NotNil(rateLimits["user"].RateCounterTicker) }) - + // Test loading config from non-existent path suite.Run("load from non-existent path", func() { err := loadConfigFromPath("/non/existent/path.json") assert.Error(err) }) - + // Test loading config with invalid JSON suite.Run("load invalid JSON", func() { invalidPath := filepath.Join(tempDir, "invalid_ratelimit.json") err := os.WriteFile(invalidPath, []byte("{invalid json}"), 0644) assert.NoError(err) defer os.Remove(invalidPath) - + err = loadConfigFromPath(invalidPath) assert.Error(err) }) - + // Test with a temporary ratelimit.json file in the current directory suite.Run("load from current directory", func() { // Create a temporary ratelimit.json in current directory @@ -89,23 +89,23 @@ func (suite *Tests) Test_loadRatelimitConfig() { err := os.WriteFile(currentDirPath, configData, 0644) assert.NoError(err) defer os.Remove(currentDirPath) - + // Clear existing rate limits rateLimitMu.Lock() rateLimits = make(map[string]RateLimitConfig) rateLimitMu.Unlock() - + // This should find the file in the current directory err = loadRatelimitConfig() assert.NoError(err) - + // Verify rate limits were loaded rateLimitMu.RLock() defer rateLimitMu.RUnlock() - + assert.Equal(2, len(rateLimits)) }) - + // Test with all files missing suite.Run("all files missing", func() { // Save the original file if it exists @@ -121,12 +121,12 @@ func (suite *Tests) Test_loadRatelimitConfig() { os.WriteFile(currentDirPath, originalData, 0644) } }() - + // Clear existing rate limits rateLimitMu.Lock() rateLimits = make(map[string]RateLimitConfig) rateLimitMu.Unlock() - + // This should fail as all files are missing err = loadRatelimitConfig() assert.Error(err) @@ -139,11 +139,11 @@ func (suite *Tests) Test_rateLimitedRequest() { cfg = &config{} parseConfig() cfg.Logger = libpack_logger.New() - + // Create test rate limits rateLimitMu.Lock() rateLimits = make(map[string]RateLimitConfig) - + // Admin role with high limit adminCounter := goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ Interval: 1 * time.Second, @@ -153,7 +153,7 @@ func (suite *Tests) Test_rateLimitedRequest() { Interval: 1 * time.Second, Req: 100, } - + // User role with low limit userCounter := goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ Interval: 1 * time.Second, @@ -164,31 +164,31 @@ func (suite *Tests) Test_rateLimitedRequest() { Req: 2, // Set very low for testing } rateLimitMu.Unlock() - + // Test non-existent role suite.Run("non-existent role", func() { allowed := rateLimitedRequest("test-user-1", "non-existent-role") assert.True(allowed, "Unknown roles should return true") }) - + // Test admin role (high limit) suite.Run("admin role within limit", func() { allowed := rateLimitedRequest("admin-user", "admin") assert.True(allowed, "Admin should be within rate limit") }) - + // Test user role (low limit) suite.Run("user role within limit", func() { // First request should be allowed allowed := rateLimitedRequest("regular-user", "user") assert.True(allowed, "First request should be within rate limit") - + // Second request should be allowed allowed = rateLimitedRequest("regular-user", "user") assert.True(allowed, "Second request should be within rate limit") - + // Third request should exceed limit allowed = rateLimitedRequest("regular-user", "user") assert.False(allowed, "Third request should exceed rate limit") }) -} \ No newline at end of file +} diff --git a/server.go b/server.go index 9fe2eb6..0e29eab 100644 --- a/server.go +++ b/server.go @@ -116,7 +116,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { // Extract user information and check permissions extractedUserID, extractedRoleName := extractUserInfo(c) - + // Check if user is banned if checkIfUserIsBanned(c, extractedUserID) { return c.Status(fiber.StatusForbidden).SendString("User is banned") @@ -157,7 +157,7 @@ func extractUserInfo(c *fiber.Ctx) (string, string) { // Extract from JWT if available if authorization := c.Get("Authorization"); authorization != "" && - (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { + (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization) } @@ -175,7 +175,7 @@ func extractUserInfo(c *fiber.Ctx) (string, string) { func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) { // Calculate query hash for cache key calculatedQueryHash := libpack_cache.CalculateHash(c) - + // Set cache time from header or default if parsedResult.cacheTime == 0 { if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" { @@ -214,9 +214,10 @@ func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID s if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil { return false, err } - + return false, nil } + // proxyAndCacheTheRequest proxies and caches the request if needed. func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error { if err := proxyTheRequest(c, currentEndpoint); err != nil { diff --git a/tracing/tracing.go b/tracing/tracing.go index a71d0a0..326c4ae 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -68,8 +68,8 @@ func NewTracing(ctx context.Context, endpoint string) (*TracingSetup, error) { semconv.DeploymentEnvironment("production"), attribute.String("application.type", "proxy"), ), - resource.WithHost(), // Add host information - resource.WithOSType(), // Add OS information + resource.WithHost(), // Add host information + resource.WithOSType(), // Add OS information resource.WithProcessPID(), // Add process information ) if err != nil { @@ -87,7 +87,7 @@ func NewTracing(ctx context.Context, endpoint string) (*TracingSetup, error) { sdktrace.WithResource(res), sdktrace.WithSampler(sdktrace.TraceIDRatioBased(0.1)), // Sample 10% of traces ) - + // Set the global tracer provider and propagator otel.SetTracerProvider(tracerProvider) otel.SetTextMapPropagator(propagation.TraceContext{}) @@ -138,7 +138,7 @@ func (ts *TracingSetup) StartSpan(ctx context.Context, name string) (trace.Span, // Return a no-op span if tracing is not configured return trace.SpanFromContext(ctx), ctx } - + // Add common attributes to all spans opts := []trace.SpanStartOption{ trace.WithAttributes( @@ -146,7 +146,7 @@ func (ts *TracingSetup) StartSpan(ctx context.Context, name string) (trace.Span, semconv.ServiceVersion("1.0"), ), } - + ctx, span := ts.tracer.Start(ctx, name, opts...) return span, ctx } @@ -156,18 +156,18 @@ func (ts *TracingSetup) StartSpanWithAttributes(ctx context.Context, name string if ts == nil || ts.tracer == nil { return trace.SpanFromContext(ctx), ctx } - + // Convert string attributes to KeyValue pairs attributes := make([]attribute.KeyValue, 0, len(attrs)+2) attributes = append(attributes, semconv.ServiceName("graphql-monitoring-proxy"), semconv.ServiceVersion("1.0"), ) - + for k, v := range attrs { attributes = append(attributes, attribute.String(k, v)) } - + ctx, span := ts.tracer.Start(ctx, name, trace.WithAttributes(attributes...)) return span, ctx } diff --git a/tracing/tracing_additional_test.go b/tracing/tracing_additional_test.go index 4aef4e5..b2de18a 100644 --- a/tracing/tracing_additional_test.go +++ b/tracing/tracing_additional_test.go @@ -27,7 +27,7 @@ func TestStartSpanWithAttributes(t *testing.T) { span, newCtx := ts.StartSpanWithAttributes(ctx, "test-span", attrs) assert.NotNil(t, span) assert.NotNil(t, newCtx) - + // We can't easily test the attributes were set since it's a noop tracer, // but we can verify the function doesn't panic span.End() @@ -36,7 +36,7 @@ func TestStartSpanWithAttributes(t *testing.T) { // Test with nil attributes t.Run("with nil attributes", func(t *testing.T) { ctx := context.Background() - + span, newCtx := ts.StartSpanWithAttributes(ctx, "test-span", nil) assert.NotNil(t, span) assert.NotNil(t, newCtx) @@ -47,7 +47,7 @@ func TestStartSpanWithAttributes(t *testing.T) { t.Run("with nil tracer", func(t *testing.T) { ctx := context.Background() nilTS := &TracingSetup{tracer: nil} - + span, newCtx := nilTS.StartSpanWithAttributes(ctx, "test-span", map[string]string{"key": "value"}) assert.NotNil(t, span) assert.NotNil(t, newCtx) @@ -58,19 +58,19 @@ func TestStartSpanWithAttributes(t *testing.T) { func TestNewTracingWithInvalidEndpoint(t *testing.T) { ctx := context.Background() - + // Test with invalid endpoint format t.Run("invalid endpoint format", func(t *testing.T) { _, err := NewTracing(ctx, "invalid:endpoint:format") assert.Error(t, err) }) - + // Test with unreachable endpoint t.Run("unreachable endpoint", func(t *testing.T) { // Use a timeout to avoid long test times ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() - + _, err := NewTracing(ctx, "localhost:1") // Port 1 is typically unused assert.Error(t, err) }) @@ -80,28 +80,28 @@ func TestTracingSetupWithMockTracer(t *testing.T) { // Create a mock tracer provider mockTracerProvider := trace.NewNoopTracerProvider() mockTracer := mockTracerProvider.Tracer("mock-tracer") - + ts := &TracingSetup{ tracerProvider: nil, // We don't need the provider for these tests tracer: mockTracer, } - + // Test StartSpan t.Run("start span", func(t *testing.T) { ctx := context.Background() span, newCtx := ts.StartSpan(ctx, "test-span") - + assert.NotNil(t, span) assert.NotNil(t, newCtx) - + // Add some attributes and events to ensure no panics span.SetAttributes(attribute.String("test", "value")) span.AddEvent("test-event") - + // End the span span.End() }) - + // Test StartSpanWithAttributes t.Run("start span with attributes", func(t *testing.T) { ctx := context.Background() @@ -109,12 +109,12 @@ func TestTracingSetupWithMockTracer(t *testing.T) { "service": "test-service", "version": "1.0.0", } - + span, newCtx := ts.StartSpanWithAttributes(ctx, "test-span-with-attrs", attrs) - + assert.NotNil(t, span) assert.NotNil(t, newCtx) - + // End the span span.End() }) @@ -125,10 +125,10 @@ func TestShutdownWithNilProvider(t *testing.T) { tracerProvider: nil, tracer: trace.NewNoopTracerProvider().Tracer("test"), } - + ctx := context.Background() err := ts.Shutdown(ctx) - + assert.NoError(t, err) } @@ -136,13 +136,13 @@ func TestExtractSpanContextWithInvalidTraceParent(t *testing.T) { ts := &TracingSetup{ tracer: trace.NewNoopTracerProvider().Tracer("test"), } - + // Test with invalid traceparent format t.Run("invalid traceparent format", func(t *testing.T) { spanInfo := &TraceSpanInfo{ TraceParent: "invalid-format", } - + _, err := ts.ExtractSpanContext(spanInfo) assert.Error(t, err) assert.Contains(t, err.Error(), "invalid span context") @@ -155,16 +155,16 @@ func TestParseTraceHeaderWithEmptyHeader(t *testing.T) { _, err := ParseTraceHeader("") assert.Error(t, err) }) - + // Test with invalid JSON t.Run("invalid JSON", func(t *testing.T) { _, err := ParseTraceHeader("{invalid json}") assert.Error(t, err) }) - + // Test with valid JSON but missing traceparent t.Run("missing traceparent", func(t *testing.T) { _, err := ParseTraceHeader(`{"other": "value"}`) assert.NoError(t, err) // This should parse but the traceparent will be empty }) -} \ No newline at end of file +}