package main import ( "context" "flag" "fmt" "net/url" "os" "os/signal" "path/filepath" "runtime" "strconv" "strings" "sync" "syscall" "time" "github.com/gofiber/fiber/v2/middleware/proxy" "github.com/gookit/goutil/envutil" graphql "github.com/lukaszraczylo/go-simple-graphql" libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache" libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config" libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" libpack_tracing "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing" ) var ( cfg *config cfgMutex sync.RWMutex once sync.Once tracer *libpack_tracing.TracingSetup shutdownManager *ShutdownManager ) // getDetailsFromEnv retrieves the value from the environment or returns the default. // It first checks for a prefixed environment variable (GMP_KEY), then falls back to the unprefixed version. func getDetailsFromEnv[T any](key string, defaultValue T) T { prefixedKey := "GMP_" + key switch v := any(defaultValue).(type) { case string: if val, ok := os.LookupEnv(prefixedKey); ok { return any(val).(T) } return any(envutil.Getenv(key, v)).(T) case int: if val, ok := os.LookupEnv(prefixedKey); ok { if intVal, err := strconv.Atoi(val); err == nil { return any(intVal).(T) } } return any(envutil.GetInt(key, v)).(T) case bool: if val, ok := os.LookupEnv(prefixedKey); ok { boolVal := strings.ToLower(val) == "true" || val == "1" return any(boolVal).(T) } return any(envutil.GetBool(key, v)).(T) default: return defaultValue } } // validateJWTClaimPath validates JWT claim paths to prevent injection attacks func validateJWTClaimPath(path string) error { if path == "" { return nil // Empty path is valid (feature disabled) } // Prevent path traversal attempts if strings.Contains(path, "..") { return fmt.Errorf("invalid JWT claim path (contains '..'): %s", path) } // Prevent absolute paths if strings.HasPrefix(path, "/") { return fmt.Errorf("invalid JWT claim path (absolute path not allowed): %s", path) } // Limit depth to prevent DoS from deeply nested claims parts := strings.Split(path, ".") if len(parts) > 10 { return fmt.Errorf("invalid JWT claim path (too deep, max 10 levels): %s", path) } // Validate each part contains only allowed characters for _, part := range parts { if part == "" { return fmt.Errorf("invalid JWT claim path (empty part): %s", path) } // Allow alphanumeric, underscore, and hyphen for _, ch := range part { if !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' || ch == '-') { return fmt.Errorf("invalid JWT claim path (invalid character '%c'): %s", ch, path) } } } return nil } // parseConfig loads and parses the configuration. func parseConfig() { libpack_config.PKG_NAME = "graphql_proxy" c := config{} // Server configurations c.Server.PortGraphQL = getDetailsFromEnv("PORT_GRAPHQL", 8080) c.Server.PortMonitoring = getDetailsFromEnv("MONITORING_PORT", 9393) c.Server.HostGraphQL = getDetailsFromEnv("HOST_GRAPHQL", "http://localhost/") c.Server.HostGraphQLReadOnly = getDetailsFromEnv("HOST_GRAPHQL_READONLY", "") // Client configurations c.Client.JWTUserClaimPath = getDetailsFromEnv("JWT_USER_CLAIM_PATH", "") c.Client.JWTRoleClaimPath = getDetailsFromEnv("JWT_ROLE_CLAIM_PATH", "") // Validate JWT claim paths for security if err := validateJWTClaimPath(c.Client.JWTUserClaimPath); err != nil { fmt.Fprintf(os.Stderr, "❌ CRITICAL ERROR: Invalid JWT_USER_CLAIM_PATH: %v\n", err) os.Exit(1) } if err := validateJWTClaimPath(c.Client.JWTRoleClaimPath); err != nil { fmt.Fprintf(os.Stderr, "❌ CRITICAL ERROR: Invalid JWT_ROLE_CLAIM_PATH: %v\n", err) os.Exit(1) } c.Client.RoleFromHeader = getDetailsFromEnv("ROLE_FROM_HEADER", "") c.Client.RoleRateLimit = getDetailsFromEnv("ROLE_RATE_LIMIT", false) // In-memory cache c.Cache.CacheEnable = getDetailsFromEnv("ENABLE_GLOBAL_CACHE", false) c.Cache.CacheTTL = getDetailsFromEnv("CACHE_TTL", 60) c.Cache.CacheMaxMemorySize = getDetailsFromEnv("CACHE_MAX_MEMORY_SIZE", 100) // Default 100MB c.Cache.CacheMaxEntries = getDetailsFromEnv("CACHE_MAX_ENTRIES", 10000) // Default 10000 entries c.Cache.CacheUseLRU = getDetailsFromEnv("CACHE_USE_LRU", false) // Use LRU eviction algorithm // GraphQL query parsing cache - auto-calculate based on CPU cores if not set c.Cache.GraphQLQueryCacheSize = getDetailsFromEnv("GRAPHQL_QUERY_CACHE_SIZE", runtime.GOMAXPROCS(0)*250) // SECURITY: Per-user cache isolation (enabled by default for security) // Set CACHE_PER_USER_DISABLED=true ONLY if you have a single-user application // or understand the security implications of shared cache across users c.Cache.PerUserCacheDisabled = getDetailsFromEnv("CACHE_PER_USER_DISABLED", false) // Log warning if per-user caching is disabled if c.Cache.PerUserCacheDisabled { defer func() { if c.Logger != nil { c.Logger.Warning(&libpack_logging.LogMessage{ Message: "⚠️ Per-user cache isolation is DISABLED - Users may see each other's cached data!", Pairs: map[string]interface{}{ "security_risk": "CRITICAL - Do not use in multi-user applications", "recommendation": "Remove CACHE_PER_USER_DISABLED or set it to false", }, }) } }() } // Redis cache c.Cache.CacheRedisEnable = getDetailsFromEnv("ENABLE_REDIS_CACHE", false) c.Cache.CacheRedisURL = getDetailsFromEnv("CACHE_REDIS_URL", "localhost:6379") c.Cache.CacheRedisPassword = getDetailsFromEnv("CACHE_REDIS_PASSWORD", "") c.Cache.CacheRedisDB = getDetailsFromEnv("CACHE_REDIS_DB", 0) // Security configurations c.Security.BlockIntrospection = getDetailsFromEnv("BLOCK_SCHEMA_INTROSPECTION", false) c.Security.IntrospectionAllowed = func() []string { urls := getDetailsFromEnv("ALLOWED_INTROSPECTION", "") if urls == "" { return nil } return strings.Split(urls, ",") }() c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info")) // Logger setup c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)). SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false) // Health check c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "") c.Client.GQLClient = graphql.NewConnection() c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL) // Server modes c.Server.AccessLog = getDetailsFromEnv("ENABLE_ACCESS_LOG", false) c.Server.ReadOnlyMode = getDetailsFromEnv("READ_ONLY_MODE", false) c.Server.AllowURLs = func() []string { urls := getDetailsFromEnv("ALLOWED_URLS", "") if urls == "" { return nil } return strings.Split(urls, ",") }() // Client timeout and connection configurations with bounds checking clientTimeout := getDetailsFromEnv("PROXIED_CLIENT_TIMEOUT", 120) if clientTimeout < 1 || clientTimeout > 3600 { // 1 second to 1 hour max c.Logger.Warning(&libpack_logging.LogMessage{ Message: "Invalid client timeout, using default", Pairs: map[string]interface{}{"requested": clientTimeout, "default": 120}, }) clientTimeout = 120 } c.Client.ClientTimeout = clientTimeout // Configure HTTP connection pool and timeouts with sensible defaults // MaxConnsPerHost limits parallel connections to prevent overwhelming backends maxConns := getDetailsFromEnv("MAX_CONNS_PER_HOST", 1024) if maxConns < 1 || maxConns > 10000 { // Reasonable bounds c.Logger.Warning(&libpack_logging.LogMessage{ Message: "Invalid max connections per host, using default", Pairs: map[string]interface{}{"requested": maxConns, "default": 1024}, }) maxConns = 1024 } c.Client.MaxConnsPerHost = maxConns // Configure distinct timeout values for more granular control with bounds checking readTimeout := getDetailsFromEnv("CLIENT_READ_TIMEOUT", c.Client.ClientTimeout) if readTimeout < 1 || readTimeout > 3600 { readTimeout = c.Client.ClientTimeout } c.Client.ReadTimeout = readTimeout writeTimeout := getDetailsFromEnv("CLIENT_WRITE_TIMEOUT", c.Client.ClientTimeout) if writeTimeout < 1 || writeTimeout > 3600 { writeTimeout = c.Client.ClientTimeout } c.Client.WriteTimeout = writeTimeout // MaxIdleConnDuration controls how long connections stay in the pool idleDuration := getDetailsFromEnv("CLIENT_MAX_IDLE_CONN_DURATION", 300) if idleDuration < 1 || idleDuration > 7200 { // 1 second to 2 hours max idleDuration = 300 } c.Client.MaxIdleConnDuration = idleDuration // Secure by default: TLS verification is enabled unless explicitly disabled c.Client.DisableTLSVerify = getDetailsFromEnv("CLIENT_DISABLE_TLS_VERIFY", false) // Warn if TLS verification is disabled (security risk) if c.Client.DisableTLSVerify { // Logger might not be initialized yet, will log after logger setup defer func() { if c.Logger != nil { c.Logger.Warning(&libpack_logging.LogMessage{ Message: "⚠️ TLS certificate verification is DISABLED - This is a security risk in production!", Pairs: map[string]interface{}{ "recommendation": "Enable TLS verification by removing CLIENT_DISABLE_TLS_VERIFY or setting it to false", }, }) } }() } // Create HTTP client with the optimized parameters c.Client.FastProxyClient = createFasthttpClient(&c) proxy.WithClient(c.Client.FastProxyClient) // Setting the global proxy client // API configurations c.Server.EnableApi = getDetailsFromEnv("ENABLE_API", false) c.Server.ApiPort = getDetailsFromEnv("API_PORT", 9090) // Validate and sanitize banned users file path to prevent path traversal bannedUsersFile := getDetailsFromEnv("BANNED_USERS_FILE", "/go/src/app/banned_users.json") if validatedPath, err := validateFilePath(bannedUsersFile); err != nil { c.Logger.Error(&libpack_logging.LogMessage{ Message: "Invalid banned users file path, using default", Pairs: map[string]interface{}{"requested": bannedUsersFile, "error": err.Error()}, }) c.Api.BannedUsersFile = "/go/src/app/banned_users.json" } else { c.Api.BannedUsersFile = validatedPath } c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false) c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 1800) // Default: purge metrics every 30 minutes // Hasura event cleaner c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false) c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1) c.HasuraEventCleaner.EventMetadataDb = getDetailsFromEnv("HASURA_EVENT_METADATA_DB", "") // Tracing configuration c.Tracing.Enable = getDetailsFromEnv("ENABLE_TRACE", false) c.Tracing.Endpoint = getDetailsFromEnv("TRACE_ENDPOINT", "localhost:4317") // Circuit Breaker configuration - optimized for high-traffic production environments c.CircuitBreaker.Enable = getDetailsFromEnv("ENABLE_CIRCUIT_BREAKER", false) c.CircuitBreaker.MaxFailures = getDetailsFromEnv("CIRCUIT_MAX_FAILURES", 10) // Higher tolerance for transient failures c.CircuitBreaker.FailureRatio = getDetailsFromEnv("CIRCUIT_FAILURE_RATIO", 0.5) // Trip at 50% failure rate c.CircuitBreaker.SampleSize = getDetailsFromEnv("CIRCUIT_SAMPLE_SIZE", 100) // Statistically significant sample c.CircuitBreaker.Timeout = getDetailsFromEnv("CIRCUIT_TIMEOUT_SECONDS", 60) // Longer recovery time for stability c.CircuitBreaker.MaxRequestsInHalfOpen = getDetailsFromEnv("CIRCUIT_MAX_HALF_OPEN_REQUESTS", 5) // More probe requests c.CircuitBreaker.ReturnCachedOnOpen = getDetailsFromEnv("CIRCUIT_RETURN_CACHED_ON_OPEN", true) c.CircuitBreaker.TripOnTimeouts = getDetailsFromEnv("CIRCUIT_TRIP_ON_TIMEOUTS", true) c.CircuitBreaker.TripOn5xx = getDetailsFromEnv("CIRCUIT_TRIP_ON_5XX", true) c.CircuitBreaker.TripOn4xx = getDetailsFromEnv("CIRCUIT_TRIP_ON_4XX", false) // 4xx are usually client errors c.CircuitBreaker.BackoffMultiplier = getDetailsFromEnv("CIRCUIT_BACKOFF_MULTIPLIER", 1.0) // No backoff by default c.CircuitBreaker.MaxBackoffTimeout = getDetailsFromEnv("CIRCUIT_MAX_BACKOFF_TIMEOUT", 300) // 5 minutes max // Initialize endpoint configs map c.CircuitBreaker.EndpointConfigs = make(map[string]*EndpointCBConfig) // Retry budget configuration c.RetryBudget.Enable = getDetailsFromEnv("RETRY_BUDGET_ENABLE", true) c.RetryBudget.TokensPerSecond = getDetailsFromEnv("RETRY_BUDGET_TOKENS_PER_SEC", 10.0) c.RetryBudget.MaxTokens = getDetailsFromEnv("RETRY_BUDGET_MAX_TOKENS", 100) // Request coalescing configuration c.RequestCoalescing.Enable = getDetailsFromEnv("REQUEST_COALESCING_ENABLE", true) // WebSocket configuration c.WebSocket.Enable = getDetailsFromEnv("WEBSOCKET_ENABLE", false) c.WebSocket.PingInterval = getDetailsFromEnv("WEBSOCKET_PING_INTERVAL", 30) c.WebSocket.PongTimeout = getDetailsFromEnv("WEBSOCKET_PONG_TIMEOUT", 60) c.WebSocket.MaxMessageSize = int64(getDetailsFromEnv("WEBSOCKET_MAX_MESSAGE_SIZE", 524288)) // 512KB // Admin dashboard configuration c.AdminDashboard.Enable = getDetailsFromEnv("ADMIN_DASHBOARD_ENABLE", true) cfgMutex.Lock() cfg = &c cfgMutex.Unlock() // Initialize tracing if enabled if cfg.Tracing.Enable { if cfg.Tracing.Endpoint == "" { cfg.Logger.Warning(&libpack_logging.LogMessage{ Message: "Tracing endpoint not configured, using default localhost:4317", }) cfg.Tracing.Endpoint = "localhost:4317" } var err error ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() tracer, err = libpack_tracing.NewTracing(ctx, cfg.Tracing.Endpoint) if err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "Failed to initialize tracing", Pairs: map[string]interface{}{"error": err.Error()}, }) } else { cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Tracing initialized", Pairs: map[string]interface{}{"endpoint": cfg.Tracing.Endpoint}, }) } } // Initialize metrics aggregator FIRST if Redis is enabled (even if cache is disabled) // This allows cluster mode monitoring even when cache is off if cfg.Cache.CacheRedisEnable { cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Initializing metrics aggregator for cluster mode", Pairs: map[string]interface{}{ "redis_url": cfg.Cache.CacheRedisURL, "redis_db": cfg.Cache.CacheRedisDB, }, }) if err := InitializeMetricsAggregator( cfg.Cache.CacheRedisURL, cfg.Cache.CacheRedisPassword, cfg.Cache.CacheRedisDB, cfg.Logger, ); err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "FAILED to initialize metrics aggregator - cluster mode will not work", Pairs: map[string]interface{}{ "error": err.Error(), }, }) } else { cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "✓ Metrics aggregator successfully initialized", Pairs: map[string]interface{}{ "instance_id": GetMetricsAggregator().GetInstanceID(), }, }) } } // Initialize cache if enabled if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable { cacheConfig := &libpack_cache.CacheConfig{ Logger: cfg.Logger, TTL: cfg.Cache.CacheTTL, PerUserCacheDisabled: cfg.Cache.PerUserCacheDisabled, } // Redis cache configurations if cfg.Cache.CacheRedisEnable { cacheConfig.Redis.Enable = true cacheConfig.Redis.URL = cfg.Cache.CacheRedisURL cacheConfig.Redis.Password = cfg.Cache.CacheRedisPassword cacheConfig.Redis.DB = cfg.Cache.CacheRedisDB } else { // Memory cache configurations cacheConfig.Memory.MaxMemorySize = int64(cfg.Cache.CacheMaxMemorySize) * 1024 * 1024 // Convert MB to bytes cacheConfig.Memory.MaxEntries = int64(cfg.Cache.CacheMaxEntries) cacheConfig.Memory.UseLRU = cfg.Cache.CacheUseLRU cacheType := "standard" if cfg.Cache.CacheUseLRU { cacheType = "LRU" } cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Configuring memory cache with limits", Pairs: map[string]interface{}{ "type": cacheType, "max_memory_mb": cfg.Cache.CacheMaxMemorySize, "max_entries": cfg.Cache.CacheMaxEntries, }, }) } libpack_cache.EnableCache(cacheConfig) // Start memory monitoring for in-memory cache if it's not Redis // Will be started with context in main() } // Initialize circuit breaker if enabled if cfg.CircuitBreaker.Enable { initCircuitBreaker(cfg) } // Note: Retry budget is initialized in main() with context for graceful shutdown // Initialize request coalescer if cfg.RequestCoalescing.Enable { InitializeRequestCoalescer(true, cfg.Logger, cfg.Monitoring) } // Initialize WebSocket proxy if cfg.WebSocket.Enable { wsConfig := WebSocketConfig{ Enabled: cfg.WebSocket.Enable, PingInterval: time.Duration(cfg.WebSocket.PingInterval) * time.Second, PongTimeout: time.Duration(cfg.WebSocket.PongTimeout) * time.Second, MaxMessageSize: cfg.WebSocket.MaxMessageSize, } InitializeWebSocketProxy(cfg.Server.HostGraphQL, wsConfig, cfg.Logger, cfg.Monitoring) } // Initialize backend health manager if cfg.Server.HostGraphQL != "" { healthMgr := InitializeBackendHealth(cfg.Client.FastProxyClient, cfg.Server.HostGraphQL, cfg.Logger) // Start health checking in background healthMgr.StartHealthChecking() } // Note: RPS tracker is initialized in main() with context for graceful shutdown // Load rate limit configuration with improved error handling if err := loadRatelimitConfig(); err != nil { // Log the error with clear guidance detailedError := err.Error() cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "Failed to start service due to rate limit configuration error", Pairs: map[string]interface{}{ "error": detailedError, }, }) // If we're not in a test environment, print to stderr and exit if config error if ifNotInTest() { fmt.Fprintln(os.Stderr, "⚠️ CRITICAL ERROR: Rate limit configuration problem detected") fmt.Fprintln(os.Stderr, detailedError) os.Exit(1) } } // API and event cleaner will be started with context in main() prepareQueriesAndExemptions() // Initialize GraphQL parsing optimizations initGraphQLParsing() } func main() { // Parse configuration parseConfig() // Setup graceful shutdown ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Initialize shutdown manager shutdownManager = NewShutdownManager(ctx) // Initialize RPS tracker with context for graceful shutdown InitializeRPSTracker(ctx) cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "RPS tracker initialized", }) // Initialize retry budget with context for graceful shutdown if cfg.RetryBudget.Enable { retryBudgetConfig := RetryBudgetConfig{ TokensPerSecond: cfg.RetryBudget.TokensPerSecond, MaxTokens: cfg.RetryBudget.MaxTokens, Enabled: cfg.RetryBudget.Enable, } InitializeRetryBudgetWithContext(ctx, retryBudgetConfig, cfg.Logger) } // Create a wait group to manage goroutines var wg sync.WaitGroup // Setup signal handling for graceful shutdown sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) go func() { <-sigCh cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Shutdown signal received, stopping services...", }) cancel() }() // Start background services with context once.Do(func() { // Start API server shutdownManager.RunGoroutine("api-server", func(ctx context.Context) { if err := enableApi(ctx); err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "API server error", Pairs: map[string]interface{}{"error": err.Error()}, }) } }) // Start event cleaner shutdownManager.RunGoroutine("event-cleaner", func(ctx context.Context) { if err := enableHasuraEventCleaner(ctx); err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "Event cleaner error", Pairs: map[string]interface{}{"error": err.Error()}, }) } }) // Start cache memory monitoring if not using Redis if cfg.Cache.CacheEnable && !cfg.Cache.CacheRedisEnable { shutdownManager.RunGoroutine("cache-memory-monitoring", startCacheMemoryMonitoring) } }) // Register connection pool for cleanup shutdownManager.RegisterComponent("http-connection-pool", func(ctx context.Context) error { if connectionPoolManager != nil { return connectionPoolManager.Shutdown() } return nil }) // Register backend health manager for cleanup shutdownManager.RegisterComponent("backend-health-manager", func(ctx context.Context) error { if healthMgr := GetBackendHealthManager(); healthMgr != nil { healthMgr.Shutdown() } return nil }) // Register metrics aggregator for cleanup shutdownManager.RegisterComponent("metrics-aggregator", func(ctx context.Context) error { if aggregator := GetMetricsAggregator(); aggregator != nil { aggregator.Shutdown() } return nil }) // Cache shutdown is handled internally by the cache implementation // Start monitoring server cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Starting monitoring server...", Pairs: map[string]interface{}{"port": cfg.Server.PortMonitoring}, }) // Start monitoring server in a goroutine wg.Add(1) monitoringErrCh := make(chan error, 1) go func() { defer wg.Done() if err := StartMonitoringServer(); err != nil { monitoringErrCh <- err } }() // Give monitoring server time to initialize select { case err := <-monitoringErrCh: cfg.Logger.Critical(&libpack_logging.LogMessage{ Message: "Failed to start monitoring server", Pairs: map[string]interface{}{ "error": err.Error(), "port": cfg.Server.PortMonitoring, }, }) os.Exit(1) case <-time.After(2 * time.Second): // Continue if no error received within timeout } // Wait for GraphQL backend to be ready before starting proxy if healthMgr := GetBackendHealthManager(); healthMgr != nil { startupTimeout := time.Duration(getDetailsFromEnv("BACKEND_STARTUP_TIMEOUT", 300)) * time.Second cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Waiting for GraphQL backend to be ready", Pairs: map[string]interface{}{ "timeout_seconds": int(startupTimeout.Seconds()), }, }) if err := healthMgr.WaitForBackendReady(startupTimeout); err != nil { cfg.Logger.Critical(&libpack_logging.LogMessage{ Message: "GraphQL backend did not become ready in time", Pairs: map[string]interface{}{ "error": err.Error(), "timeout": startupTimeout.String(), }, }) // Don't exit immediately, but warn that backend is not ready cfg.Logger.Warning(&libpack_logging.LogMessage{ Message: "Starting proxy anyway - requests will fail until backend becomes available", }) } } // Start HTTP proxy cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Starting HTTP proxy server...", Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL}, }) // Start HTTP proxy in a goroutine wg.Add(1) proxyErrCh := make(chan error, 1) go func() { defer wg.Done() if err := StartHTTPProxy(); err != nil { proxyErrCh <- err } }() // Block for a moment to check for immediate startup errors select { case err := <-proxyErrCh: cfg.Logger.Critical(&libpack_logging.LogMessage{ Message: "Failed to start HTTP proxy server", Pairs: map[string]interface{}{ "error": err.Error(), "port": cfg.Server.PortGraphQL, }, }) os.Exit(1) case <-time.After(1 * time.Second): // Continue if no error received within timeout } // Wait for context cancellation <-ctx.Done() // Perform cleanup cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Shutting down services...", }) // Register tracer shutdown if tracer != nil { shutdownManager.RegisterComponent("tracer", func(ctx context.Context) error { return tracer.Shutdown(ctx) }) } // Perform graceful shutdown of all components if err := shutdownManager.Shutdown(30 * time.Second); err != nil { cfg.Logger.Error(&libpack_logging.LogMessage{ Message: "Error during shutdown", Pairs: map[string]interface{}{"error": err.Error()}, }) } // Wait for all goroutines to finish (with timeout) waitCh := make(chan struct{}) go func() { wg.Wait() close(waitCh) }() select { case <-waitCh: cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "All services shut down gracefully", }) case <-time.After(10 * time.Second): cfg.Logger.Warning(&libpack_logging.LogMessage{ Message: "Some services didn't shut down gracefully within timeout", }) } } // startCacheMemoryMonitoring polls memory cache usage and updates metrics func startCacheMemoryMonitoring(ctx context.Context) { // Check every few seconds (more frequent than cleanup routine) ticker := time.NewTicker(15 * time.Second) defer ticker.Stop() cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Starting memory cache monitoring", }) // Use mutex to protect concurrent access to metrics registration var metricsMutex sync.Mutex // Create initial metrics with proper synchronization metricsMutex.Lock() cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryLimit, nil, float64(libpack_cache.GetCacheMaxMemorySize())) metricsMutex.Unlock() for { select { case <-ctx.Done(): cfg.Logger.Info(&libpack_logging.LogMessage{ Message: "Stopping cache memory monitoring", }) return case <-ticker.C: // Skip if monitoring not initialized or cache not initialized if cfg.Monitoring == nil || !libpack_cache.IsCacheInitialized() { continue } // Get current memory usage atomically memoryUsage := libpack_cache.GetCacheMemoryUsage() memoryLimit := libpack_cache.GetCacheMaxMemorySize() // Update metrics with proper synchronization metricsMutex.Lock() cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryUsage, nil, float64(memoryUsage)) cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryLimit, nil, float64(memoryLimit)) // Calculate percentage (protect against division by zero) var percentUsed float64 if memoryLimit > 0 { percentUsed = float64(memoryUsage) / float64(memoryLimit) * 100.0 } cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryPercent, nil, percentUsed) metricsMutex.Unlock() // Log if memory usage is high (over 80%) if percentUsed > 80.0 { cfg.Logger.Warning(&libpack_logging.LogMessage{ Message: "Memory cache usage is high", Pairs: map[string]interface{}{ "memory_usage_bytes": memoryUsage, "memory_limit_bytes": memoryLimit, "percent_used": percentUsed, }, }) } } } } // validateFilePath validates and sanitizes file paths to prevent path traversal attacks func validateFilePath(path string) (string, error) { if path == "" { return "", fmt.Errorf("empty path not allowed") } // Reject bare current directory for security if path == "." { return "", fmt.Errorf("bare current directory not allowed") } // URL decode the path to detect encoded traversal attempts decodedPath := path if strings.Contains(path, "%") { // Try to decode URL encoding (single and double) for i := 0; i < 3; i++ { // Handle multiple levels of encoding if decoded, err := url.QueryUnescape(decodedPath); err == nil { decodedPath = decoded } else { break } } } // Check for path traversal patterns (in both original and decoded) checkPaths := []string{path, decodedPath} for _, checkPath := range checkPaths { if strings.Contains(checkPath, "..") { return "", fmt.Errorf("path traversal attempt detected") } } // Check for dangerous characters dangerousChars := []string{";", "|", "\n", "\r"} for _, char := range dangerousChars { if strings.Contains(path, char) { return "", fmt.Errorf("dangerous character detected in path") } } // Clean and normalize the path cleaned := filepath.Clean(path) // Get absolute path absPath, err := filepath.Abs(cleaned) if err != nil { return "", fmt.Errorf("invalid file path: %w", err) } // Get working directory as base workDir, err := os.Getwd() if err != nil { return "", fmt.Errorf("cannot determine working directory: %w", err) } // Define allowed directories allowedDirs := []string{ workDir, // Current working directory "/tmp", // Temporary files "/var/tmp", // System temporary files "/go/src/app", // Docker container default } // Check if the path is within any allowed directory isAllowed := false for _, allowedDir := range allowedDirs { // Ensure both paths are cleaned and absolute for proper comparison cleanedAllowed := filepath.Clean(allowedDir) if strings.HasPrefix(absPath, cleanedAllowed+string(filepath.Separator)) || absPath == cleanedAllowed { isAllowed = true break } } if !isAllowed { return "", fmt.Errorf("path not in allowed directories") } // Additional security checks if strings.Contains(absPath, "\x00") { return "", fmt.Errorf("null byte in path") } // Return the original path if it's within the current working directory and is relative if strings.HasPrefix(absPath, workDir) && !filepath.IsAbs(path) { return path, nil } return absPath, nil } // ifNotInTest checks if the program is not running in a test environment. func ifNotInTest() bool { return flag.Lookup("test.v") == nil }