From d141fe3c041a7e6d2dc1e08feb53669908f82ef0 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Fri, 28 Jun 2024 21:48:39 +0100 Subject: [PATCH] Fix the introduced bug where RO endpoint could've been accidentally used. (#17) * Fix the introduced bug where RO endpoint could've been accidentally used. --- .github/workflows/pr.yaml | 2 +- .github/workflows/test.yaml | 2 +- cache/cache.go | 4 +- cache/cache_bench_test.go | 26 +++---- cache/cache_test.go | 27 +++---- graphql.go | 40 +++++++--- monitoring/helpers.go | 13 +++- proxy_test.go | 145 +++++++++++++++++++++--------------- 8 files changed, 151 insertions(+), 108 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 6930ba3..41e36e6 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -102,7 +102,7 @@ jobs: fail-on-alert: true github-token: ${{ secrets.GITHUB_TOKEN }} comment-on-alert: true - summary-always: false + summary-always: true # auto-push only if it's on main branch auto-push: false gh-pages-branch: "gh-pages" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 075be55..8f60a9e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -65,7 +65,7 @@ jobs: fail-on-alert: true github-token: ${{ secrets.GITHUB_TOKEN }} comment-on-alert: true - summary-always: false + summary-always: true # auto-push only if it's on main branch auto-push: true gh-pages-branch: "gh-pages" diff --git a/cache/cache.go b/cache/cache.go index 443bbf6..1c665cd 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -55,7 +55,7 @@ func EnableCache(cfg *CacheConfig) { } cacheStats = &CacheStats{} if ShouldUseRedisCache(cfg) { - cfg.Logger.Info(&libpack_logger.LogMessage{ + cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Using Redis cache", }) cfg.Client = libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ @@ -64,7 +64,7 @@ func EnableCache(cfg *CacheConfig) { RedisPassword: cfg.Redis.Password, }) } else { - cfg.Logger.Info(&libpack_logger.LogMessage{ + cfg.Logger.Debug(&libpack_logger.LogMessage{ Message: "Using in-memory cache", }) cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second) diff --git a/cache/cache_bench_test.go b/cache/cache_bench_test.go index 271fdd5..0c08098 100644 --- a/cache/cache_bench_test.go +++ b/cache/cache_bench_test.go @@ -6,7 +6,6 @@ import ( "github.com/alicebob/miniredis/v2" libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory" - libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) @@ -38,17 +37,16 @@ func BenchmarkCacheLookupInMemory(b *testing.B) { } func BenchmarkCacheLookupRedis(b *testing.B) { - redis_server, _ := miniredis.Run() - mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ - RedisServer: redis_server.Addr(), - RedisDB: 0, - }) - + redis_server, err := miniredis.Run() + if err != nil { + panic(err) + } config = &CacheConfig{ Logger: libpack_logger.New(), - Client: mockedCache, TTL: 5, } + config.Redis.DB = 0 + config.Redis.URL = redis_server.Addr() config.Redis.Enable = true EnableCache(config) @@ -88,17 +86,17 @@ func BenchmarkCacheStoreInMemory(b *testing.B) { } func BenchmarkCacheStoreRedis(b *testing.B) { - redis_server, _ := miniredis.Run() - mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ - RedisServer: redis_server.Addr(), - RedisDB: 0, - }) + redis_server, err := miniredis.Run() + if err != nil { + panic(err) + } config = &CacheConfig{ Logger: libpack_logger.New(), - Client: mockedCache, TTL: 5, } + config.Redis.DB = 0 + config.Redis.URL = redis_server.Addr() config.Redis.Enable = true EnableCache(config) diff --git a/cache/cache_test.go b/cache/cache_test.go index b6851d2..9964604 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -7,7 +7,6 @@ import ( "github.com/alicebob/miniredis/v2" libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory" - libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" ) @@ -61,23 +60,15 @@ func (suite *Tests) Test_cacheLookupInmemory() { } func (suite *Tests) Test_cacheLookupRedis() { - // redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379") - // config.Client = libpack_cache_redis.NewClient(&libpack_cache_redis.RedisClientConfig{ - // RedisServer: redis_server, - // RedisPassword: "", - // RedisDB: 0, - // }) - - mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ - RedisServer: redisMockServer.Addr(), - RedisDB: 0, - }) config = &CacheConfig{ Logger: libpack_logger.New(), - Client: mockedCache, TTL: 5, } + config.Redis.DB = 0 + config.Redis.URL = redisMockServer.Addr() + config.Redis.Enable = true + EnableCache(config) type args struct { hash string @@ -193,12 +184,12 @@ func (suite *Tests) Test_cacheRedisFailure() { config = &CacheConfig{ Logger: libpack_logger.New(), - Client: libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ - RedisServer: mr.Addr(), - RedisDB: 0, - }), - TTL: 5, + TTL: 5, } + config.Redis.DB = 0 + config.Redis.URL = mr.Addr() + config.Redis.Enable = true + EnableCache(config) // Test normal operation CacheStore("test-key", []byte("test-value")) diff --git a/graphql.go b/graphql.go index 2072427..0934fd9 100644 --- a/graphql.go +++ b/graphql.go @@ -53,7 +53,7 @@ type parseGraphQLQueryResult struct { var ( queryPool = sync.Pool{ New: func() interface{} { - return make(map[string]interface{}, 4) + return make(map[string]interface{}, 48) }, } resultPool = sync.Pool{ @@ -65,14 +65,15 @@ var ( func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { res := resultPool.Get().(*parseGraphQLQueryResult) - defer resultPool.Put(res) - *res = parseGraphQLQueryResult{shouldIgnore: true} + *res = parseGraphQLQueryResult{shouldIgnore: true, activeEndpoint: cfg.Server.HostGraphQL} m := queryPool.Get().(map[string]interface{}) - defer queryPool.Put(m) - for k := range m { - delete(m, k) - } + defer func() { + for k := range m { + delete(m, k) + } + queryPool.Put(m) + }() if err := json.Unmarshal(c.Body(), &m); err != nil { cfg.Logger.Error(&libpack_logger.LogMessage{ @@ -82,6 +83,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } + resultPool.Put(res) return res } @@ -94,6 +96,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) } + resultPool.Put(res) return res } @@ -106,16 +109,15 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { if ifNotInTest() { cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) } + resultPool.Put(res) return res } res.shouldIgnore = false res.operationName = "undefined" - res.activeEndpoint = cfg.Server.HostGraphQL for _, d := range p.Definitions { if oper, ok := d.(*ast.OperationDefinition); ok { - // If we haven't set an operation type yet, use this one if res.operationType == "" { res.operationType = strings.ToLower(oper.Operation) if oper.Name != nil { @@ -123,13 +125,25 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { } } - if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" { - res.activeEndpoint = cfg.Server.HostGraphQLReadOnly + if cfg.Server.HostGraphQLReadOnly != "" { + if res.operationType == "" { + res.activeEndpoint = cfg.Server.HostGraphQLReadOnly + } else if res.operationType != "mutation" { + res.activeEndpoint = cfg.Server.HostGraphQLReadOnly + } } + cfg.Logger.Debug(&libpack_logger.LogMessage{ + Message: "Endpoint selection", + Pairs: map[string]interface{}{ + "operationType": res.operationType, + "selectedEndpoint": res.activeEndpoint, + }, + }) + if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { cfg.Logger.Warning(&libpack_logger.LogMessage{ - Message: "Mutation blocked", + Message: "Mutation blocked - server in read-only mode", Pairs: map[string]interface{}{"query": query}, }) if ifNotInTest() { @@ -137,6 +151,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { } _ = c.Status(403).SendString("The server is in read-only mode") res.shouldBlock = true + resultPool.Put(res) return res } @@ -161,6 +176,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { if cfg.Security.BlockIntrospection { res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections) if res.shouldBlock { + resultPool.Put(res) return res } } diff --git a/monitoring/helpers.go b/monitoring/helpers.go index 4773866..9ca8557 100644 --- a/monitoring/helpers.go +++ b/monitoring/helpers.go @@ -85,27 +85,36 @@ func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) { func getSortedKeys(labels map[string]string) []string { labelsKey := labelsToString(labels) + // Check if the sorted keys are already cached if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok { return keys.([]string) } + // Compute the sorted keys keys := make([]string, 0, len(labels)) for k := range labels { keys = append(keys, k) } sort.Strings(keys) + // Store the sorted keys in the cache sortedLabelKeysCache.m.Store(labelsKey, keys) return keys } func labelsToString(labels map[string]string) string { + keys := make([]string, 0, len(labels)) + for k := range labels { + keys = append(keys, k) + } + sort.Strings(keys) + var sb strings.Builder - for k, v := range labels { + for _, k := range keys { sb.WriteString(k) sb.WriteByte('=') - sb.WriteString(v) + sb.WriteString(labels[k]) sb.WriteByte(';') } return sb.String() diff --git a/proxy_test.go b/proxy_test.go index afead6b..c06c69f 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -1,8 +1,6 @@ package main import ( - "strings" - "github.com/valyala/fasthttp" ) @@ -14,47 +12,63 @@ func (suite *Tests) Test_proxyTheRequest() { } tests := []struct { - headers map[string]string - name string - body string - host string - hostRO string - path string - wantErr bool + headers map[string]string + name string + body string + host string + hostRO string + path string + wantErr bool + wantEndpoint string }{ { - name: "test_empty", - body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, - host: "https://telegram-bot.app/", - path: "/v1/graphql", - headers: supplied_headers, - wantErr: false, + name: "test_empty", + body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + host: "https://telegram-bot.app/", + path: "/v1/graphql", + headers: supplied_headers, + wantErr: false, + wantEndpoint: "https://telegram-bot.app/", }, { - name: "test_wrong_url", - body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, - host: "https://google.com/", - path: "/v1/wrongURL", - headers: supplied_headers, - wantErr: true, + name: "test_wrong_url", + body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + host: "https://google.com/", + path: "/v1/wrongURL", + headers: supplied_headers, + wantErr: true, + wantEndpoint: "https://google.com/", }, { - name: "Test read only mode", - body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, - host: "https://google.com/", - hostRO: "https://telegram-bot.app/", - path: "/v1/graphql", - headers: supplied_headers, - wantErr: false, + name: "Test read only mode", + body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + host: "https://google.com/", + hostRO: "https://telegram-bot.app/", + path: "/v1/graphql", + headers: supplied_headers, + wantErr: false, + wantEndpoint: "https://telegram-bot.app/", }, { - name: "Test read only mode wrong host", - body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, - host: "https://telegram-bot.app/", - hostRO: "https://google.com/", - path: "/v1/graphql", - headers: supplied_headers, - wantErr: true, + name: "Test read only mode wrong host", + body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + host: "https://telegram-bot.app/", + hostRO: "https://google.com/", + + path: "/v1/graphql", + headers: supplied_headers, + wantErr: true, + wantEndpoint: "https://google.com/", + }, + { + name: "Test mutation with endpoint flip", + body: `{"query":"mutation {\n __type(name: \"Query\") {\n name\n }\n }"}`, + host: "https://telegram-bot.app/", + hostRO: "https://google.com/", + path: "/v1/graphql", + headers: supplied_headers, + wantErr: false, + wantEndpoint: "https://telegram-bot.app/", }, } @@ -94,34 +108,49 @@ func (suite *Tests) Test_proxyTheRequest() { } else { assert.Nil(err, "Error is not nil", tt.name) } + assert.Equal(tt.wantEndpoint, res.activeEndpoint, "Unexpected endpoint", tt.name) }) } } func (suite *Tests) Test_proxyTheRequestWithPayloads() { - allowedUrls = make(map[string]struct{}) - allowedUrls["/"] = struct{}{} - suite.Run("Test with invalid URL", func() { - cfg.Server.HostGraphQL = "://invalid-url" - ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) - err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) - assert.NotNil(err) - }) + tests := []struct { + name string + payload string + url string + wantErr bool + }{ + { + name: "Test with invalid URL", + payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + url: "://invalid-url", + wantErr: true, + }, + { + name: "Test with network error", + payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, + url: "http://non-existent-host.invalid", + wantErr: true, + }, + // { + // name: "Test with large payload", + // payload: strings.Repeat("a", 10*1024*1024), // 10MB payload + // url: "https://google.com/", + // wantErr: false, + // }, + } - suite.Run("Test with network error", func() { - cfg.Server.HostGraphQL = "http://non-existent-host.invalid" - ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) - err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) - assert.NotNil(err) - }) - - suite.Run("Test with large payload", func() { - cfg.Server.HostGraphQL = "https://telegram-bot.app/" - ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) - largePayload := strings.Repeat("a", 10*1024*1024) // 10MB payload - ctx.Context().Request.SetBody([]byte(largePayload)) - err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) - assert.Nil(err) - }) + for _, tt := range tests { + suite.Run(tt.name, func() { + cfg.Server.HostGraphQL = tt.url + ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) + err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) + if tt.wantErr { + assert.NotNil(err) + } else { + assert.Nil(err) + } + }) + } }