From 6a69694ab38e5e8efc6c217fe1a5ed9c0ed7f99d Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sat, 29 Nov 2025 14:21:09 +0000 Subject: [PATCH] November improvements. (#29) * Tackling the CPU / memory spikes after some time. * Update admin dashboard, fix the circuit breaker and request coalescing. --- admin_dashboard.go | 170 ++++++++++++++++++++-- api.go | 4 +- circuit_breaker_metrics.go | 16 +-- integration_test.go | 288 +++++++++++++++++++++++++++++++++++++ main.go | 34 +++-- monitoring/monitoring.go | 42 +++++- proxy.go | 67 ++++++++- retry_budget.go | 38 ++++- rps_tracker.go | 35 +++-- 9 files changed, 633 insertions(+), 61 deletions(-) diff --git a/admin_dashboard.go b/admin_dashboard.go index 20ef196..7a6c2cb 100644 --- a/admin_dashboard.go +++ b/admin_dashboard.go @@ -79,9 +79,44 @@ func (ad *AdminDashboard) serveDashboard(c *fiber.Ctx) error { } // getStats returns overall proxy statistics +// In cluster mode (when metrics aggregator is available), returns aggregated stats from all instances func (ad *AdminDashboard) getStats(c *fiber.Ctx) error { + // Check if cluster mode is enabled - if so, return aggregated stats + if aggregator := GetMetricsAggregator(); aggregator != nil { + metrics, err := aggregator.GetAggregatedMetrics() + if err != nil { + if ad.logger != nil { + ad.logger.Error(&libpack_logger.LogMessage{ + Message: "Failed to get aggregated metrics, falling back to local stats", + Pairs: map[string]interface{}{"error": err.Error()}, + }) + } + // Fall through to local stats on error + } else { + // Return aggregated cluster stats + response := map[string]interface{}{ + "cluster_mode": true, + "total_instances": metrics.TotalInstances, + "healthy_instances": metrics.HealthyInstances, + "timestamp": metrics.LastUpdate.Format(time.RFC3339), + "version": "0.27.0", + } + + // Add combined stats from aggregation + if metrics.CombinedStats != nil { + for k, v := range metrics.CombinedStats { + response[k] = v + } + } + + return c.JSON(response) + } + } + + // Local instance stats (fallback or non-cluster mode) uptimeSeconds := time.Since(startTime).Seconds() stats := map[string]interface{}{ + "cluster_mode": false, "timestamp": time.Now().Format(time.RFC3339), "uptime_seconds": uptimeSeconds, "uptime_human": formatDuration(time.Since(startTime)), @@ -233,9 +268,62 @@ func (ad *AdminDashboard) getCircuitBreakerStatus(c *fiber.Ctx) error { } // getCacheStats returns cache statistics +// In cluster mode, returns aggregated cache stats from all instances func (ad *AdminDashboard) getCacheStats(c *fiber.Ctx) error { + // Check if cluster mode is enabled - if so, return aggregated cache stats + if aggregator := GetMetricsAggregator(); aggregator != nil { + metrics, err := aggregator.GetAggregatedMetrics() + if err != nil { + if ad.logger != nil { + ad.logger.Error(&libpack_logger.LogMessage{ + Message: "Failed to get aggregated cache metrics, falling back to local stats", + Pairs: map[string]interface{}{"error": err.Error()}, + }) + } + // Fall through to local stats on error + } else { + // Build aggregated cache stats from combined stats + response := map[string]interface{}{ + "cluster_mode": true, + "total_instances": metrics.TotalInstances, + } + + // Add cache config from local config + if cfg != nil { + response["enabled"] = cfg.Cache.CacheEnable + response["redis_enabled"] = cfg.Cache.CacheRedisEnable + response["ttl_seconds"] = cfg.Cache.CacheTTL + response["max_memory_mb"] = cfg.Cache.CacheMaxMemorySize + response["max_entries"] = cfg.Cache.CacheMaxEntries + } + + // Extract aggregated cache stats from combined stats + if metrics.CombinedStats != nil { + if cacheHits, ok := metrics.CombinedStats["cache_hits"]; ok { + response["cache_hits"] = cacheHits + } + if cacheMisses, ok := metrics.CombinedStats["cache_misses"]; ok { + response["cache_misses"] = cacheMisses + } + if cachedQueries, ok := metrics.CombinedStats["cached_queries"]; ok { + response["cached_queries"] = cachedQueries + } + if hitRate, ok := metrics.CombinedStats["cache_hit_rate_pct"]; ok { + response["hit_rate_pct"] = hitRate + } + if memoryMB, ok := metrics.CombinedStats["memory_usage_mb"]; ok { + response["memory_usage_mb"] = memoryMB + } + } + + return c.JSON(response) + } + } + + // Local instance stats (fallback or non-cluster mode) stats := map[string]interface{}{ - "enabled": false, + "cluster_mode": false, + "enabled": false, } if cfg != nil { @@ -590,8 +678,8 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() - // Send initial stats immediately - if stats := ad.gatherAllStats(); stats != nil { + // Send initial stats immediately (cluster-aware for dashboard) + if stats := ad.gatherAllStatsClusterAware(); stats != nil { if data, err := json.Marshal(stats); err == nil { c.WriteMessage(websocket.TextMessage, data) } @@ -601,8 +689,8 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) { for { select { case <-ticker.C: - // Gather all stats - stats := ad.gatherAllStats() + // Gather all stats (cluster-aware for dashboard) + stats := ad.gatherAllStatsClusterAware() // Marshal to JSON data, err := json.Marshal(stats) @@ -635,8 +723,56 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) { } // gatherAllStats collects all statistics into a single structure +// This always returns LOCAL stats for this instance (used by metrics aggregator) func (ad *AdminDashboard) gatherAllStats() map[string]interface{} { + return ad.gatherAllStatsWithMode(false) +} + +// gatherAllStatsClusterAware collects statistics with cluster awareness +// If cluster mode is available, returns aggregated stats from all instances +func (ad *AdminDashboard) gatherAllStatsClusterAware() map[string]interface{} { + return ad.gatherAllStatsWithMode(true) +} + +// gatherAllStatsWithMode collects statistics with optional cluster mode +func (ad *AdminDashboard) gatherAllStatsWithMode(useClusterMode bool) map[string]interface{} { + // Check if cluster mode is requested and available + if useClusterMode { + if aggregator := GetMetricsAggregator(); aggregator != nil { + metrics, err := aggregator.GetAggregatedMetrics() + if err == nil && metrics != nil { + // Return aggregated cluster stats + result := map[string]interface{}{ + "cluster_mode": true, + "total_instances": metrics.TotalInstances, + "healthy_instances": metrics.HealthyInstances, + } + + // Build stats section from combined stats + stats := map[string]interface{}{ + "timestamp": metrics.LastUpdate.Format(time.RFC3339), + "version": "0.27.0", + } + + // Copy all combined stats + if metrics.CombinedStats != nil { + for k, v := range metrics.CombinedStats { + stats[k] = v + } + } + result["stats"] = stats + + // Add per-instance details + result["instances"] = metrics.Instances + + return result + } + } + } + + // Fall back to local stats result := make(map[string]interface{}) + result["cluster_mode"] = false // Main stats uptimeSeconds := time.Since(startTime).Seconds() @@ -787,16 +923,24 @@ func (ad *AdminDashboard) gatherAllStats() map[string]interface{} { } cacheStats["hit_rate_pct"] = hitRate - memoryUsage := libpack_cache.GetCacheMemoryUsage() - maxMemory := libpack_cache.GetCacheMaxMemorySize() - cacheStats["memory_usage_bytes"] = memoryUsage - cacheStats["memory_usage_mb"] = float64(memoryUsage) / (1024 * 1024) + // Only get memory usage for in-memory cache (not Redis) + if cfg.Cache.CacheEnable && !cfg.Cache.CacheRedisEnable { + memoryUsage := libpack_cache.GetCacheMemoryUsage() + maxMemory := libpack_cache.GetCacheMaxMemorySize() + cacheStats["memory_usage_bytes"] = memoryUsage + cacheStats["memory_usage_mb"] = float64(memoryUsage) / (1024 * 1024) - memoryUsagePct := 0.0 - if maxMemory > 0 { - memoryUsagePct = float64(memoryUsage) / float64(maxMemory) * 100 + memoryUsagePct := 0.0 + if maxMemory > 0 { + memoryUsagePct = float64(memoryUsage) / float64(maxMemory) * 100 + } + cacheStats["memory_usage_pct"] = memoryUsagePct + } else { + // For Redis cache, memory tracking is not available per instance + cacheStats["memory_usage_bytes"] = int64(-1) + cacheStats["memory_usage_mb"] = float64(-1) + cacheStats["memory_usage_pct"] = float64(-1) } - cacheStats["memory_usage_pct"] = memoryUsagePct } } result["cache"] = cacheStats diff --git a/api.go b/api.go index a7f0aac..a602f0a 100644 --- a/api.go +++ b/api.go @@ -170,9 +170,11 @@ func apiCircuitBreakerHealth(c *fiber.Ctx) error { }) } - // Get circuit breaker state + // Get circuit breaker state with proper mutex protection + cbMutex.RLock() state := cb.State() counts := cb.Counts() + cbMutex.RUnlock() // Determine health status var status string diff --git a/circuit_breaker_metrics.go b/circuit_breaker_metrics.go index df02a2c..cb4da6f 100644 --- a/circuit_breaker_metrics.go +++ b/circuit_breaker_metrics.go @@ -23,18 +23,14 @@ func NewCircuitBreakerMetrics(monitoring *libpack_monitoring.MetricsSetup) *Circ // Initialize state value cbm.stateValue.Store(float64(0)) - // Create gauge with callback that reads the atomic value - cbm.stateGauge = monitoring.RegisterMetricsGauge( + // Create gauge with callback that reads the atomic value on every scrape + // This ensures the metric always reflects the current circuit breaker state + cbm.stateGauge = monitoring.RegisterMetricsGaugeFunc( libpack_monitoring.MetricsCircuitState, nil, - 0, // Initial value doesn't matter as callback will be used - ) - - // Override the gauge callback to read from atomic value - cbm.stateGauge = monitoring.RegisterMetricsGauge( - libpack_monitoring.MetricsCircuitState, - nil, - cbm.GetState(), + func() float64 { + return cbm.GetState() + }, ) return cbm diff --git a/integration_test.go b/integration_test.go index 4ccc1eb..08c25fa 100644 --- a/integration_test.go +++ b/integration_test.go @@ -5,6 +5,8 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" + "sync/atomic" "time" "github.com/gofiber/fiber/v2" @@ -495,3 +497,289 @@ func getMetricValue(metricName string) int { } return int(counter.Get()) } + +// TestRequestCoalescingIntegration tests that request coalescing works end-to-end +// through the proxy layer, ensuring concurrent identical requests result in only +// one backend call while all clients receive the correct response. +func (suite *Tests) TestRequestCoalescingIntegration() { + // Save original config + originalCoalescing := cfg.RequestCoalescing + originalClient := cfg.Client.FastProxyClient + originalHostGraphQL := cfg.Server.HostGraphQL + + // Restore after test + defer func() { + cfg.RequestCoalescing = originalCoalescing + cfg.Client.FastProxyClient = originalClient + cfg.Server.HostGraphQL = originalHostGraphQL + }() + + // Track backend calls + var backendCallCount atomic.Int32 + var requestDelay = 100 * time.Millisecond + + // Create test server that counts requests and introduces delay + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCallCount.Add(1) + time.Sleep(requestDelay) // Delay to allow concurrent requests to coalesce + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":{"users":[{"id":"1","name":"Test User"}]}}`)) + })) + defer server.Close() + + // Configure for test + cfg.Server.HostGraphQL = server.URL + cfg.Client.ClientTimeout = 5 + cfg.Client.FastProxyClient = createFasthttpClient(cfg) + cfg.RequestCoalescing.Enable = true + + // Initialize request coalescer for this test + // Reset the global coalescer by creating a new one + testCoalescer := NewRequestCoalescer(true, cfg.Logger, cfg.Monitoring) + + // Temporarily replace global coalescer + originalCoalescer := requestCoalescer + requestCoalescer = testCoalescer + defer func() { + requestCoalescer = originalCoalescer + }() + + // Test Case 1: Concurrent identical requests should coalesce + suite.Run("concurrent_identical_requests_coalesce", func() { + backendCallCount.Store(0) + testCoalescer.Reset() + + concurrentRequests := 10 + var wg sync.WaitGroup + wg.Add(concurrentRequests) + + responses := make([]string, concurrentRequests) + errors := make([]error, concurrentRequests) + + // Launch concurrent requests with identical query + for i := 0; i < concurrentRequests; i++ { + go func(index int) { + defer wg.Done() + + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.SetRequestURI("/graphql") + reqCtx.Request.Header.SetMethod("POST") + reqCtx.Request.Header.Set("Content-Type", "application/json") + reqCtx.Request.SetBody([]byte(`{"query": "query { users { id name } }"}`)) + + ctx := suite.app.AcquireCtx(reqCtx) + err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) + errors[index] = err + responses[index] = string(ctx.Response().Body()) + suite.app.ReleaseCtx(ctx) + }(i) + } + + wg.Wait() + + // Verify only 1 backend call was made + suite.Equal(int32(1), backendCallCount.Load(), + "Should make only 1 backend call for %d concurrent identical requests", concurrentRequests) + + // Verify all requests succeeded with same response + expectedResponse := `{"data":{"users":[{"id":"1","name":"Test User"}]}}` + for i := 0; i < concurrentRequests; i++ { + suite.Nil(errors[i], "Request %d should succeed", i) + suite.Equal(expectedResponse, responses[i], + "Request %d should have correct response", i) + } + + // Verify coalescing stats + stats := testCoalescer.GetStats() + suite.Equal(int64(concurrentRequests), stats["total_requests"], + "Total requests should match") + suite.Equal(int64(1), stats["primary_requests"], + "Should have 1 primary request") + suite.Equal(int64(concurrentRequests-1), stats["coalesced_requests"], + "Should have %d coalesced requests", concurrentRequests-1) + }) + + // Test Case 2: Different queries should NOT coalesce + suite.Run("different_queries_not_coalesced", func() { + backendCallCount.Store(0) + testCoalescer.Reset() + + // Create server that returns query-specific responses + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCallCount.Add(1) + time.Sleep(50 * time.Millisecond) + + body := make([]byte, r.ContentLength) + _, _ = r.Body.Read(body) + + var response string + if strings.Contains(string(body), "query1") { + response = `{"data":{"result":"query1"}}` + } else if strings.Contains(string(body), "query2") { + response = `{"data":{"result":"query2"}}` + } else { + response = `{"data":{"result":"unknown"}}` + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(response)) + })) + defer server2.Close() + + cfg.Server.HostGraphQL = server2.URL + cfg.Client.FastProxyClient = createFasthttpClient(cfg) + + var wg sync.WaitGroup + wg.Add(2) + + var response1, response2 string + var err1, err2 error + + // Launch two requests with different queries concurrently + go func() { + defer wg.Done() + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.SetRequestURI("/graphql") + reqCtx.Request.Header.SetMethod("POST") + reqCtx.Request.Header.Set("Content-Type", "application/json") + reqCtx.Request.SetBody([]byte(`{"query": "query { query1 }"}`)) + + ctx := suite.app.AcquireCtx(reqCtx) + err1 = proxyTheRequest(ctx, cfg.Server.HostGraphQL) + response1 = string(ctx.Response().Body()) + suite.app.ReleaseCtx(ctx) + }() + + go func() { + defer wg.Done() + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.SetRequestURI("/graphql") + reqCtx.Request.Header.SetMethod("POST") + reqCtx.Request.Header.Set("Content-Type", "application/json") + reqCtx.Request.SetBody([]byte(`{"query": "query { query2 }"}`)) + + ctx := suite.app.AcquireCtx(reqCtx) + err2 = proxyTheRequest(ctx, cfg.Server.HostGraphQL) + response2 = string(ctx.Response().Body()) + suite.app.ReleaseCtx(ctx) + }() + + wg.Wait() + + // Both requests should succeed + suite.Nil(err1, "Query1 should succeed") + suite.Nil(err2, "Query2 should succeed") + + // Should have made 2 backend calls (no coalescing for different queries) + suite.Equal(int32(2), backendCallCount.Load(), + "Should make 2 backend calls for 2 different queries") + + // Responses should be different + suite.Contains(response1, "query1", "Response1 should be for query1") + suite.Contains(response2, "query2", "Response2 should be for query2") + }) + + // Test Case 3: Coalescing disabled should make separate calls + suite.Run("coalescing_disabled", func() { + // Create a fresh server for this test + var disabledCallCount atomic.Int32 + serverDisabled := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + disabledCallCount.Add(1) + time.Sleep(50 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":{"users":[{"id":"1"}]}}`)) + })) + defer serverDisabled.Close() + + cfg.Server.HostGraphQL = serverDisabled.URL + cfg.Client.FastProxyClient = createFasthttpClient(cfg) + + // Disable coalescing + cfg.RequestCoalescing.Enable = false + + concurrentRequests := 5 + var wg sync.WaitGroup + wg.Add(concurrentRequests) + + // Launch concurrent identical requests + for i := 0; i < concurrentRequests; i++ { + go func() { + defer wg.Done() + + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.SetRequestURI("/graphql") + reqCtx.Request.Header.SetMethod("POST") + reqCtx.Request.Header.Set("Content-Type", "application/json") + reqCtx.Request.SetBody([]byte(`{"query": "query { users { id } }"}`)) + + ctx := suite.app.AcquireCtx(reqCtx) + _ = proxyTheRequest(ctx, cfg.Server.HostGraphQL) + suite.app.ReleaseCtx(ctx) + }() + } + + wg.Wait() + + // Should make separate backend calls when coalescing is disabled + suite.Equal(int32(concurrentRequests), disabledCallCount.Load(), + "Should make %d backend calls when coalescing is disabled", concurrentRequests) + + // Re-enable for subsequent tests + cfg.RequestCoalescing.Enable = true + }) + + // Test Case 4: Error responses should be shared correctly + suite.Run("error_responses_coalesced", func() { + backendCallCount.Store(0) + testCoalescer.Reset() + + // Create server that returns errors + serverError := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + backendCallCount.Add(1) + time.Sleep(50 * time.Millisecond) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"errors":[{"message":"Internal server error"}]}`)) + })) + defer serverError.Close() + + cfg.Server.HostGraphQL = serverError.URL + cfg.Client.FastProxyClient = createFasthttpClient(cfg) + + concurrentRequests := 5 + var wg sync.WaitGroup + wg.Add(concurrentRequests) + + errors := make([]error, concurrentRequests) + + for i := 0; i < concurrentRequests; i++ { + go func(index int) { + defer wg.Done() + + reqCtx := &fasthttp.RequestCtx{} + reqCtx.Request.SetRequestURI("/graphql") + reqCtx.Request.Header.SetMethod("POST") + reqCtx.Request.Header.Set("Content-Type", "application/json") + reqCtx.Request.SetBody([]byte(`{"query": "query { fail }"}`)) + + ctx := suite.app.AcquireCtx(reqCtx) + errors[index] = proxyTheRequest(ctx, cfg.Server.HostGraphQL) + suite.app.ReleaseCtx(ctx) + }(i) + } + + wg.Wait() + + // Should still only make 1 backend call + suite.Equal(int32(1), backendCallCount.Load(), + "Should make only 1 backend call even for error responses") + + // All requests should receive the same error + for i := 0; i < concurrentRequests; i++ { + suite.NotNil(errors[i], "Request %d should have error", i) + } + }) +} diff --git a/main.go b/main.go index 17aa141..7406f55 100644 --- a/main.go +++ b/main.go @@ -267,7 +267,7 @@ func parseConfig() { c.Api.BannedUsersFile = validatedPath } c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false) - c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0) + c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 1800) // Default: purge metrics every 30 minutes // Hasura event cleaner c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false) c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1) @@ -409,15 +409,7 @@ func parseConfig() { initCircuitBreaker(cfg) } - // Initialize retry budget - if cfg.RetryBudget.Enable { - retryBudgetConfig := RetryBudgetConfig{ - TokensPerSecond: cfg.RetryBudget.TokensPerSecond, - MaxTokens: cfg.RetryBudget.MaxTokens, - Enabled: cfg.RetryBudget.Enable, - } - InitializeRetryBudget(retryBudgetConfig, cfg.Logger) - } + // Note: Retry budget is initialized in main() with context for graceful shutdown // Initialize request coalescer if cfg.RequestCoalescing.Enable { @@ -442,11 +434,7 @@ func parseConfig() { healthMgr.StartHealthChecking() } - // Initialize RPS tracker for real-time requests per second monitoring - InitializeRPSTracker() - cfg.Logger.Info(&libpack_logging.LogMessage{ - Message: "RPS tracker initialized", - }) + // Note: RPS tracker is initialized in main() with context for graceful shutdown // Load rate limit configuration with improved error handling if err := loadRatelimitConfig(); err != nil { @@ -484,6 +472,22 @@ func main() { // Initialize shutdown manager shutdownManager = NewShutdownManager(ctx) + // Initialize RPS tracker with context for graceful shutdown + InitializeRPSTracker(ctx) + cfg.Logger.Info(&libpack_logging.LogMessage{ + Message: "RPS tracker initialized", + }) + + // Initialize retry budget with context for graceful shutdown + if cfg.RetryBudget.Enable { + retryBudgetConfig := RetryBudgetConfig{ + TokensPerSecond: cfg.RetryBudget.TokensPerSecond, + MaxTokens: cfg.RetryBudget.MaxTokens, + Enabled: cfg.RetryBudget.Enable, + } + InitializeRetryBudgetWithContext(ctx, retryBudgetConfig, cfg.Logger) + } + // Create a wait group to manage goroutines var wg sync.WaitGroup diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index aa2ff20..6aaa20e 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -1,6 +1,7 @@ package libpack_monitoring import ( + "context" "flag" "fmt" "time" @@ -17,6 +18,8 @@ type MetricsSetup struct { metrics_set_custom *metrics.Set ic *InitConfig metrics_prefix string + ctx context.Context + cancel context.CancelFunc } var log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO) @@ -27,10 +30,18 @@ type InitConfig struct { } func NewMonitoring(ic *InitConfig) *MetricsSetup { + return NewMonitoringWithContext(context.Background(), ic) +} + +// NewMonitoringWithContext creates a new monitoring instance with context for graceful shutdown +func NewMonitoringWithContext(ctx context.Context, ic *InitConfig) *MetricsSetup { + monCtx, cancel := context.WithCancel(ctx) ms := &MetricsSetup{ ic: ic, metrics_set: metrics.NewSet(), metrics_set_custom: metrics.NewSet(), + ctx: monCtx, + cancel: cancel, } if flag.Lookup("test.v") == nil { @@ -39,8 +50,14 @@ func NewMonitoring(ic *InitConfig) *MetricsSetup { if ic.PurgeEvery > 0 { ticker := time.NewTicker(time.Duration(ic.PurgeEvery) * time.Second) go func() { - for range ticker.C { - ms.PurgeMetrics() + defer ticker.Stop() + for { + select { + case <-ms.ctx.Done(): + return + case <-ticker.C: + ms.PurgeMetrics() + } } }() } @@ -49,6 +66,13 @@ func NewMonitoring(ic *InitConfig) *MetricsSetup { return ms } +// Shutdown stops the monitoring goroutines +func (ms *MetricsSetup) Shutdown() { + if ms.cancel != nil { + ms.cancel() + } +} + func (ms *MetricsSetup) startPrometheusEndpoint() { app := fiber.New(fiber.Config{ DisableStartupMessage: true, @@ -95,6 +119,20 @@ func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[stri }) } +// RegisterMetricsGaugeFunc registers a gauge with a callback function that is called on every scrape +// This is useful for metrics that need to return a dynamic value +func (ms *MetricsSetup) RegisterMetricsGaugeFunc(metric_name string, labels map[string]string, fn func() float64) *metrics.Gauge { + if err := validate_metrics_name(metric_name); err != nil { + log.Error(&libpack_logger.LogMessage{ + Message: "RegisterMetricsGaugeFunc() error - invalid metric name", + Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name}, + }) + // Return a dummy gauge instead of nil to prevent panics + return &metrics.Gauge{} + } + return ms.metrics_set_custom.GetOrCreateGauge(ms.get_metrics_name(metric_name, labels), fn) +} + func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[string]string) *metrics.Counter { if err := validate_metrics_name(metric_name); err != nil { log.Error(&libpack_logger.LogMessage{ diff --git a/proxy.go b/proxy.go index 217a78b..6dcbc64 100644 --- a/proxy.go +++ b/proxy.go @@ -325,19 +325,72 @@ func setupTracing(c *fiber.Ctx) context.Context { return ctx } -// performProxyRequest executes the proxy request with retries and circuit breaker +// performProxyRequest executes the proxy request with retries, circuit breaker, and request coalescing func performProxyRequest(c *fiber.Ctx, proxyURL string) error { + // Extract user context for cache key (needed for coalescing and circuit breaker fallback) + userID, userRole := extractUserInfo(c) + + // Calculate cache key - includes user context for security + // This key is used for both request coalescing and cache fallback + cacheKey := libpack_cache.CalculateHash(c, userID, userRole) + + // Check if request coalescing is enabled + rc := GetRequestCoalescer() + if rc != nil && cfg.RequestCoalescing.Enable { + // Use request coalescing to deduplicate identical concurrent requests + response, err := rc.Do(cacheKey, func() (*CoalescedResponse, error) { + // Execute the actual proxy request + proxyErr := performProxyRequestCore(c, proxyURL, cacheKey) + + // Capture the response for coalescing + if proxyErr != nil { + return &CoalescedResponse{ + Err: proxyErr, + StatusCode: c.Response().StatusCode(), + }, proxyErr + } + + return &CoalescedResponse{ + Body: c.Response().Body(), + StatusCode: c.Response().StatusCode(), + Headers: make(map[string]string), + }, nil + }) + + // Check for error from rc.Do (though it typically returns nil) + if err != nil { + return err + } + + // Check for error stored in the response (for coalesced requests) + if response != nil && response.Err != nil { + return response.Err + } + + // For coalesced requests (not the primary), we need to copy the response + if response != nil && response.Body != nil && len(response.Body) > 0 { + // Only set response if this is a coalesced request (body would be empty otherwise) + if len(c.Response().Body()) == 0 { + c.Response().SetStatusCode(response.StatusCode) + c.Response().SetBody(response.Body) + } + } + + return nil + } + + // No coalescing - execute directly + return performProxyRequestCore(c, proxyURL, cacheKey) +} + +// performProxyRequestCore executes the proxy request with retries and circuit breaker +// This is the core implementation used by both direct calls and coalesced requests +func performProxyRequestCore(c *fiber.Ctx, proxyURL string, cacheKey string) error { // If circuit breaker is not enabled, use the original method if !cfg.CircuitBreaker.Enable || cb == nil { return performProxyRequestWithRetries(c, proxyURL) } - // Extract user context for cache key (needed for circuit breaker fallback) - userID, userRole := extractUserInfo(c) - - // Calculate cache key for potential fallback - includes user context for security - cacheKey := libpack_cache.CalculateHash(c, userID, userRole) - // Execute request through circuit breaker _, err := cb.Execute(func() (interface{}, error) { // Execute the request with retries diff --git a/retry_budget.go b/retry_budget.go index 8ed9541..5063716 100644 --- a/retry_budget.go +++ b/retry_budget.go @@ -1,6 +1,7 @@ package main import ( + "context" "sync" "sync/atomic" "time" @@ -18,6 +19,8 @@ type RetryBudget struct { mu sync.RWMutex enabled bool logger *libpack_logger.Logger + ctx context.Context + cancel context.CancelFunc // Statistics totalAttempts atomic.Int64 @@ -32,13 +35,21 @@ type RetryBudgetConfig struct { Enabled bool // Whether retry budget is enabled } -// NewRetryBudget creates a new retry budget +// NewRetryBudget creates a new retry budget (deprecated, use NewRetryBudgetWithContext) func NewRetryBudget(config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget { + return NewRetryBudgetWithContext(context.Background(), config, logger) +} + +// NewRetryBudgetWithContext creates a new retry budget with context for graceful shutdown +func NewRetryBudgetWithContext(ctx context.Context, config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget { + budgetCtx, cancel := context.WithCancel(ctx) rb := &RetryBudget{ tokensPerSecond: config.TokensPerSecond, maxTokens: int64(config.MaxTokens), enabled: config.Enabled, logger: logger, + ctx: budgetCtx, + cancel: cancel, } // Initialize with full bucket @@ -91,8 +102,20 @@ func (rb *RetryBudget) refillLoop() { ticker := time.NewTicker(100 * time.Millisecond) // Refill every 100ms defer ticker.Stop() - for range ticker.C { - rb.refill() + for { + select { + case <-rb.ctx.Done(): + return + case <-ticker.C: + rb.refill() + } + } +} + +// Shutdown stops the retry budget goroutine +func (rb *RetryBudget) Shutdown() { + if rb.cancel != nil { + rb.cancel() } } @@ -187,10 +210,15 @@ var ( retryBudgetOnce sync.Once ) -// InitializeRetryBudget initializes the global retry budget +// InitializeRetryBudget initializes the global retry budget (deprecated, use InitializeRetryBudgetWithContext) func InitializeRetryBudget(config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget { + return InitializeRetryBudgetWithContext(context.Background(), config, logger) +} + +// InitializeRetryBudgetWithContext initializes the global retry budget with context for graceful shutdown +func InitializeRetryBudgetWithContext(ctx context.Context, config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget { retryBudgetOnce.Do(func() { - retryBudget = NewRetryBudget(config, logger) + retryBudget = NewRetryBudgetWithContext(ctx, config, logger) if logger != nil && config.Enabled { logger.Info(&libpack_logger.LogMessage{ Message: "Retry budget initialized", diff --git a/rps_tracker.go b/rps_tracker.go index 3f14332..f5f6dcd 100644 --- a/rps_tracker.go +++ b/rps_tracker.go @@ -1,6 +1,7 @@ package main import ( + "context" "sync" "sync/atomic" "time" @@ -12,11 +13,17 @@ type RPSTracker struct { lastSampleTime atomic.Int64 // Unix nano currentRPS uint64 // stored as uint64, accessed with atomic operations mu sync.RWMutex // for currentRPS updates + ctx context.Context + cancel context.CancelFunc } -// NewRPSTracker creates a new RPS tracker -func NewRPSTracker() *RPSTracker { - tracker := &RPSTracker{} +// NewRPSTracker creates a new RPS tracker with context for graceful shutdown +func NewRPSTracker(ctx context.Context) *RPSTracker { + trackerCtx, cancel := context.WithCancel(ctx) + tracker := &RPSTracker{ + ctx: trackerCtx, + cancel: cancel, + } tracker.lastSampleTime.Store(time.Now().UnixNano()) go tracker.updateLoop() return tracker @@ -33,8 +40,20 @@ func (r *RPSTracker) updateLoop() { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() - for range ticker.C { - r.sample() + for { + select { + case <-r.ctx.Done(): + return + case <-ticker.C: + r.sample() + } + } +} + +// Shutdown stops the RPS tracker +func (r *RPSTracker) Shutdown() { + if r.cancel != nil { + r.cancel() } } @@ -75,10 +94,10 @@ func (r *RPSTracker) GetCurrentRPS() float64 { var globalRPSTracker *RPSTracker -// InitializeRPSTracker initializes the global RPS tracker -func InitializeRPSTracker() *RPSTracker { +// InitializeRPSTracker initializes the global RPS tracker with context for graceful shutdown +func InitializeRPSTracker(ctx context.Context) *RPSTracker { if globalRPSTracker == nil { - globalRPSTracker = NewRPSTracker() + globalRPSTracker = NewRPSTracker(ctx) } return globalRPSTracker }