Fix the introduced bug where RO endpoint could've been accidentally used. (#17)

* Fix the introduced bug where RO endpoint could've been accidentally used.
This commit is contained in:
2024-06-28 21:48:39 +01:00
committed by GitHub
parent 162c4acd7c
commit d141fe3c04
8 changed files with 151 additions and 108 deletions
+1 -1
View File
@@ -102,7 +102,7 @@ jobs:
fail-on-alert: true fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true comment-on-alert: true
summary-always: false summary-always: true
# auto-push only if it's on main branch # auto-push only if it's on main branch
auto-push: false auto-push: false
gh-pages-branch: "gh-pages" gh-pages-branch: "gh-pages"
+1 -1
View File
@@ -65,7 +65,7 @@ jobs:
fail-on-alert: true fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true comment-on-alert: true
summary-always: false summary-always: true
# auto-push only if it's on main branch # auto-push only if it's on main branch
auto-push: true auto-push: true
gh-pages-branch: "gh-pages" gh-pages-branch: "gh-pages"
+2 -2
View File
@@ -55,7 +55,7 @@ func EnableCache(cfg *CacheConfig) {
} }
cacheStats = &CacheStats{} cacheStats = &CacheStats{}
if ShouldUseRedisCache(cfg) { if ShouldUseRedisCache(cfg) {
cfg.Logger.Info(&libpack_logger.LogMessage{ cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using Redis cache", Message: "Using Redis cache",
}) })
cfg.Client = libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ cfg.Client = libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
@@ -64,7 +64,7 @@ func EnableCache(cfg *CacheConfig) {
RedisPassword: cfg.Redis.Password, RedisPassword: cfg.Redis.Password,
}) })
} else { } else {
cfg.Logger.Info(&libpack_logger.LogMessage{ cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using in-memory cache", Message: "Using in-memory cache",
}) })
cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second) cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second)
+12 -14
View File
@@ -6,7 +6,6 @@ import (
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory" libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
) )
@@ -38,17 +37,16 @@ func BenchmarkCacheLookupInMemory(b *testing.B) {
} }
func BenchmarkCacheLookupRedis(b *testing.B) { func BenchmarkCacheLookupRedis(b *testing.B) {
redis_server, _ := miniredis.Run() redis_server, err := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ if err != nil {
RedisServer: redis_server.Addr(), panic(err)
RedisDB: 0, }
})
config = &CacheConfig{ config = &CacheConfig{
Logger: libpack_logger.New(), Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5, TTL: 5,
} }
config.Redis.DB = 0
config.Redis.URL = redis_server.Addr()
config.Redis.Enable = true config.Redis.Enable = true
EnableCache(config) EnableCache(config)
@@ -88,17 +86,17 @@ func BenchmarkCacheStoreInMemory(b *testing.B) {
} }
func BenchmarkCacheStoreRedis(b *testing.B) { func BenchmarkCacheStoreRedis(b *testing.B) {
redis_server, _ := miniredis.Run() redis_server, err := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ if err != nil {
RedisServer: redis_server.Addr(), panic(err)
RedisDB: 0, }
})
config = &CacheConfig{ config = &CacheConfig{
Logger: libpack_logger.New(), Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5, TTL: 5,
} }
config.Redis.DB = 0
config.Redis.URL = redis_server.Addr()
config.Redis.Enable = true config.Redis.Enable = true
EnableCache(config) EnableCache(config)
+9 -18
View File
@@ -7,7 +7,6 @@ import (
"github.com/alicebob/miniredis/v2" "github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory" libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging" libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
) )
@@ -61,23 +60,15 @@ func (suite *Tests) Test_cacheLookupInmemory() {
} }
func (suite *Tests) Test_cacheLookupRedis() { func (suite *Tests) Test_cacheLookupRedis() {
// redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379")
// config.Client = libpack_cache_redis.NewClient(&libpack_cache_redis.RedisClientConfig{
// RedisServer: redis_server,
// RedisPassword: "",
// RedisDB: 0,
// })
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redisMockServer.Addr(),
RedisDB: 0,
})
config = &CacheConfig{ config = &CacheConfig{
Logger: libpack_logger.New(), Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5, TTL: 5,
} }
config.Redis.DB = 0
config.Redis.URL = redisMockServer.Addr()
config.Redis.Enable = true
EnableCache(config)
type args struct { type args struct {
hash string hash string
@@ -193,12 +184,12 @@ func (suite *Tests) Test_cacheRedisFailure() {
config = &CacheConfig{ config = &CacheConfig{
Logger: libpack_logger.New(), Logger: libpack_logger.New(),
Client: libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{ TTL: 5,
RedisServer: mr.Addr(),
RedisDB: 0,
}),
TTL: 5,
} }
config.Redis.DB = 0
config.Redis.URL = mr.Addr()
config.Redis.Enable = true
EnableCache(config)
// Test normal operation // Test normal operation
CacheStore("test-key", []byte("test-value")) CacheStore("test-key", []byte("test-value"))
+28 -12
View File
@@ -53,7 +53,7 @@ type parseGraphQLQueryResult struct {
var ( var (
queryPool = sync.Pool{ queryPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make(map[string]interface{}, 4) return make(map[string]interface{}, 48)
}, },
} }
resultPool = sync.Pool{ resultPool = sync.Pool{
@@ -65,14 +65,15 @@ var (
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult { func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
res := resultPool.Get().(*parseGraphQLQueryResult) res := resultPool.Get().(*parseGraphQLQueryResult)
defer resultPool.Put(res) *res = parseGraphQLQueryResult{shouldIgnore: true, activeEndpoint: cfg.Server.HostGraphQL}
*res = parseGraphQLQueryResult{shouldIgnore: true}
m := queryPool.Get().(map[string]interface{}) m := queryPool.Get().(map[string]interface{})
defer queryPool.Put(m) defer func() {
for k := range m { for k := range m {
delete(m, k) delete(m, k)
} }
queryPool.Put(m)
}()
if err := json.Unmarshal(c.Body(), &m); err != nil { if err := json.Unmarshal(c.Body(), &m); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{ cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -82,6 +83,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
resultPool.Put(res)
return res return res
} }
@@ -94,6 +96,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
resultPool.Put(res)
return res return res
} }
@@ -106,16 +109,15 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
} }
resultPool.Put(res)
return res return res
} }
res.shouldIgnore = false res.shouldIgnore = false
res.operationName = "undefined" res.operationName = "undefined"
res.activeEndpoint = cfg.Server.HostGraphQL
for _, d := range p.Definitions { for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok { if oper, ok := d.(*ast.OperationDefinition); ok {
// If we haven't set an operation type yet, use this one
if res.operationType == "" { if res.operationType == "" {
res.operationType = strings.ToLower(oper.Operation) res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil { if oper.Name != nil {
@@ -123,13 +125,25 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
} }
} }
if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" { if cfg.Server.HostGraphQLReadOnly != "" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly if res.operationType == "" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
} else if res.operationType != "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
} }
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Endpoint selection",
Pairs: map[string]interface{}{
"operationType": res.operationType,
"selectedEndpoint": res.activeEndpoint,
},
})
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode { if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning(&libpack_logger.LogMessage{ cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Mutation blocked", Message: "Mutation blocked - server in read-only mode",
Pairs: map[string]interface{}{"query": query}, Pairs: map[string]interface{}{"query": query},
}) })
if ifNotInTest() { if ifNotInTest() {
@@ -137,6 +151,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
} }
_ = c.Status(403).SendString("The server is in read-only mode") _ = c.Status(403).SendString("The server is in read-only mode")
res.shouldBlock = true res.shouldBlock = true
resultPool.Put(res)
return res return res
} }
@@ -161,6 +176,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if cfg.Security.BlockIntrospection { if cfg.Security.BlockIntrospection {
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections) res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if res.shouldBlock { if res.shouldBlock {
resultPool.Put(res)
return res return res
} }
} }
+11 -2
View File
@@ -85,27 +85,36 @@ func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
func getSortedKeys(labels map[string]string) []string { func getSortedKeys(labels map[string]string) []string {
labelsKey := labelsToString(labels) labelsKey := labelsToString(labels)
// Check if the sorted keys are already cached
if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok { if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok {
return keys.([]string) return keys.([]string)
} }
// Compute the sorted keys
keys := make([]string, 0, len(labels)) keys := make([]string, 0, len(labels))
for k := range labels { for k := range labels {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
// Store the sorted keys in the cache
sortedLabelKeysCache.m.Store(labelsKey, keys) sortedLabelKeysCache.m.Store(labelsKey, keys)
return keys return keys
} }
func labelsToString(labels map[string]string) string { func labelsToString(labels map[string]string) string {
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder var sb strings.Builder
for k, v := range labels { for _, k := range keys {
sb.WriteString(k) sb.WriteString(k)
sb.WriteByte('=') sb.WriteByte('=')
sb.WriteString(v) sb.WriteString(labels[k])
sb.WriteByte(';') sb.WriteByte(';')
} }
return sb.String() return sb.String()
+87 -58
View File
@@ -1,8 +1,6 @@
package main package main
import ( import (
"strings"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@@ -14,47 +12,63 @@ func (suite *Tests) Test_proxyTheRequest() {
} }
tests := []struct { tests := []struct {
headers map[string]string headers map[string]string
name string name string
body string body string
host string host string
hostRO string hostRO string
path string path string
wantErr bool wantErr bool
wantEndpoint string
}{ }{
{ {
name: "test_empty", name: "test_empty",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/", host: "https://telegram-bot.app/",
path: "/v1/graphql", path: "/v1/graphql",
headers: supplied_headers, headers: supplied_headers,
wantErr: false, wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
}, },
{ {
name: "test_wrong_url", name: "test_wrong_url",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/", host: "https://google.com/",
path: "/v1/wrongURL", path: "/v1/wrongURL",
headers: supplied_headers, headers: supplied_headers,
wantErr: true, wantErr: true,
wantEndpoint: "https://google.com/",
}, },
{ {
name: "Test read only mode", name: "Test read only mode",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/", host: "https://google.com/",
hostRO: "https://telegram-bot.app/", hostRO: "https://telegram-bot.app/",
path: "/v1/graphql", path: "/v1/graphql",
headers: supplied_headers, headers: supplied_headers,
wantErr: false, wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
}, },
{ {
name: "Test read only mode wrong host", name: "Test read only mode wrong host",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`, body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/", host: "https://telegram-bot.app/",
hostRO: "https://google.com/", hostRO: "https://google.com/",
path: "/v1/graphql",
headers: supplied_headers, path: "/v1/graphql",
wantErr: true, headers: supplied_headers,
wantErr: true,
wantEndpoint: "https://google.com/",
},
{
name: "Test mutation with endpoint flip",
body: `{"query":"mutation {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
hostRO: "https://google.com/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
}, },
} }
@@ -94,34 +108,49 @@ func (suite *Tests) Test_proxyTheRequest() {
} else { } else {
assert.Nil(err, "Error is not nil", tt.name) assert.Nil(err, "Error is not nil", tt.name)
} }
assert.Equal(tt.wantEndpoint, res.activeEndpoint, "Unexpected endpoint", tt.name)
}) })
} }
} }
func (suite *Tests) Test_proxyTheRequestWithPayloads() { func (suite *Tests) Test_proxyTheRequestWithPayloads() {
allowedUrls = make(map[string]struct{})
allowedUrls["/"] = struct{}{}
suite.Run("Test with invalid URL", func() { tests := []struct {
cfg.Server.HostGraphQL = "://invalid-url" name string
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) payload string
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) url string
assert.NotNil(err) wantErr bool
}) }{
{
name: "Test with invalid URL",
payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
url: "://invalid-url",
wantErr: true,
},
{
name: "Test with network error",
payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
url: "http://non-existent-host.invalid",
wantErr: true,
},
// {
// name: "Test with large payload",
// payload: strings.Repeat("a", 10*1024*1024), // 10MB payload
// url: "https://google.com/",
// wantErr: false,
// },
}
suite.Run("Test with network error", func() { for _, tt := range tests {
cfg.Server.HostGraphQL = "http://non-existent-host.invalid" suite.Run(tt.name, func() {
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) cfg.Server.HostGraphQL = tt.url
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL) ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
assert.NotNil(err) err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
}) if tt.wantErr {
assert.NotNil(err)
suite.Run("Test with large payload", func() { } else {
cfg.Server.HostGraphQL = "https://telegram-bot.app/" assert.Nil(err)
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{}) }
largePayload := strings.Repeat("a", 10*1024*1024) // 10MB payload })
ctx.Context().Request.SetBody([]byte(largePayload)) }
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.Nil(err)
})
} }