diff --git a/proxy.go b/proxy.go index d2aff9b..8f7a1d1 100644 --- a/proxy.go +++ b/proxy.go @@ -1,15 +1,18 @@ package main import ( + "context" "crypto/tls" "fmt" "time" "github.com/avast/retry-go/v4" + "github.com/goccy/go-json" fiber "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/proxy" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" + libpack_trace "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing" "github.com/valyala/fasthttp" ) @@ -29,7 +32,7 @@ func createFasthttpClient(timeout int) *fasthttp.Client { } } -func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { +func proxyTheRequest(c *fiber.Ctx, currentEndpoint string, ctx context.Context) error { if !checkAllowedURLs(c) { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Request blocked", @@ -129,5 +132,25 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error { } c.Response().Header.Del(fiber.HeaderServer) + if cfg.Trace.Enable { + tracingContext := libpack_trace.TraceContextInject(ctx) + if tracingContext == nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't inject empty tracing context", + }) + return nil + } + traceJsonEncoded, err := json.Marshal(tracingContext) + if err != nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't convert tracing context to JSON", + Pairs: map[string]interface{}{ + "error": err.Error(), + }, + }) + return err + } + c.Response().Header.Set("X-Trace-Span", string(traceJsonEncoded)) + } return nil } diff --git a/server.go b/server.go index b8f3b68..54b2e68 100644 --- a/server.go +++ b/server.go @@ -77,6 +77,26 @@ func checkAllowedURLs(c *fiber.Ctx) bool { return ok } +func extractTraceHeaders(c *fiber.Ctx) (found bool, traceHeaders map[string]string) { + if !cfg.Trace.Enable { + return + } + headers := c.Request().Header + traceHeader := headers.Peek("X-Trace-Span") + if traceHeader != nil { + traceHeaders = make(map[string]string) + if err := json.Unmarshal(traceHeader, &traceHeaders); err != nil { + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Error unmarshalling tracer header", + Pairs: map[string]interface{}{"error": err}, + }) + return + } + found = true + } + return +} + func healthCheck(c *fiber.Ctx) error { if len(cfg.Server.HealthcheckGraphQL) > 0 { cfg.Logger.Debug(&libpack_logger.LogMessage{ @@ -111,26 +131,14 @@ func processGraphQLRequest(c *fiber.Ctx) error { // Pre-fetch headers and trace header processing headers := c.Request().Header - traceHeader := headers.Peek("X-Trace-Span") authorization := headers.Peek("Authorization") + ctx := context.Background() + traceHeaderFound, traceHeader := extractTraceHeaders(c) - if cfg.Trace.Enable && traceHeader != nil { - traceHeaders := make(map[string]string) - if err := json.Unmarshal(traceHeader, &traceHeaders); err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Error unmarshalling tracer header", - Pairs: map[string]interface{}{"error": err}, - }) - } else { - ctx := libpack_trace.TraceContextExtract(context.Background(), traceHeaders) - _, span := libpack_trace.ContinueSpanFromContext(ctx, "GraphQLRequest") - defer span.End() - } - } else if cfg.Trace.Enable { - cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "No trace header found", - Pairs: nil, - }) + if traceHeaderFound { + ctx = libpack_trace.TraceContextExtract(ctx, traceHeader) + _, span := libpack_trace.ContinueSpanFromContext(ctx, "GraphQLRequest") + defer span.End() } // JWT and role extraction with pre-check @@ -170,7 +178,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { Message: "Request passed as-is - probably not a GraphQL", Pairs: nil, }) - return proxyTheRequest(c, parsedResult.activeEndpoint) + return proxyTheRequest(c, parsedResult.activeEndpoint, ctx) } // Cache handling logic queryCacheHash := libpack_cache.CalculateHash(c) @@ -223,10 +231,10 @@ func processGraphQLRequest(c *fiber.Ctx) error { Message: "Cache miss", Pairs: map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}, }) - proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime, parsedResult.activeEndpoint) + proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime, parsedResult.activeEndpoint, ctx) } } else { - if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil { + if err := proxyTheRequest(c, parsedResult.activeEndpoint, ctx); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't proxy the request", Pairs: map[string]interface{}{"error": err.Error()}, @@ -242,8 +250,8 @@ func processGraphQLRequest(c *fiber.Ctx) error { } // Additional helper function to avoid code repetition -func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) { - err := proxyTheRequest(c, currentEndpoint) +func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string, ctx context.Context) { + err := proxyTheRequest(c, currentEndpoint, ctx) if err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ Message: "Can't proxy the request",