diff --git a/graphql.go b/graphql.go index fda2f6e..67de270 100644 --- a/graphql.go +++ b/graphql.go @@ -70,8 +70,17 @@ func prepareQueriesAndExemptions() { }() } -func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) { - should_ignore = true +type parseGraphQLQueryResult struct { + operationType string + operationName string + cacheRequest bool + cacheTime int + shouldBlock bool + shouldIgnore bool +} + +func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) { + res = &parseGraphQLQueryResult{shouldIgnore: true} m := make(map[string]interface{}) err := json.Unmarshal(c.Body(), &m) if err != nil { @@ -100,32 +109,32 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache return } - should_ignore = false - operationName = "undefined" + res.shouldIgnore = false + res.operationName = "undefined" for _, d := range p.Definitions { if oper, ok := d.(*ast.OperationDefinition); ok { - operationType = oper.Operation + res.operationType = oper.Operation if oper.Name != nil { - operationName = oper.Name.Value + res.operationName = oper.Name.Value } - if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode { + if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode { cfg.Logger.Warning("Mutation blocked", m) if flag.Lookup("test.v") == nil { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } c.Status(403).SendString("The server is in read-only mode") - should_block = true + res.shouldBlock = true return } for _, dir := range oper.Directives { if dir.Name.Value == "cached" { - cacheRequest = true + res.cacheRequest = true for _, arg := range dir.Arguments { if arg.Name.Value == "ttl" { - cache_time, err = strconv.Atoi(arg.Value.GetValue().(string)) + res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string)) if err != nil { cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}) if flag.Lookup("test.v") == nil { @@ -135,15 +144,15 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache } } if arg.Name.Value == "refresh" { - cacheRequest = arg.Value.GetValue().(bool) + res.cacheRequest = arg.Value.GetValue().(bool) } } } } if cfg.Security.BlockIntrospection { - should_block = checkSelections(c, oper.GetSelectionSet().Selections) - if should_block { + res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections) + if res.shouldBlock { return } } @@ -171,7 +180,7 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { return false } -func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) { +func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) { whateverLower := strings.ToLower(whatever) got_exemption := false if _, exists := introspectionQuerySet[whateverLower]; exists { @@ -179,14 +188,14 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block b if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever}) got_exemption = true - should_block = false + shouldBlock = false } } if !got_exemption { - should_block = true + shouldBlock = true } } - if should_block { + if shouldBlock { if flag.Lookup("test.v") == nil { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } diff --git a/graphql_test.go b/graphql_test.go index 61900a2..4c86089 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -11,13 +11,13 @@ import ( func (suite *Tests) Test_parseGraphQLQuery() { type results struct { - op_name string - op_type string - cached_ttl int - returnCode int - is_cached bool - should_block bool - should_ignore bool + op_name string + op_type string + cached_ttl int + returnCode int + is_cached bool + shouldBlock bool + shouldIgnore bool } type queries struct { @@ -38,11 +38,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { headers: map[string]string{}, }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: true, - op_name: "", - op_type: "", + is_cached: false, + shouldBlock: false, + shouldIgnore: true, + op_name: "", + op_type: "", }, }, @@ -53,11 +53,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { headers: map[string]string{}, }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: true, - op_name: "", - op_type: "", + is_cached: false, + shouldBlock: false, + shouldIgnore: true, + op_name: "", + op_type: "", }, }, @@ -68,11 +68,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { headers: map[string]string{}, }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: true, - op_name: "", - op_type: "", + is_cached: false, + shouldBlock: false, + shouldIgnore: true, + op_name: "", + op_type: "", }, }, @@ -82,11 +82,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: false, - op_name: "MyQuery", - op_type: "query", + is_cached: false, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyQuery", + op_type: "query", }, }, @@ -96,11 +96,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", }, wantResults: results{ - is_cached: true, - should_block: false, - should_ignore: false, - op_name: "MyQuery", - op_type: "query", + is_cached: true, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyQuery", + op_type: "query", }, }, @@ -110,12 +110,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", }, wantResults: results{ - is_cached: true, - cached_ttl: 60, - should_block: false, - should_ignore: false, - op_name: "MyQuery", - op_type: "query", + is_cached: true, + cached_ttl: 60, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyQuery", + op_type: "query", }, }, @@ -125,12 +125,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", }, wantResults: results{ - is_cached: true, - cached_ttl: 0, - should_block: false, - should_ignore: false, - op_name: "MyQuery", - op_type: "query", + is_cached: true, + cached_ttl: 0, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyQuery", + op_type: "query", }, }, @@ -140,11 +140,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: false, - op_name: "MyMutation", - op_type: "mutation", + is_cached: false, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyMutation", + op_type: "mutation", }, }, @@ -158,12 +158,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", }, wantResults: results{ - is_cached: false, - should_block: true, - should_ignore: false, - op_name: "MyMutation", - op_type: "mutation", - returnCode: 403, + is_cached: false, + shouldBlock: true, + shouldIgnore: false, + op_name: "MyMutation", + op_type: "mutation", + returnCode: 403, }, }, @@ -173,11 +173,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: false, - op_name: "MyMutation", - op_type: "mutation", + is_cached: false, + shouldBlock: false, + shouldIgnore: false, + op_name: "MyMutation", + op_type: "mutation", }, }, @@ -191,12 +191,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", }, wantResults: results{ - is_cached: false, - should_block: true, - should_ignore: false, - op_name: "MyIntroQuery", - op_type: "query", - returnCode: 403, + is_cached: false, + shouldBlock: true, + shouldIgnore: false, + op_name: "MyIntroQuery", + op_type: "query", + returnCode: 403, }, }, @@ -213,12 +213,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", }, wantResults: results{ - is_cached: false, - should_block: true, - should_ignore: false, - op_name: "undefined", - op_type: "query", - returnCode: 403, + is_cached: false, + shouldBlock: true, + shouldIgnore: false, + op_name: "undefined", + op_type: "query", + returnCode: 403, }, }, @@ -235,12 +235,12 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: false, - op_name: "undefined", - op_type: "query", - returnCode: 200, + is_cached: false, + shouldBlock: false, + shouldIgnore: false, + op_name: "undefined", + op_type: "query", + returnCode: 200, }, }, @@ -250,11 +250,11 @@ func (suite *Tests) Test_parseGraphQLQuery() { body: "{\"query\":\"query MyQuery tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } \"}", }, wantResults: results{ - is_cached: false, - should_block: false, - should_ignore: true, - op_name: "", - op_type: "", + is_cached: false, + shouldBlock: false, + shouldIgnore: true, + op_name: "", + op_type: "", }, }, } @@ -298,14 +298,13 @@ func (suite *Tests) Test_parseGraphQLQuery() { cfg = &config{} }() - opType, opName, cacheFromQuery, cached_ttl, shouldBlock, should_ignore := parseGraphQLQuery(ctx) - - assert.Equal(tt.wantResults.op_type, opType, "Unexpected operation type", tt.name) - assert.Equal(tt.wantResults.op_name, opName, "Unexpected operation name", tt.name) - assert.Equal(tt.wantResults.is_cached, cacheFromQuery, "Unexpected cache value", tt.name) - assert.Equal(tt.wantResults.cached_ttl, cached_ttl, "Unexpected cache TTL value", tt.name) - assert.Equal(tt.wantResults.should_block, shouldBlock, "Unexpected block value", tt.name) - assert.Equal(tt.wantResults.should_ignore, should_ignore, "Unexpected ignore value", tt.name) + parseResult := parseGraphQLQuery(ctx) + assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type", tt.name) + assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name", tt.name) + assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value", tt.name) + assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value", tt.name) + assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value", tt.name) + assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value", tt.name) if tt.wantResults.returnCode > 0 { assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name) diff --git a/server.go b/server.go index e8c754c..08813a4 100644 --- a/server.go +++ b/server.go @@ -90,6 +90,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { } if checkIfUserIsBanned(c, extractedUserID) { + c.Status(403).SendString("User is banned") return nil } @@ -109,35 +110,35 @@ func processGraphQLRequest(c *fiber.Ctx) error { } } - opType, opName, cacheFromQuery, cache_time, shouldBlock, should_ignore := parseGraphQLQuery(c) - if shouldBlock { + parsedResult := parseGraphQLQuery(c) + if parsedResult.shouldBlock { c.Status(403).SendString("Request blocked") return nil } - if should_ignore { + if parsedResult.shouldIgnore { cfg.Logger.Debug("Request passed as-is - probably not a GraphQL") return proxyTheRequest(c) } - if cache_time > 0 { - cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cache_time": cache_time}) + if parsedResult.cacheTime > 0 { + cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cacheTime": parsedResult.cacheTime}) } else { // If not set via query, try setting via header cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query") if cacheQuery != nil { - cache_time, _ = strconv.Atoi(string(cacheQuery)) - cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cache_time": cache_time}) + parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery)) + cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cacheTime": parsedResult.cacheTime}) } else { - cache_time = cfg.Cache.CacheTTL + parsedResult.cacheTime = cfg.Cache.CacheTTL } } wasCached := false // Handling Cache Logic - if cacheFromQuery || cfg.Cache.CacheEnable { - cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cacheFromQuery, "via_env": cfg.Cache.CacheEnable}) + if parsedResult.cacheRequest || cfg.Cache.CacheEnable { + cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable}) queryCacheHash = calculateHash(c) if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil { @@ -146,7 +147,7 @@ func processGraphQLRequest(c *fiber.Ctx) error { wasCached = true } else { cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}) - proxyAndCacheTheRequest(c, queryCacheHash, cache_time) + proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime) } } else { proxyTheRequest(c) @@ -155,13 +156,13 @@ func processGraphQLRequest(c *fiber.Ctx) error { timeTaken := time.Since(startTime) // Logging & Monitoring - logAndMonitorRequest(c, extractedUserID, opType, opName, wasCached, timeTaken, startTime) + logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime) return nil } // Additional helper function to avoid code repetition -func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int) { +func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int) { err := proxyTheRequest(c) if err != nil { cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()}) @@ -169,7 +170,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int c.Status(500).SendString("Can't proxy the request - try again later") return } - cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cache_time)*time.Second) + cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second) c.Send(c.Response().Body()) }