From fc9bab47fbfccdca24a965f5ea5dafe6b401bd23 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 6 Dec 2024 11:04:26 +0000 Subject: [PATCH] Fix query introspection blocking on deeply nested types. --- events.go | 2 +- graphql.go | 243 ++++++++++++++++++++++++------------------------ graphql_test.go | 2 + 3 files changed, 126 insertions(+), 121 deletions(-) diff --git a/events.go b/events.go index e32cbaf..4f053a2 100644 --- a/events.go +++ b/events.go @@ -96,7 +96,7 @@ func cleanEvents(pool *pgxpool.Pool) { Message: "Failed to execute some queries", Pairs: map[string]interface{}{ "failed_queries": failedQueries, - "errors": errMsgs, + "errors": errMsgs, }, }) } diff --git a/graphql.go b/graphql.go index acf3338..a7a8113 100644 --- a/graphql.go +++ b/graphql.go @@ -66,163 +66,166 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { m := queryPool.Get().(map[string]interface{}) defer func() { - for k := range m { - delete(m, k) - } - queryPool.Put(m) + for k := range m { + delete(m, k) + } + queryPool.Put(m) }() if err := json.Unmarshal(c.Body(), &m); err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't unmarshal the request", - Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - if res.shouldBlock { - resultPool.Put(res) + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't unmarshal the request", + Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } return res - } - return res } query, ok := m["query"].(string) if !ok { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't find the query", - Pairs: map[string]interface{}{"m_val": m}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - resultPool.Put(res) - return res + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't find the query", + Pairs: map[string]interface{}{"m_val": m}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + return res } p, err := parser.Parse(parser.ParseParams{Source: query}) if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't parse the query", - Pairs: map[string]interface{}{"query": query, "m_val": m}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - } - resultPool.Put(res) - return res + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't parse the query", + Pairs: map[string]interface{}{"query": query, "m_val": m}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } + return res } res.shouldIgnore = false res.operationName = "undefined" for _, d := range p.Definitions { - if oper, ok := d.(*ast.OperationDefinition); ok { - if res.operationType == "" { - res.operationType = strings.ToLower(oper.Operation) - if oper.Name != nil { - res.operationName = oper.Name.Value - } - } - - if cfg.Server.HostGraphQLReadOnly != "" { - if res.operationType == "" { - res.activeEndpoint = cfg.Server.HostGraphQLReadOnly - } else if res.operationType != "mutation" { - res.activeEndpoint = cfg.Server.HostGraphQLReadOnly - } - } - - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Endpoint selection", - Pairs: map[string]interface{}{ - "operationType": res.operationType, - "selectedEndpoint": res.activeEndpoint, - }, - }) - - if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { - cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "Mutation blocked - server in read-only mode", - Pairs: map[string]interface{}{"query": query}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - _ = c.Status(403).SendString("The server is in read-only mode") - res.shouldBlock = true - resultPool.Put(res) - return res - } - - for _, dir := range oper.Directives { - if dir.Name.Value == "cached" { - res.cacheRequest = true - for _, arg := range dir.Arguments { - switch arg.Name.Value { - case "ttl": - if v, ok := arg.Value.GetValue().(string); ok { - res.cacheTime, _ = strconv.Atoi(v) + if oper, ok := d.(*ast.OperationDefinition); ok { + if res.operationType == "" { + res.operationType = strings.ToLower(oper.Operation) + if oper.Name != nil { + res.operationName = oper.Name.Value } - case "refresh": - if v, ok := arg.Value.GetValue().(bool); ok { - res.cacheRefresh = v - } - } } - } - } - if cfg.Security.BlockIntrospection { - res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections) - if res.shouldBlock { - resultPool.Put(res) - return res - } + if cfg.Server.HostGraphQLReadOnly != "" { + if res.operationType == "" || res.operationType != "mutation" { + res.activeEndpoint = cfg.Server.HostGraphQLReadOnly + } + } + + if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { + cfg.Logger.Warning(&libpack_logger.LogMessage{ + Message: "Mutation blocked - server in read-only mode", + Pairs: map[string]interface{}{"query": query}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + _ = c.Status(403).SendString("The server is in read-only mode") + res.shouldBlock = true + resultPool.Put(res) + return res + } + + for _, dir := range oper.Directives { + if dir.Name.Value == "cached" { + res.cacheRequest = true + for _, arg := range dir.Arguments { + switch arg.Name.Value { + case "ttl": + if v, ok := arg.Value.GetValue().(string); ok { + res.cacheTime, _ = strconv.Atoi(v) + } + case "refresh": + if v, ok := arg.Value.GetValue().(bool); ok { + res.cacheRefresh = v + } + } + } + } + } + + if cfg.Security.BlockIntrospection { + if checkSelections(c, oper.GetSelectionSet().Selections) { + _ = c.Status(403).SendString("Introspection queries are not allowed") + res.shouldBlock = true + resultPool.Put(res) + return res + } + } } - } } return res } func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { - stack := make([]ast.Selection, len(selections)) - copy(stack, selections) - - for len(stack) > 0 { - var s ast.Selection - s, stack = stack[len(stack)-1], stack[:len(stack)-1] - - if field, ok := s.(*ast.Field); ok { - if checkIfContainsIntrospection(c, field.Name.Value) { - return true + for _, s := range selections { + switch sel := s.(type) { + case *ast.Field: + fieldName := strings.ToLower(sel.Name.Value) + if _, exists := introspectionQueries[fieldName]; exists { + if len(cfg.Security.IntrospectionAllowed) > 0 { + if _, allowed := introspectionAllowedQueries[fieldName]; !allowed { + return true + } + } else { + return true + } + } + if sel.SelectionSet != nil { + if checkSelections(c, sel.GetSelectionSet().Selections) { + return true + } + } } - if field.SelectionSet != nil { - stack = append(stack, field.GetSelectionSet().Selections...) - } - } } return false } -func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) bool { - whateverLower := strings.ToLower(whatever) - - if _, exists := introspectionQueries[whateverLower]; exists { - if len(cfg.Security.IntrospectionAllowed) > 0 { - if _, allowed := introspectionAllowedQueries[whateverLower]; allowed { - cfg.Logger.Debug(&libpack_logger.LogMessage{ - Message: "Introspection query allowed, passing through", - Pairs: map[string]interface{}{"query": whatever}, - }) - return false +func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool { + blocked := false + // Try parsing as a complete query first + p, err := parser.Parse(parser.ParseParams{Source: query}) + if err == nil { + // It's a complete query, check all selections + for _, def := range p.Definitions { + if op, ok := def.(*ast.OperationDefinition); ok { + if op.SelectionSet != nil { + blocked = checkSelections(c, op.GetSelectionSet().Selections) + } } } + } else { + // Not a complete query, check as a field name + whateverLower := strings.ToLower(query) + if _, exists := introspectionQueries[whateverLower]; exists { + if len(cfg.Security.IntrospectionAllowed) > 0 { + if _, allowed := introspectionAllowedQueries[whateverLower]; !allowed { + blocked = true + } + } else { + blocked = true + } + } + } + + if blocked { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } _ = c.Status(403).SendString("Introspection queries are not allowed") - return true } - return false + return blocked } diff --git a/graphql_test.go b/graphql_test.go index 73ef42a..92d7416 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -408,6 +408,8 @@ func (suite *Tests) Test_checkIfContainsIntrospection() { {"allowed introspection", "__schema", []string{"__schema"}, false}, {"disallowed introspection", "__type", []string{"__schema"}, true}, {"non-introspection query", "normalQuery", []string{}, false}, + {"allowed introspection with deep nesting of __typename", "{__schema {queryType {fields {name description __typename}}}}", []string{"__schema", "__typename"}, false}, + {"disallowed introspection with deep nesting of __typename", "{__type {queryType {fields {name description __typename}}}}", []string{"__type"}, true}, } for _, tt := range tests {