Fixes the code for additional test cases.

This commit is contained in:
2024-12-06 12:54:36 +00:00
parent e54bbe8249
commit ac84c69812
2 changed files with 50 additions and 49 deletions
+31 -29
View File
@@ -27,12 +27,17 @@ var (
)
func prepareQueriesAndExemptions() {
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
// Process allowed introspection queries
for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
}
// Process allowed URLs
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
allowedUrls[u] = struct{}{}
}
}
@@ -172,36 +177,33 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
for _, s := range selections {
switch sel := s.(type) {
case *ast.Field:
fieldName := strings.ToLower(sel.Name.Value)
if _, exists := introspectionQueries[fieldName]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
// If this field is allowed, don't block and continue checking other fields
if _, allowed := introspectionAllowedQueries[fieldName]; allowed {
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
switch sel := s.(type) {
case *ast.Field:
fieldName := strings.ToLower(sel.Name.Value)
if _, exists := introspectionQueries[fieldName]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
_, allowed := introspectionAllowedQueries[fieldName]
if !allowed {
return true // Block if this field isn't allowed
}
// Even if this field is allowed, we need to check its nested selections
} else {
return true // Block if no allowlist exists
}
}
// Always check nested selections
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
continue
}
return true
}
return true
}
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
}
}
return false
}