diff --git a/graphql.go b/graphql.go index f23f91c..c8aaa19 100644 --- a/graphql.go +++ b/graphql.go @@ -28,11 +28,11 @@ var ( func prepareQueriesAndExemptions() { for _, q := range cfg.Security.IntrospectionAllowed { - introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} + introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{} } for _, u := range cfg.Server.AllowURLs { - allowedUrls[u] = struct{}{} + allowedUrls[u] = struct{}{} } } @@ -66,142 +66,142 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { m := queryPool.Get().(map[string]interface{}) defer func() { - for k := range m { - delete(m, k) - } - queryPool.Put(m) + for k := range m { + delete(m, k) + } + queryPool.Put(m) }() if err := json.Unmarshal(c.Body(), &m); err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't unmarshal the request", - Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - return res + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't unmarshal the request", + Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + return res } query, ok := m["query"].(string) if !ok { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't find the query", - Pairs: map[string]interface{}{"m_val": m}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - return res + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't find the query", + Pairs: map[string]interface{}{"m_val": m}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + return res } p, err := parser.Parse(parser.ParseParams{Source: query}) if err != nil { - cfg.Logger.Error(&libpack_logger.LogMessage{ - Message: "Can't parse the query", - Pairs: map[string]interface{}{"query": query, "m_val": m}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) - } - return res + cfg.Logger.Error(&libpack_logger.LogMessage{ + Message: "Can't parse the query", + Pairs: map[string]interface{}{"query": query, "m_val": m}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) + } + return res } res.shouldIgnore = false res.operationName = "undefined" for _, d := range p.Definitions { - if oper, ok := d.(*ast.OperationDefinition); ok { - if res.operationType == "" { - res.operationType = strings.ToLower(oper.Operation) - if oper.Name != nil { - res.operationName = oper.Name.Value - } - } - - if cfg.Server.HostGraphQLReadOnly != "" { - if res.operationType == "" || res.operationType != "mutation" { - res.activeEndpoint = cfg.Server.HostGraphQLReadOnly - } - } - - if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { - cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "Mutation blocked - server in read-only mode", - Pairs: map[string]interface{}{"query": query}, - }) - if ifNotInTest() { - cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) - } - _ = c.Status(403).SendString("The server is in read-only mode") - res.shouldBlock = true - resultPool.Put(res) - return res - } - - for _, dir := range oper.Directives { - if dir.Name.Value == "cached" { - res.cacheRequest = true - for _, arg := range dir.Arguments { - switch arg.Name.Value { - case "ttl": - if v, ok := arg.Value.GetValue().(string); ok { - res.cacheTime, _ = strconv.Atoi(v) - } - case "refresh": - if v, ok := arg.Value.GetValue().(bool); ok { - res.cacheRefresh = v - } - } - } - } - } - - if cfg.Security.BlockIntrospection { - if checkSelections(c, oper.GetSelectionSet().Selections) { - _ = c.Status(403).SendString("Introspection queries are not allowed") - res.shouldBlock = true - resultPool.Put(res) - return res - } - } + if oper, ok := d.(*ast.OperationDefinition); ok { + if res.operationType == "" { + res.operationType = strings.ToLower(oper.Operation) + if oper.Name != nil { + res.operationName = oper.Name.Value + } } + + if cfg.Server.HostGraphQLReadOnly != "" { + if res.operationType == "" || res.operationType != "mutation" { + res.activeEndpoint = cfg.Server.HostGraphQLReadOnly + } + } + + if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { + cfg.Logger.Warning(&libpack_logger.LogMessage{ + Message: "Mutation blocked - server in read-only mode", + Pairs: map[string]interface{}{"query": query}, + }) + if ifNotInTest() { + cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) + } + _ = c.Status(403).SendString("The server is in read-only mode") + res.shouldBlock = true + resultPool.Put(res) + return res + } + + for _, dir := range oper.Directives { + if dir.Name.Value == "cached" { + res.cacheRequest = true + for _, arg := range dir.Arguments { + switch arg.Name.Value { + case "ttl": + if v, ok := arg.Value.GetValue().(string); ok { + res.cacheTime, _ = strconv.Atoi(v) + } + case "refresh": + if v, ok := arg.Value.GetValue().(bool); ok { + res.cacheRefresh = v + } + } + } + } + } + + if cfg.Security.BlockIntrospection { + if checkSelections(c, oper.GetSelectionSet().Selections) { + _ = c.Status(403).SendString("Introspection queries are not allowed") + res.shouldBlock = true + resultPool.Put(res) + return res + } + } + } } return res } func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { for _, s := range selections { - switch sel := s.(type) { - case *ast.Field: - fieldName := strings.ToLower(sel.Name.Value) - if _, exists := introspectionQueries[fieldName]; exists { - if len(cfg.Security.IntrospectionAllowed) > 0 { - // If this field is allowed, don't block and continue checking other fields - if _, allowed := introspectionAllowedQueries[fieldName]; allowed { - if sel.SelectionSet != nil { - if checkSelections(c, sel.GetSelectionSet().Selections) { - return true - } - } - continue - } - return true - } - return true - } - if sel.SelectionSet != nil { + switch sel := s.(type) { + case *ast.Field: + fieldName := strings.ToLower(sel.Name.Value) + if _, exists := introspectionQueries[fieldName]; exists { + if len(cfg.Security.IntrospectionAllowed) > 0 { + // If this field is allowed, don't block and continue checking other fields + if _, allowed := introspectionAllowedQueries[fieldName]; allowed { + if sel.SelectionSet != nil { if checkSelections(c, sel.GetSelectionSet().Selections) { - return true - } - } - case *ast.InlineFragment: - if sel.SelectionSet != nil { - if checkSelections(c, sel.GetSelectionSet().Selections) { - return true + return true } + } + continue } + return true + } + return true } + if sel.SelectionSet != nil { + if checkSelections(c, sel.GetSelectionSet().Selections) { + return true + } + } + case *ast.InlineFragment: + if sel.SelectionSet != nil { + if checkSelections(c, sel.GetSelectionSet().Selections) { + return true + } + } + } } return false } diff --git a/graphql_test.go b/graphql_test.go index f41a3ea..8c50693 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -3,9 +3,12 @@ package main import ( "fmt" "strings" + "testing" "github.com/goccy/go-json" fiber "github.com/gofiber/fiber/v2" + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/parser" "github.com/valyala/fasthttp" ) @@ -436,69 +439,173 @@ func createTestContext(body string) *fiber.Ctx { func (suite *Tests) Test_DeepIntrospectionQueries() { tests := []struct { - name string - query string - allowed []string - expected bool + name string + query string + allowed []string + expected bool }{ - { - name: "deeply nested single introspection", - query: "query { users { profiles { settings { preferences { __typename } } } } }", - allowed: []string{}, - expected: true, - }, - { - name: "multiple nested introspections", - query: "query { users { __typename profiles { __schema settings { __type } } } }", - allowed: []string{}, - expected: true, - }, - { - name: "nested with selective allowlist", - query: "query { users { __typename profiles { __schema settings { __type } } } }", - allowed: []string{"__typename"}, - expected: true, - }, - { - name: "deeply nested with full allowlist", - query: "query { users { __typename profiles { __schema settings { __type } } } }", - allowed: []string{"__typename", "__schema", "__type"}, - expected: false, - }, - { - name: "deeply nested with repeated item from allowlist", - query: "query PreloadStaticData {\n scenario {\n id\n name\n __typename\n }\n impact {\n id\n description\n __typename\n }\n likelihood {\n id\n description\n __typename\n }\n consequence {\n name\n __typename\n }\n risk_categories {\n name\n abbreviation\n __typename\n }\n mitigation {\n name\n __typename\n }\n}", - allowed: []string{"__type", "__typename"}, - expected: false, - }, - { - name: "deeply nested with repeated item denied", - query: "query PreloadStaticData {\n scenario {\n id\n name\n __typename\n }\n impact {\n id\n description\n __typename\n }\n likelihood {\n id\n description\n __typename\n }\n consequence {\n name\n __typename\n }\n risk_categories {\n name\n abbreviation\n __typename\n }\n mitigation {\n name\n __typename\n }\n}", - allowed: []string{}, - expected: true, - }, + { + name: "deeply nested single introspection", + query: "query { users { profiles { settings { preferences { __typename } } } } }", + allowed: []string{}, + expected: true, + }, + { + name: "multiple nested introspections", + query: "query { users { __typename profiles { __schema settings { __type } } } }", + allowed: []string{}, + expected: true, + }, + { + name: "nested with selective allowlist", + query: "query { users { __typename profiles { __schema settings { __type } } } }", + allowed: []string{"__typename"}, + expected: true, + }, + { + name: "deeply nested with full allowlist", + query: "query { users { __typename profiles { __schema settings { __type } } } }", + allowed: []string{"__typename", "__schema", "__type"}, + expected: false, + }, + { + name: "deeply nested with repeated item from allowlist", + query: "query PreloadStaticData {\n scenario {\n id\n name\n __typename\n }\n impact {\n id\n description\n __typename\n }\n likelihood {\n id\n description\n __typename\n }\n consequence {\n name\n __typename\n }\n risk_categories {\n name\n abbreviation\n __typename\n }\n mitigation {\n name\n __typename\n }\n}", + allowed: []string{"__type", "__typename"}, + expected: false, + }, + { + name: "deeply nested with repeated item denied", + query: "query PreloadStaticData {\n scenario {\n id\n name\n __typename\n }\n impact {\n id\n description\n __typename\n }\n likelihood {\n id\n description\n __typename\n }\n consequence {\n name\n __typename\n }\n risk_categories {\n name\n abbreviation\n __typename\n }\n mitigation {\n name\n __typename\n }\n}", + allowed: []string{}, + expected: true, + }, } for _, tt := range tests { - suite.Run(tt.name, func() { - cfg.Security.BlockIntrospection = true - cfg.Security.IntrospectionAllowed = tt.allowed - introspectionAllowedQueries = make(map[string]struct{}) - for _, q := range tt.allowed { - introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} - } - body := map[string]interface{}{ - "query": tt.query, - } - bodyBytes, _ := json.Marshal(body) - ctx := fiber.New().AcquireCtx(&fasthttp.RequestCtx{}) - ctx.Request().SetBody(bodyBytes) - parseGraphQLQuery(ctx) - if tt.expected { - suite.Equal(403, ctx.Response().StatusCode()) - } else { - suite.Equal(200, ctx.Response().StatusCode()) - } - }) + suite.Run(tt.name, func() { + cfg.Security.BlockIntrospection = true + cfg.Security.IntrospectionAllowed = tt.allowed + introspectionAllowedQueries = make(map[string]struct{}) + for _, q := range tt.allowed { + introspectionAllowedQueries[strings.ToLower(q)] = struct{}{} + } + body := map[string]interface{}{ + "query": tt.query, + } + bodyBytes, _ := json.Marshal(body) + ctx := fiber.New().AcquireCtx(&fasthttp.RequestCtx{}) + ctx.Request().SetBody(bodyBytes) + parseGraphQLQuery(ctx) + if tt.expected { + suite.Equal(403, ctx.Response().StatusCode()) + } else { + suite.Equal(200, ctx.Response().StatusCode()) + } + }) } -} \ No newline at end of file +} + +func TestIntrospectionQueryHandling(t *testing.T) { + tests := []struct { + name string + blockIntrospection bool + allowedQueries []string + query string + wantBlocked bool + }{ + { + name: "allows __typename when in allowed list", + blockIntrospection: true, + allowedQueries: []string{"__typename"}, + query: `{ + users { + id + name + __typename + } + }`, + wantBlocked: false, + }, + { + name: "case insensitive matching for allowed queries", + blockIntrospection: true, + allowedQueries: []string{"__TYPENAME"}, + query: `{ + users { + __typename + } + }`, + wantBlocked: false, + }, + { + name: "blocks other introspection queries", + blockIntrospection: true, + allowedQueries: []string{"__typename"}, + query: `{ + __schema { + types { + name + } + } + }`, + wantBlocked: true, + }, + { + name: "allows multiple __typename occurrences", + blockIntrospection: true, + allowedQueries: []string{"__typename"}, + query: `{ + users { + __typename + posts { + __typename + } + } + }`, + wantBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup config + cfg = &config{ + Security: struct { + IntrospectionAllowed []string + BlockIntrospection bool + }{ + IntrospectionAllowed: tt.allowedQueries, + BlockIntrospection: tt.blockIntrospection, + }, + } + + // Initialize allowed queries + prepareQueriesAndExemptions() + + // Parse query + p, err := parser.Parse(parser.ParseParams{Source: tt.query}) + if err != nil { + t.Fatalf("failed to parse query: %v", err) + } + + // Create mock fiber context + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + + // Check selections + var blocked bool + for _, def := range p.Definitions { + if op, ok := def.(*ast.OperationDefinition); ok { + blocked = checkSelections(ctx, op.GetSelectionSet().Selections) + break + } + } + + if blocked != tt.wantBlocked { + t.Errorf("checkSelections() blocked = %v, want %v", blocked, tt.wantBlocked) + } + }) + } +} diff --git a/main_test.go b/main_test.go index e2ea74c..63acb15 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "os" "testing" "time" @@ -11,6 +12,7 @@ import ( libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" assertions "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + "github.com/valyala/fasthttp" ) type Tests struct { @@ -138,3 +140,123 @@ func (suite *Tests) Test_getDetailsFromEnv() { }) } } + +func TestIntrospectionEnvironmentConfig(t *testing.T) { + // Save original env vars + oldEnv := make(map[string]string) + varsToSave := []string{ + "BLOCK_SCHEMA_INTROSPECTION", + "ALLOWED_INTROSPECTION", + "GMP_BLOCK_SCHEMA_INTROSPECTION", + "GMP_ALLOWED_INTROSPECTION", + } + for _, env := range varsToSave { + if val, exists := os.LookupEnv(env); exists { + oldEnv[env] = val + os.Unsetenv(env) + } + } + defer func() { + // Restore original env vars + for k, v := range oldEnv { + os.Setenv(k, v) + } + }() + + tests := []struct { + name string + envVars map[string]string + query string + wantBlocked bool + wantEndpoint string + }{ + { + name: "basic typename allowed", + envVars: map[string]string{ + "BLOCK_SCHEMA_INTROSPECTION": "true", + "ALLOWED_INTROSPECTION": "__typename", + }, + query: `{ + users { + id + __typename + } + }`, + wantBlocked: false, + }, + { + name: "GMP prefix takes precedence", + envVars: map[string]string{ + "BLOCK_SCHEMA_INTROSPECTION": "false", + "GMP_BLOCK_SCHEMA_INTROSPECTION": "true", + "ALLOWED_INTROSPECTION": "__type", + "GMP_ALLOWED_INTROSPECTION": "__typename", + }, + query: `{ + users { + __typename + } + }`, + wantBlocked: false, + }, + { + name: "multiple allowed queries", + envVars: map[string]string{ + "BLOCK_SCHEMA_INTROSPECTION": "true", + "ALLOWED_INTROSPECTION": "__typename,__schema", + }, + query: `{ + __schema { + types { + name + __typename + } + } + }`, + wantBlocked: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set test env vars + for k, v := range tt.envVars { + os.Setenv(k, v) + } + + // Reset global config + cfg = nil + parseConfig() + + // Create test request + app := fiber.New() + ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) + defer app.ReleaseCtx(ctx) + ctx.Request().Header.SetMethod("POST") + ctx.Request().SetBody([]byte(fmt.Sprintf(`{"query": %q}`, tt.query))) + + result := parseGraphQLQuery(ctx) + + if result.shouldBlock != tt.wantBlocked { + t.Errorf("query blocked = %v, want %v", result.shouldBlock, tt.wantBlocked) + } + + // Clean up test env vars + for k := range tt.envVars { + os.Unsetenv(k) + } + }) + } +} + +func TestMain(m *testing.M) { + // Setup test environment + os.Setenv("LOG_LEVEL", "error") // Reduce noise in tests + + // Run tests + code := m.Run() + + // Cleanup + os.Unsetenv("LOG_LEVEL") + os.Exit(code) +}