diff --git a/server.go b/server.go index 9dc40a5..b8f3b68 100644 --- a/server.go +++ b/server.go @@ -107,67 +107,64 @@ func processGraphQLRequest(c *fiber.Ctx) error { startTime := time.Now() // Initialize variables with default values - extractedUserID := "-" - extractedRoleName := "-" - var queryCacheHash string + extractedUserID, extractedRoleName := "-", "-" - if cfg.Trace.Enable { - trace_header := c.Request().Header.Peek("X-Trace-Span") - if trace_header != nil { - traceHeaders := make(map[string]string) - err := json.Unmarshal([]byte(trace_header), &traceHeaders) - if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Error unmarshalling tracer header", - Pairs: map[string]interface{}{"error": err}, - }) - } + // Pre-fetch headers and trace header processing + headers := c.Request().Header + traceHeader := headers.Peek("X-Trace-Span") + authorization := headers.Peek("Authorization") + + if cfg.Trace.Enable && traceHeader != nil { + traceHeaders := make(map[string]string) + if err := json.Unmarshal(traceHeader, &traceHeaders); err != nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Error unmarshalling tracer header", + Pairs: map[string]interface{}{"error": err}, + }) + } else { ctx := libpack_trace.TraceContextExtract(context.Background(), traceHeaders) _, span := libpack_trace.ContinueSpanFromContext(ctx, "GraphQLRequest") defer span.End() - } else { - cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "No trace header found", - Pairs: nil, - }) } + } else if cfg.Trace.Enable { + cfg.Logger.Warning(&libpack_logger.LogMessage{ + Message: "No trace header found", + Pairs: nil, + }) } - authorization := c.Request().Header.Peek("Authorization") + // JWT and role extraction with pre-check if authorization != nil && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(string(authorization)) } + // Check for banned users early if checkIfUserIsBanned(c, extractedUserID) { - c.Status(403).SendString("User is banned") - return nil + return c.Status(403).SendString("User is banned") } + // Role extraction from header if len(cfg.Client.RoleFromHeader) > 0 { - extractedRoleName = string(c.Request().Header.Peek(cfg.Client.RoleFromHeader)) + extractedRoleName = string(headers.Peek(cfg.Client.RoleFromHeader)) if extractedRoleName == "" { extractedRoleName = "-" } } - // Implementing rate limiting if enabled - if cfg.Client.RoleRateLimit { + // Rate limiting check + if cfg.Client.RoleRateLimit && !rateLimitedRequest(extractedUserID, extractedRoleName) { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Rate limiting enabled", Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName}, }) - if !rateLimitedRequest(extractedUserID, extractedRoleName) { - c.Status(429).SendString("Rate limit exceeded, try again later") - return nil - } + return c.Status(429).SendString("Rate limit exceeded, try again later") } + // Parsing GraphQL query parsedResult := parseGraphQLQuery(c) if parsedResult.shouldBlock { - c.Status(403).SendString("Request blocked") - return nil + return c.Status(403).SendString("Request blocked") } - if parsedResult.shouldIgnore { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Request passed as-is - probably not a GraphQL", @@ -175,17 +172,10 @@ func processGraphQLRequest(c *fiber.Ctx) error { }) return proxyTheRequest(c, parsedResult.activeEndpoint) } - - calculatedQueryHash := libpack_cache.CalculateHash(c) - - if parsedResult.cacheTime > 0 { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Cache time set via query", - Pairs: map[string]interface{}{"cacheTime": parsedResult.cacheTime}, - }) - } else { - // If not set via query, try setting via header - cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query") + // Cache handling logic + queryCacheHash := libpack_cache.CalculateHash(c) + if parsedResult.cacheTime == 0 { + cacheQuery := headers.Peek("X-Cache-Graphql-Query") if cacheQuery != nil { parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery)) cfg.Logger.Debug(&libpack_logger.LogMessage{ @@ -197,39 +187,34 @@ func processGraphQLRequest(c *fiber.Ctx) error { } } - wasCached := false - if parsedResult.cacheRefresh { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Cache refresh requested via query", Pairs: map[string]interface{}{"user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, }) - libpack_cache.CacheDelete(calculatedQueryHash) + libpack_cache.CacheDelete(queryCacheHash) } - // Handling Cache Logic + wasCached := false if parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable { cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Cache enabled", Pairs: map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable}, }) - queryCacheHash = calculatedQueryHash - if cachedResponse := libpack_cache.CacheLookup(queryCacheHash); cachedResponse != nil { cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil) cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Cache hit", Pairs: map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, }) - c.Request().Header.Add("X-Cache-Hit", "true") - err := c.Send(cachedResponse) - if err != nil { + headers.Add("X-Cache-Hit", "true") + if err := c.Send(cachedResponse); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't send the cached response", Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't send the cached response - try again later") + return c.Status(500).SendString("Can't send the cached response - try again later") } wasCached = true } else { @@ -241,23 +226,18 @@ func processGraphQLRequest(c *fiber.Ctx) error { proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime, parsedResult.activeEndpoint) } } else { - err := proxyTheRequest(c, parsedResult.activeEndpoint) - if err != nil { + if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't proxy the request", Pairs: map[string]interface{}{"error": err.Error()}, }) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - c.Status(500).SendString("Can't proxy the request - try again later") - return nil + return c.Status(500).SendString("Can't proxy the request - try again later") } } timeTaken := time.Since(startTime) - - // Logging & Monitoring logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime) - return nil }