From 1390e7cdd1888016d0c2f06bb514b904505f9d05 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Sat, 18 Nov 2023 02:11:38 +0000 Subject: [PATCH] Fix blocking the introspection + add unit tests. --- Makefile | 2 +- go.mod | 1 + go.sum | 2 + graphql.go | 132 ++++++++++++++------- graphql_test.go | 301 ++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 9 +- 6 files changed, 400 insertions(+), 47 deletions(-) create mode 100644 graphql_test.go diff --git a/Makefile b/Makefile index eaaeb09..a9fe336 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ help: ## display this help .PHONY: run run: build ## run application - @LOG_LEVEL=debug PURGE_METRICS_ON_CRAWL=true BLOCK_SCHEMA_INTROSPECTION=false CACHE_TTL=10 JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/ HEALTHCHECK_GRAPHQL_URL=https://hasura8.lan/v1/graphql ./graphql-proxy + @LOG_LEVEL=debug PURGE_METRICS_ON_CRAWL=true BLOCK_SCHEMA_INTROSPECTION=true CACHE_TTL=10 JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/ HEALTHCHECK_GRAPHQL_URL=https://hasura8.lan/v1/graphql ./graphql-proxy .PHONY: build build: ## build the binary diff --git a/go.mod b/go.mod index 44085b4..4b1348b 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/avast/retry-go/v4 v4.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gookit/color v1.5.4 // indirect + github.com/k0kubun/pp v3.0.1+incompatible // indirect github.com/klauspost/compress v1.17.2 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 54af5fa..17b4699 100644 --- a/go.sum +++ b/go.sum @@ -27,6 +27,8 @@ github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuM github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40= +github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg= github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4= github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= diff --git a/graphql.go b/graphql.go index 03c96c3..7d4716e 100644 --- a/graphql.go +++ b/graphql.go @@ -1,6 +1,7 @@ package main import ( + "flag" "strconv" "strings" @@ -10,7 +11,7 @@ import ( libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring" ) -var retrospection_queries = []string{ +var introspection_queries = []string{ "__schema", "__type", "__typename", @@ -34,13 +35,29 @@ var retrospection_queries = []string{ } // Saving the introspection queries as a map O(1) operation instead of O(n) for a slice. -var retrospectionQuerySet = func() map[string]struct{} { - rsqs := make(map[string]struct{}, len(retrospection_queries)) - for _, query := range retrospection_queries { - rsqs[strings.ToLower(query)] = struct{}{} - } - return rsqs -}() + +var introspectionQuerySet = map[string]struct{}{} +var introspectionAllowedQueries = map[string]struct{}{} + +func prepareQueriesAndExemptions() { + introspectionQuerySet = map[string]struct{}{} + introspectionQuerySet = func() map[string]struct{} { + rsqs := make(map[string]struct{}, len(introspection_queries)) + for _, query := range introspection_queries { + rsqs[strings.ToLower(query)] = struct{}{} + } + return rsqs + }() + + introspectionAllowedQueries = map[string]struct{}{} + introspectionAllowedQueries = func() map[string]struct{} { + rsqs := make(map[string]struct{}, len(cfg.Security.IntrospectionAllowed)) + for _, query := range cfg.Security.IntrospectionAllowed { + rsqs[strings.ToLower(query)] = struct{}{} + } + return rsqs + }() +} func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) { should_ignore = true @@ -48,21 +65,27 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache err := json.Unmarshal(c.Body(), &m) if err != nil { cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } return } // get the query query, ok := m["query"].(string) if !ok { cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } return } p, err := parser.Parse(parser.ParseParams{Source: query}) if err != nil { cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } return } @@ -71,19 +94,21 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache for _, d := range p.Definitions { if oper, ok := d.(*ast.OperationDefinition); ok { operationType = oper.Operation + + if oper.Name != nil { + operationName = oper.Name.Value + } + if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode { cfg.Logger.Warning("Mutation blocked", m) - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } c.Status(403).SendString("The server is in read-only mode") should_block = true return } - if oper.Name != nil { - operationName = oper.Name.Value - } else { - operationName = "undefined" - } for _, dir := range oper.Directives { if dir.Name.Value == "cached" { cacheRequest = true @@ -91,38 +116,67 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache if arg.Name.Value == "ttl" { cache_time, err = strconv.Atoi(arg.Value.GetValue().(string)) if err != nil { - cfg.Logger.Error("Can't parse the ttl", map[string]interface{}{"ttl": arg.Value.GetValue().(string)}) - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}) + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } return } } } } } - if cfg.Security.BlockIntrospection { - for _, s := range oper.SelectionSet.Selections { - for _, s2 := range s.GetSelectionSet().Selections { - if _, exists := retrospectionQuerySet[strings.ToLower(s2.(*ast.Field).Name.Value)]; exists { - if len(cfg.Security.IntrospectionAllowed) > 0 { - for _, introspectionQueryAllowed := range cfg.Security.IntrospectionAllowed { - if strings.EqualFold(strings.ToLower(introspectionQueryAllowed), strings.ToLower(s2.(*ast.Field).Name.Value)) { - cfg.Logger.Debug("Introspection query allowed, passing through", m) - return - } - } - } - cfg.Logger.Warning("Introspection query blocked", m) - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - c.Status(403).SendString("Introspection queries are not allowed") - should_block = true - return - } - } + if cfg.Security.BlockIntrospection { + should_block = checkSelections(c, oper.GetSelectionSet().Selections) + if should_block { + return } } } } - + return +} + +func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { + for _, s := range selections { + field, ok := s.(*ast.Field) + if !ok { + continue // or handle the case where the type assertion fails + } + shouldBlock := checkIfContainsIntrospection(c, field.Name.Value) + if shouldBlock { + return true + } + if field.SelectionSet != nil { + if checkSelections(c, field.GetSelectionSet().Selections) { + return true + } + } + } + return false +} + +func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) { + whateverLower := strings.ToLower(whatever) + got_exemption := false + if _, exists := introspectionQuerySet[whateverLower]; exists { + if len(cfg.Security.IntrospectionAllowed) > 0 { + if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { + cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever}) + got_exemption = true + should_block = false + } + } + if !got_exemption { + should_block = true + } + } + if should_block { + if flag.Lookup("test.v") == nil { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + c.Status(403).SendString("Introspection queries are not allowed") + } return } diff --git a/graphql_test.go b/graphql_test.go new file mode 100644 index 0000000..6956cd7 --- /dev/null +++ b/graphql_test.go @@ -0,0 +1,301 @@ +package main + +import ( + "testing" + + fiber "github.com/gofiber/fiber/v2" + libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" + "github.com/valyala/fasthttp" +) + +func (suite *Tests) Test_parseGraphQLQuery() { + + type results struct { + is_cached bool + cached_ttl int + should_block bool + should_ignore bool + op_name string + op_type string + returnCode int + } + + type queries struct { + body string + headers map[string]string + } + + tests := []struct { + name string + suppliedSettings *config + suppliedQuery queries + wantResults results + }{ + { + name: "test empty body", + suppliedQuery: queries{ + body: "", + headers: map[string]string{}, + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: true, + op_name: "", + op_type: "", + }, + }, + + { + name: "test empty json", + suppliedQuery: queries{ + body: "{}", + headers: map[string]string{}, + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: true, + op_name: "", + op_type: "", + }, + }, + + { + name: "test empty with some random garbage", + suppliedQuery: queries{ + body: "{\"variables\": {\"id\": \"1\"}}", + headers: map[string]string{}, + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: true, + op_name: "", + op_type: "", + }, + }, + + { + name: "test valid query with op name", + suppliedQuery: queries{ + body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: false, + op_name: "MyQuery", + op_type: "query", + }, + }, + + { + name: "test valid query with op name, variables and cache", + suppliedQuery: queries{ + body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", + }, + wantResults: results{ + is_cached: true, + should_block: false, + should_ignore: false, + op_name: "MyQuery", + op_type: "query", + }, + }, + + { + name: "test valid query with op name, cache and ttl", + suppliedQuery: queries{ + body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", + }, + wantResults: results{ + is_cached: true, + cached_ttl: 60, + should_block: false, + should_ignore: false, + op_name: "MyQuery", + op_type: "query", + }, + }, + + { + name: "test valid query with op name, cache and INVALID ttl", + suppliedQuery: queries{ + body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}", + }, + wantResults: results{ + is_cached: true, + cached_ttl: 0, + should_block: false, + should_ignore: false, + op_name: "MyQuery", + op_type: "query", + }, + }, + + { + name: "test mutation query with op name", + suppliedQuery: queries{ + body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: false, + op_name: "MyMutation", + op_type: "mutation", + }, + }, + + { + name: "test mutation query with config: read only", + suppliedSettings: func() *config { + cfg.Server.ReadOnlyMode = true + return cfg + }(), + suppliedQuery: queries{ + body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}", + }, + wantResults: results{ + is_cached: false, + should_block: true, + should_ignore: false, + op_name: "MyMutation", + op_type: "mutation", + returnCode: 403, + }, + }, + + { + name: "test simple query with introspection __schema", + suppliedQuery: queries{ + body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: false, + op_name: "MyMutation", + op_type: "mutation", + }, + }, + + { + name: "test simple query with introspection __schema config: block introspection", + suppliedSettings: func() *config { + cfg.Security.BlockIntrospection = true + return cfg + }(), + suppliedQuery: queries{ + body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}", + }, + wantResults: results{ + is_cached: false, + should_block: true, + should_ignore: false, + op_name: "MyIntroQuery", + op_type: "query", + returnCode: 403, + }, + }, + + { + name: "test user supplied query with introspection #1 - config: block", + suppliedSettings: func() *config { + parseConfig() + cfg.Security.BlockIntrospection = true + cfg.Security.IntrospectionAllowed = []string{} + prepareQueriesAndExemptions() + return cfg + }(), + suppliedQuery: queries{ + body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", + }, + wantResults: results{ + is_cached: false, + should_block: true, + should_ignore: false, + op_name: "undefined", + op_type: "query", + returnCode: 403, + }, + }, + + { + name: "test user supplied query with introspection #1 - config: block & allow __schema", + suppliedSettings: func() *config { + parseConfig() + cfg.Security.BlockIntrospection = true + cfg.Security.IntrospectionAllowed = []string{"__schema"} + prepareQueriesAndExemptions() + return cfg + }(), + suppliedQuery: queries{ + body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}", + }, + wantResults: results{ + is_cached: false, + should_block: false, + should_ignore: false, + op_name: "undefined", + op_type: "query", + returnCode: 200, + }, + }, + } + + for _, tt := range tests { + suite.T().Run(tt.name, func(t *testing.T) { + cfg = &config{} + cfg.Logger = libpack_logging.NewLogger() + defer func() { + cfg = &config{} + }() + + app := fiber.New() + + ctx_headers := func() *fasthttp.RequestHeader { + h := fasthttp.RequestHeader{} + for k, v := range tt.suppliedQuery.headers { + h.Add(k, v) + } + return &h + }() + + ctx_request := fasthttp.Request{ + Header: *ctx_headers, + } + + ctx_request.AppendBody([]byte(tt.suppliedQuery.body)) + + ctx := app.AcquireCtx(&fasthttp.RequestCtx{ + Request: ctx_request, + }) + + defer app.ReleaseCtx(ctx) + assert.NotNil(ctx, "Fiber context is nil") + + if tt.suppliedSettings != nil { + cfg = tt.suppliedSettings + } + + defer func() { + cfg = &config{} + }() + + opType, opName, cacheFromQuery, cached_ttl, shouldBlock, should_ignore := parseGraphQLQuery(ctx) + + assert.Equal(tt.wantResults.op_type, opType, "Unexpected operation type", tt.name) + assert.Equal(tt.wantResults.op_name, opName, "Unexpected operation name", tt.name) + assert.Equal(tt.wantResults.is_cached, cacheFromQuery, "Unexpected cache value", tt.name) + assert.Equal(tt.wantResults.cached_ttl, cached_ttl, "Unexpected cache TTL value", tt.name) + assert.Equal(tt.wantResults.should_block, shouldBlock, "Unexpected block value", tt.name) + assert.Equal(tt.wantResults.should_ignore, should_ignore, "Unexpected ignore value", tt.name) + + if tt.wantResults.returnCode > 0 { + assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name) + } + }) + } +} diff --git a/main.go b/main.go index b91c785..b4b65bc 100644 --- a/main.go +++ b/main.go @@ -11,15 +11,9 @@ import ( var cfg *config -func init() { - for _, query := range retrospection_queries { - retrospectionQuerySet[query] = struct{}{} - } -} - func parseConfig() { libpack_config.PKG_NAME = "graphql_proxy" - var c config + c := config{} c.Server.PortGraphQL = envutil.GetInt("PORT_GRAPHQL", 8080) c.Server.PortMonitoring = envutil.GetInt("MONITORING_PORT", 9393) c.Server.HostGraphQL = envutil.Getenv("HOST_GRAPHQL", "http://localhost/") @@ -61,6 +55,7 @@ func parseConfig() { enableCache() // takes close to no resources, but can be used with dynamic query cache loadRatelimitConfig() enableApi() + prepareQueriesAndExemptions() } func main() {