From bc128493b0cdb740262417d6c142886a079394dc Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 10 Oct 2023 13:54:03 +0100 Subject: [PATCH] Add tests (#2) * Initial tests draft. * Add tests for parsing jwt token * Further code optimisations --- Makefile | 4 +- cache_test.go | 49 +++++++++++++++++++++++++ details.go | 47 ++++++++++++------------ details_test.go | 83 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 +++ go.sum | 9 +++++ main_test.go | 34 +++++++++++++++++ ratelimit.go | 97 ++++++++++++++++++++++++++++++++----------------- server.go | 88 +++++++++++++++++++++++++++----------------- 9 files changed, 322 insertions(+), 94 deletions(-) create mode 100644 cache_test.go create mode 100644 details_test.go create mode 100644 main_test.go diff --git a/Makefile b/Makefile index c1857f7..13ea839 100644 --- a/Makefile +++ b/Makefile @@ -10,8 +10,8 @@ help: ## display this help @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n\nTargets:\n"} /^[a-zA-Z0-9_-]+:.*?##/ { printf " \033[36m%-20s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST) .PHONY: run -run: ## run application - @LOG_LEVEL=warn BLOCK_SCHEMA_INTROSPECTION=false JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/v1/graphql go run *.go +run: build ## run application + @LOG_LEVEL=warn BLOCK_SCHEMA_INTROSPECTION=false JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/v1/graphql ./graphql-proxy .PHONY: build build: ## build the binary diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..307e533 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "testing" + "time" +) + +func (suite *Tests) Test_cacheLookup() { + type args struct { + hash string + } + tests := []struct { + name string + args args + want []byte + addCache struct { + data []byte + } + }{ + { + name: "test_non_existent", + args: args{ + hash: "00000000000000000000000000000000000000", + }, + want: nil, + }, + { + name: "test_existent", + args: args{ + hash: "00000000000000000000000000000000000001", + }, + want: []byte("it's fine."), + addCache: struct { + data []byte + }{ + data: []byte("it's fine."), + }, + }, + } + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + if tt.addCache.data != nil { + cfg.Cache.CacheClient.Set(tt.args.hash, tt.addCache.data, time.Duration(1)*time.Second) + } + got := cacheLookup(tt.args.hash) + assert.Equal(tt.want, got, "Unexpected cache lookup result") + }) + } +} diff --git a/details.go b/details.go index da6587f..94a654e 100644 --- a/details.go +++ b/details.go @@ -2,6 +2,7 @@ package main import ( "encoding/base64" + "fmt" "strings" "github.com/lukaszraczylo/ask" @@ -9,45 +10,43 @@ import ( ) func extractClaimsFromJWTHeader(authorization string) (usr string, role string) { + usr, role = "-", "-" + + handleError := func(msg string, details map[string]interface{}) { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + cfg.Logger.Error(msg, details) + } + tokenParts := strings.Split(authorization, ".") if len(tokenParts) != 3 { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error("Can't split the token", map[string]interface{}{"token": authorization}) + handleError("Can't split the token", map[string]interface{}{"token": authorization}) return } + claim, err := base64.RawURLEncoding.DecodeString(tokenParts[1]) if err != nil { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error("Can't decode the token", map[string]interface{}{"token": authorization}) + handleError("Can't decode the token", map[string]interface{}{"token": authorization}) return } + var claimMap map[string]interface{} - err = json.Unmarshal(claim, &claimMap) - if err != nil { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error("Can't unmarshal the claim", map[string]interface{}{"token": authorization}) + if err = json.Unmarshal(claim, &claimMap); err != nil { + handleError("Can't unmarshal the claim", map[string]interface{}{"token": authorization}) return } - if len(cfg.Client.JWTUserClaimPath) > 0 { - var ok bool - usr, ok = ask.For(claimMap, cfg.Client.JWTUserClaimPath).String("-") - if !ok { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error("Can't find the user id", map[string]interface{}{"claim_map": claimMap, "path": cfg.Client.JWTUserClaimPath}) - return + extractClaim := func(claimPath string, target *string, name string) { + if len(claimPath) > 0 { + var ok bool + *target, ok = ask.For(claimMap, claimPath).String("-") + if !ok { + handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath}) + } } } - if len(cfg.Client.JWTRoleClaimPath) > 0 { - var ok bool - role, ok = ask.For(claimMap, cfg.Client.JWTRoleClaimPath).String("-") - if !ok { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - cfg.Logger.Error("Can't find the role", map[string]interface{}{"claim_map": claimMap, "path": cfg.Client.JWTRoleClaimPath}) - return - } - } + extractClaim(cfg.Client.JWTUserClaimPath, &usr, "user id") + extractClaim(cfg.Client.JWTRoleClaimPath, &role, "role") return } diff --git a/details_test.go b/details_test.go new file mode 100644 index 0000000..b58fb18 --- /dev/null +++ b/details_test.go @@ -0,0 +1,83 @@ +package main + +import "testing" + +func (suite *Tests) Test_extractClaimsFromJWTHeader() { + jwt_token_for_tests := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiSGFzdXJhIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsiZ3Vlc3QiLCJ1c2VyIiwiZ3JvdXBhZG1pbiIsInBheWFkbWluIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6Imd1ZXN0IiwieC1oYXN1cmEtdXNlci1pZCI6IjE2NyIsIngtaGFzdXJhLXVzZXItdXVpZCI6ImRkM2U2ZTM1LTA0MDktNDNiMC1iZmYxLWNlZjNjNmVkNWYxMCJ9LCJpc3MiOiJBdXRoU2VydmljZSIsImV4cCI6MTY5NjgwMTcyNiwibmJmIjoxNjk2NTg1NzI2LCJpYXQiOjE2OTY1ODU3MjZ9.dsJ5JKzG5tXOlqeZ_Gfe2XC-vyrcwtYwOGfhvt8q9UY" + + type args struct { + authorization string + } + + tests := []struct { + name string + args args + wantUsr string + wantRole string + jwt_token_path string + jwt_role_path string + }{ + { + name: "test_empty", + wantUsr: "-", + wantRole: "-", + }, + { + name: "test_invalid_path", + args: args{ + authorization: jwt_token_for_tests, + }, + wantUsr: "-", + wantRole: "-", + jwt_token_path: "invalid", + }, + { + name: "test_invalid_role_path", + args: args{ + authorization: jwt_token_for_tests, + }, + wantUsr: "-", + wantRole: "-", + jwt_role_path: "invalid", + }, + { + name: "test_valid", + args: args{ + authorization: jwt_token_for_tests, + }, + wantUsr: "167", + wantRole: "guest", + jwt_token_path: "Hasura.x-hasura-user-id", + jwt_role_path: "Hasura.x-hasura-default-role", + }, + { + name: "test_invalid_token", + args: args{ + authorization: "invalid", + }, + wantUsr: "-", + wantRole: "-", + }, + { + name: "test_invalid_three_part_token", + args: args{ + authorization: "invalid.threepart.token", + }, + wantUsr: "-", + wantRole: "-", + }, + } + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + if len(tt.jwt_token_path) > 0 { + cfg.Client.JWTUserClaimPath = tt.jwt_token_path + } + if len(tt.jwt_role_path) > 0 { + cfg.Client.JWTRoleClaimPath = tt.jwt_role_path + } + gotUsr, gotRole := extractClaimsFromJWTHeader(tt.args.authorization) + assert.Equal(tt.wantUsr, gotUsr, "Unexpected user ID") + assert.Equal(tt.wantRole, gotRole, "Unexpected role") + }) + } +} diff --git a/go.mod b/go.mod index a253b1d..4084a15 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 github.com/lukaszraczylo/go-ratecounter v0.1.8 github.com/lukaszraczylo/go-simple-graphql v1.1.31 + github.com/stretchr/testify v1.8.4 github.com/telegram-bot-app/libpack v0.0.0-20231008100411-9f7f8bf94315 ) @@ -19,15 +20,18 @@ require ( github.com/VictoriaMetrics/metrics v1.24.0 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/avast/retry-go/v4 v4.5.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.3.1 // indirect github.com/gookit/color v1.5.4 // indirect github.com/klauspost/compress v1.17.0 // indirect + github.com/kr/text v0.2.0 // indirect github.com/lukaszraczylo/pandati v0.0.29 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.4 // indirect github.com/rs/zerolog v1.31.0 // indirect github.com/telegram-bot-app/lib-logging v0.0.19 // indirect @@ -43,4 +47,5 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/term v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index e7c6318..959e8c7 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,7 @@ github.com/avast/retry-go/v4 v4.5.0/go.mod h1:7hLEXp0oku2Nir2xBAsg0PTphp9z71bN5A github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -30,6 +31,10 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 h1:lvI8Wlbg4PxkRcg2f10wgoaRpfN19v+YdRek3+dLtlM= github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415/go.mod h1:M+UVdyqZs++xtEPrascaVmZdOMhCnxjZ2SgH+xHpR0c= github.com/lukaszraczylo/go-ratecounter v0.1.8 h1:ZYm6Wkn58ZAlFWRmC7PaD4oAYHWcu8/0MUDWGe3PnJQ= @@ -56,6 +61,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= @@ -97,5 +104,7 @@ golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..e5145d3 --- /dev/null +++ b/main_test.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + "testing" + + assertions "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type Tests struct { + suite.Suite +} + +var ( + assert *assertions.Assertions +) + +func (suite *Tests) SetupTest() { + assert = assertions.New(suite.T()) +} + +func (suite *Tests) BeforeTest(suiteName, testName string) { + fmt.Println("BeforeTest") + cfg = &config{} + parseConfig() + StartMonitoringServer() +} + +// func (suite *Tests) AfterTest(suiteName, testName string) {) + +func TestSuite(t *testing.T) { + suite.Run(t, new(Tests)) +} diff --git a/ratelimit.go b/ratelimit.go index 41e51bc..3042ce9 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -25,60 +25,89 @@ var ratelimit_intervals = map[string]time.Duration{ } func loadRatelimitConfig() error { - paths := [3]string{"/app/ratelimit.json", "./ratelimit.json", "./static/default-ratelimit.json"} + paths := []string{"/app/ratelimit.json", "./ratelimit.json", "./static/default-ratelimit.json"} + for _, path := range paths { - file, err := os.Open(path) - if err != nil { - continue + err := loadConfigFromPath(path) + if err == nil { + return nil } - defer file.Close() - decoder := json.NewDecoder(file) - config := struct { - RateLimit map[string]RateLimitConfig `json:"ratelimit"` - }{} - err = decoder.Decode(&config) - if err != nil { - return err - } - - for key, value := range config.RateLimit { - value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ - Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval], - }) - cfg.Logger.Debug("Setting ratelimit config for role", map[string]interface{}{"role": key, "interval_provided": value.Interval, "interval_used": ratelimit_intervals[value.Interval], "ratelimit": value.Req}) - config.RateLimit[key] = value - } - - rateLimits = config.RateLimit - cfg.Logger.Debug("Rate limit config loaded", map[string]interface{}{"ratelimit": rateLimits}) - return nil + cfg.Logger.Error("Failed to load config", map[string]interface{}{"path": path, "error": err}) } + cfg.Logger.Debug("Rate limit config not found") return os.ErrNotExist } -func rateLimitedRequest(userId string, userRole string) (shouldAllow bool) { +func loadConfigFromPath(path string) error { + file, err := os.Open(path) + if err != nil { + return err + } + defer file.Close() + + config := struct { + RateLimit map[string]RateLimitConfig `json:"ratelimit"` + }{} + + decoder := json.NewDecoder(file) + if err := decoder.Decode(&config); err != nil { + return err + } + + for key, value := range config.RateLimit { + value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{ + Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval], + }) + cfg.Logger.Debug("Setting ratelimit config for role", map[string]interface{}{ + "role": key, + "interval_provided": value.Interval, + "interval_used": ratelimit_intervals[value.Interval], + "ratelimit": value.Req, + }) + config.RateLimit[key] = value + } + + rateLimits = config.RateLimit + cfg.Logger.Debug("Rate limit config loaded", map[string]interface{}{"ratelimit": rateLimits}) + return nil +} + +func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) { if rateLimits == nil { cfg.Logger.Debug("Rate limit config not found", map[string]interface{}{"user_role": userRole}) return true } - // check if userRole is in rateLimits - if _, ok := rateLimits[userRole]; !ok { + + // Fetch role config once to avoid multiple map lookups + roleConfig, ok := rateLimits[userRole] + if !ok { cfg.Logger.Warning("Rate limit role not found", map[string]interface{}{"user_role": userRole}) return true } - if rateLimits[userRole].RateCounterTicker == nil { + if roleConfig.RateCounterTicker == nil { cfg.Logger.Warning("Rate limit ticker not found", map[string]interface{}{"user_role": userRole}) return true } - rateLimits[userRole].RateCounterTicker.Incr(1) - ticker_rate := rateLimits[userRole].RateCounterTicker.GetRate() - cfg.Logger.Debug("Rate limit ticker", map[string]interface{}{"user_role": userRole, "user_id": userId, "rate": ticker_rate, "config_rate": rateLimits[userRole].Req, "interval": rateLimits[userRole].Interval, "interval_duration": rateLimits[userRole].Interval}) - if ticker_rate > float64(rateLimits[userRole].Req) { - cfg.Logger.Debug("Rate limit exceeded", map[string]interface{}{"user_role": userRole, "user_id": userId, "rate": ticker_rate, "config_rate": rateLimits[userRole].Req, "interval": rateLimits[userRole].Interval, "interval_duration": rateLimits[userRole].Interval}) + roleConfig.RateCounterTicker.Incr(1) + tickerRate := roleConfig.RateCounterTicker.GetRate() + + logDetails := map[string]interface{}{ + "user_role": userRole, + "user_id": userID, + "rate": tickerRate, + "config_rate": roleConfig.Req, + "interval": roleConfig.Interval, + } + + cfg.Logger.Debug("Rate limit ticker", logDetails) + + if tickerRate > float64(roleConfig.Req) { + cfg.Logger.Debug("Rate limit exceeded", logDetails) return false } + return true } diff --git a/server.go b/server.go index 53eec3e..345d105 100644 --- a/server.go +++ b/server.go @@ -42,70 +42,90 @@ func healthCheck(c *fiber.Ctx) error { } func processGraphQLRequest(c *fiber.Ctx) error { - t := time.Now() + startTime := time.Now() - var extracted_user_id string = "-" - var extracted_role_name string = "-" - var query_cache_hash string + // Initialize variables with default values + extractedUserID := "-" + extractedRoleName := "-" + var queryCacheHash string authorization := c.Request().Header.Peek("Authorization") if authorization != nil && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) { - extracted_user_id, extracted_role_name = extractClaimsFromJWTHeader(string(authorization)) + extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(string(authorization)) } + // Implementing rate limiting if enabled if cfg.Client.JWTRoleRateLimit { - cfg.Logger.Debug("Rate limiting enabled", map[string]interface{}{"user_id": extracted_user_id, "role_name": extracted_role_name}) - if !rateLimitedRequest(extracted_user_id, extracted_role_name) { + cfg.Logger.Debug("Rate limiting enabled", map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName}) + if !rateLimitedRequest(extractedUserID, extractedRoleName) { c.Status(429).SendString("Rate limit exceeded, try again later") return nil } } - opType, opName, cache_from_query, should_block := parseGraphQLQuery(c) - - if should_block { + opType, opName, cacheFromQuery, shouldBlock := parseGraphQLQuery(c) + if shouldBlock { return nil } - was_cached := false + wasCached := false - if cache_from_query || cfg.Cache.CacheEnable { - cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cache_from_query, "via_env": cfg.Cache.CacheEnable}) - query_cache_hash = calculateHash(c) - cachedResponse := cacheLookup(query_cache_hash) - if cachedResponse != nil { - cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": query_cache_hash, "user_id": extracted_user_id}) + // Handling Cache Logic + if cacheFromQuery || cfg.Cache.CacheEnable { + cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cacheFromQuery, "via_env": cfg.Cache.CacheEnable}) + queryCacheHash = calculateHash(c) + + if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil { + cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID}) c.Send(cachedResponse) - was_cached = true + wasCached = true } else { - cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": query_cache_hash, "user_id": extracted_user_id}) - proxyTheRequest(c) - cfg.Cache.CacheClient.Set(query_cache_hash, c.Response().Body(), time.Duration(cfg.Cache.CacheTTL)*time.Second) - c.Send(c.Response().Body()) + cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID}) + proxyAndCacheTheRequest(c, queryCacheHash) } } else { proxyTheRequest(c) } - time_taken := time.Since(t) - if cfg.Server.AccessLog { - cfg.Logger.Info("Request processed", map[string]interface{}{"ip": c.IP(), "user_id": extracted_user_id, "op_type": opType, "op_name": opName, "time": time_taken, "cache": was_cached}) - } - cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil) + timeTaken := time.Since(startTime) + // Logging & Monitoring + logAndMonitorRequest(c, extractedUserID, opType, opName, wasCached, timeTaken, startTime) + + return nil +} + +// Additional helper function to avoid code repetition +func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string) { + proxyTheRequest(c) + cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cfg.Cache.CacheTTL)*time.Second) + c.Send(c.Response().Body()) +} + +func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) { labels := map[string]string{ "op_type": opType, "op_name": opName, - "cached": fmt.Sprintf("%t", was_cached), - "user_id": extracted_user_id, + "cached": fmt.Sprintf("%t", wasCached), + "user_id": userID, } + if cfg.Server.AccessLog { + cfg.Logger.Info("Request processed", map[string]interface{}{ + "ip": c.IP(), + "user_id": userID, + "op_type": opType, + "op_name": opName, + "time": duration, + "cache": wasCached, + }) + } + + cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil) cfg.Monitoring.Increment("executed_query", labels) - if !was_cached { - cfg.Monitoring.UpdateDuration("timed_query", labels, t) - cfg.Monitoring.Update("timed_query", labels, float64(time_taken.Milliseconds())) + if !wasCached { + cfg.Monitoring.UpdateDuration("timed_query", labels, startTime) + cfg.Monitoring.Update("timed_query", labels, float64(duration.Milliseconds())) } - // // cfg.Monitoring.Set("timed_query", time_taken.Milliseconds()) - return nil }