From 3a18e0e935bf7a30a340a7a2b729e3f867df7af2 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 5 Mar 2024 22:40:06 +0000 Subject: [PATCH] Improve stats gathering and tests improvements. (#8) --- .github/workflows/pr.yaml | 82 +++++++++++++++++++++++++++++++++++++++ cache.go | 2 +- cache_test.go | 3 +- details_test.go | 4 +- graphql.go | 56 +++++++++++--------------- graphql_test.go | 45 +++++++++------------ logging/logging.go | 6 +++ logging/logging_test.go | 4 +- main.go | 5 +++ main_test.go | 20 ++++++++-- monitoring.go | 2 +- monitoring/monitoring.go | 39 +++++++++++-------- monitoring/structs.go | 10 +++-- proxy.go | 20 ++++++---- proxy_test.go | 82 +++++++++++++++++++++++++++++++++++++++ server.go | 7 ++-- 16 files changed, 284 insertions(+), 103 deletions(-) create mode 100644 .github/workflows/pr.yaml create mode 100644 proxy_test.go diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml new file mode 100644 index 0000000..3611939 --- /dev/null +++ b/.github/workflows/pr.yaml @@ -0,0 +1,82 @@ +name: Run tests on PR + +on: + pull_request: + branches: + - "main" + push: + paths-ignore: + - "**/**.md" + - "**/**.yaml" + - "static/**" + branches: + - "!main" + +env: + GO_VERSION: ">=1.21" + +jobs: + jobs: + # This job is responsible for preparation of the build + # environment variables. + prepare: + name: Preparing build context + runs-on: ubuntu-latest + + steps: + - name: Checkout repo + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + id: cache + with: + go-version: ${{env.GO_VERSION}} + cache-dependency-path: "**/*.sum" + + - name: Go get dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: | + go get ./... + + # This job is responsible for running tests and linting the codebase + test: + name: "Unit testing" + # needs: [prepare] + runs-on: ubuntu-latest + # container: github/super-linter:v4 + needs: [prepare] + + services: + # Label used to access the service container + redis: + # Docker Hub image + image: redis + # Set health checks to wait until redis has started + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ${{env.GO_VERSION}} + cache-dependency-path: "**/*.sum" + + - name: Install dependencies + run: | + go mod tidy + + - name: Run unit tests + env: + REDIS_HOST: redis + REDIS_PORT: 6379 + run: | + export REDIS_SERVER="$REDIS_HOST:$REDIS_PORT" + CI_RUN=${CI} make test diff --git a/cache.go b/cache.go index be8f398..c72a61b 100644 --- a/cache.go +++ b/cache.go @@ -13,7 +13,7 @@ func calculateHash(c *fiber.Ctx) string { } func enableCache() { - cfg.Cache.CacheClient = libpack_cache.New(time.Duration(cfg.Cache.CacheTTL) * time.Second * 100) + cfg.Cache.CacheClient = libpack_cache.New(time.Duration(cfg.Cache.CacheTTL) * time.Second) } func cacheLookup(hash string) []byte { diff --git a/cache_test.go b/cache_test.go index e076935..ca2ef96 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,7 +1,6 @@ package main import ( - "testing" "time" ) @@ -38,7 +37,7 @@ func (suite *Tests) Test_cacheLookup() { }, } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { if tt.addCache.data != nil { cfg.Cache.CacheClient.Set(tt.args.hash, tt.addCache.data, time.Duration(90*time.Second)) } diff --git a/details_test.go b/details_test.go index b58fb18..f5e5a70 100644 --- a/details_test.go +++ b/details_test.go @@ -1,7 +1,5 @@ package main -import "testing" - func (suite *Tests) Test_extractClaimsFromJWTHeader() { jwt_token_for_tests := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiSGFzdXJhIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsiZ3Vlc3QiLCJ1c2VyIiwiZ3JvdXBhZG1pbiIsInBheWFkbWluIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6Imd1ZXN0IiwieC1oYXN1cmEtdXNlci1pZCI6IjE2NyIsIngtaGFzdXJhLXVzZXItdXVpZCI6ImRkM2U2ZTM1LTA0MDktNDNiMC1iZmYxLWNlZjNjNmVkNWYxMCJ9LCJpc3MiOiJBdXRoU2VydmljZSIsImV4cCI6MTY5NjgwMTcyNiwibmJmIjoxNjk2NTg1NzI2LCJpYXQiOjE2OTY1ODU3MjZ9.dsJ5JKzG5tXOlqeZ_Gfe2XC-vyrcwtYwOGfhvt8q9UY" @@ -68,7 +66,7 @@ func (suite *Tests) Test_extractClaimsFromJWTHeader() { }, } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { if len(tt.jwt_token_path) > 0 { cfg.Client.JWTUserClaimPath = tt.jwt_token_path } diff --git a/graphql.go b/graphql.go index 8ecbfb3..25fd1c8 100644 --- a/graphql.go +++ b/graphql.go @@ -1,7 +1,6 @@ package main import ( - "flag" "strconv" "strings" @@ -41,40 +40,26 @@ var introspectionQuerySet = map[string]struct{}{} var introspectionAllowedQueries = map[string]struct{}{} var allowedUrls = map[string]struct{}{} +// Utility function to convert a slice of strings to a map for O(1) lookups. +func sliceToMap(slice []string) map[string]struct{} { + resultMap := make(map[string]struct{}, len(slice)) + for _, item := range slice { + resultMap[strings.ToLower(item)] = struct{}{} + } + return resultMap +} + func prepareQueriesAndExemptions() { - introspectionQuerySet = map[string]struct{}{} - introspectionQuerySet = func() map[string]struct{} { - rsqs := make(map[string]struct{}, len(introspection_queries)) - for _, query := range introspection_queries { - rsqs[strings.ToLower(query)] = struct{}{} - } - return rsqs - }() - - introspectionAllowedQueries = map[string]struct{}{} - introspectionAllowedQueries = func() map[string]struct{} { - rsqs := make(map[string]struct{}, len(cfg.Security.IntrospectionAllowed)) - for _, query := range cfg.Security.IntrospectionAllowed { - rsqs[strings.ToLower(query)] = struct{}{} - } - return rsqs - }() - - allowedUrls = map[string]struct{}{} - allowedUrls = func() map[string]struct{} { - rsqs := make(map[string]struct{}, len(cfg.Server.AllowURLs)) - for _, query := range cfg.Server.AllowURLs { - rsqs[strings.ToLower(query)] = struct{}{} - } - return rsqs - }() + introspectionQuerySet = sliceToMap(introspection_queries) + introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed) + allowedUrls = sliceToMap(cfg.Server.AllowURLs) } type parseGraphQLQueryResult struct { operationType string operationName string - cacheRequest bool cacheTime int + cacheRequest bool cacheRefresh bool shouldBlock bool shouldIgnore bool @@ -86,7 +71,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { err := json.Unmarshal(c.Body(), &m) if err != nil { cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())}) - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } return @@ -95,7 +80,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { query, ok := m["query"].(string) if !ok { cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m}) - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } return @@ -104,7 +89,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { p, err := parser.Parse(parser.ParseParams{Source: query}) if err != nil { cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m}) - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } return @@ -122,7 +107,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode { cfg.Logger.Warning("Mutation blocked", m) - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } c.Status(403).SendString("The server is in read-only mode") @@ -138,7 +123,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string)) if err != nil { cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}) - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } return @@ -184,8 +169,11 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) { whateverLower := strings.ToLower(whatever) got_exemption := false + + // If the query is an introspection query, we need to check if it's allowed. if _, exists := introspectionQuerySet[whateverLower]; exists { if len(cfg.Security.IntrospectionAllowed) > 0 { + if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever}) got_exemption = true @@ -197,7 +185,7 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bo } } if shouldBlock { - if flag.Lookup("test.v") == nil { + if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } c.Status(403).SendString("Introspection queries are not allowed") diff --git a/graphql_test.go b/graphql_test.go index a135947..7d5f664 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -1,10 +1,6 @@ package main import ( - "testing" - - fiber "github.com/gofiber/fiber/v2" - libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" "github.com/valyala/fasthttp" ) @@ -166,6 +162,7 @@ func (suite *Tests) Test_parseGraphQLQuery() { { name: "test mutation query with config: read only", suppliedSettings: func() *config { + parseConfig() cfg.Server.ReadOnlyMode = true return cfg }(), @@ -199,6 +196,7 @@ func (suite *Tests) Test_parseGraphQLQuery() { { name: "test simple query with introspection __schema config: block introspection", suppliedSettings: func() *config { + parseConfig() cfg.Security.BlockIntrospection = true return cfg }(), @@ -221,7 +219,6 @@ func (suite *Tests) Test_parseGraphQLQuery() { parseConfig() cfg.Security.BlockIntrospection = true cfg.Security.IntrospectionAllowed = []string{} - prepareQueriesAndExemptions() return cfg }(), suppliedQuery: queries{ @@ -243,7 +240,6 @@ func (suite *Tests) Test_parseGraphQLQuery() { parseConfig() cfg.Security.BlockIntrospection = true cfg.Security.IntrospectionAllowed = []string{"__schema"} - prepareQueriesAndExemptions() return cfg }(), suppliedQuery: queries{ @@ -275,15 +271,9 @@ func (suite *Tests) Test_parseGraphQLQuery() { } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { cfg = &config{} - cfg.Logger = libpack_logging.NewLogger() - defer func() { - cfg = &config{} - }() - - app := fiber.New() - + parseConfig() ctx_headers := func() *fasthttp.RequestHeader { h := fasthttp.RequestHeader{} for k, v := range tt.suppliedQuery.headers { @@ -298,28 +288,29 @@ func (suite *Tests) Test_parseGraphQLQuery() { ctx_request.AppendBody([]byte(tt.suppliedQuery.body)) - ctx := app.AcquireCtx(&fasthttp.RequestCtx{ + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{ Request: ctx_request, }) - defer app.ReleaseCtx(ctx) + // defer func() { + // cfg = &config{} + // parseConfig() + // suite.app.ReleaseCtx(ctx) + // }() + assert.NotNil(ctx, "Fiber context is nil") if tt.suppliedSettings != nil { cfg = tt.suppliedSettings } - - defer func() { - cfg = &config{} - }() - + prepareQueriesAndExemptions() parseResult := parseGraphQLQuery(ctx) - assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type", tt.name) - assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name", tt.name) - assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value", tt.name) - assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value", tt.name) - assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value", tt.name) - assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value", tt.name) + assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type "+tt.name) + assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name "+tt.name) + assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value "+tt.name) + assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value "+tt.name) + assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value "+tt.name) + assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value "+tt.name) if tt.wantResults.returnCode > 0 { assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name) diff --git a/logging/logging.go b/logging/logging.go index b8610e0..952e88a 100644 --- a/logging/logging.go +++ b/logging/logging.go @@ -59,10 +59,16 @@ func (lw *LogConfig) log(w io.Writer, level zerolog.Level, message string, v map } func (lw *LogConfig) Debug(message string, v ...map[string]interface{}) { + if !lw.logger.Debug().Enabled() { + return + } lw.log(os.Stdout, zerolog.DebugLevel, message, mergeMaps(v)) } func (lw *LogConfig) Info(message string, v ...map[string]interface{}) { + if !lw.logger.Info().Enabled() { + return + } lw.log(os.Stdout, zerolog.InfoLevel, message, mergeMaps(v)) } diff --git a/logging/logging_test.go b/logging/logging_test.go index 823eac6..37e2ec6 100644 --- a/logging/logging_test.go +++ b/logging/logging_test.go @@ -183,7 +183,7 @@ func (suite *LoggingTestSuite) TestLogConfig_AllHandlers() { } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { if tt.envMinLogLevel != "" { os.Setenv("LOG_LEVEL", tt.envMinLogLevel) defer os.Unsetenv("LOG_LEVEL") @@ -274,7 +274,7 @@ func (suite *LoggingTestSuite) TestFullMessage() { } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { if tt.envMinLogLevel != "" { os.Setenv("LOG_LEVEL", tt.envMinLogLevel) defer os.Unsetenv("LOG_LEVEL") diff --git a/main.go b/main.go index 0bd8825..a5f53f0 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "flag" "os" "strings" @@ -86,3 +87,7 @@ func main() { StartMonitoringServer() StartHTTPProxy() } + +func ifNotInTest() bool { + return flag.Lookup("test.v") == nil +} diff --git a/main_test.go b/main_test.go index 2e89a78..458e976 100644 --- a/main_test.go +++ b/main_test.go @@ -4,12 +4,16 @@ import ( "os" "testing" + "github.com/goccy/go-json" + "github.com/gofiber/fiber/v2" + libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" assertions "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) type Tests struct { suite.Suite + app *fiber.App } var ( @@ -21,6 +25,16 @@ func (suite *Tests) BeforeTest(suiteName, testName string) { func (suite *Tests) SetupTest() { assert = assertions.New(suite.T()) + suite.app = fiber.New( + fiber.Config{ + DisableStartupMessage: true, + JSONEncoder: json.Marshal, + JSONDecoder: json.Unmarshal, + }, + ) + parseConfig() + StartMonitoringServer() + cfg.Logger = libpack_logging.NewLogger() // Setup environment variables here if needed os.Setenv("GMP_TEST_STRING", "testValue") os.Setenv("GMP_TEST_INT", "123") @@ -48,10 +62,10 @@ func TestSuite(t *testing.T) { func (suite *Tests) Test_envVariableSetting() { tests := []struct { - name string - envKey string defaultValue any expected any + name string + envKey string }{ { name: "test_string", @@ -86,7 +100,7 @@ func (suite *Tests) Test_envVariableSetting() { } for _, tt := range tests { - suite.T().Run(tt.name, func(t *testing.T) { + suite.Run(tt.name, func() { result := getDetailsFromEnv(tt.envKey, tt.defaultValue) assert.Equal(tt.expected, result) }) diff --git a/monitoring.go b/monitoring.go index e115fd8..bad9af8 100644 --- a/monitoring.go +++ b/monitoring.go @@ -5,7 +5,7 @@ import ( ) func StartMonitoringServer() { - cfg.Monitoring = libpack_monitoring.NewMonitoring(cfg.Server.PurgeOnCrawl, cfg.Server.PurgeEvery) + cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{PurgeOnCrawl: cfg.Server.PurgeOnCrawl, PurgeEvery: cfg.Server.PurgeEvery}) cfg.Monitoring.AddMetricsPrefix("graphql_proxy") cfg.Monitoring.RegisterDefaultMetrics() } diff --git a/monitoring/monitoring.go b/monitoring/monitoring.go index fe582bf..b8fe825 100644 --- a/monitoring/monitoring.go +++ b/monitoring/monitoring.go @@ -4,6 +4,7 @@ package libpack_monitoring import ( + "flag" "fmt" "time" @@ -17,31 +18,37 @@ import ( type MetricsSetup struct { metrics_set *metrics.Set metrics_set_custom *metrics.Set + ic *InitConfig metrics_prefix string } var ( - log *logging.LogConfig - purgeMetricsOnCrawl bool - purgeMetricsEvery int + log *logging.LogConfig ) -func NewMonitoring(purgeOnCrawl bool, purgeEvery int) *MetricsSetup { - purgeMetricsOnCrawl = purgeOnCrawl - purgeMetricsEvery = purgeEvery +type InitConfig struct { + PurgeOnCrawl bool + PurgeEvery int +} + +func NewMonitoring(ic *InitConfig) *MetricsSetup { log = logging.NewLogger() - ms := &MetricsSetup{} + ms := &MetricsSetup{ic: ic} ms.metrics_set = metrics.NewSet() ms.metrics_set_custom = metrics.NewSet() - go ms.startPrometheusEndpoint() + // if not testing, start the prometheus endpoint - if purgeEvery > 0 { - ticker := time.NewTicker(time.Duration(purgeEvery) * time.Second) - go func() { - for range ticker.C { - ms.PurgeMetrics() - } - }() + if flag.Lookup("test.v") == nil { + go ms.startPrometheusEndpoint() + + if ic.PurgeEvery > 0 { + ticker := time.NewTicker(time.Duration(ic.PurgeEvery) * time.Second) + go func() { + for range ticker.C { + ms.PurgeMetrics() + } + }() + } } return ms @@ -63,7 +70,7 @@ func (ms *MetricsSetup) metricsEndpoint(c *fiber.Ctx) error { ms.metrics_set.WritePrometheus(c.Response().BodyWriter()) ms.metrics_set_custom.WritePrometheus(c.Response().BodyWriter()) - if purgeMetricsOnCrawl && purgeMetricsEvery == 0 { + if ms.ic.PurgeOnCrawl && ms.ic.PurgeEvery == 0 { ms.PurgeMetrics() } return nil diff --git a/monitoring/structs.go b/monitoring/structs.go index 99f7710..fe1cd5e 100644 --- a/monitoring/structs.go +++ b/monitoring/structs.go @@ -1,8 +1,10 @@ package libpack_monitoring const ( - MetricsSucceeded = "requests_succesful" - MetricsFailed = "requests_failed" - MetricsDuration = "requests_duration" - MetricsSkipped = "requests_skipped" + MetricsSucceeded = "requests_succesful" + MetricsFailed = "requests_failed" + MetricsDuration = "requests_duration" + MetricsSkipped = "requests_skipped" + MetricsExecutedQuery = "executed_query" + MetricsTimedQuery = "timed_query" ) diff --git a/proxy.go b/proxy.go index f615dbf..b8a3386 100644 --- a/proxy.go +++ b/proxy.go @@ -31,7 +31,9 @@ func createFasthttpClient(timeout int) *fasthttp.Client { func proxyTheRequest(c *fiber.Ctx) error { if !checkAllowedURLs(c) { cfg.Logger.Error("Request blocked", map[string]interface{}{"path": c.Path()}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } c.Status(403).SendString("Request blocked - not allowed URL") return nil } @@ -44,11 +46,13 @@ func proxyTheRequest(c *fiber.Ctx) error { err := retry.Do( func() error { - err := proxy.DoRedirects(c, cfg.Server.HostGraphQL+c.Path(), 3, cfg.Client.FastProxyClient) - if err != nil { - cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - return err + errInt := proxy.DoRedirects(c, cfg.Server.HostGraphQL+c.Path(), 3, cfg.Client.FastProxyClient) + if errInt != nil { + cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": errInt.Error()}) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } + return errInt } return nil }, @@ -69,7 +73,9 @@ func proxyTheRequest(c *fiber.Ctx) error { cfg.Logger.Debug("Received proxied response", map[string]interface{}{"path": c.Path(), "response_body": string(c.Response().Body()), "response_code": c.Response().StatusCode(), "headers": c.GetRespHeaders(), "request_uuid": c.Locals("request_uuid")}) if c.Response().StatusCode() != 200 { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } return fmt.Errorf("Received non-200 response from the GraphQL server: %d", c.Response().StatusCode()) } diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..5918d66 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,82 @@ +package main + +import ( + "github.com/valyala/fasthttp" +) + +func (suite *Tests) Test_proxyTheRequest() { + + supplied_headers := map[string]string{ + "X-Forwarded-For": "127.0.0.1", + "Content-Type": "application/json", + } + + tests := []struct { + name string + query string + host string + path string + headers map[string]string + wantErr bool + }{ + { + name: "test_empty", + query: `query { + __type(name: "Query") { + name + } + }`, + host: "https://telegram-bot.app/", + path: "/v1/graphql", + headers: supplied_headers, + wantErr: false, + }, + { + name: "test_wrong_url", + query: `query { + __type(name: "Query") { + name + } + }`, + host: "https://google.com/", + path: "/v1/wrongURL", + headers: supplied_headers, + wantErr: true, + }, + } + + for _, tt := range tests { + suite.Run(tt.name, func() { + + cfg = &config{} + parseConfig() + cfg.Server.HostGraphQL = tt.host + + ctx_headers := func() *fasthttp.RequestHeader { + h := fasthttp.RequestHeader{} + for k, v := range tt.headers { + h.Add(k, v) + } + return &h + }() + + ctx_request := fasthttp.Request{ + Header: *ctx_headers, + } + ctx_request.SetRequestURI(tt.path) + ctx_request.Header.SetMethod("POST") + + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{ + Request: ctx_request, + }) + + assert.NotNil(ctx, "Fiber context is nil", tt.name) + err := proxyTheRequest(ctx) + if tt.wantErr { + assert.NotNil(err, "Error is nil", tt.name) + } else { + assert.Nil(err, "Error is not nil", tt.name) + } + }) + } +} diff --git a/server.go b/server.go index 959ba49..0da9986 100644 --- a/server.go +++ b/server.go @@ -148,6 +148,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil { cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}) + c.Request().Header.Add("X-Cache-Hit", "true") c.Send(cachedResponse) wasCached = true } else { @@ -201,10 +202,10 @@ func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached } cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil) - cfg.Monitoring.Increment("executed_query", labels) + cfg.Monitoring.Increment(libpack_monitoring.MetricsExecutedQuery, labels) if !wasCached { - cfg.Monitoring.UpdateDuration("timed_query", labels, startTime) - cfg.Monitoring.Update("timed_query", labels, float64(duration.Milliseconds())) + cfg.Monitoring.UpdateDuration(libpack_monitoring.MetricsTimedQuery, labels, startTime) + cfg.Monitoring.Update(libpack_monitoring.MetricsTimedQuery, labels, float64(duration.Milliseconds())) } }