From 6b31e5c4c0f36f18969bc6058c9b13ae4f45f59e Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Thu, 10 Oct 2024 10:34:23 +0100 Subject: [PATCH] Little code cleanup. (#19) --- graphql.go | 1 + graphql_test.go | 2 +- logging/logger.go | 136 ++++++++++++++++++----------------- logging/logger_bench_test.go | 5 -- logging/logger_test.go | 10 +-- main.go | 37 ++++++---- monitoring.go | 6 +- proxy.go | 15 ++-- ratelimit.go | 5 +- server.go | 33 +++++---- 10 files changed, 138 insertions(+), 112 deletions(-) diff --git a/graphql.go b/graphql.go index 22e2bc0..acf3338 100644 --- a/graphql.go +++ b/graphql.go @@ -30,6 +30,7 @@ func prepareQueriesAndExemptions() { for _, q := range cfg.Security.IntrospectionAllowed { introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} } + for _, u := range cfg.Server.AllowURLs { allowedUrls[u] = struct{}{} } diff --git a/graphql_test.go b/graphql_test.go index f49924a..73ef42a 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -393,7 +393,7 @@ func (suite *Tests) Test_checkAllowedURLs() { ctx.Request().SetRequestURI(tt.path) ctx.Request().URI().SetPath(tt.path) result := checkAllowedURLs(ctx) - assert.Equal(tt.expected, result) + assert.Equal(tt.expected, result, "Unexpected result in test case: "+tt.name) }) } } diff --git a/logging/logger.go b/logging/logger.go index 7b1a964..36f080b 100644 --- a/logging/logger.go +++ b/logging/logger.go @@ -2,7 +2,6 @@ package libpack_logger import ( "bytes" - "flag" "fmt" "io" "os" @@ -16,16 +15,14 @@ import ( ) const ( - _ = iota - LEVEL_DEBUG + LEVEL_DEBUG = iota LEVEL_INFO LEVEL_WARN LEVEL_ERROR LEVEL_FATAL ) -var LevelNames = [...]string{ - "none", +var levelNames = []string{ "debug", "info", "warn", @@ -34,74 +31,103 @@ var LevelNames = [...]string{ } const ( - defaultFormat = time.RFC3339 + defaultTimeFormat = time.RFC3339 defaultMinLevel = LEVEL_INFO defaultShowCaller = false ) -var defaultOutput = os.Stdout - +// Logger represents the logging object with configurations. type Logger struct { output io.Writer - format string + timeFormat string minLogLevel int showCaller bool } +// LogMessage represents a log message with optional pairs. type LogMessage struct { - output io.Writer - Pairs map[string]any + Pairs map[string]interface{} Message string } -func (m *LogMessage) String() string { - return m.Message +// bufferPool is used to reuse bytes.Buffer for efficiency. +var bufferPool = sync.Pool{ + New: func() interface{} { + return new(bytes.Buffer) + }, } +// fieldNames allows customization of output field names. var fieldNames = map[string]string{ "timestamp": "timestamp", "level": "level", "message": "message", } +// New creates a new Logger with default settings. func New() *Logger { return &Logger{ - format: defaultFormat, + timeFormat: defaultTimeFormat, minLogLevel: defaultMinLevel, - output: defaultOutput, + output: os.Stdout, showCaller: defaultShowCaller, } } +// SetOutput sets the output destination for the logger. func (l *Logger) SetOutput(output io.Writer) *Logger { l.output = output return l } -var bufferPool = sync.Pool{ - New: func() any { - return new(bytes.Buffer) - }, -} - -var defaultPairs = make(map[string]any) - +// GetLogLevel returns the log level integer corresponding to the given level name. func GetLogLevel(level string) int { - for i, name := range LevelNames { - if name == strings.ToLower(level) { + level = strings.ToLower(level) + for i, name := range levelNames { + if name == level { return i } } return defaultMinLevel } +// SetTimeFormat sets the time format for the logger's timestamp field. +func (l *Logger) SetTimeFormat(format string) *Logger { + l.timeFormat = format + return l +} + +// SetMinLogLevel sets the minimum log level for the logger. +func (l *Logger) SetMinLogLevel(level int) *Logger { + l.minLogLevel = level + return l +} + +// SetFieldName allows customizing the field names in log output. +func (l *Logger) SetFieldName(field, name string) *Logger { + fieldNames[field] = name + return l +} + +// SetShowCaller enables or disables including the caller information in log output. +func (l *Logger) SetShowCaller(show bool) *Logger { + l.showCaller = show + return l +} + +// shouldLog determines if the message should be logged based on the logger's minimum log level. +func (l *Logger) shouldLog(level int) bool { + return level >= l.minLogLevel +} + +// log writes the log message with the given level. func (l *Logger) log(level int, m *LogMessage) { if m.Pairs == nil { - m.Pairs = defaultPairs + m.Pairs = make(map[string]interface{}) } - m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.format) - m.Pairs[fieldNames["level"]] = LevelNames[level] + m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.timeFormat) + m.Pairs[fieldNames["level"]] = levelNames[level] m.Pairs[fieldNames["message"]] = m.Message if l.showCaller { @@ -109,93 +135,73 @@ func (l *Logger) log(level int, m *LogMessage) { } buffer := bufferPool.Get().(*bytes.Buffer) - defer bufferPool.Put(buffer) buffer.Reset() + defer bufferPool.Put(buffer) - var encoder = json.NewEncoder(buffer) + encoder := json.NewEncoder(buffer) err := encoder.Encode(m.Pairs) if err != nil { - fmt.Println("Error marshalling log message:", err) + fmt.Fprintln(os.Stderr, "Error marshalling log message:", err) return } - // if not running in test - use stderr and stdout, otherwise - use logger's output setting - if flag.Lookup("test.v") != nil { - m.output = os.Stdout - if level >= LEVEL_ERROR { - m.output = os.Stderr - } + _, err = l.output.Write(buffer.Bytes()) + if err != nil { + fmt.Fprintln(os.Stderr, "Error writing log message:", err) } - - // Use logger's output setting instead of os.Stdout or os.Stderr - l.output.Write(buffer.Bytes()) } +// Debug logs a debug-level message. func (l *Logger) Debug(m *LogMessage) { if l.shouldLog(LEVEL_DEBUG) { l.log(LEVEL_DEBUG, m) } } +// Info logs an info-level message. func (l *Logger) Info(m *LogMessage) { if l.shouldLog(LEVEL_INFO) { l.log(LEVEL_INFO, m) } } +// Warn logs a warning-level message. func (l *Logger) Warn(m *LogMessage) { if l.shouldLog(LEVEL_WARN) { l.log(LEVEL_WARN, m) } } +// Warning is an alias for Warn. func (l *Logger) Warning(m *LogMessage) { l.Warn(m) } +// Error logs an error-level message. func (l *Logger) Error(m *LogMessage) { if l.shouldLog(LEVEL_ERROR) { l.log(LEVEL_ERROR, m) } } +// Fatal logs a fatal-level message. func (l *Logger) Fatal(m *LogMessage) { if l.shouldLog(LEVEL_FATAL) { l.log(LEVEL_FATAL, m) } } +// Critical logs a critical-level message and exits the application. func (l *Logger) Critical(m *LogMessage) { l.Fatal(m) os.Exit(1) } -func (l *Logger) shouldLog(level int) bool { - return level >= l.minLogLevel -} - -func (l *Logger) SetFormat(format string) *Logger { - l.format = format - return l -} - -func (l *Logger) SetMinLogLevel(level int) *Logger { - l.minLogLevel = level - return l -} - -func (l *Logger) SetFieldName(field, name string) *Logger { - fieldNames[field] = name - return l -} - -func (l *Logger) SetShowCaller(show bool) *Logger { - l.showCaller = show - return l -} - +// getCaller retrieves the file and line number of the caller. func getCaller() string { - _, file, line, ok := runtime.Caller(3) + // Skip 3 stack frames: getCaller -> log -> [Debug|Info|...] + const depth = 3 + _, file, line, ok := runtime.Caller(depth) if !ok { return "unknown:0" } diff --git a/logging/logger_bench_test.go b/logging/logger_bench_test.go index 9d92425..6067026 100644 --- a/logging/logger_bench_test.go +++ b/logging/logger_bench_test.go @@ -56,11 +56,6 @@ func Benchmark_NewLogger(b *testing.B) { b.Run(tt.name, func(b *testing.B) { for i := 0; i < b.N; i++ { got := New() - - if tt.triggers.ModFormat.Format != "" { - got = got.SetFormat(tt.triggers.ModFormat.Format) - } - if tt.triggers.ModLevel.Level != 0 { got = got.SetMinLogLevel(tt.triggers.ModLevel.Level) } diff --git a/logging/logger_test.go b/logging/logger_test.go index 92d04a9..e802fe1 100644 --- a/logging/logger_test.go +++ b/logging/logger_test.go @@ -40,7 +40,7 @@ func (suite *LoggerTestSuite) Test_LogMessageString() { Message: "test message", } - assert.Equal("test message", msg.String()) + assert.Equal("test message", msg.Message) } func callLoggerMethod(logger *Logger, methodName string, message *LogMessage) { @@ -125,7 +125,7 @@ func (suite *LoggerTestSuite) Test_LogsLevelsPrint() { // Set logger's minimum log level logger.SetMinLogLevel(tt.loggerMinLevel) - fmt.Println("Logger min log level:", LevelNames[logger.minLogLevel]) + fmt.Println("Logger min log level:", levelNames[logger.minLogLevel]) // Call the logging method callLoggerMethod(logger, tt.method, msg) @@ -143,7 +143,7 @@ func (suite *LoggerTestSuite) Test_LogsLevelsPrint() { if !containsLogMessage(logOutput, tt.message) { t.Errorf("Expected log message %q, but got %q", tt.message, logOutput) } - assert.Equal(LevelNames[tt.messageLogLevel], loggedMessage["level"]) + assert.Equal(levelNames[tt.messageLogLevel], loggedMessage["level"]) if tt.pairs != nil { for k, v := range tt.pairs { assert.Equal(v, loggedMessage[k]) @@ -161,9 +161,9 @@ func containsLogMessage(logOutput, expectedMessage string) bool { } func (suite *LoggerTestSuite) Test_SetFormat() { - logger := New().SetFormat(time.RFC3339Nano) + logger := New().SetTimeFormat(time.RFC3339Nano) - assert.Equal(time.RFC3339Nano, logger.format) + assert.Equal(time.RFC3339Nano, logger.timeFormat) } func (suite *LoggerTestSuite) Test_SetMinLogLevel() { diff --git a/main.go b/main.go index b1b6e7a..2babe61 100644 --- a/main.go +++ b/main.go @@ -20,45 +20,49 @@ var ( once sync.Once ) -// function get value from the env where the value can be anything +// getDetailsFromEnv retrieves the value from the environment or returns the default. func getDetailsFromEnv[T any](key string, defaultValue T) T { var result any - if _, ok := os.LookupEnv("GMP_" + key); ok { - key = "GMP_" + key + envKey := "GMP_" + key + if _, ok := os.LookupEnv(envKey); !ok { + envKey = key } switch v := any(defaultValue).(type) { case string: - result = envutil.Getenv(key, v) + result = envutil.Getenv(envKey, v) case int: - result = envutil.GetInt(key, v) + result = envutil.GetInt(envKey, v) case bool: - result = envutil.GetBool(key, v) + result = envutil.GetBool(envKey, v) default: result = defaultValue } return result.(T) } +// parseConfig loads and parses the configuration. func parseConfig() { libpack_config.PKG_NAME = "graphql_proxy" c := config{} + // Server configurations c.Server.PortGraphQL = getDetailsFromEnv("PORT_GRAPHQL", 8080) c.Server.PortMonitoring = getDetailsFromEnv("MONITORING_PORT", 9393) c.Server.HostGraphQL = getDetailsFromEnv("HOST_GRAPHQL", "http://localhost/") c.Server.HostGraphQLReadOnly = getDetailsFromEnv("HOST_GRAPHQL_READONLY", "") + // Client configurations c.Client.JWTUserClaimPath = getDetailsFromEnv("JWT_USER_CLAIM_PATH", "") c.Client.JWTRoleClaimPath = getDetailsFromEnv("JWT_ROLE_CLAIM_PATH", "") c.Client.RoleFromHeader = getDetailsFromEnv("ROLE_FROM_HEADER", "") c.Client.RoleRateLimit = getDetailsFromEnv("ROLE_RATE_LIMIT", false) - /* in-memory cache */ + // In-memory cache c.Cache.CacheEnable = getDetailsFromEnv("ENABLE_GLOBAL_CACHE", false) c.Cache.CacheTTL = getDetailsFromEnv("CACHE_TTL", 60) - /* redis cache */ + // Redis cache c.Cache.CacheRedisEnable = getDetailsFromEnv("ENABLE_REDIS_CACHE", false) c.Cache.CacheRedisURL = getDetailsFromEnv("CACHE_REDIS_URL", "localhost:6379") c.Cache.CacheRedisPassword = getDetailsFromEnv("CACHE_REDIS_PASSWORD", "") c.Cache.CacheRedisDB = getDetailsFromEnv("CACHE_REDIS_DB", 0) - /* security */ + // Security configurations c.Security.BlockIntrospection = getDetailsFromEnv("BLOCK_SCHEMA_INTROSPECTION", false) c.Security.IntrospectionAllowed = func() []string { urls := getDetailsFromEnv("ALLOWED_INTROSPECTION", "") @@ -68,10 +72,14 @@ func parseConfig() { return strings.Split(urls, ",") }() c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info")) - c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false) + // Logger setup + c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)). + SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false) + // Health check c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "") c.Client.GQLClient = graphql.NewConnection() c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL) + // Server modes c.Server.AccessLog = getDetailsFromEnv("ENABLE_ACCESS_LOG", false) c.Server.ReadOnlyMode = getDetailsFromEnv("READ_ONLY_MODE", false) c.Server.AllowURLs = func() []string { @@ -83,22 +91,26 @@ func parseConfig() { }() c.Client.ClientTimeout = getDetailsFromEnv("PROXIED_CLIENT_TIMEOUT", 120) c.Client.FastProxyClient = createFasthttpClient(c.Client.ClientTimeout) - proxy.WithClient(c.Client.FastProxyClient) // setting the global proxy client here instead of per request + proxy.WithClient(c.Client.FastProxyClient) // Setting the global proxy client + // API configurations c.Server.EnableApi = getDetailsFromEnv("ENABLE_API", false) c.Server.ApiPort = getDetailsFromEnv("API_PORT", 9090) c.Api.BannedUsersFile = getDetailsFromEnv("BANNED_USERS_FILE", "/go/src/app/banned_users.json") c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false) c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0) + // Hasura event cleaner c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false) c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1) c.HasuraEventCleaner.EventMetadataDb = getDetailsFromEnv("HASURA_EVENT_METADATA_DB", "") cfg = &c + // Initialize cache if enabled if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable { cacheConfig := &libpack_cache.CacheConfig{ Logger: cfg.Logger, TTL: cfg.Cache.CacheTTL, } + // Redis cache configurations if cfg.Cache.CacheRedisEnable { cacheConfig.Redis.Enable = true cacheConfig.Redis.URL = cfg.Cache.CacheRedisURL @@ -113,7 +125,7 @@ func parseConfig() { go enableApi() go enableHasuraEventCleaner() }) - prepareQueriesAndExemptions() + prepareQueriesAndExemptions() // Ensure this function is defined elsewhere } func main() { @@ -123,6 +135,7 @@ func main() { StartHTTPProxy() } +// ifNotInTest checks if the program is not running in a test environment. func ifNotInTest() bool { return flag.Lookup("test.v") == nil } diff --git a/monitoring.go b/monitoring.go index bad9af8..5933f2e 100644 --- a/monitoring.go +++ b/monitoring.go @@ -4,8 +4,12 @@ import ( libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" ) +// StartMonitoringServer initializes and starts the monitoring server. func StartMonitoringServer() { - cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{PurgeOnCrawl: cfg.Server.PurgeOnCrawl, PurgeEvery: cfg.Server.PurgeEvery}) + cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{ + PurgeOnCrawl: cfg.Server.PurgeOnCrawl, + PurgeEvery: cfg.Server.PurgeEvery, + }) cfg.Monitoring.AddMetricsPrefix("graphql_proxy") cfg.Monitoring.RegisterDefaultMetrics() } diff --git a/proxy.go b/proxy.go index 2f9e90e..2fd8a5a 100644 --- a/proxy.go +++ b/proxy.go @@ -17,6 +17,7 @@ import ( "github.com/valyala/fasthttp" ) +// createFasthttpClient creates and configures a fasthttp client. func createFasthttpClient(timeout int) *fasthttp.Client { return &fasthttp.Client{ Name: "graphql_proxy", @@ -33,6 +34,7 @@ func createFasthttpClient(timeout int) *fasthttp.Client { } } +// proxyTheRequest handles the request proxying logic. func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { if !checkAllowedURLs(c) { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -51,7 +53,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { return fmt.Errorf("invalid URL: %v", err) } - if cfg.LogLevel == "debug" { + if cfg.LogLevel == "DEBUG" { logDebugRequest(c) } @@ -61,7 +63,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { if proxyErr != nil { return proxyErr } - if c.Response().StatusCode() != 200 { + if c.Response().StatusCode() != fiber.StatusOK { return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode()) } return nil @@ -94,11 +96,12 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { return fmt.Errorf("failed to proxy request: %v", err) } - if cfg.LogLevel == "debug" { + if cfg.LogLevel == "DEBUG" { logDebugResponse(c) } - if c.Response().Header.Peek("Content-Encoding") != nil && string(c.Response().Header.Peek("Content-Encoding")) == "gzip" { + if bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) { + // Decompress gzip response reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body())) if err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -122,7 +125,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { c.Response().Header.Del("Content-Encoding") } - if c.Response().StatusCode() != 200 { + if c.Response().StatusCode() != fiber.StatusOK { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } @@ -133,6 +136,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { return nil } +// logDebugRequest logs the request details when in debug mode. func logDebugRequest(c *fiber.Ctx) { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Proxying the request", @@ -145,6 +149,7 @@ func logDebugRequest(c *fiber.Ctx) { }) } +// logDebugResponse logs the response details when in debug mode. func logDebugResponse(c *fiber.Ctx) { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Received proxied response", diff --git a/ratelimit.go b/ratelimit.go index 18e9f85..ea1b6e4 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -10,6 +10,7 @@ import ( libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) +// RateLimitConfig holds the rate limit configuration for a role type RateLimitConfig struct { RateCounterTicker *goratecounter.RateCounter Interval time.Duration `json:"interval"` @@ -21,6 +22,7 @@ var ( rateLimitMu sync.RWMutex ) +// loadRatelimitConfig loads the rate limit configurations from file func loadRatelimitConfig() error { paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"} for _, path := range paths { @@ -59,7 +61,7 @@ func loadConfigFromPath(path string) error { Interval: value.Interval, }) - if cfg.LogLevel == "debug" { + if cfg.LogLevel == "DEBUG" { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Setting ratelimit config for role", Pairs: map[string]interface{}{ @@ -83,6 +85,7 @@ func loadConfigFromPath(path string) error { return nil } +// rateLimitedRequest checks if a request should be rate-limited func rateLimitedRequest(userID, userRole string) bool { rateLimitMu.RLock() roleConfig, ok := rateLimits[userRole] diff --git a/server.go b/server.go index 635be9e..99fca0c 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package main import ( "fmt" "strconv" - "sync" "time" "github.com/goccy/go-json" @@ -21,14 +20,7 @@ const ( healthCheckQueryStr = `{ __typename }` ) -var ( - ctxPool = sync.Pool{ - New: func() interface{} { - return new(fiber.Ctx) - }, - } -) - +// StartHTTPProxy initializes and starts the HTTP proxy server. func StartHTTPProxy() { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Starting the HTTP proxy", @@ -71,15 +63,18 @@ func StartHTTPProxy() { } } +// proxyTheRequestToDefault proxies the request to the default GraphQL endpoint. func proxyTheRequestToDefault(c *fiber.Ctx) error { return proxyTheRequest(c, cfg.Server.HostGraphQL) } +// AddRequestUUID adds a unique request UUID to the context. func AddRequestUUID(c *fiber.Ctx) error { c.Locals("request_uuid", uuid.NewString()) return c.Next() } +// checkAllowedURLs checks if the requested URL is allowed. func checkAllowedURLs(c *fiber.Ctx) bool { if len(allowedUrls) == 0 { return true @@ -89,6 +84,7 @@ func checkAllowedURLs(c *fiber.Ctx) bool { return ok } +// healthCheck performs a health check on the GraphQL server. func healthCheck(c *fiber.Ctx) error { if len(cfg.Server.HealthcheckGraphQL) > 0 { cfg.Logger.Debug(&libpack_logger.LogMessage{ @@ -103,16 +99,17 @@ func healthCheck(c *fiber.Ctx) error { Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - return c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query") + return c.Status(fiber.StatusInternalServerError).SendString("Can't reach the GraphQL server with {__typename} query") } } cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Health check returning OK", }) - return c.Status(200).SendString("Health check OK") + return c.Status(fiber.StatusOK).SendString("Health check OK") } +// processGraphQLRequest handles the incoming GraphQL requests. func processGraphQLRequest(c *fiber.Ctx) error { startTime := time.Now() @@ -124,7 +121,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { } if checkIfUserIsBanned(c, extractedUserID) { - return c.Status(403).SendString("User is banned") + return c.Status(fiber.StatusForbidden).SendString("User is banned") } if cfg.Client.RoleFromHeader != "" { @@ -139,13 +136,13 @@ func processGraphQLRequest(c *fiber.Ctx) error { Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName}, }) if !rateLimitedRequest(extractedUserID, extractedRoleName) { - return c.Status(429).SendString("Rate limit exceeded, try again later") + return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later") } } - parsedResult := parseGraphQLQuery(c) + parsedResult := parseGraphQLQuery(c) // Ensure this function is defined elsewhere if parsedResult.shouldBlock { - return c.Status(403).SendString("Request blocked") + return c.Status(fiber.StatusForbidden).SendString("Request blocked") } if parsedResult.shouldIgnore { @@ -208,7 +205,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - return c.Status(500).SendString("Can't proxy the request - try again later") + return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later") } } @@ -217,6 +214,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { return nil } +// proxyAndCacheTheRequest proxies and caches the request if needed. func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error { if err := proxyTheRequest(c, currentEndpoint); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -224,7 +222,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - return c.Status(500).SendString("Can't proxy the request - try again later") + return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later") } libpack_cache.CacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second) @@ -232,6 +230,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, return c.Send(c.Response().Body()) } +// logAndMonitorRequest logs and monitors the request processing. func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) { labels := map[string]string{ "op_type": opType,