Improve stats gathering and tests improvements. (#8)

This commit is contained in:
2024-03-05 22:40:06 +00:00
committed by GitHub
parent b6c284b66d
commit 3a18e0e935
16 changed files with 284 additions and 103 deletions
+82
View File
@@ -0,0 +1,82 @@
name: Run tests on PR
on:
pull_request:
branches:
- "main"
push:
paths-ignore:
- "**/**.md"
- "**/**.yaml"
- "static/**"
branches:
- "!main"
env:
GO_VERSION: ">=1.21"
jobs:
jobs:
# This job is responsible for preparation of the build
# environment variables.
prepare:
name: Preparing build context
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
id: cache
with:
go-version: ${{env.GO_VERSION}}
cache-dependency-path: "**/*.sum"
- name: Go get dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
go get ./...
# This job is responsible for running tests and linting the codebase
test:
name: "Unit testing"
# needs: [prepare]
runs-on: ubuntu-latest
# container: github/super-linter:v4
needs: [prepare]
services:
# Label used to access the service container
redis:
# Docker Hub image
image: redis
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: ${{env.GO_VERSION}}
cache-dependency-path: "**/*.sum"
- name: Install dependencies
run: |
go mod tidy
- name: Run unit tests
env:
REDIS_HOST: redis
REDIS_PORT: 6379
run: |
export REDIS_SERVER="$REDIS_HOST:$REDIS_PORT"
CI_RUN=${CI} make test
+1 -1
View File
@@ -13,7 +13,7 @@ func calculateHash(c *fiber.Ctx) string {
} }
func enableCache() { func enableCache() {
cfg.Cache.CacheClient = libpack_cache.New(time.Duration(cfg.Cache.CacheTTL) * time.Second * 100) cfg.Cache.CacheClient = libpack_cache.New(time.Duration(cfg.Cache.CacheTTL) * time.Second)
} }
func cacheLookup(hash string) []byte { func cacheLookup(hash string) []byte {
+1 -2
View File
@@ -1,7 +1,6 @@
package main package main
import ( import (
"testing"
"time" "time"
) )
@@ -38,7 +37,7 @@ func (suite *Tests) Test_cacheLookup() {
}, },
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
if tt.addCache.data != nil { if tt.addCache.data != nil {
cfg.Cache.CacheClient.Set(tt.args.hash, tt.addCache.data, time.Duration(90*time.Second)) cfg.Cache.CacheClient.Set(tt.args.hash, tt.addCache.data, time.Duration(90*time.Second))
} }
+1 -3
View File
@@ -1,7 +1,5 @@
package main package main
import "testing"
func (suite *Tests) Test_extractClaimsFromJWTHeader() { func (suite *Tests) Test_extractClaimsFromJWTHeader() {
jwt_token_for_tests := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiSGFzdXJhIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsiZ3Vlc3QiLCJ1c2VyIiwiZ3JvdXBhZG1pbiIsInBheWFkbWluIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6Imd1ZXN0IiwieC1oYXN1cmEtdXNlci1pZCI6IjE2NyIsIngtaGFzdXJhLXVzZXItdXVpZCI6ImRkM2U2ZTM1LTA0MDktNDNiMC1iZmYxLWNlZjNjNmVkNWYxMCJ9LCJpc3MiOiJBdXRoU2VydmljZSIsImV4cCI6MTY5NjgwMTcyNiwibmJmIjoxNjk2NTg1NzI2LCJpYXQiOjE2OTY1ODU3MjZ9.dsJ5JKzG5tXOlqeZ_Gfe2XC-vyrcwtYwOGfhvt8q9UY" jwt_token_for_tests := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ0b2tlbl90eXBlIjoiYWNjZXNzIiwiSGFzdXJhIjp7IngtaGFzdXJhLWFsbG93ZWQtcm9sZXMiOlsiZ3Vlc3QiLCJ1c2VyIiwiZ3JvdXBhZG1pbiIsInBheWFkbWluIl0sIngtaGFzdXJhLWRlZmF1bHQtcm9sZSI6Imd1ZXN0IiwieC1oYXN1cmEtdXNlci1pZCI6IjE2NyIsIngtaGFzdXJhLXVzZXItdXVpZCI6ImRkM2U2ZTM1LTA0MDktNDNiMC1iZmYxLWNlZjNjNmVkNWYxMCJ9LCJpc3MiOiJBdXRoU2VydmljZSIsImV4cCI6MTY5NjgwMTcyNiwibmJmIjoxNjk2NTg1NzI2LCJpYXQiOjE2OTY1ODU3MjZ9.dsJ5JKzG5tXOlqeZ_Gfe2XC-vyrcwtYwOGfhvt8q9UY"
@@ -68,7 +66,7 @@ func (suite *Tests) Test_extractClaimsFromJWTHeader() {
}, },
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
if len(tt.jwt_token_path) > 0 { if len(tt.jwt_token_path) > 0 {
cfg.Client.JWTUserClaimPath = tt.jwt_token_path cfg.Client.JWTUserClaimPath = tt.jwt_token_path
} }
+22 -34
View File
@@ -1,7 +1,6 @@
package main package main
import ( import (
"flag"
"strconv" "strconv"
"strings" "strings"
@@ -41,40 +40,26 @@ var introspectionQuerySet = map[string]struct{}{}
var introspectionAllowedQueries = map[string]struct{}{} var introspectionAllowedQueries = map[string]struct{}{}
var allowedUrls = map[string]struct{}{} var allowedUrls = map[string]struct{}{}
// Utility function to convert a slice of strings to a map for O(1) lookups.
func sliceToMap(slice []string) map[string]struct{} {
resultMap := make(map[string]struct{}, len(slice))
for _, item := range slice {
resultMap[strings.ToLower(item)] = struct{}{}
}
return resultMap
}
func prepareQueriesAndExemptions() { func prepareQueriesAndExemptions() {
introspectionQuerySet = map[string]struct{}{} introspectionQuerySet = sliceToMap(introspection_queries)
introspectionQuerySet = func() map[string]struct{} { introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed)
rsqs := make(map[string]struct{}, len(introspection_queries)) allowedUrls = sliceToMap(cfg.Server.AllowURLs)
for _, query := range introspection_queries {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
introspectionAllowedQueries = map[string]struct{}{}
introspectionAllowedQueries = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(cfg.Security.IntrospectionAllowed))
for _, query := range cfg.Security.IntrospectionAllowed {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
allowedUrls = map[string]struct{}{}
allowedUrls = func() map[string]struct{} {
rsqs := make(map[string]struct{}, len(cfg.Server.AllowURLs))
for _, query := range cfg.Server.AllowURLs {
rsqs[strings.ToLower(query)] = struct{}{}
}
return rsqs
}()
} }
type parseGraphQLQueryResult struct { type parseGraphQLQueryResult struct {
operationType string operationType string
operationName string operationName string
cacheRequest bool
cacheTime int cacheTime int
cacheRequest bool
cacheRefresh bool cacheRefresh bool
shouldBlock bool shouldBlock bool
shouldIgnore bool shouldIgnore bool
@@ -86,7 +71,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
err := json.Unmarshal(c.Body(), &m) err := json.Unmarshal(c.Body(), &m)
if err != nil { if err != nil {
cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())}) cfg.Logger.Debug("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())})
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
return return
@@ -95,7 +80,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
query, ok := m["query"].(string) query, ok := m["query"].(string)
if !ok { if !ok {
cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m}) cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m})
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
return return
@@ -104,7 +89,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
p, err := parser.Parse(parser.ParseParams{Source: query}) p, err := parser.Parse(parser.ParseParams{Source: query})
if err != nil { if err != nil {
cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m}) cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m})
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
} }
return return
@@ -122,7 +107,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode { if strings.ToLower(res.operationType) == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m) cfg.Logger.Warning("Mutation blocked", m)
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
c.Status(403).SendString("The server is in read-only mode") c.Status(403).SendString("The server is in read-only mode")
@@ -138,7 +123,7 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string)) res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil { if err != nil {
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)}) cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
} }
return return
@@ -184,8 +169,11 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) { func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
whateverLower := strings.ToLower(whatever) whateverLower := strings.ToLower(whatever)
got_exemption := false got_exemption := false
// If the query is an introspection query, we need to check if it's allowed.
if _, exists := introspectionQuerySet[whateverLower]; exists { if _, exists := introspectionQuerySet[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 { if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists { if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever}) cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true got_exemption = true
@@ -197,7 +185,7 @@ func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bo
} }
} }
if shouldBlock { if shouldBlock {
if flag.Lookup("test.v") == nil { if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
} }
c.Status(403).SendString("Introspection queries are not allowed") c.Status(403).SendString("Introspection queries are not allowed")
+18 -27
View File
@@ -1,10 +1,6 @@
package main package main
import ( import (
"testing"
fiber "github.com/gofiber/fiber/v2"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@@ -166,6 +162,7 @@ func (suite *Tests) Test_parseGraphQLQuery() {
{ {
name: "test mutation query with config: read only", name: "test mutation query with config: read only",
suppliedSettings: func() *config { suppliedSettings: func() *config {
parseConfig()
cfg.Server.ReadOnlyMode = true cfg.Server.ReadOnlyMode = true
return cfg return cfg
}(), }(),
@@ -199,6 +196,7 @@ func (suite *Tests) Test_parseGraphQLQuery() {
{ {
name: "test simple query with introspection __schema config: block introspection", name: "test simple query with introspection __schema config: block introspection",
suppliedSettings: func() *config { suppliedSettings: func() *config {
parseConfig()
cfg.Security.BlockIntrospection = true cfg.Security.BlockIntrospection = true
return cfg return cfg
}(), }(),
@@ -221,7 +219,6 @@ func (suite *Tests) Test_parseGraphQLQuery() {
parseConfig() parseConfig()
cfg.Security.BlockIntrospection = true cfg.Security.BlockIntrospection = true
cfg.Security.IntrospectionAllowed = []string{} cfg.Security.IntrospectionAllowed = []string{}
prepareQueriesAndExemptions()
return cfg return cfg
}(), }(),
suppliedQuery: queries{ suppliedQuery: queries{
@@ -243,7 +240,6 @@ func (suite *Tests) Test_parseGraphQLQuery() {
parseConfig() parseConfig()
cfg.Security.BlockIntrospection = true cfg.Security.BlockIntrospection = true
cfg.Security.IntrospectionAllowed = []string{"__schema"} cfg.Security.IntrospectionAllowed = []string{"__schema"}
prepareQueriesAndExemptions()
return cfg return cfg
}(), }(),
suppliedQuery: queries{ suppliedQuery: queries{
@@ -275,15 +271,9 @@ func (suite *Tests) Test_parseGraphQLQuery() {
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
cfg = &config{} cfg = &config{}
cfg.Logger = libpack_logging.NewLogger() parseConfig()
defer func() {
cfg = &config{}
}()
app := fiber.New()
ctx_headers := func() *fasthttp.RequestHeader { ctx_headers := func() *fasthttp.RequestHeader {
h := fasthttp.RequestHeader{} h := fasthttp.RequestHeader{}
for k, v := range tt.suppliedQuery.headers { for k, v := range tt.suppliedQuery.headers {
@@ -298,28 +288,29 @@ func (suite *Tests) Test_parseGraphQLQuery() {
ctx_request.AppendBody([]byte(tt.suppliedQuery.body)) ctx_request.AppendBody([]byte(tt.suppliedQuery.body))
ctx := app.AcquireCtx(&fasthttp.RequestCtx{ ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{
Request: ctx_request, Request: ctx_request,
}) })
defer app.ReleaseCtx(ctx) // defer func() {
// cfg = &config{}
// parseConfig()
// suite.app.ReleaseCtx(ctx)
// }()
assert.NotNil(ctx, "Fiber context is nil") assert.NotNil(ctx, "Fiber context is nil")
if tt.suppliedSettings != nil { if tt.suppliedSettings != nil {
cfg = tt.suppliedSettings cfg = tt.suppliedSettings
} }
prepareQueriesAndExemptions()
defer func() {
cfg = &config{}
}()
parseResult := parseGraphQLQuery(ctx) parseResult := parseGraphQLQuery(ctx)
assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type", tt.name) assert.Equal(tt.wantResults.op_type, parseResult.operationType, "Unexpected operation type "+tt.name)
assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name", tt.name) assert.Equal(tt.wantResults.op_name, parseResult.operationName, "Unexpected operation name "+tt.name)
assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value", tt.name) assert.Equal(tt.wantResults.is_cached, parseResult.cacheRequest, "Unexpected cache value "+tt.name)
assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value", tt.name) assert.Equal(tt.wantResults.cached_ttl, parseResult.cacheTime, "Unexpected cache TTL value "+tt.name)
assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value", tt.name) assert.Equal(tt.wantResults.shouldBlock, parseResult.shouldBlock, "Unexpected block value "+tt.name)
assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value", tt.name) assert.Equal(tt.wantResults.shouldIgnore, parseResult.shouldIgnore, "Unexpected ignore value "+tt.name)
if tt.wantResults.returnCode > 0 { if tt.wantResults.returnCode > 0 {
assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name) assert.Equal(tt.wantResults.returnCode, ctx.Response().StatusCode(), "Unexpected return code", tt.name)
+6
View File
@@ -59,10 +59,16 @@ func (lw *LogConfig) log(w io.Writer, level zerolog.Level, message string, v map
} }
func (lw *LogConfig) Debug(message string, v ...map[string]interface{}) { func (lw *LogConfig) Debug(message string, v ...map[string]interface{}) {
if !lw.logger.Debug().Enabled() {
return
}
lw.log(os.Stdout, zerolog.DebugLevel, message, mergeMaps(v)) lw.log(os.Stdout, zerolog.DebugLevel, message, mergeMaps(v))
} }
func (lw *LogConfig) Info(message string, v ...map[string]interface{}) { func (lw *LogConfig) Info(message string, v ...map[string]interface{}) {
if !lw.logger.Info().Enabled() {
return
}
lw.log(os.Stdout, zerolog.InfoLevel, message, mergeMaps(v)) lw.log(os.Stdout, zerolog.InfoLevel, message, mergeMaps(v))
} }
+2 -2
View File
@@ -183,7 +183,7 @@ func (suite *LoggingTestSuite) TestLogConfig_AllHandlers() {
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
if tt.envMinLogLevel != "" { if tt.envMinLogLevel != "" {
os.Setenv("LOG_LEVEL", tt.envMinLogLevel) os.Setenv("LOG_LEVEL", tt.envMinLogLevel)
defer os.Unsetenv("LOG_LEVEL") defer os.Unsetenv("LOG_LEVEL")
@@ -274,7 +274,7 @@ func (suite *LoggingTestSuite) TestFullMessage() {
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
if tt.envMinLogLevel != "" { if tt.envMinLogLevel != "" {
os.Setenv("LOG_LEVEL", tt.envMinLogLevel) os.Setenv("LOG_LEVEL", tt.envMinLogLevel)
defer os.Unsetenv("LOG_LEVEL") defer os.Unsetenv("LOG_LEVEL")
+5
View File
@@ -1,6 +1,7 @@
package main package main
import ( import (
"flag"
"os" "os"
"strings" "strings"
@@ -86,3 +87,7 @@ func main() {
StartMonitoringServer() StartMonitoringServer()
StartHTTPProxy() StartHTTPProxy()
} }
func ifNotInTest() bool {
return flag.Lookup("test.v") == nil
}
+17 -3
View File
@@ -4,12 +4,16 @@ import (
"os" "os"
"testing" "testing"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
assertions "github.com/stretchr/testify/assert" assertions "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
) )
type Tests struct { type Tests struct {
suite.Suite suite.Suite
app *fiber.App
} }
var ( var (
@@ -21,6 +25,16 @@ func (suite *Tests) BeforeTest(suiteName, testName string) {
func (suite *Tests) SetupTest() { func (suite *Tests) SetupTest() {
assert = assertions.New(suite.T()) assert = assertions.New(suite.T())
suite.app = fiber.New(
fiber.Config{
DisableStartupMessage: true,
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
},
)
parseConfig()
StartMonitoringServer()
cfg.Logger = libpack_logging.NewLogger()
// Setup environment variables here if needed // Setup environment variables here if needed
os.Setenv("GMP_TEST_STRING", "testValue") os.Setenv("GMP_TEST_STRING", "testValue")
os.Setenv("GMP_TEST_INT", "123") os.Setenv("GMP_TEST_INT", "123")
@@ -48,10 +62,10 @@ func TestSuite(t *testing.T) {
func (suite *Tests) Test_envVariableSetting() { func (suite *Tests) Test_envVariableSetting() {
tests := []struct { tests := []struct {
name string
envKey string
defaultValue any defaultValue any
expected any expected any
name string
envKey string
}{ }{
{ {
name: "test_string", name: "test_string",
@@ -86,7 +100,7 @@ func (suite *Tests) Test_envVariableSetting() {
} }
for _, tt := range tests { for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) { suite.Run(tt.name, func() {
result := getDetailsFromEnv(tt.envKey, tt.defaultValue) result := getDetailsFromEnv(tt.envKey, tt.defaultValue)
assert.Equal(tt.expected, result) assert.Equal(tt.expected, result)
}) })
+1 -1
View File
@@ -5,7 +5,7 @@ import (
) )
func StartMonitoringServer() { func StartMonitoringServer() {
cfg.Monitoring = libpack_monitoring.NewMonitoring(cfg.Server.PurgeOnCrawl, cfg.Server.PurgeEvery) cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{PurgeOnCrawl: cfg.Server.PurgeOnCrawl, PurgeEvery: cfg.Server.PurgeEvery})
cfg.Monitoring.AddMetricsPrefix("graphql_proxy") cfg.Monitoring.AddMetricsPrefix("graphql_proxy")
cfg.Monitoring.RegisterDefaultMetrics() cfg.Monitoring.RegisterDefaultMetrics()
} }
+23 -16
View File
@@ -4,6 +4,7 @@
package libpack_monitoring package libpack_monitoring
import ( import (
"flag"
"fmt" "fmt"
"time" "time"
@@ -17,31 +18,37 @@ import (
type MetricsSetup struct { type MetricsSetup struct {
metrics_set *metrics.Set metrics_set *metrics.Set
metrics_set_custom *metrics.Set metrics_set_custom *metrics.Set
ic *InitConfig
metrics_prefix string metrics_prefix string
} }
var ( var (
log *logging.LogConfig log *logging.LogConfig
purgeMetricsOnCrawl bool
purgeMetricsEvery int
) )
func NewMonitoring(purgeOnCrawl bool, purgeEvery int) *MetricsSetup { type InitConfig struct {
purgeMetricsOnCrawl = purgeOnCrawl PurgeOnCrawl bool
purgeMetricsEvery = purgeEvery PurgeEvery int
}
func NewMonitoring(ic *InitConfig) *MetricsSetup {
log = logging.NewLogger() log = logging.NewLogger()
ms := &MetricsSetup{} ms := &MetricsSetup{ic: ic}
ms.metrics_set = metrics.NewSet() ms.metrics_set = metrics.NewSet()
ms.metrics_set_custom = metrics.NewSet() ms.metrics_set_custom = metrics.NewSet()
go ms.startPrometheusEndpoint() // if not testing, start the prometheus endpoint
if purgeEvery > 0 { if flag.Lookup("test.v") == nil {
ticker := time.NewTicker(time.Duration(purgeEvery) * time.Second) go ms.startPrometheusEndpoint()
go func() {
for range ticker.C { if ic.PurgeEvery > 0 {
ms.PurgeMetrics() ticker := time.NewTicker(time.Duration(ic.PurgeEvery) * time.Second)
} go func() {
}() for range ticker.C {
ms.PurgeMetrics()
}
}()
}
} }
return ms return ms
@@ -63,7 +70,7 @@ func (ms *MetricsSetup) metricsEndpoint(c *fiber.Ctx) error {
ms.metrics_set.WritePrometheus(c.Response().BodyWriter()) ms.metrics_set.WritePrometheus(c.Response().BodyWriter())
ms.metrics_set_custom.WritePrometheus(c.Response().BodyWriter()) ms.metrics_set_custom.WritePrometheus(c.Response().BodyWriter())
if purgeMetricsOnCrawl && purgeMetricsEvery == 0 { if ms.ic.PurgeOnCrawl && ms.ic.PurgeEvery == 0 {
ms.PurgeMetrics() ms.PurgeMetrics()
} }
return nil return nil
+6 -4
View File
@@ -1,8 +1,10 @@
package libpack_monitoring package libpack_monitoring
const ( const (
MetricsSucceeded = "requests_succesful" MetricsSucceeded = "requests_succesful"
MetricsFailed = "requests_failed" MetricsFailed = "requests_failed"
MetricsDuration = "requests_duration" MetricsDuration = "requests_duration"
MetricsSkipped = "requests_skipped" MetricsSkipped = "requests_skipped"
MetricsExecutedQuery = "executed_query"
MetricsTimedQuery = "timed_query"
) )
+13 -7
View File
@@ -31,7 +31,9 @@ func createFasthttpClient(timeout int) *fasthttp.Client {
func proxyTheRequest(c *fiber.Ctx) error { func proxyTheRequest(c *fiber.Ctx) error {
if !checkAllowedURLs(c) { if !checkAllowedURLs(c) {
cfg.Logger.Error("Request blocked", map[string]interface{}{"path": c.Path()}) cfg.Logger.Error("Request blocked", map[string]interface{}{"path": c.Path()})
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil) if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Request blocked - not allowed URL") c.Status(403).SendString("Request blocked - not allowed URL")
return nil return nil
} }
@@ -44,11 +46,13 @@ func proxyTheRequest(c *fiber.Ctx) error {
err := retry.Do( err := retry.Do(
func() error { func() error {
err := proxy.DoRedirects(c, cfg.Server.HostGraphQL+c.Path(), 3, cfg.Client.FastProxyClient) errInt := proxy.DoRedirects(c, cfg.Server.HostGraphQL+c.Path(), 3, cfg.Client.FastProxyClient)
if err != nil { if errInt != nil {
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()}) cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": errInt.Error()})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) if ifNotInTest() {
return err cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return errInt
} }
return nil return nil
}, },
@@ -69,7 +73,9 @@ func proxyTheRequest(c *fiber.Ctx) error {
cfg.Logger.Debug("Received proxied response", map[string]interface{}{"path": c.Path(), "response_body": string(c.Response().Body()), "response_code": c.Response().StatusCode(), "headers": c.GetRespHeaders(), "request_uuid": c.Locals("request_uuid")}) cfg.Logger.Debug("Received proxied response", map[string]interface{}{"path": c.Path(), "response_body": string(c.Response().Body()), "response_code": c.Response().StatusCode(), "headers": c.GetRespHeaders(), "request_uuid": c.Locals("request_uuid")})
if c.Response().StatusCode() != 200 { if c.Response().StatusCode() != 200 {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil) if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("Received non-200 response from the GraphQL server: %d", c.Response().StatusCode()) return fmt.Errorf("Received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
} }
+82
View File
@@ -0,0 +1,82 @@
package main
import (
"github.com/valyala/fasthttp"
)
func (suite *Tests) Test_proxyTheRequest() {
supplied_headers := map[string]string{
"X-Forwarded-For": "127.0.0.1",
"Content-Type": "application/json",
}
tests := []struct {
name string
query string
host string
path string
headers map[string]string
wantErr bool
}{
{
name: "test_empty",
query: `query {
__type(name: "Query") {
name
}
}`,
host: "https://telegram-bot.app/",
path: "/v1/graphql",
headers: supplied_headers,
wantErr: false,
},
{
name: "test_wrong_url",
query: `query {
__type(name: "Query") {
name
}
}`,
host: "https://google.com/",
path: "/v1/wrongURL",
headers: supplied_headers,
wantErr: true,
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
cfg = &config{}
parseConfig()
cfg.Server.HostGraphQL = tt.host
ctx_headers := func() *fasthttp.RequestHeader {
h := fasthttp.RequestHeader{}
for k, v := range tt.headers {
h.Add(k, v)
}
return &h
}()
ctx_request := fasthttp.Request{
Header: *ctx_headers,
}
ctx_request.SetRequestURI(tt.path)
ctx_request.Header.SetMethod("POST")
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{
Request: ctx_request,
})
assert.NotNil(ctx, "Fiber context is nil", tt.name)
err := proxyTheRequest(ctx)
if tt.wantErr {
assert.NotNil(err, "Error is nil", tt.name)
} else {
assert.Nil(err, "Error is not nil", tt.name)
}
})
}
}
+4 -3
View File
@@ -148,6 +148,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil { if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil {
cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")}) cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
c.Request().Header.Add("X-Cache-Hit", "true")
c.Send(cachedResponse) c.Send(cachedResponse)
wasCached = true wasCached = true
} else { } else {
@@ -201,10 +202,10 @@ func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached
} }
cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil) cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil)
cfg.Monitoring.Increment("executed_query", labels) cfg.Monitoring.Increment(libpack_monitoring.MetricsExecutedQuery, labels)
if !wasCached { if !wasCached {
cfg.Monitoring.UpdateDuration("timed_query", labels, startTime) cfg.Monitoring.UpdateDuration(libpack_monitoring.MetricsTimedQuery, labels, startTime)
cfg.Monitoring.Update("timed_query", labels, float64(duration.Milliseconds())) cfg.Monitoring.Update(libpack_monitoring.MetricsTimedQuery, labels, float64(duration.Milliseconds()))
} }
} }