Move results to the struct for ease of management.

This commit is contained in:
2024-02-15 09:50:51 +00:00
parent 4cb0d22874
commit 0bdea741bf
3 changed files with 131 additions and 122 deletions
+26 -17
View File
@@ -70,8 +70,17 @@ func prepareQueriesAndExemptions() {
}()
}
func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) {
should_ignore = true
type parseGraphQLQueryResult struct {
operationType string
operationName string
cacheRequest bool
cacheTime int
shouldBlock bool
shouldIgnore bool
}
func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res = &parseGraphQLQueryResult{shouldIgnore: true}
m := make(map[string]interface{})
err := json.Unmarshal(c.Body(), &m)
if err != nil {
@@ -100,32 +109,32 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
return
}
should_ignore = false
operationName = "undefined"
res.shouldIgnore = false
res.operationName = "undefined"
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
operationType = oper.Operation
res.operationType = oper.Operation
if oper.Name != nil {
operationName = oper.Name.Value
res.operationName = oper.Name.Value
}
if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode {
if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m)
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("The server is in read-only mode")
should_block = true
res.shouldBlock = true
return
}
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
cacheRequest = true
res.cacheRequest = true
for _, arg := range dir.Arguments {
if arg.Name.Value == "ttl" {
cache_time, err = strconv.Atoi(arg.Value.GetValue().(string))
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil {
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if flag.Lookup("test.v") == nil {
@@ -135,15 +144,15 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
}
}
if arg.Name.Value == "refresh" {
cacheRequest = arg.Value.GetValue().(bool)
res.cacheRequest = arg.Value.GetValue().(bool)
}
}
}
}
if cfg.Security.BlockIntrospection {
should_block = checkSelections(c, oper.GetSelectionSet().Selections)
if should_block {
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if res.shouldBlock {
return
}
}
@@ -171,7 +180,7 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) {
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
whateverLower := strings.ToLower(whatever)
got_exemption := false
if _, exists := introspectionQuerySet[whateverLower]; exists {
@@ -179,14 +188,14 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block b
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true
should_block = false
shouldBlock = false
}
}
if !got_exemption {
should_block = true
shouldBlock = true
}
}
if should_block {
if shouldBlock {
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}