General code optimisations. (#16)

* General code optimisations.
This commit is contained in:
2024-06-28 12:31:01 +01:00
committed by GitHub
parent 1b1656c4b5
commit b10a28bf52
20 changed files with 917 additions and 534 deletions
+88 -98
View File
@@ -3,6 +3,8 @@ package main
import (
"strconv"
"strings"
"sync"
"unsafe"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
@@ -12,48 +14,29 @@ import (
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
var introspection_queries = []string{
"__schema",
"__type",
"__typename",
"__directive",
"__directivelocation",
"__field",
"__inputvalue",
"__enumvalue",
"__typekind",
"__fieldtype",
"__inputobjecttype",
"__enumtype",
"__uniontype",
"__scalars",
"__objects",
"__interfaces",
"__unions",
"__enums",
"__inputobjects",
"__directives",
}
// Saving the introspection queries as a map O(1) operation instead of O(n) for a slice.
var introspectionQuerySet = map[string]struct{}{}
var introspectionAllowedQueries = map[string]struct{}{}
var allowedUrls = map[string]struct{}{}
// Utility function to convert a slice of strings to a map for O(1) lookups.
func sliceToMap(slice []string) map[string]struct{} {
resultMap := make(map[string]struct{}, len(slice))
for _, item := range slice {
resultMap[strings.ToLower(item)] = struct{}{}
var (
introspectionQueries = map[string]struct{}{
"__schema": {}, "__type": {}, "__typename": {}, "__directive": {},
"__directivelocation": {}, "__field": {}, "__inputvalue": {},
"__enumvalue": {}, "__typekind": {}, "__fieldtype": {},
"__inputobjecttype": {}, "__enumtype": {}, "__uniontype": {},
"__scalars": {}, "__objects": {}, "__interfaces": {},
"__unions": {}, "__enums": {}, "__inputobjects": {}, "__directives": {},
}
return resultMap
}
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
mu sync.RWMutex
)
func prepareQueriesAndExemptions() {
introspectionQuerySet = sliceToMap(introspection_queries)
introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed)
allowedUrls = sliceToMap(cfg.Server.AllowURLs)
mu.Lock()
defer mu.Unlock()
for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(q)] = struct{}{}
}
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
}
}
type parseGraphQLQueryResult struct {
@@ -67,21 +50,41 @@ type parseGraphQLQueryResult struct {
shouldIgnore bool
}
func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res = &parseGraphQLQueryResult{shouldIgnore: true}
m := make(map[string]interface{})
err := json.Unmarshal(c.Body(), &m)
if err != nil {
var (
queryPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{}, 4)
},
}
resultPool = sync.Pool{
New: func() interface{} {
return &parseGraphQLQueryResult{}
},
}
)
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
res := resultPool.Get().(*parseGraphQLQueryResult)
defer resultPool.Put(res)
*res = parseGraphQLQueryResult{shouldIgnore: true}
m := queryPool.Get().(map[string]interface{})
defer queryPool.Put(m)
for k := range m {
delete(m, k)
}
if err := json.Unmarshal(c.Body(), &m); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't unmarshal the request",
Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())},
Pairs: map[string]interface{}{"error": err.Error(), "body": unsafeString(c.Body())},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
return res
}
// get the query
query, ok := m["query"].(string)
if !ok {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -91,7 +94,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
return res
}
p, err := parser.Parse(parser.ParseParams{Source: query})
@@ -103,7 +106,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
return res
}
res.shouldIgnore = false
@@ -112,14 +115,14 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
res.operationName = oper.Name.Value
// If we haven't set an operation type yet, use this one
if res.operationType == "" {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
res.operationName = oper.Name.Value
}
}
// If the query is a mutation then direct it to the RW endpoint,
// otherwise direct it to the RO endpoint if it's set.
if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
@@ -132,30 +135,24 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("The server is in read-only mode")
_ = c.Status(403).SendString("The server is in read-only mode")
res.shouldBlock = true
return
return res
}
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
res.cacheRequest = true
for _, arg := range dir.Arguments {
if arg.Name.Value == "ttl" {
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the ttl, using global",
Pairs: map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
switch arg.Name.Value {
case "ttl":
if v, ok := arg.Value.GetValue().(string); ok {
res.cacheTime, _ = strconv.Atoi(v)
}
case "refresh":
if v, ok := arg.Value.GetValue().(bool); ok {
res.cacheRefresh = v
}
}
if arg.Name.Value == "refresh" {
res.cacheRefresh = arg.Value.GetValue().(bool)
}
}
}
@@ -164,26 +161,25 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if cfg.Security.BlockIntrospection {
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if res.shouldBlock {
return
return res
}
}
}
}
return
return res
}
func unsafeString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
for _, s := range selections {
field, ok := s.(*ast.Field)
if !ok {
continue // or handle the case where the type assertion fails
}
shouldBlock := checkIfContainsIntrospection(c, field.Name.Value)
if shouldBlock {
return true
}
if field.SelectionSet != nil {
if checkSelections(c, field.GetSelectionSet().Selections) {
if field, ok := s.(*ast.Field); ok {
if checkIfContainsIntrospection(c, field.Name.Value) {
return true
}
if field.SelectionSet != nil && checkSelections(c, field.GetSelectionSet().Selections) {
return true
}
}
@@ -191,32 +187,26 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) bool {
whateverLower := strings.ToLower(whatever)
got_exemption := false
mu.RLock()
defer mu.RUnlock()
// If the query is an introspection query, we need to check if it's allowed.
if _, exists := introspectionQuerySet[whateverLower]; exists {
if _, exists := introspectionQueries[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
if _, allowed := introspectionAllowedQueries[whateverLower]; allowed {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Introspection query allowed, passing through",
Pairs: map[string]interface{}{"query": whatever},
})
got_exemption = true
shouldBlock = false
return false
}
}
if !got_exemption {
shouldBlock = true
}
}
if shouldBlock {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Introspection queries are not allowed")
_ = c.Status(403).SendString("Introspection queries are not allowed")
return true
}
return
return false
}