Fix blocking the introspection + add unit tests.

This commit is contained in:
2023-11-18 02:11:38 +00:00
parent a71b3950db
commit 1390e7cdd1
6 changed files with 400 additions and 47 deletions
+1 -1
View File
@@ -11,7 +11,7 @@ help: ## display this help
.PHONY: run
run: build ## run application
@LOG_LEVEL=debug PURGE_METRICS_ON_CRAWL=true BLOCK_SCHEMA_INTROSPECTION=false CACHE_TTL=10 JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/ HEALTHCHECK_GRAPHQL_URL=https://hasura8.lan/v1/graphql ./graphql-proxy
@LOG_LEVEL=debug PURGE_METRICS_ON_CRAWL=true BLOCK_SCHEMA_INTROSPECTION=true CACHE_TTL=10 JWT_ROLE_RATE_LIMIT=false JWT_ROLE_CLAIM_PATH="Hasura.x-hasura-default-role" JWT_USER_CLAIM_PATH="Hasura.x-hasura-user-id" HOST_GRAPHQL=https://hasura8.lan/ HEALTHCHECK_GRAPHQL_URL=https://hasura8.lan/v1/graphql ./graphql-proxy
.PHONY: build
build: ## build the binary
+1
View File
@@ -24,6 +24,7 @@ require (
github.com/avast/retry-go/v4 v4.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/gookit/color v1.5.4 // indirect
github.com/k0kubun/pp v3.0.1+incompatible // indirect
github.com/klauspost/compress v1.17.2 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
+2
View File
@@ -27,6 +27,8 @@ github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuM
github.com/graphql-go/graphql v0.8.1/go.mod h1:nKiHzRM0qopJEwCITUuIsxk9PlVlwIiiI8pnJEhordQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40=
github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg=
github.com/klauspost/compress v1.17.2 h1:RlWWUY/Dr4fL8qk9YG7DTZ7PDgME2V4csBXA8L/ixi4=
github.com/klauspost/compress v1.17.2/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
+93 -39
View File
@@ -1,6 +1,7 @@
package main
import (
"flag"
"strconv"
"strings"
@@ -10,7 +11,7 @@ import (
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
var retrospection_queries = []string{
var introspection_queries = []string{
"__schema",
"__type",
"__typename",
@@ -34,13 +35,29 @@ var retrospection_queries = []string{
}
// Saving the introspection queries as a map O(1) operation instead of O(n) for a slice.
var retrospectionQuerySet = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(retrospection_queries))
for _, query := range retrospection_queries {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
var introspectionQuerySet = map[string]struct{}{}
var introspectionAllowedQueries = map[string]struct{}{}
func prepareQueriesAndExemptions() {
introspectionQuerySet = map[string]struct{}{}
introspectionQuerySet = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(introspection_queries))
for _, query := range introspection_queries {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
introspectionAllowedQueries = map[string]struct{}{}
introspectionAllowedQueries = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(cfg.Security.IntrospectionAllowed))
for _, query := range cfg.Security.IntrospectionAllowed {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
}
func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cacheRequest bool, cache_time int, should_block bool, should_ignore bool) {
should_ignore = true
@@ -48,21 +65,27 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
err := json.Unmarshal(c.Body(), &m)
if err != nil {
cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())})
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
}
// get the query
query, ok := m["query"].(string)
if !ok {
cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m})
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
}
p, err := parser.Parse(parser.ParseParams{Source: query})
if err != nil {
cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
}
@@ -71,19 +94,21 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
operationType = oper.Operation
if oper.Name != nil {
operationName = oper.Name.Value
}
if strings.ToLower(operationType) == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m)
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("The server is in read-only mode")
should_block = true
return
}
if oper.Name != nil {
operationName = oper.Name.Value
} else {
operationName = "undefined"
}
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
cacheRequest = true
@@ -91,38 +116,67 @@ func parseGraphQLQuery(c *fiber.Ctx) (operationType, operationName string, cache
if arg.Name.Value == "ttl" {
cache_time, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil {
cfg.Logger.Error("Can't parse the ttl", map[string]interface{}{"ttl": arg.Value.GetValue().(string)})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
}
}
}
}
}
if cfg.Security.BlockIntrospection {
for _, s := range oper.SelectionSet.Selections {
for _, s2 := range s.GetSelectionSet().Selections {
if _, exists := retrospectionQuerySet[strings.ToLower(s2.(*ast.Field).Name.Value)]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
for _, introspectionQueryAllowed := range cfg.Security.IntrospectionAllowed {
if strings.EqualFold(strings.ToLower(introspectionQueryAllowed), strings.ToLower(s2.(*ast.Field).Name.Value)) {
cfg.Logger.Debug("Introspection query allowed, passing through", m)
return
}
}
}
cfg.Logger.Warning("Introspection query blocked", m)
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
c.Status(403).SendString("Introspection queries are not allowed")
should_block = true
return
}
}
if cfg.Security.BlockIntrospection {
should_block = checkSelections(c, oper.GetSelectionSet().Selections)
if should_block {
return
}
}
}
}
return
}
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
for _, s := range selections {
field, ok := s.(*ast.Field)
if !ok {
continue // or handle the case where the type assertion fails
}
shouldBlock := checkIfContainsIntrospection(c, field.Name.Value)
if shouldBlock {
return true
}
if field.SelectionSet != nil {
if checkSelections(c, field.GetSelectionSet().Selections) {
return true
}
}
}
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (should_block bool) {
whateverLower := strings.ToLower(whatever)
got_exemption := false
if _, exists := introspectionQuerySet[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true
should_block = false
}
}
if !got_exemption {
should_block = true
}
}
if should_block {
if flag.Lookup("test.v") == nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Introspection queries are not allowed")
}
return
}
+301
View File
@@ -0,0 +1,301 @@
package main
import (
"testing"
fiber "github.com/gofiber/fiber/v2"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
"github.com/valyala/fasthttp"
)
func (suite *Tests) Test_parseGraphQLQuery() {
type results struct {
is_cached bool
cached_ttl int
should_block bool
should_ignore bool
op_name string
op_type string
returnCode int
}
type queries struct {
body string
headers map[string]string
}
tests := []struct {
name string
suppliedSettings *config
suppliedQuery queries
wantResults results
}{
{
name: "test empty body",
suppliedQuery: queries{
body: "",
headers: map[string]string{},
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: true,
op_name: "",
op_type: "",
},
},
{
name: "test empty json",
suppliedQuery: queries{
body: "{}",
headers: map[string]string{},
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: true,
op_name: "",
op_type: "",
},
},
{
name: "test empty with some random garbage",
suppliedQuery: queries{
body: "{\"variables\": {\"id\": \"1\"}}",
headers: map[string]string{},
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: true,
op_name: "",
op_type: "",
},
},
{
name: "test valid query with op name",
suppliedQuery: queries{
body: "{\"query\":\"query MyQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: false,
op_name: "MyQuery",
op_type: "query",
},
},
{
name: "test valid query with op name, variables and cache",
suppliedQuery: queries{
body: "{\"query\":\"query MyQuery @cached { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
},
wantResults: results{
is_cached: true,
should_block: false,
should_ignore: false,
op_name: "MyQuery",
op_type: "query",
},
},
{
name: "test valid query with op name, cache and ttl",
suppliedQuery: queries{
body: "{\"query\":\"query MyQuery @cached(ttl: 60) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
},
wantResults: results{
is_cached: true,
cached_ttl: 60,
should_block: false,
should_ignore: false,
op_name: "MyQuery",
op_type: "query",
},
},
{
name: "test valid query with op name, cache and INVALID ttl",
suppliedQuery: queries{
body: "{\"query\":\"query MyQuery @cached(ttl: nope) { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\", \"variables\": {\"id\": \"1\"}}",
},
wantResults: results{
is_cached: true,
cached_ttl: 0,
should_block: false,
should_ignore: false,
op_name: "MyQuery",
op_type: "query",
},
},
{
name: "test mutation query with op name",
suppliedQuery: queries{
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: false,
op_name: "MyMutation",
op_type: "mutation",
},
},
{
name: "test mutation query with config: read only",
suppliedSettings: func() *config {
cfg.Server.ReadOnlyMode = true
return cfg
}(),
suppliedQuery: queries{
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __typename } }\"}",
},
wantResults: results{
is_cached: false,
should_block: true,
should_ignore: false,
op_name: "MyMutation",
op_type: "mutation",
returnCode: 403,
},
},
{
name: "test simple query with introspection __schema",
suppliedQuery: queries{
body: "{\"query\":\"mutation MyMutation { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: false,
op_name: "MyMutation",
op_type: "mutation",
},
},
{
name: "test simple query with introspection __schema config: block introspection",
suppliedSettings: func() *config {
cfg.Security.BlockIntrospection = true
return cfg
}(),
suppliedQuery: queries{
body: "{\"query\":\"query MyIntroQuery { tg_users(where: {handle: {_eq: \\\"tozuo\\\"}}) { id __schema } }\"}",
},
wantResults: results{
is_cached: false,
should_block: true,
should_ignore: false,
op_name: "MyIntroQuery",
op_type: "query",
returnCode: 403,
},
},
{
name: "test user supplied query with introspection #1 - config: block",
suppliedSettings: func() *config {
parseConfig()
cfg.Security.BlockIntrospection = true
cfg.Security.IntrospectionAllowed = []string{}
prepareQueriesAndExemptions()
return cfg
}(),
suppliedQuery: queries{
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
},
wantResults: results{
is_cached: false,
should_block: true,
should_ignore: false,
op_name: "undefined",
op_type: "query",
returnCode: 403,
},
},
{
name: "test user supplied query with introspection #1 - config: block & allow __schema",
suppliedSettings: func() *config {
parseConfig()
cfg.Security.BlockIntrospection = true
cfg.Security.IntrospectionAllowed = []string{"__schema"}
prepareQueriesAndExemptions()
return cfg
}(),
suppliedQuery: queries{
body: "{\"query\":\"{__schema {queryType {fields {name description}}}}\"}",
},
wantResults: results{
is_cached: false,
should_block: false,
should_ignore: false,
op_name: "undefined",
op_type: "query",
returnCode: 200,
},
},
}
for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) {
cfg = &config{}
cfg.Logger = libpack_logging.NewLogger()
defer func() {
cfg = &config{}
}()
app := fiber.New()
ctx_headers := func() *fasthttp.RequestHeader {
h := fasthttp.RequestHeader{}
for k, v := range tt.suppliedQuery.headers {
h.Add(k, v)
}
return &h
}()
ctx_request := fasthttp.Request{
Header: *ctx_headers,
}
ctx_request.AppendBody([]byte(tt.suppliedQuery.body))
ctx := app.AcquireCtx(&fasthttp.RequestCtx{
Request: ctx_request,
})
defer app.ReleaseCtx(ctx)
assert.NotNil(ctx, "Fiber context is nil")
if tt.suppliedSettings != nil {
cfg = tt.suppliedSettings
}
defer func() {
cfg = &config{}
}()
opType, opName, cacheFromQuery, cached_ttl, shouldBlock, should_ignore := parseGraphQLQuery(ctx)
assert.Equal(tt.wantResults.op_type, opType, "Unexpected operation type", tt.name)
assert.Equal(tt.wantResults.op_name, opName, "Unexpected operation name", tt.name)
assert.Equal(tt.wantResults.is_cached, cacheFromQuery, "Unexpected cache value", tt.name)
assert.Equal(tt.wantResults.cached_ttl, cached_ttl, "Unexpected cache TTL value", tt.name)
assert.Equal(tt.wantResults.should_block, shouldBlock, "Unexpected block value", tt.name)
assert.Equal(tt.wantResults.should_ignore, should_ignore, "Unexpected ignore value", tt.name)
if tt.wantResults.returnCode > 0 {
assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name)
}
})
}
}
+2 -7
View File
@@ -11,15 +11,9 @@ import (
var cfg *config
func init() {
for _, query := range retrospection_queries {
retrospectionQuerySet[query] = struct{}{}
}
}
func parseConfig() {
libpack_config.PKG_NAME = "graphql_proxy"
var c config
c := config{}
c.Server.PortGraphQL = envutil.GetInt("PORT_GRAPHQL", 8080)
c.Server.PortMonitoring = envutil.GetInt("MONITORING_PORT", 9393)
c.Server.HostGraphQL = envutil.Getenv("HOST_GRAPHQL", "http://localhost/")
@@ -61,6 +55,7 @@ func parseConfig() {
enableCache() // takes close to no resources, but can be used with dynamic query cache
loadRatelimitConfig()
enableApi()
prepareQueriesAndExemptions()
}
func main() {