mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
Move results to the struct for ease of management.
This commit is contained in:
+26
-17
@@ -70,8 +70,17 @@ func prepareQueriesAndExemptions() {
|
||||
}()
|
||||
}
|
||||
|
||||
func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) {
|
||||
should_ignore = true
|
||||
type parseGraphQLQueryResult struct {
|
||||
operationType string
|
||||
operationName string
|
||||
cacheRequest bool
|
||||
cacheTime int
|
||||
shouldBlock bool
|
||||
shouldIgnore bool
|
||||
}
|
||||
|
||||
func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
|
||||
res = &parseGraphQLQueryResult{shouldIgnore: true}
|
||||
m := make(map[string]interface{})
|
||||
err := json.Unmarshal(c.Body(), &m)
|
||||
if err != nil {
|
||||
@@ -100,32 +109,32 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
|
||||
return
|
||||
}
|
||||
|
||||
should_ignore = false
|
||||
operationName = "undefined"
|
||||
res.shouldIgnore = false
|
||||
res.operationName = "undefined"
|
||||
for _, d := range p.Definitions {
|
||||
if oper, ok := d.(*ast.OperationDefinition); ok {
|
||||
operationType = oper.Operation
|
||||
res.operationType = oper.Operation
|
||||
|
||||
if oper.Name != nil {
|
||||
operationName = oper.Name.Value
|
||||
res.operationName = oper.Name.Value
|
||||
}
|
||||
|
||||
if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode {
|
||||
if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode {
|
||||
cfg.Logger.Warning("Mutation blocked", m)
|
||||
if flag.Lookup("test.v") == nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
c.Status(403).SendString("The server is in read-only mode")
|
||||
should_block = true
|
||||
res.shouldBlock = true
|
||||
return
|
||||
}
|
||||
|
||||
for _, dir := range oper.Directives {
|
||||
if dir.Name.Value == "cached" {
|
||||
cacheRequest = true
|
||||
res.cacheRequest = true
|
||||
for _, arg := range dir.Arguments {
|
||||
if arg.Name.Value == "ttl" {
|
||||
cache_time, err = strconv.Atoi(arg.Value.GetValue().(string))
|
||||
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
|
||||
if err != nil {
|
||||
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
|
||||
if flag.Lookup("test.v") == nil {
|
||||
@@ -135,15 +144,15 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
|
||||
}
|
||||
}
|
||||
if arg.Name.Value == "refresh" {
|
||||
cacheRequest = arg.Value.GetValue().(bool)
|
||||
res.cacheRequest = arg.Value.GetValue().(bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Security.BlockIntrospection {
|
||||
should_block = checkSelections(c, oper.GetSelectionSet().Selections)
|
||||
if should_block {
|
||||
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
|
||||
if res.shouldBlock {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -171,7 +180,7 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) {
|
||||
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
|
||||
whateverLower := strings.ToLower(whatever)
|
||||
got_exemption := false
|
||||
if _, exists := introspectionQuerySet[whateverLower]; exists {
|
||||
@@ -179,14 +188,14 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block b
|
||||
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
|
||||
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
|
||||
got_exemption = true
|
||||
should_block = false
|
||||
shouldBlock = false
|
||||
}
|
||||
}
|
||||
if !got_exemption {
|
||||
should_block = true
|
||||
shouldBlock = true
|
||||
}
|
||||
}
|
||||
if should_block {
|
||||
if shouldBlock {
|
||||
if flag.Lookup("test.v") == nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
|
||||
+90
-91
@@ -11,13 +11,13 @@ import (
|
||||
func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
|
||||
type results struct {
|
||||
op_name string
|
||||
op_type string
|
||||
cached_ttl int
|
||||
returnCode int
|
||||
is_cached bool
|
||||
should_block bool
|
||||
should_ignore bool
|
||||
op_name string
|
||||
op_type string
|
||||
cached_ttl int
|
||||
returnCode int
|
||||
is_cached bool
|
||||
shouldBlock bool
|
||||
shouldIgnore bool
|
||||
}
|
||||
|
||||
type queries struct {
|
||||
@@ -38,11 +38,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
headers: map[string]string{},
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -53,11 +53,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
headers: map[string]string{},
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -68,11 +68,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
headers: map[string]string{},
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -82,11 +82,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -96,11 +96,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: true,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
is_cached: true,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -110,12 +110,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: true,
|
||||
cached_ttl: 60,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
is_cached: true,
|
||||
cached_ttl: 60,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -125,12 +125,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: true,
|
||||
cached_ttl: 0,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
is_cached: true,
|
||||
cached_ttl: 0,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyQuery",
|
||||
op_type: "query",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -140,11 +140,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -158,12 +158,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: true,
|
||||
should_ignore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
returnCode: 403,
|
||||
is_cached: false,
|
||||
shouldBlock: true,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
returnCode: 403,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -173,11 +173,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyMutation",
|
||||
op_type: "mutation",
|
||||
},
|
||||
},
|
||||
|
||||
@@ -191,12 +191,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: true,
|
||||
should_ignore: false,
|
||||
op_name: "MyIntroQuery",
|
||||
op_type: "query",
|
||||
returnCode: 403,
|
||||
is_cached: false,
|
||||
shouldBlock: true,
|
||||
shouldIgnore: false,
|
||||
op_name: "MyIntroQuery",
|
||||
op_type: "query",
|
||||
returnCode: 403,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -213,12 +213,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: true,
|
||||
should_ignore: false,
|
||||
op_name: "undefined",
|
||||
op_type: "query",
|
||||
returnCode: 403,
|
||||
is_cached: false,
|
||||
shouldBlock: true,
|
||||
shouldIgnore: false,
|
||||
op_name: "undefined",
|
||||
op_type: "query",
|
||||
returnCode: 403,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -235,12 +235,12 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: false,
|
||||
op_name: "undefined",
|
||||
op_type: "query",
|
||||
returnCode: 200,
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: false,
|
||||
op_name: "undefined",
|
||||
op_type: "query",
|
||||
returnCode: 200,
|
||||
},
|
||||
},
|
||||
|
||||
@@ -250,11 +250,11 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
body: "{\"query\":\"query MyQuery tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } \"}",
|
||||
},
|
||||
wantResults: results{
|
||||
is_cached: false,
|
||||
should_block: false,
|
||||
should_ignore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
is_cached: false,
|
||||
shouldBlock: false,
|
||||
shouldIgnore: true,
|
||||
op_name: "",
|
||||
op_type: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -298,14 +298,13 @@ func (suite *Tests) Test_parseGraphQLQuery() {
|
||||
cfg = &config{}
|
||||
}()
|
||||
|
||||
opType, opName, cacheFromQuery, cached_ttl, shouldBlock, should_ignore := parseGraphQLQuery(ctx)
|
||||
|
||||
assert.Equal(tt.wantResults.op_type, opType, "Unexpected operation type", tt.name)
|
||||
assert.Equal(tt.wantResults.op_name, opName, "Unexpected operation name", tt.name)
|
||||
assert.Equal(tt.wantResults.is_cached, cacheFromQuery, "Unexpected cache value", tt.name)
|
||||
assert.Equal(tt.wantResults.cached_ttl, cached_ttl, "Unexpected cache TTL value", tt.name)
|
||||
assert.Equal(tt.wantResults.should_block, shouldBlock, "Unexpected block value", tt.name)
|
||||
assert.Equal(tt.wantResults.should_ignore, should_ignore, "Unexpected ignore value", tt.name)
|
||||
parseResult := parseGraphQLQuery(ctx)
|
||||
assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type", tt.name)
|
||||
assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name", tt.name)
|
||||
assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value", tt.name)
|
||||
assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value", tt.name)
|
||||
assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value", tt.name)
|
||||
assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value", tt.name)
|
||||
|
||||
if tt.wantResults.returnCode > 0 {
|
||||
assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name)
|
||||
|
||||
@@ -90,6 +90,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
if checkIfUserIsBanned(c, extractedUserID) {
|
||||
c.Status(403).SendString("User is banned")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -109,35 +110,35 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
opType, opName, cacheFromQuery, cache_time, shouldBlock, should_ignore := parseGraphQLQuery(c)
|
||||
if shouldBlock {
|
||||
parsedResult := parseGraphQLQuery(c)
|
||||
if parsedResult.shouldBlock {
|
||||
c.Status(403).SendString("Request blocked")
|
||||
return nil
|
||||
}
|
||||
|
||||
if should_ignore {
|
||||
if parsedResult.shouldIgnore {
|
||||
cfg.Logger.Debug("Request passed as-is - probably not a GraphQL")
|
||||
return proxyTheRequest(c)
|
||||
}
|
||||
|
||||
if cache_time > 0 {
|
||||
cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cache_time": cache_time})
|
||||
if parsedResult.cacheTime > 0 {
|
||||
cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
|
||||
} else {
|
||||
// If not set via query, try setting via header
|
||||
cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query")
|
||||
if cacheQuery != nil {
|
||||
cache_time, _ = strconv.Atoi(string(cacheQuery))
|
||||
cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cache_time": cache_time})
|
||||
parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery))
|
||||
cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
|
||||
} else {
|
||||
cache_time = cfg.Cache.CacheTTL
|
||||
parsedResult.cacheTime = cfg.Cache.CacheTTL
|
||||
}
|
||||
}
|
||||
|
||||
wasCached := false
|
||||
|
||||
// Handling Cache Logic
|
||||
if cacheFromQuery || cfg.Cache.CacheEnable {
|
||||
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": cacheFromQuery, "via_env": cfg.Cache.CacheEnable})
|
||||
if parsedResult.cacheRequest || cfg.Cache.CacheEnable {
|
||||
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable})
|
||||
queryCacheHash = calculateHash(c)
|
||||
|
||||
if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil {
|
||||
@@ -146,7 +147,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
wasCached = true
|
||||
} else {
|
||||
cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
|
||||
proxyAndCacheTheRequest(c, queryCacheHash, cache_time)
|
||||
proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime)
|
||||
}
|
||||
} else {
|
||||
proxyTheRequest(c)
|
||||
@@ -155,13 +156,13 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
timeTaken := time.Since(startTime)
|
||||
|
||||
// Logging & Monitoring
|
||||
logAndMonitorRequest(c, extractedUserID, opType, opName, wasCached, timeTaken, startTime)
|
||||
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Additional helper function to avoid code repetition
|
||||
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int) {
|
||||
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int) {
|
||||
err := proxyTheRequest(c)
|
||||
if err != nil {
|
||||
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()})
|
||||
@@ -169,7 +170,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cache_time int
|
||||
c.Status(500).SendString("Can't proxy the request - try again later")
|
||||
return
|
||||
}
|
||||
cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cache_time)*time.Second)
|
||||
cfg.Cache.CacheClient.Set(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
|
||||
c.Send(c.Response().Body())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user