Fixes the code for additional test cases.

This commit is contained in:
2024-12-06 12:54:36 +00:00
parent e54bbe8249
commit ac84c69812
2 changed files with 50 additions and 49 deletions
+13 -11
View File
@@ -27,10 +27,15 @@ var (
)
func prepareQueriesAndExemptions() {
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
// Process allowed introspection queries
for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(strings.TrimSpace(q))] = struct{}{}
}
// Process allowed URLs
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
}
@@ -177,19 +182,16 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
fieldName := strings.ToLower(sel.Name.Value)
if _, exists := introspectionQueries[fieldName]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
// If this field is allowed, don't block and continue checking other fields
if _, allowed := introspectionAllowedQueries[fieldName]; allowed {
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
_, allowed := introspectionAllowedQueries[fieldName]
if !allowed {
return true // Block if this field isn't allowed
}
// Even if this field is allowed, we need to check its nested selections
} else {
return true // Block if no allowlist exists
}
}
continue
}
return true
}
return true
}
// Always check nested selections
if sel.SelectionSet != nil {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
+19 -20
View File
@@ -141,7 +141,7 @@ func (suite *Tests) Test_getDetailsFromEnv() {
}
}
func TestIntrospectionEnvironmentConfig(t *testing.T) {
func (suite *Tests) TestIntrospectionEnvironmentConfig() {
// Save original env vars
oldEnv := make(map[string]string)
varsToSave := []string{
@@ -215,10 +215,26 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
}`,
wantBlocked: false,
},
{
name: "multiple allowed queries with one of them blocked",
envVars: map[string]string{
"BLOCK_SCHEMA_INTROSPECTION": "true",
"ALLOWED_INTROSPECTION": "__schema",
},
query: `{
__schema {
types {
name
__typename
}
}
}`,
wantBlocked: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
suite.Run(tt.name, func() {
// Set test env vars
for k, v := range tt.envVars {
os.Setenv(k, v)
@@ -236,27 +252,10 @@ func TestIntrospectionEnvironmentConfig(t *testing.T) {
ctx.Request().SetBody([]byte(fmt.Sprintf(`{"query": %q}`, tt.query)))
result := parseGraphQLQuery(ctx)
if result.shouldBlock != tt.wantBlocked {
t.Errorf("query blocked = %v, want %v", result.shouldBlock, tt.wantBlocked)
}
// Clean up test env vars
assert.Equal(tt.wantBlocked, result.shouldBlock)
for k := range tt.envVars {
os.Unsetenv(k)
}
})
}
}
func TestMain(m *testing.M) {
// Setup test environment
os.Setenv("LOG_LEVEL", "error") // Reduce noise in tests
// Run tests
code := m.Run()
// Cleanup
os.Unsetenv("LOG_LEVEL")
os.Exit(code)
}