mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
Fixes the code for additional test cases.
This commit is contained in:
+31
-29
@@ -27,12 +27,17 @@ var (
|
||||
)
|
||||
|
||||
func prepareQueriesAndExemptions() {
|
||||
introspectionAllowedQueries = make(map[string]struct{})
|
||||
allowedUrls = make(map[string]struct{})
|
||||
|
||||
// Process allowed introspection queries
|
||||
for _, q := range cfg.Security.IntrospectionAllowed {
|
||||
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
|
||||
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
|
||||
}
|
||||
|
||||
// Process allowed URLs
|
||||
for _, u := range cfg.Server.AllowURLs {
|
||||
allowedUrls[u] = struct{}{}
|
||||
allowedUrls[u] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,36 +177,33 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
|
||||
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
|
||||
for _, s := range selections {
|
||||
switch sel := s.(type) {
|
||||
case *ast.Field:
|
||||
fieldName := strings.ToLower(sel.Name.Value)
|
||||
if _, exists := introspectionQueries[fieldName]; exists {
|
||||
if len(cfg.Security.IntrospectionAllowed) > 0 {
|
||||
// If this field is allowed, don't block and continue checking other fields
|
||||
if _, allowed := introspectionAllowedQueries[fieldName]; allowed {
|
||||
if sel.SelectionSet != nil {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
switch sel := s.(type) {
|
||||
case *ast.Field:
|
||||
fieldName := strings.ToLower(sel.Name.Value)
|
||||
if _, exists := introspectionQueries[fieldName]; exists {
|
||||
if len(cfg.Security.IntrospectionAllowed) > 0 {
|
||||
_, allowed := introspectionAllowedQueries[fieldName]
|
||||
if !allowed {
|
||||
return true // Block if this field isn't allowed
|
||||
}
|
||||
// Even if this field is allowed, we need to check its nested selections
|
||||
} else {
|
||||
return true // Block if no allowlist exists
|
||||
}
|
||||
}
|
||||
// Always check nested selections
|
||||
if sel.SelectionSet != nil {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case *ast.InlineFragment:
|
||||
if sel.SelectionSet != nil {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
return true
|
||||
}
|
||||
return true
|
||||
}
|
||||
if sel.SelectionSet != nil {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case *ast.InlineFragment:
|
||||
if sel.SelectionSet != nil {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
+19
-20
@@ -141,7 +141,7 @@ func (suite *Tests) Test_getDetailsFromEnv() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntrospectionEnvironmentConfig(t *testing.T) {
|
||||
func (suite *Tests) TestIntrospectionEnvironmentConfig() {
|
||||
// Save original env vars
|
||||
oldEnv := make(map[string]string)
|
||||
varsToSave := []string{
|
||||
@@ -215,10 +215,26 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
|
||||
}`,
|
||||
wantBlocked: false,
|
||||
},
|
||||
{
|
||||
name: "multiple allowed queries with one of them blocked",
|
||||
envVars: map[string]string{
|
||||
"BLOCK_SCHEMA_INTROSPECTION": "true",
|
||||
"ALLOWED_INTROSPECTION": "__schema",
|
||||
},
|
||||
query: `{
|
||||
__schema {
|
||||
types {
|
||||
name
|
||||
__typename
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantBlocked: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
suite.Run(tt.name, func() {
|
||||
// Set test env vars
|
||||
for k, v := range tt.envVars {
|
||||
os.Setenv(k, v)
|
||||
@@ -236,27 +252,10 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
|
||||
ctx.Request().SetBody([]byte(fmt.Sprintf(`{"query": %q}`, tt.query)))
|
||||
|
||||
result := parseGraphQLQuery(ctx)
|
||||
|
||||
if result.shouldBlock != tt.wantBlocked {
|
||||
t.Errorf("query blocked = %v, want %v", result.shouldBlock, tt.wantBlocked)
|
||||
}
|
||||
|
||||
// Clean up test env vars
|
||||
assert.Equal(tt.wantBlocked, result.shouldBlock)
|
||||
for k := range tt.envVars {
|
||||
os.Unsetenv(k)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Setup test environment
|
||||
os.Setenv("LOG_LEVEL", "error") // Reduce noise in tests
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup
|
||||
os.Unsetenv("LOG_LEVEL")
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user