diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 620b036..77e47e7 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -78,7 +78,24 @@ jobs: apt-get install ca-certificates make -y update-ca-certificates go mod tidy + git config --global --add safe.directory "$GITHUB_WORKSPACE" - name: Run unit tests run: | CI_RUN=${CI} make test + + - name: Run benchmark + run: | + go test -bench=. -benchmem ./... -run=^# | tee output.txt + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + tool: "go" + output-file-path: output.txt + fail-on-alert: true + github-token: ${{ secrets.GHCR_TOKEN }} + comment-on-alert: true + summary-always: true + # auto-push only if it's on main branch + auto-push: false diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b9a5bf4..aca62b9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -10,6 +10,9 @@ on: branches: - "main" +env: + GO_VERSION: ">=1.21" + jobs: shared: uses: telegram-bot-app/ci-scripts/.github/workflows/build-test-publish-inject.yaml@main @@ -18,3 +21,44 @@ jobs: should-deploy: false secrets: ghcr-token: ${{ secrets.GHCR_TOKEN }} + + test: + name: "Unit testing" + # needs: [prepare] + runs-on: ubuntu-latest + container: golang:1 + # container: github/super-linter:v4 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ${{env.GO_VERSION}} + cache-dependency-path: "**/*.sum" + + - name: Install dependencies + run: | + apt-get update + apt-get install ca-certificates make -y + update-ca-certificates + go mod tidy + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Run benchmark + run: | + go test -bench=. -benchmem ./... -run=^# | tee output.txt + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + tool: "go" + output-file-path: output.txt + fail-on-alert: true + github-token: ${{ secrets.GHCR_TOKEN }} + comment-on-alert: true + summary-always: true + # auto-push only if it's on main branch + auto-push: true diff --git a/api.go b/api.go index c990081..bb76d1d 100644 --- a/api.go +++ b/api.go @@ -3,6 +3,7 @@ package main import ( "fmt" "os" + "sync" "time" "github.com/goccy/go-json" @@ -13,55 +14,66 @@ import ( libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) -var bannedUsersIDs map[string]string = make(map[string]string) +var ( + bannedUsersIDs = make(map[string]string) + bannedUsersIDsMutex sync.RWMutex +) func enableApi() { - if cfg.Server.EnableApi { - apiserver := fiber.New(fiber.Config{ - DisableStartupMessage: true, - AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION), + if !cfg.Server.EnableApi { + return + } + + apiserver := fiber.New(fiber.Config{ + DisableStartupMessage: true, + AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION), + }) + + api := apiserver.Group("/api") + api.Post("/user-ban", apiBanUser) + api.Post("/user-unban", apiUnbanUser) + api.Post("/cache-clear", apiClearCache) + api.Get("/cache-stats", apiCacheStats) + + go periodicallyReloadBannedUsers() + + if err := apiserver.Listen(fmt.Sprintf(":%d", cfg.Server.ApiPort)); err != nil { + cfg.Logger.Critical(&libpack_logger.LogMessage{ + Message: "Can't start the service", + Pairs: map[string]interface{}{"port": cfg.Server.ApiPort}, }) - - api := apiserver.Group("/api") - api.Post("/user-ban", apiBanUser) - api.Post("/user-unban", apiUnbanUser) - api.Post("/cache-clear", apiClearCache) - api.Get("/cache-stats", apiCacheStats) - - go periodicallyReloadBannedUsers() - err := apiserver.Listen(fmt.Sprintf(":%d", cfg.Server.ApiPort)) - if err != nil { - cfg.Logger.Critical(&libpack_logger.LogMessage{ - Message: "Can't start the service", - Pairs: map[string]interface{}{"port": cfg.Server.ApiPort}, - }) - } } } func periodicallyReloadBannedUsers() { - for { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for range ticker.C { loadBannedUsers() cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Banned users reloaded", Pairs: map[string]interface{}{"users": bannedUsersIDs}, }) - <-time.After(10 * time.Second) } } func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool { + bannedUsersIDsMutex.RLock() _, found := bannedUsersIDs[userID] + bannedUsersIDsMutex.RUnlock() + cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Checking if user is banned", - Pairs: map[string]interface{}{"user_id": userID, "found": found}, + Pairs: map[string]interface{}{"user_id": userID, "banned": found}, }) + if found { cfg.Logger.Info(&libpack_logger.LogMessage{ Message: "User is banned", Pairs: map[string]interface{}{"user_id": userID}, }) - c.Status(403).SendString("User is banned") + c.Status(fiber.StatusForbidden).SendString("User is banned") } return found } @@ -69,28 +81,16 @@ func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool { func apiClearCache(c *fiber.Ctx) error { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Clearing cache via API", - Pairs: nil, }) libpack_cache.CacheClear() cfg.Logger.Info(&libpack_logger.LogMessage{ Message: "Cache cleared via API", - Pairs: nil, }) - c.Status(200).SendString("OK: cache cleared") - return nil + return c.SendString("OK: cache cleared") } func apiCacheStats(c *fiber.Ctx) error { - stats := libpack_cache.GetCacheStats() - err := c.JSON(stats) - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't marshal cache stats", - Pairs: map[string]interface{}{"error": err.Error()}, - }) - return err - } - return nil + return c.JSON(libpack_cache.GetCacheStats()) } type apiBanUserRequest struct { @@ -100,71 +100,92 @@ type apiBanUserRequest struct { func apiBanUser(c *fiber.Ctx) error { var req apiBanUserRequest - err := c.BodyParser(&req) - if err != nil { + if err := c.BodyParser(&req); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't parse the ban user request", Pairs: map[string]interface{}{"error": err.Error()}, }) - return err + return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload") } + + if req.UserID == "" || req.Reason == "" { + return c.Status(fiber.StatusBadRequest).SendString("user_id and reason are required") + } + + bannedUsersIDsMutex.Lock() bannedUsersIDs[req.UserID] = req.Reason + bannedUsersIDsMutex.Unlock() + cfg.Logger.Info(&libpack_logger.LogMessage{ Message: "Banned user", Pairs: map[string]interface{}{"user_id": req.UserID, "reason": req.Reason}, }) - storeBannedUsers() - c.Status(200).SendString("OK: user banned") - return nil + + if err := storeBannedUsers(); err != nil { + return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users") + } + + return c.SendString("OK: user banned") } func apiUnbanUser(c *fiber.Ctx) error { var req apiBanUserRequest - err := c.BodyParser(&req) - if err != nil { + if err := c.BodyParser(&req); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't parse the unban user request", Pairs: map[string]interface{}{"error": err.Error()}, }) - return err + return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload") } + + if req.UserID == "" { + return c.Status(fiber.StatusBadRequest).SendString("user_id is required") + } + + bannedUsersIDsMutex.Lock() delete(bannedUsersIDs, req.UserID) + bannedUsersIDsMutex.Unlock() + cfg.Logger.Info(&libpack_logger.LogMessage{ Message: "Unbanned user", Pairs: map[string]interface{}{"user_id": req.UserID}, }) - storeBannedUsers() - c.Status(200).SendString("OK: user unbanned") - return nil + + if err := storeBannedUsers(); err != nil { + return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users") + } + + return c.SendString("OK: user unbanned") } -func storeBannedUsers() { +func storeBannedUsers() error { fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) - err := fileLock.Lock() - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't lock the file", - Pairs: map[string]interface{}{"error": err.Error()}, - }) - return + if err := lockFile(fileLock); err != nil { + return err } defer fileLock.Unlock() + + bannedUsersIDsMutex.RLock() data, err := json.Marshal(bannedUsersIDs) + bannedUsersIDsMutex.RUnlock() + if err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't marshal banned users", Pairs: map[string]interface{}{"error": err.Error()}, }) - return + return err } - err = os.WriteFile(cfg.Api.BannedUsersFile, data, 0644) - if err != nil { + + if err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0644); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't write banned users to file", Pairs: map[string]interface{}{"error": err.Error()}, }) - return + return err } + + return nil } func loadBannedUsers() { @@ -173,19 +194,9 @@ func loadBannedUsers() { Message: "Banned users file doesn't exist - creating it", Pairs: map[string]interface{}{"file": cfg.Api.BannedUsersFile}, }) - _, err := os.Create(cfg.Api.BannedUsersFile) - if err != nil { + if err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{}"), 0644); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't create the file", - Pairs: map[string]interface{}{"error": err.Error()}, - }) - return - } - // write empty json to the file - err = os.WriteFile(cfg.Api.BannedUsersFile, []byte("{}"), 0644) - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't write to the file", + Message: "Can't create and write to the file", Pairs: map[string]interface{}{"error": err.Error()}, }) return @@ -193,8 +204,7 @@ func loadBannedUsers() { } fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile)) - err := fileLock.RLock() // Use RLock for read lock - if err != nil { + if err := lockFileRead(fileLock); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't lock the file [load]", Pairs: map[string]interface{}{"error": err.Error()}, @@ -211,12 +221,39 @@ func loadBannedUsers() { }) return } - err = json.Unmarshal(data, &bannedUsersIDs) - if err != nil { + + var newBannedUsers map[string]string + if err := json.Unmarshal(data, &newBannedUsers); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't unmarshal banned users", Pairs: map[string]interface{}{"error": err.Error()}, }) return } + + bannedUsersIDsMutex.Lock() + bannedUsersIDs = newBannedUsers + bannedUsersIDsMutex.Unlock() +} + +func lockFile(fileLock *flock.Flock) error { + if err := fileLock.Lock(); err != nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't lock the file", + Pairs: map[string]interface{}{"error": err.Error()}, + }) + return err + } + return nil +} + +func lockFileRead(fileLock *flock.Flock) error { + if err := fileLock.RLock(); err != nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't lock the file for reading", + Pairs: map[string]interface{}{"error": err.Error()}, + }) + return err + } + return nil } diff --git a/cache/cache_test.go b/cache/cache_test.go index bc1f16e..b6851d2 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -1,8 +1,11 @@ package libpack_cache import ( + "fmt" + "sync" "time" + "github.com/alicebob/miniredis/v2" libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory" libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" @@ -117,3 +120,96 @@ func (suite *Tests) Test_cacheLookupRedis() { }) } } + +func (suite *Tests) Test_cacheConcurrency() { + config = &CacheConfig{ + Logger: libpack_logger.New(), + Client: libpack_cache_memory.New(5 * time.Second), + TTL: 5, + } + + const numGoroutines = 10 + const numOperations = 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("key-%d-%d", id, j) + value := []byte(fmt.Sprintf("value-%d-%d", id, j)) + CacheStore(key, value) + retrieved := CacheLookup(key) + assert.Equal(string(value), string(retrieved), "Concurrent cache operation failed") + } + }(i) + } + + wg.Wait() +} + +// func (suite *Tests) Test_cacheEviction() { +// config = &CacheConfig{ +// Logger: libpack_logger.New(), +// Client: libpack_cache_memory.New(3 * time.Second), // 3 seconds TTL +// TTL: 3, +// } + +// // Fill the cache +// for i := 0; i < 20; i++ { +// key := fmt.Sprintf("key-%d", i) +// value := []byte(fmt.Sprintf("value-%d", i)) +// CacheStore(key, value) +// time.Sleep(100 * time.Millisecond) // Ensure different creation times +// } + +// // Wait for the TTL to expire for the first half of the items +// time.Sleep(3100 * time.Millisecond) + +// // Check that the oldest items have been evicted +// for i := 0; i < 10; i++ { +// key := fmt.Sprintf("key-%d", i) +// retrieved := CacheLookup(key) +// assert.Nil(retrieved, fmt.Sprintf("Old item %s should have been evicted", key)) +// } + +// // Check that the newer items are still in the cache +// for i := 10; i < 20; i++ { +// key := fmt.Sprintf("key-%d", i) +// expected := []byte(fmt.Sprintf("value-%d", i)) +// retrieved := CacheLookup(key) +// assert.Equal(expected, retrieved, fmt.Sprintf("Recent item %s should be in cache", key)) +// } +// } + +func (suite *Tests) Test_cacheRedisFailure() { + mr, err := miniredis.Run() + if err != nil { + suite.T().Fatal(err) + } + defer mr.Close() + + config = &CacheConfig{ + Logger: libpack_logger.New(), + Client: libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ + RedisServer: mr.Addr(), + RedisDB: 0, + }), + TTL: 5, + } + + // Test normal operation + CacheStore("test-key", []byte("test-value")) + retrieved := CacheLookup("test-key") + assert.Equal([]byte("test-value"), retrieved) + + // Simulate Redis failure + mr.Close() + + // Operations should not panic, but should return errors or nil values + CacheStore("another-key", []byte("another-value")) + retrieved = CacheLookup("another-key") + assert.Nil(retrieved, "Lookup should return nil when Redis is down") +} diff --git a/cache/memory/memory.go b/cache/memory/memory.go index a738ee7..fafaa68 100644 --- a/cache/memory/memory.go +++ b/cache/memory/memory.go @@ -69,11 +69,17 @@ func (c *Cache) Set(key string, value []byte, ttl time.Duration) { func (c *Cache) Get(key string) ([]byte, bool) { entry, ok := c.entries.Load(key) - if !ok || entry.(CacheEntry).ExpiresAt.Before(time.Now()) { + if !ok { return nil, false } - compressedValue := entry.(CacheEntry).Value - value, err := c.decompress(compressedValue) + + cacheEntry := entry.(CacheEntry) + if cacheEntry.ExpiresAt.Before(time.Now()) { + c.entries.Delete(key) + return nil, false + } + + value, err := c.decompress(cacheEntry.Value) if err != nil { log.Printf("Error decompressing value for key %s: %v", key, err) return nil, false diff --git a/cache/memory/memory_bench_test.go b/cache/memory/memory_bench_test.go index fffd786..054cbfb 100644 --- a/cache/memory/memory_bench_test.go +++ b/cache/memory/memory_bench_test.go @@ -1,6 +1,7 @@ package libpack_cache_memory import ( + "fmt" "testing" "time" ) @@ -52,3 +53,30 @@ func BenchmarkMemCacheStats(b *testing.B) { cache.Set(key, value, 5*time.Second) // Pre-set a value to retrieve cache.Get(key) } + +func BenchmarkCacheSet(b *testing.B) { + cache := New(5 * time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.Set(fmt.Sprintf("key-%d", i), []byte("value"), 5*time.Second) + } +} + +func BenchmarkCacheGet(b *testing.B) { + cache := New(5 * time.Second) + cache.Set("test-key", []byte("test-value"), 5*time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + cache.Get("test-key") + } +} + +func BenchmarkCacheDelete(b *testing.B) { + cache := New(5 * time.Second) + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("key-%d", i) + cache.Set(key, []byte("value"), 5*time.Second) + cache.Delete(key) + } +} diff --git a/cache/memory/memory_test.go b/cache/memory/memory_test.go index b057312..948091c 100644 --- a/cache/memory/memory_test.go +++ b/cache/memory/memory_test.go @@ -1,6 +1,8 @@ package libpack_cache_memory import ( + "fmt" + "sync" "testing" "time" @@ -110,3 +112,57 @@ func (suite *MemoryTestSuite) Test_CacheExpire() { }) } } + +func (suite *MemoryTestSuite) Test_ConcurrentReadWrite() { + cache := New(5 * time.Second) + const numGoroutines = 100 + const numOperations = 1000 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < numOperations; j++ { + key := fmt.Sprintf("key-%d-%d", id, j) + value := []byte(fmt.Sprintf("value-%d-%d", id, j)) + + if j%2 == 0 { + cache.Set(key, value, 5*time.Second) + } else { + _, _ = cache.Get(key) + } + } + }(i) + } + + wg.Wait() +} + +func (suite *MemoryTestSuite) Test_LargeItems() { + cache := New(5 * time.Second) + largeValue := make([]byte, 10*1024*1024) // 10MB + cache.Set("large-key", largeValue, 5*time.Second) + + retrieved, found := cache.Get("large-key") + suite.Assert().True(found) + suite.Assert().Equal(largeValue, retrieved) +} + +func (suite *MemoryTestSuite) Test_ZeroTTL() { + cache := New(5 * time.Second) + cache.Set("zero-ttl", []byte("value"), 0) + + _, found := cache.Get("zero-ttl") + suite.Assert().False(found, "Item with zero TTL should not be stored") +} + +func (suite *MemoryTestSuite) Test_LongTTL() { + cache := New(5 * time.Second) + cache.Set("long-ttl", []byte("value"), 24*365*time.Hour) // 1 year + + retrieved, found := cache.Get("long-ttl") + suite.Assert().True(found) + suite.Assert().Equal([]byte("value"), retrieved) +} diff --git a/details.go b/details.go index 25940ca..dc35a1a 100644 --- a/details.go +++ b/details.go @@ -11,18 +11,14 @@ import ( libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" ) -func extractClaimsFromJWTHeader(authorization string) (usr string, role string) { - usr, role = "-", "-" +const defaultValue = "-" - handleError := func(msg string, details map[string]interface{}) { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: msg, - Pairs: details, - }) - } +var emptyMetrics = map[string]string{} - tokenParts := strings.Split(authorization, ".") +func extractClaimsFromJWTHeader(authorization string) (usr, role string) { + usr, role = defaultValue, defaultValue + + tokenParts := strings.SplitN(authorization, ".", 3) if len(tokenParts) != 3 { handleError("Can't split the token", map[string]interface{}{"token": authorization}) return @@ -40,18 +36,30 @@ func extractClaimsFromJWTHeader(authorization string) (usr string, role string) return } - extractClaim := func(claimPath string, target *string, name string) { - if len(claimPath) > 0 { - var ok bool - *target, ok = ask.For(claimMap, claimPath).String("-") - if !ok { - handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath}) - } - } - } - - extractClaim(cfg.Client.JWTUserClaimPath, &usr, "user id") - extractClaim(cfg.Client.JWTRoleClaimPath, &role, "role") + usr = extractClaim(claimMap, cfg.Client.JWTUserClaimPath, "user id") + role = extractClaim(claimMap, cfg.Client.JWTRoleClaimPath, "role") return } + +func extractClaim(claimMap map[string]interface{}, claimPath, name string) string { + if claimPath == "" { + return defaultValue + } + + value, ok := ask.For(claimMap, claimPath).String(defaultValue) + if !ok { + handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath}) + return defaultValue + } + + return value +} + +func handleError(msg string, details map[string]interface{}) { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, emptyMetrics) + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: msg, + Pairs: details, + }) +} diff --git a/events.go b/events.go index b917db7..81ccc47 100644 --- a/events.go +++ b/events.go @@ -5,70 +5,76 @@ import ( "fmt" "time" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) -func enableHasuraEventCleaner() { - if cfg.HasuraEventCleaner.Enable { - if cfg.HasuraEventCleaner.EventMetadataDb == "" { - cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "Event metadata db URL not specified, event cleaner not active", - Pairs: nil, - }) - return - } +const ( + initialDelay = 60 * time.Second + cleanupInterval = 1 * time.Hour +) - ticker := time.NewTicker(1 * time.Hour) - defer ticker.Stop() - cfg.Logger.Info(&libpack_logger.LogMessage{ - Message: "Event cleaner enabled", - Pairs: map[string]interface{}{"interval_in_days": cfg.HasuraEventCleaner.ClearOlderThan}, - }) - - time.Sleep(60 * time.Second) // wait for everything to start and settle down - cfg.Logger.Info(&libpack_logger.LogMessage{ - Message: "Initial cleanup of old events", - Pairs: nil, - }) - cleanEvents() - - for { - select { - case <-ticker.C: - cfg.Logger.Info(&libpack_logger.LogMessage{ - Message: "Cleaning up old events", - Pairs: nil, - }) - cleanEvents() - } - } - } +var delQueries = []string{ + "DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < now() - interval '%d days';", + "DELETE FROM hdb_catalog.event_log WHERE created_at < now() - interval '%d days';", + "DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL '%d days';", + "DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", + "DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", } -func cleanEvents() { - conn, err := pgx.Connect(context.Background(), cfg.HasuraEventCleaner.EventMetadataDb) +func enableHasuraEventCleaner() { + if !cfg.HasuraEventCleaner.Enable { + return + } + + if cfg.HasuraEventCleaner.EventMetadataDb == "" { + cfg.Logger.Warning(&libpack_logger.LogMessage{ + Message: "Event metadata db URL not specified, event cleaner not active", + }) + return + } + + cfg.Logger.Info(&libpack_logger.LogMessage{ + Message: "Event cleaner enabled", + Pairs: map[string]interface{}{"interval_in_days": cfg.HasuraEventCleaner.ClearOlderThan}, + }) + + pool, err := pgxpool.New(context.Background(), cfg.HasuraEventCleaner.EventMetadataDb) if err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Failed to connect to event metadata db", + Message: "Failed to create connection pool", Pairs: map[string]interface{}{"error": err}, }) return } - defer conn.Close(context.Background()) + defer pool.Close() - delQueries := []string{ - fmt.Sprintf("DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < now() - interval '%d days';", cfg.HasuraEventCleaner.ClearOlderThan), - fmt.Sprintf("DELETE FROM hdb_catalog.event_log WHERE created_at < now() - interval '%d days';", cfg.HasuraEventCleaner.ClearOlderThan), - fmt.Sprintf("DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan), - fmt.Sprintf("DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan), - fmt.Sprintf("DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan), - } + go func() { + time.Sleep(initialDelay) + cfg.Logger.Info(&libpack_logger.LogMessage{ + Message: "Initial cleanup of old events", + }) + cleanEvents(pool) + + ticker := time.NewTicker(cleanupInterval) + defer ticker.Stop() + + for range ticker.C { + cfg.Logger.Info(&libpack_logger.LogMessage{ + Message: "Cleaning up old events", + }) + cleanEvents(pool) + } + }() +} + +func cleanEvents(pool *pgxpool.Pool) { + ctx := context.Background() for _, query := range delQueries { - _, err := conn.Exec(context.Background(), query) + _, err := pool.Exec(ctx, fmt.Sprintf(query, cfg.HasuraEventCleaner.ClearOlderThan)) if err != nil { - cfg.Logger.Debug(&libpack_logger.LogMessage{ + cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Failed to execute query", Pairs: map[string]interface{}{"query": query, "error": err}, }) diff --git a/go.mod b/go.mod index 0b2282e..11c9d67 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/lukaszraczylo/graphql-monitoring-proxy -go 1.21.0 - -toolchain go1.22.4 +go 1.22.4 require ( github.com/VictoriaMetrics/metrics v1.34.0 @@ -16,8 +14,8 @@ require ( github.com/graphql-go/graphql v0.8.1 github.com/jackc/pgx/v5 v5.6.0 github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 - github.com/lukaszraczylo/go-ratecounter v0.1.10 - github.com/lukaszraczylo/go-simple-graphql v1.2.14 + github.com/lukaszraczylo/go-ratecounter v0.1.12 + github.com/lukaszraczylo/go-simple-graphql v1.2.17 github.com/redis/go-redis/v9 v9.5.3 github.com/stretchr/testify v1.9.0 github.com/valyala/fasthttp v1.55.0 @@ -32,13 +30,13 @@ require ( github.com/gookit/color v1.5.4 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect - github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/rs/zerolog v1.33.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fastrand v1.1.0 // indirect @@ -47,7 +45,6 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/gopher-lua v1.1.1 // indirect golang.org/x/crypto v0.24.0 // indirect - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect diff --git a/go.sum b/go.sum index c9f0ccc..0e7902c 100644 --- a/go.sum +++ b/go.sum @@ -51,10 +51,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 h1:lvI8Wlbg4PxkRcg2f10wgoaRpfN19v+YdRek3+dLtlM= github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415/go.mod h1:M+UVdyqZs++xtEPrascaVmZdOMhCnxjZ2SgH+xHpR0c= -github.com/lukaszraczylo/go-ratecounter v0.1.10 h1:rwlKGNRXK7nLpDJxRlBqhIW+8cm2OEbnsslYSjDZwls= -github.com/lukaszraczylo/go-ratecounter v0.1.10/go.mod h1:TqXEOCtFJStk1i0tkipprv1kiDHGon1MVUisjSTBSKM= -github.com/lukaszraczylo/go-simple-graphql v1.2.14 h1:Dth+yZ+1ialCpnslSb6UgHbXszExjDUu/I95QZbnWVU= -github.com/lukaszraczylo/go-simple-graphql v1.2.14/go.mod h1:pSKmm9OLGoS9pjmIvhBB/fo0+LganRrL29CN3fdkRPw= +github.com/lukaszraczylo/go-ratecounter v0.1.12 h1:VO6hHYGw/Jy9JUizXf/bS0AI2QX1ueWWAWckMFVJ/w4= +github.com/lukaszraczylo/go-ratecounter v0.1.12/go.mod h1:TqXEOCtFJStk1i0tkipprv1kiDHGon1MVUisjSTBSKM= +github.com/lukaszraczylo/go-simple-graphql v1.2.17 h1:XxUUgxcCIZSVLzI4UfhBDXoFoMlygcXHfAJwXxawr1s= +github.com/lukaszraczylo/go-simple-graphql v1.2.17/go.mod h1:pSKmm9OLGoS9pjmIvhBB/fo0+LganRrL29CN3fdkRPw= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -71,8 +71,8 @@ github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0 github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -97,8 +97,8 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= diff --git a/graphql.go b/graphql.go index f8a0861..2072427 100644 --- a/graphql.go +++ b/graphql.go @@ -3,6 +3,8 @@ package main import ( "strconv" "strings" + "sync" + "unsafe" "github.com/goccy/go-json" fiber "github.com/gofiber/fiber/v2" @@ -12,48 +14,29 @@ import ( libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" ) -var introspection_queries = []string{ - "__schema", - "__type", - "__typename", - "__directive", - "__directivelocation", - "__field", - "__inputvalue", - "__enumvalue", - "__typekind", - "__fieldtype", - "__inputobjecttype", - "__enumtype", - "__uniontype", - "__scalars", - "__objects", - "__interfaces", - "__unions", - "__enums", - "__inputobjects", - "__directives", -} - -// Saving the introspection queries as a map O(1) operation instead of O(n) for a slice. - -var introspectionQuerySet = map[string]struct{}{} -var introspectionAllowedQueries = map[string]struct{}{} -var allowedUrls = map[string]struct{}{} - -// Utility function to convert a slice of strings to a map for O(1) lookups. -func sliceToMap(slice []string) map[string]struct{} { - resultMap := make(map[string]struct{}, len(slice)) - for _, item := range slice { - resultMap[strings.ToLower(item)] = struct{}{} +var ( + introspectionQueries = map[string]struct{}{ + "__schema": {}, "__type": {}, "__typename": {}, "__directive": {}, + "__directivelocation": {}, "__field": {}, "__inputvalue": {}, + "__enumvalue": {}, "__typekind": {}, "__fieldtype": {}, + "__inputobjecttype": {}, "__enumtype": {}, "__uniontype": {}, + "__scalars": {}, "__objects": {}, "__interfaces": {}, + "__unions": {}, "__enums": {}, "__inputobjects": {}, "__directives": {}, } - return resultMap -} + introspectionAllowedQueries = make(map[string]struct{}) + allowedUrls = make(map[string]struct{}) + mu sync.RWMutex +) func prepareQueriesAndExemptions() { - introspectionQuerySet = sliceToMap(introspection_queries) - introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed) - allowedUrls = sliceToMap(cfg.Server.AllowURLs) + mu.Lock() + defer mu.Unlock() + for _, q := range cfg.Security.IntrospectionAllowed { + introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} + } + for _, u := range cfg.Server.AllowURLs { + allowedUrls[u] = struct{}{} + } } type parseGraphQLQueryResult struct { @@ -67,21 +50,41 @@ type parseGraphQLQueryResult struct { shouldIgnore bool } -func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { - res = &parseGraphQLQueryResult{shouldIgnore: true} - m := make(map[string]interface{}) - err := json.Unmarshal(c.Body(), &m) - if err != nil { +var ( + queryPool = sync.Pool{ + New: func() interface{} { + return make(map[string]interface{}, 4) + }, + } + resultPool = sync.Pool{ + New: func() interface{} { + return &parseGraphQLQueryResult{} + }, + } +) + +func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { + res := resultPool.Get().(*parseGraphQLQueryResult) + defer resultPool.Put(res) + *res = parseGraphQLQueryResult{shouldIgnore: true} + + m := queryPool.Get().(map[string]interface{}) + defer queryPool.Put(m) + for k := range m { + delete(m, k) + } + + if err := json.Unmarshal(c.Body(), &m); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't unmarshal the request", - Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())}, + Pairs: map[string]interface{}{"error": err.Error(), "body": unsafeString(c.Body())}, }) if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } - return + return res } - // get the query + query, ok := m["query"].(string) if !ok { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -91,7 +94,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } - return + return res } p, err := parser.Parse(parser.ParseParams{Source: query}) @@ -103,7 +106,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } - return + return res } res.shouldIgnore = false @@ -112,14 +115,14 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { for _, d := range p.Definitions { if oper, ok := d.(*ast.OperationDefinition); ok { - res.operationType = strings.ToLower(oper.Operation) - - if oper.Name != nil { - res.operationName = oper.Name.Value + // If we haven't set an operation type yet, use this one + if res.operationType == "" { + res.operationType = strings.ToLower(oper.Operation) + if oper.Name != nil { + res.operationName = oper.Name.Value + } } - // If the query is a mutation then direct it to the RW endpoint, - // otherwise direct it to the RO endpoint if it's set. if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" { res.activeEndpoint = cfg.Server.HostGraphQLReadOnly } @@ -132,30 +135,24 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } - c.Status(403).SendString("The server is in read-only mode") + _ = c.Status(403).SendString("The server is in read-only mode") res.shouldBlock = true - return + return res } for _, dir := range oper.Directives { if dir.Name.Value == "cached" { res.cacheRequest = true for _, arg := range dir.Arguments { - if arg.Name.Value == "ttl" { - res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string)) - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't parse the ttl, using global", - Pairs: map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - } - return + switch arg.Name.Value { + case "ttl": + if v, ok := arg.Value.GetValue().(string); ok { + res.cacheTime, _ = strconv.Atoi(v) + } + case "refresh": + if v, ok := arg.Value.GetValue().(bool); ok { + res.cacheRefresh = v } - } - if arg.Name.Value == "refresh" { - res.cacheRefresh = arg.Value.GetValue().(bool) } } } @@ -164,26 +161,25 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { if cfg.Security.BlockIntrospection { res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections) if res.shouldBlock { - return + return res } } } } - return + return res +} + +func unsafeString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) } func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { for _, s := range selections { - field, ok := s.(*ast.Field) - if !ok { - continue // or handle the case where the type assertion fails - } - shouldBlock := checkIfContainsIntrospection(c, field.Name.Value) - if shouldBlock { - return true - } - if field.SelectionSet != nil { - if checkSelections(c, field.GetSelectionSet().Selections) { + if field, ok := s.(*ast.Field); ok { + if checkIfContainsIntrospection(c, field.Name.Value) { + return true + } + if field.SelectionSet != nil && checkSelections(c, field.GetSelectionSet().Selections) { return true } } @@ -191,32 +187,26 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { return false } -func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) { +func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) bool { whateverLower := strings.ToLower(whatever) - got_exemption := false + mu.RLock() + defer mu.RUnlock() - // If the query is an introspection query, we need to check if it's allowed. - if _, exists := introspectionQuerySet[whateverLower]; exists { + if _, exists := introspectionQueries[whateverLower]; exists { if len(cfg.Security.IntrospectionAllowed) > 0 { - - if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { + if _, allowed := introspectionAllowedQueries[whateverLower]; allowed { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Introspection query allowed, passing through", Pairs: map[string]interface{}{"query": whatever}, }) - got_exemption = true - shouldBlock = false + return false } } - if !got_exemption { - shouldBlock = true - } - } - if shouldBlock { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } - c.Status(403).SendString("Introspection queries are not allowed") + _ = c.Status(403).SendString("Introspection queries are not allowed") + return true } - return + return false } diff --git a/graphql_test.go b/graphql_test.go index 7d5f664..f49924a 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -1,6 +1,10 @@ package main import ( + "fmt" + "strings" + + fiber "github.com/gofiber/fiber/v2" "github.com/valyala/fasthttp" ) @@ -318,3 +322,111 @@ func (suite *Tests) Test_parseGraphQLQuery() { }) } } + +func (suite *Tests) Test_parseGraphQLQuery_complex() { + // ... existing tests ... + + // Add these new test cases + suite.Run("test complex query with multiple operations", func() { + query := ` + query GetUser($id: ID!) { + user(id: $id) { + name + email + } + } + mutation UpdateUser($id: ID!, $name: String!) { + updateUser(id: $id, name: $name) { + id + name + } + } + ` + body := fmt.Sprintf(`{"query": %q}`, query) + ctx := createTestContext(body) + result := parseGraphQLQuery(ctx) + assert.Equal("query", result.operationType) + assert.Equal("GetUser", result.operationName) + assert.False(result.shouldBlock) + }) + + suite.Run("test query with custom directives", func() { + query := ` + query GetUser($id: ID!) @custom(directive: "value") { + user(id: $id) { + name + email + } + } + ` + body := fmt.Sprintf(`{"query": %q}`, query) + ctx := createTestContext(body) + result := parseGraphQLQuery(ctx) + assert.Equal("query", result.operationType) + assert.Equal("GetUser", result.operationName) + assert.False(result.shouldBlock) + assert.False(result.shouldBlock) + }) +} + +func (suite *Tests) Test_checkAllowedURLs() { + tests := []struct { + name string + path string + allowed []string + expected bool + }{ + {"allowed path", "/v1/graphql", []string{"/v1/graphql"}, true}, + {"disallowed path", "/v2/graphql", []string{"/v1/graphql"}, false}, + {"empty allowed list", "/v1/graphql", []string{}, true}, + {"multiple allowed paths", "/v2/graphql", []string{"/v1/graphql", "/v2/graphql"}, true}, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + allowedUrls = make(map[string]struct{}) + for _, url := range tt.allowed { + allowedUrls[url] = struct{}{} + } + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().SetRequestURI(tt.path) + ctx.Request().URI().SetPath(tt.path) + result := checkAllowedURLs(ctx) + assert.Equal(tt.expected, result) + }) + } +} + +func (suite *Tests) Test_checkIfContainsIntrospection() { + tests := []struct { + name string + query string + allowed []string + expected bool + }{ + {"allowed introspection", "__schema", []string{"__schema"}, false}, + {"disallowed introspection", "__type", []string{"__schema"}, true}, + {"non-introspection query", "normalQuery", []string{}, false}, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + cfg.Security.IntrospectionAllowed = tt.allowed + introspectionAllowedQueries = make(map[string]struct{}) + for _, q := range tt.allowed { + introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} + } + ctx := createTestContext("") + result := checkIfContainsIntrospection(ctx, tt.query) + assert.Equal(tt.expected, result) + }) + } +} + +func createTestContext(body string) *fiber.Ctx { + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().SetBody([]byte(body)) + return ctx +} diff --git a/main_test.go b/main_test.go index 7e3e9ee..e2ea74c 100644 --- a/main_test.go +++ b/main_test.go @@ -112,3 +112,29 @@ func (suite *Tests) Test_envVariableSetting() { }) } } + +func (suite *Tests) Test_getDetailsFromEnv() { + tests := []struct { + name string + key string + defaultValue interface{} + envValue string + expected interface{} + }{ + {"string value", "TEST_STRING", "default", "envValue", "envValue"}, + {"int value", "TEST_INT", 0, "123", 123}, + {"bool value", "TEST_BOOL", false, "true", true}, + {"default value", "NON_EXISTENT", "default", "", "default"}, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + if tt.envValue != "" { + os.Setenv("GMP_"+tt.key, tt.envValue) + defer os.Unsetenv("GMP_" + tt.key) + } + result := getDetailsFromEnv(tt.key, tt.defaultValue) + assert.Equal(tt.expected, result) + }) + } +} diff --git a/monitoring/helpers.go b/monitoring/helpers.go index 7412b91..4773866 100644 --- a/monitoring/helpers.go +++ b/monitoring/helpers.go @@ -12,17 +12,14 @@ import ( libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config" ) -// Cache for sorted label keys to avoid repeated sorting var sortedLabelKeysCache = struct { - m map[string][]string - sync.RWMutex -}{m: make(map[string][]string)} + m sync.Map +}{} func (ms *MetricsSetup) get_metrics_name(name string, labels map[string]string) string { const unknownPodName = "unknown" var buf bytes.Buffer - // Prepare default labels without initializing a new map podName := getPodName() if labels == nil { labels = defaultLabels(podName) @@ -30,18 +27,16 @@ func (ms *MetricsSetup) get_metrics_name(name string, labels map[string]string) ensureDefaultLabels(&labels, podName) } - // Prefix handling if ms.metrics_prefix != "" { buf.WriteString(ms.metrics_prefix) - buf.WriteString("_") + buf.WriteByte('_') } buf.WriteString(name) - // Append labels if any if len(labels) > 0 { - buf.WriteString("{") + buf.WriteByte('{') appendSortedLabels(&buf, labels) - buf.WriteString("}") + buf.WriteByte('}') } return buf.String() @@ -78,34 +73,30 @@ func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) { keys := getSortedKeys(labels) for i, k := range keys { if i > 0 { - buf.WriteString(",") + buf.WriteByte(',') } buf.WriteString(k) - buf.WriteString("=\"") + buf.WriteString(`="`) buf.WriteString(labels[k]) - buf.WriteString("\"") + buf.WriteByte('"') } } func getSortedKeys(labels map[string]string) []string { labelsKey := labelsToString(labels) - sortedLabelKeysCache.RLock() - keys, exists := sortedLabelKeysCache.m[labelsKey] - sortedLabelKeysCache.RUnlock() - - if !exists { - keys = make([]string, 0, len(labels)) - for k := range labels { - keys = append(keys, k) - } - sort.Strings(keys) - - sortedLabelKeysCache.Lock() - sortedLabelKeysCache.m[labelsKey] = keys - sortedLabelKeysCache.Unlock() + if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok { + return keys.([]string) } + keys := make([]string, 0, len(labels)) + for k := range labels { + keys = append(keys, k) + } + sort.Strings(keys) + + sortedLabelKeysCache.m.Store(labelsKey, keys) + return keys } @@ -113,29 +104,24 @@ func labelsToString(labels map[string]string) string { var sb strings.Builder for k, v := range labels { sb.WriteString(k) - sb.WriteString("=") + sb.WriteByte('=') sb.WriteString(v) - sb.WriteString(";") + sb.WriteByte(';') } return sb.String() } -// validate_metrics_name validates the name of the metric to adhere to the Prometheus naming conventions -// https://prometheus.io/docs/practices/naming/ func validate_metrics_name(name string) error { cleanedName := clean_metric_name(name) - // Trim leading and trailing underscores finalName := strings.Trim(cleanedName, "_") - // Check if the processed name matches the original input if finalName != name { - return fmt.Errorf("Invalid metric name: %s, expected %s", name, finalName) + return fmt.Errorf("invalid metric name: %s, expected %s", name, finalName) } return nil } -// clean_metric_name processes the metric name according to Prometheus naming conventions func clean_metric_name(name string) string { var buf bytes.Buffer lastWasUnderscore := false @@ -144,31 +130,27 @@ func clean_metric_name(name string) string { if is_allowed_rune(r) { if is_special_rune(r) { if lastWasUnderscore { - continue // Skip if the previous character was also an underscore + continue } - r = '_' // Convert spaces and special characters to underscores + r = '_' lastWasUnderscore = true } else { lastWasUnderscore = false } buf.WriteRune(r) } else if !lastWasUnderscore { - buf.WriteRune('_') + buf.WriteByte('_') lastWasUnderscore = true } } - // Remove trailing underscore - result := buf.String() - return strings.Trim(result, "_") + return strings.Trim(buf.String(), "_") } -// is_allowed_rune checks if the rune is allowed in the metric name func is_allowed_rune(r rune) bool { return unicode.IsLetter(r) || unicode.IsDigit(r) || r == ' ' || r == '_' } -// is_special_rune checks if the rune is a space or an underscore func is_special_rune(r rune) bool { return r == ' ' || r == '_' } @@ -178,14 +160,12 @@ func compile_metrics_with_labels(name string, labels map[string]string) string { buf.WriteString(name) - // Collect keys and sort them keys := getSortedKeys(labels) - // Append sorted key-value pairs to the buffer for _, k := range keys { - buf.WriteString("_") + buf.WriteByte('_') buf.WriteString(k) - buf.WriteString("_") + buf.WriteByte('_') buf.WriteString(labels[k]) } diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index 32f08f6..165a0df 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -1,6 +1,3 @@ -// Package `libpack_monitoring` provides and easy way to add prometheus metrics to your application. -// It also provides a way to add custom metrics to the already started prometheus registry. - package libpack_monitoring import ( @@ -22,9 +19,7 @@ type MetricsSetup struct { metrics_prefix string } -var ( - log *libpack_logger.Logger -) +var log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO) type InitConfig struct { PurgeOnCrawl bool @@ -32,11 +27,11 @@ type InitConfig struct { } func NewMonitoring(ic *InitConfig) *MetricsSetup { - log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO) - ms := &MetricsSetup{ic: ic} - ms.metrics_set = metrics.NewSet() - ms.metrics_set_custom = metrics.NewSet() - // if not testing, start the prometheus endpoint + ms := &MetricsSetup{ + ic: ic, + metrics_set: metrics.NewSet(), + metrics_set_custom: metrics.NewSet(), + } if flag.Lookup("test.v") == nil { go ms.startPrometheusEndpoint() @@ -60,9 +55,11 @@ func (ms *MetricsSetup) startPrometheusEndpoint() { AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION), }) app.Get("/metrics", ms.metricsEndpoint) - err := app.Listen(fmt.Sprintf(":%d", envutil.GetInt("MONITORING_PORT", 9393))) - if err != nil { - fmt.Println("Can't start the service: ", err) + if err := app.Listen(fmt.Sprintf(":%d", envutil.GetInt("MONITORING_PORT", 9393))); err != nil { + log.Critical(&libpack_logger.LogMessage{ + Message: "Can't start the service", + Pairs: map[string]interface{}{"error": err}, + }) } } @@ -85,7 +82,7 @@ func (ms *MetricsSetup) ListActiveMetrics() []string { } func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[string]string, val float64) *metrics.Gauge { - if validate_metrics_name(metric_name) != nil { + if err := validate_metrics_name(metric_name); err != nil { log.Critical(&libpack_logger.LogMessage{ Message: "RegisterMetricsGauge() error", Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name}, @@ -93,13 +90,12 @@ func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[stri return nil } return ms.metrics_set_custom.GetOrCreateGauge(ms.get_metrics_name(metric_name, labels), func() float64 { - // get current value of the gauge and add val to it return val }) } func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[string]string) *metrics.Counter { - if validate_metrics_name(metric_name) != nil { + if err := validate_metrics_name(metric_name); err != nil { log.Critical(&libpack_logger.LogMessage{ Message: "RegisterMetricsCounter() error", Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name}, @@ -113,7 +109,7 @@ func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[st } func (ms *MetricsSetup) RegisterFloatCounter(metric_name string, labels map[string]string) *metrics.FloatCounter { - if validate_metrics_name(metric_name) != nil { + if err := validate_metrics_name(metric_name); err != nil { log.Critical(&libpack_logger.LogMessage{ Message: "RegisterFloatCounter() error", Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name}, @@ -124,7 +120,7 @@ func (ms *MetricsSetup) RegisterFloatCounter(metric_name string, labels map[stri } func (ms *MetricsSetup) RegisterMetricsSummary(metric_name string, labels map[string]string) *metrics.Summary { - if validate_metrics_name(metric_name) != nil { + if err := validate_metrics_name(metric_name); err != nil { log.Critical(&libpack_logger.LogMessage{ Message: "RegisterMetricsSummary() error", Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name}, @@ -135,7 +131,7 @@ func (ms *MetricsSetup) RegisterMetricsSummary(metric_name string, labels map[st } func (ms *MetricsSetup) RegisterMetricsHistogram(metric_name string, labels map[string]string) *metrics.Histogram { - if validate_metrics_name(metric_name) != nil { + if err := validate_metrics_name(metric_name); err != nil { log.Critical(&libpack_logger.LogMessage{ Message: "RegisterMetricsHistogram() error", Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name}, diff --git a/proxy.go b/proxy.go index 93b6bd9..3ff07ff 100644 --- a/proxy.go +++ b/proxy.go @@ -3,16 +3,25 @@ package main import ( "crypto/tls" "fmt" + "net/url" "time" "github.com/avast/retry-go/v4" - fiber "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/proxy" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" "github.com/valyala/fasthttp" ) +var ( + httpClient *fasthttp.Client +) + +func init() { + httpClient = createFasthttpClient(30) // Assuming a default timeout of 30 seconds +} + func createFasthttpClient(timeout int) *fasthttp.Client { return &fasthttp.Client{ Name: "graphql_proxy", @@ -21,14 +30,13 @@ func createFasthttpClient(timeout int) *fasthttp.Client { InsecureSkipVerify: true, }, MaxConnsPerHost: 2048, - ReadTimeout: time.Second * time.Duration(timeout), - WriteTimeout: time.Second * time.Duration(timeout), - MaxIdleConnDuration: time.Second * time.Duration(timeout), - MaxConnDuration: time.Second * time.Duration(timeout), + ReadTimeout: time.Duration(timeout) * time.Second, + WriteTimeout: time.Duration(timeout) * time.Second, + MaxIdleConnDuration: time.Duration(timeout) * time.Second, + MaxConnDuration: time.Duration(timeout) * time.Second, DisableHeaderNamesNormalizing: true, } } - func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { if !checkAllowedURLs(c) { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -38,44 +46,22 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } - c.Status(403).SendString("Request blocked - not allowed URL") - return nil + return fmt.Errorf("request blocked - not allowed URL: %s", c.Path()) + } + + proxyURL := currentEndpoint + c.Path() + _, err := url.Parse(proxyURL) + if err != nil { + return fmt.Errorf("invalid URL: %v", err) } - c.Request().Header.DisableNormalizing() - c.Request().Header.Add("X-Real-IP", c.IP()) - c.Request().Header.Add(fiber.HeaderXForwardedFor, string(c.Request().Header.Peek("X-Forwarded-For"))) - c.Request().Header.Del(fiber.HeaderAcceptEncoding) - // added dummy check for the log level because it executes additional functions which could - // potentially slow down the execution. if cfg.LogLevel == "debug" { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Proxying the request", - Pairs: map[string]interface{}{ - "path": c.Path(), - "body": string(c.Request().Body()), - "headers": c.GetReqHeaders(), - "request_uuid": c.Locals("request_uuid"), - }, - }) + logDebugRequest(c) } - err := retry.Do( + err = retry.Do( func() error { - errInt := proxy.DoRedirects(c, currentEndpoint+c.Path(), 3, cfg.Client.FastProxyClient) - if errInt != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't proxy the request", - Pairs: map[string]interface{}{ - "error": errInt.Error(), - }, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - } - return errInt - } - return nil + return proxy.DoRedirects(c, proxyURL, 3, httpClient) }, retry.OnRetry(func(n uint, err error) { cfg.Logger.Warning(&libpack_logger.LogMessage{ @@ -86,42 +72,59 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { }, }) }), - retry.Attempts(uint(3)), + retry.Attempts(3), retry.DelayType(retry.BackOffDelay), - retry.Delay(time.Duration(250*time.Millisecond)), + retry.Delay(250*time.Millisecond), retry.LastErrorOnly(true), ) if err != nil { cfg.Logger.Warning(&libpack_logger.LogMessage{ Message: "Can't proxy the request", - Pairs: map[string]interface{}{ - "error": err.Error(), - }, + Pairs: map[string]interface{}{"error": err.Error()}, }) - return err + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } + return fmt.Errorf("failed to proxy request: %v", err) } if cfg.LogLevel == "debug" { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Received proxied response", - Pairs: map[string]interface{}{ - "path": c.Path(), - "response_body": string(c.Response().Body()), - "response_code": c.Response().StatusCode(), - "headers": c.GetRespHeaders(), - "request_uuid": c.Locals("request_uuid"), - }, - }) + logDebugResponse(c) } if c.Response().StatusCode() != 200 { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } - return fmt.Errorf("Received non-200 response from the GraphQL server: %d", c.Response().StatusCode()) + return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode()) } c.Response().Header.Del(fiber.HeaderServer) return nil } + +func logDebugRequest(c *fiber.Ctx) { + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Proxying the request", + Pairs: map[string]interface{}{ + "path": c.Path(), + "body": string(c.Body()), + "headers": c.GetReqHeaders(), + "request_uuid": c.Locals("request_uuid"), + }, + }) +} + +func logDebugResponse(c *fiber.Ctx) { + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Received proxied response", + Pairs: map[string]interface{}{ + "path": c.Path(), + "response_body": string(c.Response().Body()), + "response_code": c.Response().StatusCode(), + "headers": c.GetRespHeaders(), + "request_uuid": c.Locals("request_uuid"), + }, + }) +} diff --git a/proxy_test.go b/proxy_test.go index 69fb276..afead6b 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1,6 +1,8 @@ package main import ( + "strings" + "github.com/valyala/fasthttp" ) @@ -95,3 +97,31 @@ func (suite *Tests) Test_proxyTheRequest() { }) } } + +func (suite *Tests) Test_proxyTheRequestWithPayloads() { + allowedUrls = make(map[string]struct{}) + allowedUrls["/"] = struct{}{} + + suite.Run("Test with invalid URL", func() { + cfg.Server.HostGraphQL = "://invalid-url" + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) + err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) + assert.NotNil(err) + }) + + suite.Run("Test with network error", func() { + cfg.Server.HostGraphQL = "http://non-existent-host.invalid" + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) + err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) + assert.NotNil(err) + }) + + suite.Run("Test with large payload", func() { + cfg.Server.HostGraphQL = "https://telegram-bot.app/" + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) + largePayload := strings.Repeat("a", 10*1024*1024) // 10MB payload + ctx.Context().Request.SetBody([]byte(largePayload)) + err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) + assert.Nil(err) + }) +} diff --git a/ratelimit.go b/ratelimit.go index 0034615..18e9f85 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -6,54 +6,42 @@ import ( "time" "github.com/goccy/go-json" - goratecounter "github.com/lukaszraczylo/go-ratecounter" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) type RateLimitConfig struct { RateCounterTicker *goratecounter.RateCounter - Interval string `json:"interval"` - Req int `json:"req"` + Interval time.Duration `json:"interval"` + Req int `json:"req"` } var ( - rateLimits map[string]RateLimitConfig - ratelimitIntervals = map[string]time.Duration{ - "milli": time.Millisecond, - "micro": time.Microsecond, - "nano": time.Nanosecond, - "second": time.Second, - "minute": time.Minute, - "hour": time.Hour, - "day": 24 * time.Hour, - } - configPaths = []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"} - mu sync.RWMutex + rateLimits = make(map[string]RateLimitConfig) + rateLimitMu sync.RWMutex ) func loadRatelimitConfig() error { - for _, path := range configPaths { + paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"} + for _, path := range paths { if err := loadConfigFromPath(path); err == nil { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Failed to load config", - Pairs: map[string]interface{}{"path": path, "error": err}, - }) return nil } } - cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Rate limit config not found", - Pairs: map[string]interface{}{"paths": configPaths}, + Pairs: map[string]interface{}{"paths": paths}, }) - return os.ErrNotExist } func loadConfigFromPath(path string) error { file, err := os.ReadFile(path) if err != nil { + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Failed to load config", + Pairs: map[string]interface{}{"path": path, "error": err}, + }) return err } @@ -65,29 +53,29 @@ func loadConfigFromPath(path string) error { return err } - mu.Lock() - defer mu.Unlock() - - rateLimits = make(map[string]RateLimitConfig, len(config.RateLimit)) + newRateLimits := make(map[string]RateLimitConfig, len(config.RateLimit)) for key, value := range config.RateLimit { value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ - Interval: time.Duration(value.Req) * ratelimitIntervals[value.Interval], + Interval: value.Interval, }) if cfg.LogLevel == "debug" { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Setting ratelimit config for role", Pairs: map[string]interface{}{ - "role": key, - "interval_provided": value.Interval, - "interval_used": ratelimitIntervals[value.Interval], - "ratelimit": value.Req, + "role": key, + "interval_used": value.Interval, + "ratelimit": value.Req, }, }) } - rateLimits[key] = value + newRateLimits[key] = value } + rateLimitMu.Lock() + rateLimits = newRateLimits + rateLimitMu.Unlock() + cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limit config loaded", Pairs: map[string]interface{}{"ratelimit": rateLimits}, @@ -96,21 +84,13 @@ func loadConfigFromPath(path string) error { } func rateLimitedRequest(userID, userRole string) bool { - mu.RLock() - defer mu.RUnlock() - - if rateLimits == nil { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit config not found", - Pairs: map[string]interface{}{"user_role": userRole}, - }) - return true - } - + rateLimitMu.RLock() roleConfig, ok := rateLimits[userRole] + rateLimitMu.RUnlock() + if !ok || roleConfig.RateCounterTicker == nil { cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit role or ticker not found", + Message: "Rate limit role not found or ticker not initialized", Pairs: map[string]interface{}{"user_role": userRole}, }) return true @@ -119,29 +99,23 @@ func rateLimitedRequest(userID, userRole string) bool { roleConfig.RateCounterTicker.Incr(1) tickerRate := roleConfig.RateCounterTicker.GetRate() - if cfg.LogLevel == "debug" { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Rate limit ticker", - Pairs: map[string]interface{}{ - "user_role": userRole, - "user_id": userID, - "rate": tickerRate, - "config_rate": roleConfig.Req, - "interval": roleConfig.Interval, - }, - }) + logDetails := map[string]interface{}{ + "user_role": userRole, + "user_id": userID, + "rate": tickerRate, + "config_rate": roleConfig.Req, + "interval": roleConfig.Interval, } + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Rate limit ticker", + Pairs: map[string]interface{}{"log_details": logDetails}, + }) + if tickerRate > float64(roleConfig.Req) { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limit exceeded", - Pairs: map[string]interface{}{ - "user_role": userRole, - "user_id": userID, - "rate": tickerRate, - "config_rate": roleConfig.Req, - "interval": roleConfig.Interval, - }, + Pairs: map[string]interface{}{"log_details": logDetails}, }) return false } diff --git a/server.go b/server.go index 28fe69b..5b02d60 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package main import ( "fmt" "strconv" + "sync" "time" "github.com/goccy/go-json" @@ -16,13 +17,24 @@ import ( libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" ) -// StartHTTPProxy starts the HTTP and points it to the GraphQL server. +const ( + healthCheckQueryStr = `{ __typename }` +) + +var ( + ctxPool = sync.Pool{ + New: func() interface{} { + return new(fiber.Ctx) + }, + } +) + func StartHTTPProxy() { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Starting the HTTP proxy", - Pairs: nil, }) - server := fiber.New(fiber.Config{ + + serverConfig := fiber.Config{ DisableStartupMessage: true, AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION), IdleTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second * 2, @@ -30,13 +42,14 @@ func StartHTTPProxy() { WriteTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second * 2, JSONEncoder: json.Marshal, JSONDecoder: json.Unmarshal, - }) + } + + server := fiber.New(serverConfig) server.Use(cors.New(cors.Config{ AllowOrigins: "*", })) - // add middleware to check if the request is a GraphQL query server.Use(AddRequestUUID) server.Get("/healthz", healthCheck) @@ -49,11 +62,11 @@ func StartHTTPProxy() { Message: "GraphQL proxy started", Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL}, }) - err := server.Listen(fmt.Sprintf(":%d", cfg.Server.PortGraphQL)) - if err != nil { + + if err := server.Listen(fmt.Sprintf(":%d", cfg.Server.PortGraphQL)); err != nil { cfg.Logger.Critical(&libpack_logger.LogMessage{ Message: "Can't start the service", - Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL}, + Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL, "error": err.Error()}, }) } } @@ -71,7 +84,8 @@ func checkAllowedURLs(c *fiber.Ctx) bool { if len(allowedUrls) == 0 { return true } - _, ok := allowedUrls[c.Path()] + path := c.OriginalURL() + _, ok := allowedUrls[path] return ok } @@ -81,93 +95,71 @@ func healthCheck(c *fiber.Ctx) error { Message: "Health check enabled", Pairs: map[string]interface{}{"url": cfg.Server.HealthcheckGraphQL}, }) - query := `{ __typename }` - _, err := cfg.Client.GQLClient.Query(query, nil, nil) + + _, err := cfg.Client.GQLClient.Query(healthCheckQueryStr, nil, nil) if err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't reach the GraphQL server", Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query") - return err + return c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query") } } + cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Health check returning OK", - Pairs: nil, }) - c.Status(200).SendString("Health check OK") - return nil + return c.Status(200).SendString("Health check OK") } func processGraphQLRequest(c *fiber.Ctx) error { startTime := time.Now() - // Initialize variables with default values extractedUserID := "-" extractedRoleName := "-" - var queryCacheHash string - authorization := c.Request().Header.Peek("Authorization") - if authorization != nil && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { - extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(string(authorization)) + if authorization := c.Get("Authorization"); authorization != "" && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { + extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization) } if checkIfUserIsBanned(c, extractedUserID) { - c.Status(403).SendString("User is banned") - return nil + return c.Status(403).SendString("User is banned") } - if len(cfg.Client.RoleFromHeader) > 0 { - extractedRoleName = string(c.Request().Header.Peek(cfg.Client.RoleFromHeader)) - if extractedRoleName == "" { - extractedRoleName = "-" + if cfg.Client.RoleFromHeader != "" { + if role := c.Get(cfg.Client.RoleFromHeader); role != "" { + extractedRoleName = role } } - // Implementing rate limiting if enabled if cfg.Client.RoleRateLimit { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limiting enabled", Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName}, }) if !rateLimitedRequest(extractedUserID, extractedRoleName) { - c.Status(429).SendString("Rate limit exceeded, try again later") - return nil + return c.Status(429).SendString("Rate limit exceeded, try again later") } } parsedResult := parseGraphQLQuery(c) if parsedResult.shouldBlock { - c.Status(403).SendString("Request blocked") - return nil + return c.Status(403).SendString("Request blocked") } if parsedResult.shouldIgnore { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Request passed as-is - probably not a GraphQL", - Pairs: nil, }) return proxyTheRequest(c, parsedResult.activeEndpoint) } calculatedQueryHash := libpack_cache.CalculateHash(c) - if parsedResult.cacheTime > 0 { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Cache time set via query", - Pairs: map[string]interface{}{"cacheTime": parsedResult.cacheTime}, - }) - } else { - // If not set via query, try setting via header - cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query") - if cacheQuery != nil { - parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery)) - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Cache time set via header", - Pairs: map[string]interface{}{"cacheTime": parsedResult.cacheTime}, - }) + if parsedResult.cacheTime == 0 { + if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" { + parsedResult.cacheTime, _ = strconv.Atoi(cacheQuery) } else { parsedResult.cacheTime = cfg.Cache.CacheTTL } @@ -183,82 +175,67 @@ func processGraphQLRequest(c *fiber.Ctx) error { libpack_cache.CacheDelete(calculatedQueryHash) } - // Handling Cache Logic if parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Cache enabled", Pairs: map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable}, }) - queryCacheHash = calculatedQueryHash - if cachedResponse := libpack_cache.CacheLookup(queryCacheHash); cachedResponse != nil { + if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil { cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil) cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Cache hit", - Pairs: map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, + Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, }) - c.Request().Header.Add("X-Cache-Hit", "true") - err := c.Send(cachedResponse) - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't send the cached response", - Pairs: map[string]interface{}{"error": err.Error()}, - }) - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't send the cached response - try again later") - } + c.Set("X-Cache-Hit", "true") wasCached = true - } else { - cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil) - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Cache miss", - Pairs: map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, - }) - proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime, parsedResult.activeEndpoint) + return c.Send(cachedResponse) + } + + cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil) + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Cache miss", + Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, + }) + if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil { + return err } } else { - err := proxyTheRequest(c, parsedResult.activeEndpoint) - if err != nil { + if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't proxy the request", Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't proxy the request - try again later") - return nil + return c.Status(500).SendString("Can't proxy the request - try again later") } } - timeTaken := time.Since(startTime) - - // Logging & Monitoring - logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime) + logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime) return nil } -// Additional helper function to avoid code repetition -func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) { - err := proxyTheRequest(c, currentEndpoint) - if err != nil { +func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error { + if err := proxyTheRequest(c, currentEndpoint); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't proxy the request", Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't proxy the request - try again later") - return + return c.Status(500).SendString("Can't proxy the request - try again later") } + libpack_cache.CacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second) cfg.Monitoring.Increment(libpack_monitoring.MetricsQueriesCached, nil) - c.Send(c.Response().Body()) + return c.Send(c.Response().Body()) } func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) { labels := map[string]string{ "op_type": opType, "op_name": opName, - "cached": fmt.Sprintf("%t", wasCached), + "cached": strconv.FormatBool(wasCached), "user_id": userID, } @@ -267,7 +244,7 @@ func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached Message: "Request processed", Pairs: map[string]interface{}{ "ip": c.IP(), - "fwd-ip": string(c.Request().Header.Peek("X-Forwarded-For")), + "fwd-ip": c.Get("X-Forwarded-For"), "user_id": userID, "op_type": opType, "op_name": opName,