Fix the introduced bug where RO endpoint could've been accidentally used. (#17)

* Fix the introduced bug where RO endpoint could've been accidentally used.
This commit is contained in:
2024-06-28 21:48:39 +01:00
committed by GitHub
parent 162c4acd7c
commit d141fe3c04
8 changed files with 151 additions and 108 deletions
+1 -1
View File
@@ -102,7 +102,7 @@ jobs:
fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: false
summary-always: true
# auto-push only if it's on main branch
auto-push: false
gh-pages-branch: "gh-pages"
+1 -1
View File
@@ -65,7 +65,7 @@ jobs:
fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: false
summary-always: true
# auto-push only if it's on main branch
auto-push: true
gh-pages-branch: "gh-pages"
+2 -2
View File
@@ -55,7 +55,7 @@ func EnableCache(cfg *CacheConfig) {
}
cacheStats = &CacheStats{}
if ShouldUseRedisCache(cfg) {
cfg.Logger.Info(&libpack_logger.LogMessage{
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using Redis cache",
})
cfg.Client = libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
@@ -64,7 +64,7 @@ func EnableCache(cfg *CacheConfig) {
RedisPassword: cfg.Redis.Password,
})
} else {
cfg.Logger.Info(&libpack_logger.LogMessage{
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using in-memory cache",
})
cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second)
+12 -14
View File
@@ -6,7 +6,6 @@ import (
"github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
@@ -38,17 +37,16 @@ func BenchmarkCacheLookupInMemory(b *testing.B) {
}
func BenchmarkCacheLookupRedis(b *testing.B) {
redis_server, _ := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redis_server.Addr(),
RedisDB: 0,
})
redis_server, err := miniredis.Run()
if err != nil {
panic(err)
}
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
config.Redis.DB = 0
config.Redis.URL = redis_server.Addr()
config.Redis.Enable = true
EnableCache(config)
@@ -88,17 +86,17 @@ func BenchmarkCacheStoreInMemory(b *testing.B) {
}
func BenchmarkCacheStoreRedis(b *testing.B) {
redis_server, _ := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redis_server.Addr(),
RedisDB: 0,
})
redis_server, err := miniredis.Run()
if err != nil {
panic(err)
}
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
config.Redis.DB = 0
config.Redis.URL = redis_server.Addr()
config.Redis.Enable = true
EnableCache(config)
+9 -18
View File
@@ -7,7 +7,6 @@ import (
"github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
@@ -61,23 +60,15 @@ func (suite *Tests) Test_cacheLookupInmemory() {
}
func (suite *Tests) Test_cacheLookupRedis() {
// redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379")
// config.Client = libpack_cache_redis.NewClient(&libpack_cache_redis.RedisClientConfig{
// RedisServer: redis_server,
// RedisPassword: "",
// RedisDB: 0,
// })
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redisMockServer.Addr(),
RedisDB: 0,
})
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
config.Redis.DB = 0
config.Redis.URL = redisMockServer.Addr()
config.Redis.Enable = true
EnableCache(config)
type args struct {
hash string
@@ -193,12 +184,12 @@ func (suite *Tests) Test_cacheRedisFailure() {
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: mr.Addr(),
RedisDB: 0,
}),
TTL: 5,
TTL: 5,
}
config.Redis.DB = 0
config.Redis.URL = mr.Addr()
config.Redis.Enable = true
EnableCache(config)
// Test normal operation
CacheStore("test-key", []byte("test-value"))
+28 -12
View File
@@ -53,7 +53,7 @@ type parseGraphQLQueryResult struct {
var (
queryPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{}, 4)
return make(map[string]interface{}, 48)
},
}
resultPool = sync.Pool{
@@ -65,14 +65,15 @@ var (
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
res := resultPool.Get().(*parseGraphQLQueryResult)
defer resultPool.Put(res)
*res = parseGraphQLQueryResult{shouldIgnore: true}
*res = parseGraphQLQueryResult{shouldIgnore: true, activeEndpoint: cfg.Server.HostGraphQL}
m := queryPool.Get().(map[string]interface{})
defer queryPool.Put(m)
for k := range m {
delete(m, k)
}
defer func() {
for k := range m {
delete(m, k)
}
queryPool.Put(m)
}()
if err := json.Unmarshal(c.Body(), &m); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -82,6 +83,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
resultPool.Put(res)
return res
}
@@ -94,6 +96,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
resultPool.Put(res)
return res
}
@@ -106,16 +109,15 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
resultPool.Put(res)
return res
}
res.shouldIgnore = false
res.operationName = "undefined"
res.activeEndpoint = cfg.Server.HostGraphQL
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
// If we haven't set an operation type yet, use this one
if res.operationType == "" {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
@@ -123,13 +125,25 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
}
}
if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
if cfg.Server.HostGraphQLReadOnly != "" {
if res.operationType == "" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
} else if res.operationType != "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Endpoint selection",
Pairs: map[string]interface{}{
"operationType": res.operationType,
"selectedEndpoint": res.activeEndpoint,
},
})
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Mutation blocked",
Message: "Mutation blocked - server in read-only mode",
Pairs: map[string]interface{}{"query": query},
})
if ifNotInTest() {
@@ -137,6 +151,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
}
_ = c.Status(403).SendString("The server is in read-only mode")
res.shouldBlock = true
resultPool.Put(res)
return res
}
@@ -161,6 +176,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
if cfg.Security.BlockIntrospection {
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if res.shouldBlock {
resultPool.Put(res)
return res
}
}
+11 -2
View File
@@ -85,27 +85,36 @@ func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
func getSortedKeys(labels map[string]string) []string {
labelsKey := labelsToString(labels)
// Check if the sorted keys are already cached
if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok {
return keys.([]string)
}
// Compute the sorted keys
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
// Store the sorted keys in the cache
sortedLabelKeysCache.m.Store(labelsKey, keys)
return keys
}
func labelsToString(labels map[string]string) string {
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
for k, v := range labels {
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(v)
sb.WriteString(labels[k])
sb.WriteByte(';')
}
return sb.String()
+87 -58
View File
@@ -1,8 +1,6 @@
package main
import (
"strings"
"github.com/valyala/fasthttp"
)
@@ -14,47 +12,63 @@ func (suite *Tests) Test_proxyTheRequest() {
}
tests := []struct {
headers map[string]string
name string
body string
host string
hostRO string
path string
wantErr bool
headers map[string]string
name string
body string
host string
hostRO string
path string
wantErr bool
wantEndpoint string
}{
{
name: "test_empty",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
name: "test_empty",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
},
{
name: "test_wrong_url",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/",
path: "/v1/wrongURL",
headers: supplied_headers,
wantErr: true,
name: "test_wrong_url",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/",
path: "/v1/wrongURL",
headers: supplied_headers,
wantErr: true,
wantEndpoint: "https://google.com/",
},
{
name: "Test read only mode",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/",
hostRO: "https://telegram-bot.app/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
name: "Test read only mode",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://google.com/",
hostRO: "https://telegram-bot.app/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
},
{
name: "Test read only mode wrong host",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
hostRO: "https://google.com/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: true,
name: "Test read only mode wrong host",
body: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
hostRO: "https://google.com/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: true,
wantEndpoint: "https://google.com/",
},
{
name: "Test mutation with endpoint flip",
body: `{"query":"mutation {\n __type(name: \"Query\") {\n name\n }\n }"}`,
host: "https://telegram-bot.app/",
hostRO: "https://google.com/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
wantEndpoint: "https://telegram-bot.app/",
},
}
@@ -94,34 +108,49 @@ func (suite *Tests) Test_proxyTheRequest() {
} else {
assert.Nil(err, "Error is not nil", tt.name)
}
assert.Equal(tt.wantEndpoint, res.activeEndpoint, "Unexpected endpoint", tt.name)
})
}
}
func (suite *Tests) Test_proxyTheRequestWithPayloads() {
allowedUrls = make(map[string]struct{})
allowedUrls["/"] = struct{}{}
suite.Run("Test with invalid URL", func() {
cfg.Server.HostGraphQL = "://invalid-url"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.NotNil(err)
})
tests := []struct {
name string
payload string
url string
wantErr bool
}{
{
name: "Test with invalid URL",
payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
url: "://invalid-url",
wantErr: true,
},
{
name: "Test with network error",
payload: `{"query":"query {\n __type(name: \"Query\") {\n name\n }\n }"}`,
url: "http://non-existent-host.invalid",
wantErr: true,
},
// {
// name: "Test with large payload",
// payload: strings.Repeat("a", 10*1024*1024), // 10MB payload
// url: "https://google.com/",
// wantErr: false,
// },
}
suite.Run("Test with network error", func() {
cfg.Server.HostGraphQL = "http://non-existent-host.invalid"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.NotNil(err)
})
suite.Run("Test with large payload", func() {
cfg.Server.HostGraphQL = "https://telegram-bot.app/"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
largePayload := strings.Repeat("a", 10*1024*1024) // 10MB payload
ctx.Context().Request.SetBody([]byte(largePayload))
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.Nil(err)
})
for _, tt := range tests {
suite.Run(tt.name, func() {
cfg.Server.HostGraphQL = tt.url
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
if tt.wantErr {
assert.NotNil(err)
} else {
assert.Nil(err)
}
})
}
}