Improve stats gathering and tests improvements. (#8)

This commit is contained in:
2024-03-05 22:40:06 +00:00
committed by GitHub
parent b6c284b66d
commit 3a18e0e935
16 changed files with 284 additions and 103 deletions
+22 -34
View File
@@ -1,7 +1,6 @@
package main
import (
"flag"
"strconv"
"strings"
@@ -41,40 +40,26 @@ var introspectionQuerySet = map[string]struct{}{}
var introspectionAllowedQueries = map[string]struct{}{}
var allowedUrls = map[string]struct{}{}
// Utility function to convert a slice of strings to a map for O(1) lookups.
func sliceToMap(slice []string) map[string]struct{} {
resultMap := make(map[string]struct{}, len(slice))
for _, item := range slice {
resultMap[strings.ToLower(item)] = struct{}{}
}
return resultMap
}
func prepareQueriesAndExemptions() {
introspectionQuerySet = map[string]struct{}{}
introspectionQuerySet = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(introspection_queries))
for _, query := range introspection_queries {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
introspectionAllowedQueries = map[string]struct{}{}
introspectionAllowedQueries = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(cfg.Security.IntrospectionAllowed))
for _, query := range cfg.Security.IntrospectionAllowed {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
allowedUrls = map[string]struct{}{}
allowedUrls = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(cfg.Server.AllowURLs))
for _, query := range cfg.Server.AllowURLs {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
introspectionQuerySet = sliceToMap(introspection_queries)
introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed)
allowedUrls = sliceToMap(cfg.Server.AllowURLs)
}
type parseGraphQLQueryResult struct {
operationType string
operationName string
cacheRequest bool
cacheTime int
cacheRequest bool
cacheRefresh bool
shouldBlock bool
shouldIgnore bool
@@ -86,7 +71,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
err := json.Unmarshal(c.Body(), &m)
if err != nil {
cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())})
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
@@ -95,7 +80,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
query, ok := m["query"].(string)
if !ok {
cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m})
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
@@ -104,7 +89,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
p, err := parser.Parse(parser.ParseParams{Source: query})
if err != nil {
cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m})
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
@@ -122,7 +107,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m)
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("The server is in read-only mode")
@@ -138,7 +123,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil {
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
@@ -184,8 +169,11 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
whateverLower := strings.ToLower(whatever)
got_exemption := false
// If the query is an introspection query, we need to check if it's allowed.
if _, exists := introspectionQuerySet[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true
@@ -197,7 +185,7 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bo
}
}
if shouldBlock {
if flag.Lookup("test.v") == nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Introspection queries are not allowed")