Fixes the code for additional test cases.

This commit is contained in:
2024-12-06 12:54:36 +00:00
parent e54bbe8249
commit ac84c69812
2 changed files with 50 additions and 49 deletions
+31 -29
View File
@@ -27,12 +27,17 @@ var (
) )
func prepareQueriesAndExemptions() { func prepareQueriesAndExemptions() {
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
// Process allowed introspection queries
for _, q := range cfg.Security.IntrospectionAllowed { for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{} introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
} }
// Process allowed URLs
for _, u := range cfg.Server.AllowURLs { for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{} allowedUrls[u] = struct{}{}
} }
} }
@@ -172,36 +177,33 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool { func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
for _, s := range selections { for _, s := range selections {
switch sel := s.(type) { switch sel := s.(type) {
case *ast.Field: case *ast.Field:
fieldName := strings.ToLower(sel.Name.Value) fieldName := strings.ToLower(sel.Name.Value)
if _, exists := introspectionQueries[fieldName]; exists { if _, exists := introspectionQueries[fieldName]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 { if len(cfg.Security.IntrospectionAllowed) > 0 {
// If this field is allowed, don't block and continue checking other fields _, allowed := introspectionAllowedQueries[fieldName]
if _, allowed := introspectionAllowedQueries[fieldName]; allowed { if !allowed {
if sel.SelectionSet != nil { return true // Block if this field isn't allowed
if checkSelections(c, sel.GetSelectionSet().Selections) { }
return true // Even if this field is allowed, we need to check its nested selections
} else {
return true // Block if no allowlist exists
}
}
// Always check nested selections
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
} }
}
continue
} }
return true
}
return true
} }
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
}
} }
return false return false
} }
+19 -20
View File
@@ -141,7 +141,7 @@ func (suite *Tests) Test_getDetailsFromEnv() {
} }
} }
func TestIntrospectionEnvironmentConfig(t *testing.T) { func (suite *Tests) TestIntrospectionEnvironmentConfig() {
// Save original env vars // Save original env vars
oldEnv := make(map[string]string) oldEnv := make(map[string]string)
varsToSave := []string{ varsToSave := []string{
@@ -215,10 +215,26 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
}`, }`,
wantBlocked: false, wantBlocked: false,
}, },
{
name: "multiple allowed queries with one of them blocked",
envVars: map[string]string{
"BLOCK_SCHEMA_INTROSPECTION": "true",
"ALLOWED_INTROSPECTION": "__schema",
},
query: `{
__schema {
types {
name
__typename
}
}
}`,
wantBlocked: true,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
// Set test env vars // Set test env vars
for k, v := range tt.envVars { for k, v := range tt.envVars {
os.Setenv(k, v) os.Setenv(k, v)
@@ -236,27 +252,10 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
ctx.Request().SetBody([]byte(fmt.Sprintf(`{"query": %q}`, tt.query))) ctx.Request().SetBody([]byte(fmt.Sprintf(`{"query": %q}`, tt.query)))
result := parseGraphQLQuery(ctx) result := parseGraphQLQuery(ctx)
assert.Equal(tt.wantBlocked, result.shouldBlock)
if result.shouldBlock != tt.wantBlocked {
t.Errorf("query blocked = %v, want %v", result.shouldBlock, tt.wantBlocked)
}
// Clean up test env vars
for k := range tt.envVars { for k := range tt.envVars {
os.Unsetenv(k) os.Unsetenv(k)
} }
}) })
} }
} }
func TestMain(m *testing.M) {
// Setup test environment
os.Setenv("LOG_LEVEL", "error") // Reduce noise in tests
// Run tests
code := m.Run()
// Cleanup
os.Unsetenv("LOG_LEVEL")
os.Exit(code)
}