Move results to the struct for ease of management.

This commit is contained in:
2024-02-15 09:50:51 +00:00
parent 4cb0d22874
commit 0bdea741bf
3 changed files with 131 additions and 122 deletions
+26 -17
View File
@@ -70,8 +70,17 @@ func prepareQueriesAndExemptions() {
}() }()
} }
func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) { type parseGraphQLQueryResult struct {
should_ignore = true operationType string
operationName string
cacheRequest bool
cacheTime int
shouldBlock bool
shouldIgnore bool
}
func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res = &parseGraphQLQueryResult{shouldIgnore: true}
m := make(map[string]interface{}) m := make(map[string]interface{})
err := json.Unmarshal(c.Body(), &m) err := json.Unmarshal(c.Body(), &m)
if err != nil { if err != nil {
@@ -100,32 +109,32 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
return return
} }
should_ignore = false res.shouldIgnore = false
operationName = "undefined" res.operationName = "undefined"
for _, d := range p.Definitions { for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok { if oper, ok := d.(*ast.OperationDefinition); ok {
operationType = oper.Operation res.operationType = oper.Operation
if oper.Name != nil { if oper.Name != nil {
operationName = oper.Name.Value res.operationName = oper.Name.Value
} }
if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode { if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m) cfg.Logger.Warning("Mutation blocked", m)
if flag.Lookup("test.v") == nil { if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
c.Status(403).SendString("The server is in read-only mode") c.Status(403).SendString("The server is in read-only mode")
should_block = true res.shouldBlock = true
return return
} }
for _, dir := range oper.Directives { for _, dir := range oper.Directives {
if dir.Name.Value == "cached" { if dir.Name.Value == "cached" {
cacheRequest = true res.cacheRequest = true
for _, arg := range dir.Arguments { for _, arg := range dir.Arguments {
if arg.Name.Value == "ttl" { if arg.Name.Value == "ttl" {
cache_time, err = strconv.Atoi(arg.Value.GetValue().(string)) res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil { if err != nil {
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}) cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if flag.Lookup("test.v") == nil { if flag.Lookup("test.v") == nil {
@@ -135,15 +144,15 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
} }
} }
if arg.Name.Value == "refresh" { if arg.Name.Value == "refresh" {
cacheRequest = arg.Value.GetValue().(bool) res.cacheRequest = arg.Value.GetValue().(bool)
} }
} }
} }
} }
if cfg.Security.BlockIntrospection { if cfg.Security.BlockIntrospection {
should_block = checkSelections(c, oper.GetSelectionSet().Selections) res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if should_block { if res.shouldBlock {
return return
} }
} }
@@ -171,7 +180,7 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
return false return false
} }
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) { func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
whateverLower := strings.ToLower(whatever) whateverLower := strings.ToLower(whatever)
got_exemption := false got_exemption := false
if _, exists := introspectionQuerySet[whateverLower]; exists { if _, exists := introspectionQuerySet[whateverLower]; exists {
@@ -179,14 +188,14 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block b
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever}) cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true got_exemption = true
should_block = false shouldBlock = false
} }
} }
if !got_exemption { if !got_exemption {
should_block = true shouldBlock = true
} }
} }
if should_block { if shouldBlock {
if flag.Lookup("test.v") == nil { if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
+90 -91
View File
@@ -11,13 +11,13 @@ import (
func (suite *Tests) Test_parseGraphQLQuery() { func (suite *Tests) Test_parseGraphQLQuery() {
type results struct { type results struct {
op_name string op_name string
op_type string op_type string
cached_ttl int cached_ttl int
returnCode int returnCode int
is_cached bool is_cached bool
should_block bool shouldBlock bool
should_ignore bool shouldIgnore bool
} }
type queries struct { type queries struct {
@@ -38,11 +38,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
headers: map[string]string{}, headers: map[string]string{},
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: true, shouldIgnore: true,
op_name: "", op_name: "",
op_type: "", op_type: "",
}, },
}, },
@@ -53,11 +53,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
headers: map[string]string{}, headers: map[string]string{},
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: true, shouldIgnore: true,
op_name: "", op_name: "",
op_type: "", op_type: "",
}, },
}, },
@@ -68,11 +68,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
headers: map[string]string{}, headers: map[string]string{},
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: true, shouldIgnore: true,
op_name: "", op_name: "",
op_type: "", op_type: "",
}, },
}, },
@@ -82,11 +82,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyQuery", op_name: "MyQuery",
op_type: "query", op_type: "query",
}, },
}, },
@@ -96,11 +96,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
}, },
wantResults: results{ wantResults: results{
is_cached: true, is_cached: true,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyQuery", op_name: "MyQuery",
op_type: "query", op_type: "query",
}, },
}, },
@@ -110,12 +110,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
}, },
wantResults: results{ wantResults: results{
is_cached: true, is_cached: true,
cached_ttl: 60, cached_ttl: 60,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyQuery", op_name: "MyQuery",
op_type: "query", op_type: "query",
}, },
}, },
@@ -125,12 +125,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
}, },
wantResults: results{ wantResults: results{
is_cached: true, is_cached: true,
cached_ttl: 0, cached_ttl: 0,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyQuery", op_name: "MyQuery",
op_type: "query", op_type: "query",
}, },
}, },
@@ -140,11 +140,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyMutation", op_name: "MyMutation",
op_type: "mutation", op_type: "mutation",
}, },
}, },
@@ -158,12 +158,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: true, shouldBlock: true,
should_ignore: false, shouldIgnore: false,
op_name: "MyMutation", op_name: "MyMutation",
op_type: "mutation", op_type: "mutation",
returnCode: 403, returnCode: 403,
}, },
}, },
@@ -173,11 +173,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "MyMutation", op_name: "MyMutation",
op_type: "mutation", op_type: "mutation",
}, },
}, },
@@ -191,12 +191,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: true, shouldBlock: true,
should_ignore: false, shouldIgnore: false,
op_name: "MyIntroQuery", op_name: "MyIntroQuery",
op_type: "query", op_type: "query",
returnCode: 403, returnCode: 403,
}, },
}, },
@@ -213,12 +213,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: true, shouldBlock: true,
should_ignore: false, shouldIgnore: false,
op_name: "undefined", op_name: "undefined",
op_type: "query", op_type: "query",
returnCode: 403, returnCode: 403,
}, },
}, },
@@ -235,12 +235,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: false, shouldIgnore: false,
op_name: "undefined", op_name: "undefined",
op_type: "query", op_type: "query",
returnCode: 200, returnCode: 200,
}, },
}, },
@@ -250,11 +250,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
body: "{\"query\":\"query MyQuery tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } \"}", body: "{\"query\":\"query MyQuery tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } \"}",
}, },
wantResults: results{ wantResults: results{
is_cached: false, is_cached: false,
should_block: false, shouldBlock: false,
should_ignore: true, shouldIgnore: true,
op_name: "", op_name: "",
op_type: "", op_type: "",
}, },
}, },
} }
@@ -298,14 +298,13 @@ func (suite *Tests) Test_parseGraphQLQuery() {
cfg = &config{} cfg = &config{}
}() }()
opType, opName, cacheFromQuery, cached_ttl, shouldBlock, should_ignore := parseGraphQLQuery(ctx) parseResult := parseGraphQLQuery(ctx)
assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type", tt.name)
assert.Equal(tt.wantResults.op_type, opType, "Unexpected operation type", tt.name) assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name", tt.name)
assert.Equal(tt.wantResults.op_name, opName, "Unexpected operation name", tt.name) assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value", tt.name)
assert.Equal(tt.wantResults.is_cached, cacheFromQuery, "Unexpected cache value", tt.name) assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value", tt.name)
assert.Equal(tt.wantResults.cached_ttl, cached_ttl, "Unexpected cache TTL value", tt.name) assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value", tt.name)
assert.Equal(tt.wantResults.should_block, shouldBlock, "Unexpected block value", tt.name) assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value", tt.name)
assert.Equal(tt.wantResults.should_ignore, should_ignore, "Unexpected ignore value", tt.name)
if tt.wantResults.returnCode > 0 { if tt.wantResults.returnCode > 0 {
assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name) assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name)
+15 -14
View File
@@ -90,6 +90,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
} }
if checkIfUserIsBanned(c, extractedUserID) { if checkIfUserIsBanned(c, extractedUserID) {
c.Status(403).SendString("User is banned")
return nil return nil
} }
@@ -109,35 +110,35 @@ func processGraphQLRequest(c *fiber.Ctx) error {
} }
} }
opType, opName, cacheFromQuery, cache_time, shouldBlock, should_ignore := parseGraphQLQuery(c) parsedResult := parseGraphQLQuery(c)
if shouldBlock { if parsedResult.shouldBlock {
c.Status(403).SendString("Request blocked") c.Status(403).SendString("Request blocked")
return nil return nil
} }
if should_ignore { if parsedResult.shouldIgnore {
cfg.Logger.Debug("Request passed as-is - probably not a GraphQL") cfg.Logger.Debug("Request passed as-is - probably not a GraphQL")
return proxyTheRequest(c) return proxyTheRequest(c)
} }
if cache_time > 0 { if parsedResult.cacheTime > 0 {
cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cache_time": cache_time}) cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
} else { } else {
// If not set via query, try setting via header // If not set via query, try setting via header
cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query") cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query")
if cacheQuery != nil { if cacheQuery != nil {
cache_time, _ = strconv.Atoi(string(cacheQuery)) parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery))
cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cache_time": cache_time}) cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
} else { } else {
cache_time = cfg.Cache.CacheTTL parsedResult.cacheTime = cfg.Cache.CacheTTL
} }
} }
wasCached := false wasCached := false
// Handling Cache Logic // Handling Cache Logic
if cacheFromQuery || cfg.Cache.CacheEnable { if parsedResult.cacheRequest || cfg.Cache.CacheEnable {
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cacheFromQuery, "via_env": cfg.Cache.CacheEnable}) cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable})
queryCacheHash = calculateHash(c) queryCacheHash = calculateHash(c)
if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil { if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil {
@@ -146,7 +147,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
wasCached = true wasCached = true
} else { } else {
cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}) cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
proxyAndCacheTheRequest(c, queryCacheHash, cache_time) proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime)
} }
} else { } else {
proxyTheRequest(c) proxyTheRequest(c)
@@ -155,13 +156,13 @@ func processGraphQLRequest(c *fiber.Ctx) error {
timeTaken := time.Since(startTime) timeTaken := time.Since(startTime)
// Logging & Monitoring // Logging & Monitoring
logAndMonitorRequest(c, extractedUserID, opType, opName, wasCached, timeTaken, startTime) logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime)
return nil return nil
} }
// Additional helper function to avoid code repetition // Additional helper function to avoid code repetition
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int) { func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int) {
err := proxyTheRequest(c) err := proxyTheRequest(c)
if err != nil { if err != nil {
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()}) cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()})
@@ -169,7 +170,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int
c.Status(500).SendString("Can't proxy the request - try again later") c.Status(500).SendString("Can't proxy the request - try again later")
return return
} }
cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cache_time)*time.Second) cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
c.Send(c.Response().Body()) c.Send(c.Response().Body())
} }