Compare commits

...

4 Commits

Author SHA1 Message Date
lukaszraczylo 615836ab36 Update go.mod and go.sum
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-10-15 03:08:28 +00:00
lukaszraczylo a51f37c0a2 Update go.mod and go.sum
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-10-12 03:04:33 +00:00
lukaszraczylo 6b31e5c4c0 Little code cleanup. (#19) 2024-10-10 10:34:23 +01:00
lukaszraczylo d919a1df75 Update go.mod and go.sum
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-10-10 03:07:51 +00:00
12 changed files with 149 additions and 121 deletions
+3 -3
View File
@@ -16,14 +16,14 @@ require (
github.com/lukaszraczylo/ask v0.0.0-20240916204100-6e9ef53a62d9
github.com/lukaszraczylo/go-ratecounter v0.1.12
github.com/lukaszraczylo/go-simple-graphql v1.2.29
github.com/redis/go-redis/v9 v9.6.1
github.com/redis/go-redis/v9 v9.6.2
github.com/stretchr/testify v1.9.0
github.com/valyala/fasthttp v1.56.0
)
require (
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
@@ -31,7 +31,7 @@ require (
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/compress v1.17.10 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
+8 -6
View File
@@ -4,8 +4,8 @@ github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZp
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA=
github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0=
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
github.com/avast/retry-go/v4 v4.6.0/go.mod h1:gvWlPhBVsvBbLkVGDg/KwvBv0bEkCOLRRSHKIr2PyOE=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
@@ -44,8 +44,8 @@ github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0=
github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
@@ -65,8 +65,8 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T
github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/redis/go-redis/v9 v9.6.1 h1:HHDteefn6ZkTtY5fGUE8tj8uy85AHk6zP7CpzIAM0y4=
github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA=
github.com/redis/go-redis/v9 v9.6.2 h1:w0uvkRbc9KpgD98zcvo5IrVUsn0lXpRMuhNgiHDJzdk=
github.com/redis/go-redis/v9 v9.6.2/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJCQLgE9+RA=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
@@ -89,6 +89,8 @@ github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVS
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
+1
View File
@@ -30,6 +30,7 @@ func prepareQueriesAndExemptions() {
for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(q)] = struct{}{}
}
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
}
+1 -1
View File
@@ -393,7 +393,7 @@ func (suite *Tests) Test_checkAllowedURLs() {
ctx.Request().SetRequestURI(tt.path)
ctx.Request().URI().SetPath(tt.path)
result := checkAllowedURLs(ctx)
assert.Equal(tt.expected, result)
assert.Equal(tt.expected, result, "Unexpected result in test case: "+tt.name)
})
}
}
+71 -65
View File
@@ -2,7 +2,6 @@ package libpack_logger
import (
"bytes"
"flag"
"fmt"
"io"
"os"
@@ -16,16 +15,14 @@ import (
)
const (
_ = iota
LEVEL_DEBUG
LEVEL_DEBUG = iota
LEVEL_INFO
LEVEL_WARN
LEVEL_ERROR
LEVEL_FATAL
)
var LevelNames = [...]string{
"none",
var levelNames = []string{
"debug",
"info",
"warn",
@@ -34,74 +31,103 @@ var LevelNames = [...]string{
}
const (
defaultFormat = time.RFC3339
defaultTimeFormat = time.RFC3339
defaultMinLevel = LEVEL_INFO
defaultShowCaller = false
)
var defaultOutput = os.Stdout
// Logger represents the logging object with configurations.
type Logger struct {
output io.Writer
format string
timeFormat string
minLogLevel int
showCaller bool
}
// LogMessage represents a log message with optional pairs.
type LogMessage struct {
output io.Writer
Pairs map[string]any
Pairs map[string]interface{}
Message string
}
func (m *LogMessage) String() string {
return m.Message
// bufferPool is used to reuse bytes.Buffer for efficiency.
var bufferPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// fieldNames allows customization of output field names.
var fieldNames = map[string]string{
"timestamp": "timestamp",
"level": "level",
"message": "message",
}
// New creates a new Logger with default settings.
func New() *Logger {
return &Logger{
format: defaultFormat,
timeFormat: defaultTimeFormat,
minLogLevel: defaultMinLevel,
output: defaultOutput,
output: os.Stdout,
showCaller: defaultShowCaller,
}
}
// SetOutput sets the output destination for the logger.
func (l *Logger) SetOutput(output io.Writer) *Logger {
l.output = output
return l
}
var bufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
var defaultPairs = make(map[string]any)
// GetLogLevel returns the log level integer corresponding to the given level name.
func GetLogLevel(level string) int {
for i, name := range LevelNames {
if name == strings.ToLower(level) {
level = strings.ToLower(level)
for i, name := range levelNames {
if name == level {
return i
}
}
return defaultMinLevel
}
// SetTimeFormat sets the time format for the logger's timestamp field.
func (l *Logger) SetTimeFormat(format string) *Logger {
l.timeFormat = format
return l
}
// SetMinLogLevel sets the minimum log level for the logger.
func (l *Logger) SetMinLogLevel(level int) *Logger {
l.minLogLevel = level
return l
}
// SetFieldName allows customizing the field names in log output.
func (l *Logger) SetFieldName(field, name string) *Logger {
fieldNames[field] = name
return l
}
// SetShowCaller enables or disables including the caller information in log output.
func (l *Logger) SetShowCaller(show bool) *Logger {
l.showCaller = show
return l
}
// shouldLog determines if the message should be logged based on the logger's minimum log level.
func (l *Logger) shouldLog(level int) bool {
return level >= l.minLogLevel
}
// log writes the log message with the given level.
func (l *Logger) log(level int, m *LogMessage) {
if m.Pairs == nil {
m.Pairs = defaultPairs
m.Pairs = make(map[string]interface{})
}
m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.format)
m.Pairs[fieldNames["level"]] = LevelNames[level]
m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.timeFormat)
m.Pairs[fieldNames["level"]] = levelNames[level]
m.Pairs[fieldNames["message"]] = m.Message
if l.showCaller {
@@ -109,93 +135,73 @@ func (l *Logger) log(level int, m *LogMessage) {
}
buffer := bufferPool.Get().(*bytes.Buffer)
defer bufferPool.Put(buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
var encoder = json.NewEncoder(buffer)
encoder := json.NewEncoder(buffer)
err := encoder.Encode(m.Pairs)
if err != nil {
fmt.Println("Error marshalling log message:", err)
fmt.Fprintln(os.Stderr, "Error marshalling log message:", err)
return
}
// if not running in test - use stderr and stdout, otherwise - use logger's output setting
if flag.Lookup("test.v") != nil {
m.output = os.Stdout
if level >= LEVEL_ERROR {
m.output = os.Stderr
}
_, err = l.output.Write(buffer.Bytes())
if err != nil {
fmt.Fprintln(os.Stderr, "Error writing log message:", err)
}
// Use logger's output setting instead of os.Stdout or os.Stderr
l.output.Write(buffer.Bytes())
}
// Debug logs a debug-level message.
func (l *Logger) Debug(m *LogMessage) {
if l.shouldLog(LEVEL_DEBUG) {
l.log(LEVEL_DEBUG, m)
}
}
// Info logs an info-level message.
func (l *Logger) Info(m *LogMessage) {
if l.shouldLog(LEVEL_INFO) {
l.log(LEVEL_INFO, m)
}
}
// Warn logs a warning-level message.
func (l *Logger) Warn(m *LogMessage) {
if l.shouldLog(LEVEL_WARN) {
l.log(LEVEL_WARN, m)
}
}
// Warning is an alias for Warn.
func (l *Logger) Warning(m *LogMessage) {
l.Warn(m)
}
// Error logs an error-level message.
func (l *Logger) Error(m *LogMessage) {
if l.shouldLog(LEVEL_ERROR) {
l.log(LEVEL_ERROR, m)
}
}
// Fatal logs a fatal-level message.
func (l *Logger) Fatal(m *LogMessage) {
if l.shouldLog(LEVEL_FATAL) {
l.log(LEVEL_FATAL, m)
}
}
// Critical logs a critical-level message and exits the application.
func (l *Logger) Critical(m *LogMessage) {
l.Fatal(m)
os.Exit(1)
}
func (l *Logger) shouldLog(level int) bool {
return level >= l.minLogLevel
}
func (l *Logger) SetFormat(format string) *Logger {
l.format = format
return l
}
func (l *Logger) SetMinLogLevel(level int) *Logger {
l.minLogLevel = level
return l
}
func (l *Logger) SetFieldName(field, name string) *Logger {
fieldNames[field] = name
return l
}
func (l *Logger) SetShowCaller(show bool) *Logger {
l.showCaller = show
return l
}
// getCaller retrieves the file and line number of the caller.
func getCaller() string {
_, file, line, ok := runtime.Caller(3)
// Skip 3 stack frames: getCaller -> log -> [Debug|Info|...]
const depth = 3
_, file, line, ok := runtime.Caller(depth)
if !ok {
return "unknown:0"
}
-5
View File
@@ -56,11 +56,6 @@ func Benchmark_NewLogger(b *testing.B) {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
got := New()
if tt.triggers.ModFormat.Format != "" {
got = got.SetFormat(tt.triggers.ModFormat.Format)
}
if tt.triggers.ModLevel.Level != 0 {
got = got.SetMinLogLevel(tt.triggers.ModLevel.Level)
}
+5 -5
View File
@@ -40,7 +40,7 @@ func (suite *LoggerTestSuite) Test_LogMessageString() {
Message: "test message",
}
assert.Equal("test message", msg.String())
assert.Equal("test message", msg.Message)
}
func callLoggerMethod(logger *Logger, methodName string, message *LogMessage) {
@@ -125,7 +125,7 @@ func (suite *LoggerTestSuite) Test_LogsLevelsPrint() {
// Set logger's minimum log level
logger.SetMinLogLevel(tt.loggerMinLevel)
fmt.Println("Logger min log level:", LevelNames[logger.minLogLevel])
fmt.Println("Logger min log level:", levelNames[logger.minLogLevel])
// Call the logging method
callLoggerMethod(logger, tt.method, msg)
@@ -143,7 +143,7 @@ func (suite *LoggerTestSuite) Test_LogsLevelsPrint() {
if !containsLogMessage(logOutput, tt.message) {
t.Errorf("Expected log message %q, but got %q", tt.message, logOutput)
}
assert.Equal(LevelNames[tt.messageLogLevel], loggedMessage["level"])
assert.Equal(levelNames[tt.messageLogLevel], loggedMessage["level"])
if tt.pairs != nil {
for k, v := range tt.pairs {
assert.Equal(v, loggedMessage[k])
@@ -161,9 +161,9 @@ func containsLogMessage(logOutput, expectedMessage string) bool {
}
func (suite *LoggerTestSuite) Test_SetFormat() {
logger := New().SetFormat(time.RFC3339Nano)
logger := New().SetTimeFormat(time.RFC3339Nano)
assert.Equal(time.RFC3339Nano, logger.format)
assert.Equal(time.RFC3339Nano, logger.timeFormat)
}
func (suite *LoggerTestSuite) Test_SetMinLogLevel() {
+25 -12
View File
@@ -20,45 +20,49 @@ var (
once sync.Once
)
// function get value from the env where the value can be anything
// getDetailsFromEnv retrieves the value from the environment or returns the default.
func getDetailsFromEnv[T any](key string, defaultValue T) T {
var result any
if _, ok := os.LookupEnv("GMP_" + key); ok {
key = "GMP_" + key
envKey := "GMP_" + key
if _, ok := os.LookupEnv(envKey); !ok {
envKey = key
}
switch v := any(defaultValue).(type) {
case string:
result = envutil.Getenv(key, v)
result = envutil.Getenv(envKey, v)
case int:
result = envutil.GetInt(key, v)
result = envutil.GetInt(envKey, v)
case bool:
result = envutil.GetBool(key, v)
result = envutil.GetBool(envKey, v)
default:
result = defaultValue
}
return result.(T)
}
// parseConfig loads and parses the configuration.
func parseConfig() {
libpack_config.PKG_NAME = "graphql_proxy"
c := config{}
// Server configurations
c.Server.PortGraphQL = getDetailsFromEnv("PORT_GRAPHQL", 8080)
c.Server.PortMonitoring = getDetailsFromEnv("MONITORING_PORT", 9393)
c.Server.HostGraphQL = getDetailsFromEnv("HOST_GRAPHQL", "http://localhost/")
c.Server.HostGraphQLReadOnly = getDetailsFromEnv("HOST_GRAPHQL_READONLY", "")
// Client configurations
c.Client.JWTUserClaimPath = getDetailsFromEnv("JWT_USER_CLAIM_PATH", "")
c.Client.JWTRoleClaimPath = getDetailsFromEnv("JWT_ROLE_CLAIM_PATH", "")
c.Client.RoleFromHeader = getDetailsFromEnv("ROLE_FROM_HEADER", "")
c.Client.RoleRateLimit = getDetailsFromEnv("ROLE_RATE_LIMIT", false)
/* in-memory cache */
// In-memory cache
c.Cache.CacheEnable = getDetailsFromEnv("ENABLE_GLOBAL_CACHE", false)
c.Cache.CacheTTL = getDetailsFromEnv("CACHE_TTL", 60)
/* redis cache */
// Redis cache
c.Cache.CacheRedisEnable = getDetailsFromEnv("ENABLE_REDIS_CACHE", false)
c.Cache.CacheRedisURL = getDetailsFromEnv("CACHE_REDIS_URL", "localhost:6379")
c.Cache.CacheRedisPassword = getDetailsFromEnv("CACHE_REDIS_PASSWORD", "")
c.Cache.CacheRedisDB = getDetailsFromEnv("CACHE_REDIS_DB", 0)
/* security */
// Security configurations
c.Security.BlockIntrospection = getDetailsFromEnv("BLOCK_SCHEMA_INTROSPECTION", false)
c.Security.IntrospectionAllowed = func() []string {
urls := getDetailsFromEnv("ALLOWED_INTROSPECTION", "")
@@ -68,10 +72,14 @@ func parseConfig() {
return strings.Split(urls, ",")
}()
c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info"))
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
// Logger setup
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).
SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
// Health check
c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "")
c.Client.GQLClient = graphql.NewConnection()
c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL)
// Server modes
c.Server.AccessLog = getDetailsFromEnv("ENABLE_ACCESS_LOG", false)
c.Server.ReadOnlyMode = getDetailsFromEnv("READ_ONLY_MODE", false)
c.Server.AllowURLs = func() []string {
@@ -83,22 +91,26 @@ func parseConfig() {
}()
c.Client.ClientTimeout = getDetailsFromEnv("PROXIED_CLIENT_TIMEOUT", 120)
c.Client.FastProxyClient = createFasthttpClient(c.Client.ClientTimeout)
proxy.WithClient(c.Client.FastProxyClient) // setting the global proxy client here instead of per request
proxy.WithClient(c.Client.FastProxyClient) // Setting the global proxy client
// API configurations
c.Server.EnableApi = getDetailsFromEnv("ENABLE_API", false)
c.Server.ApiPort = getDetailsFromEnv("API_PORT", 9090)
c.Api.BannedUsersFile = getDetailsFromEnv("BANNED_USERS_FILE", "/go/src/app/banned_users.json")
c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false)
c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0)
// Hasura event cleaner
c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false)
c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1)
c.HasuraEventCleaner.EventMetadataDb = getDetailsFromEnv("HASURA_EVENT_METADATA_DB", "")
cfg = &c
// Initialize cache if enabled
if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
cacheConfig := &libpack_cache.CacheConfig{
Logger: cfg.Logger,
TTL: cfg.Cache.CacheTTL,
}
// Redis cache configurations
if cfg.Cache.CacheRedisEnable {
cacheConfig.Redis.Enable = true
cacheConfig.Redis.URL = cfg.Cache.CacheRedisURL
@@ -113,7 +125,7 @@ func parseConfig() {
go enableApi()
go enableHasuraEventCleaner()
})
prepareQueriesAndExemptions()
prepareQueriesAndExemptions() // Ensure this function is defined elsewhere
}
func main() {
@@ -123,6 +135,7 @@ func main() {
StartHTTPProxy()
}
// ifNotInTest checks if the program is not running in a test environment.
func ifNotInTest() bool {
return flag.Lookup("test.v") == nil
}
+5 -1
View File
@@ -4,8 +4,12 @@ import (
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
// StartMonitoringServer initializes and starts the monitoring server.
func StartMonitoringServer() {
cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{PurgeOnCrawl: cfg.Server.PurgeOnCrawl, PurgeEvery: cfg.Server.PurgeEvery})
cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{
PurgeOnCrawl: cfg.Server.PurgeOnCrawl,
PurgeEvery: cfg.Server.PurgeEvery,
})
cfg.Monitoring.AddMetricsPrefix("graphql_proxy")
cfg.Monitoring.RegisterDefaultMetrics()
}
+10 -5
View File
@@ -17,6 +17,7 @@ import (
"github.com/valyala/fasthttp"
)
// createFasthttpClient creates and configures a fasthttp client.
func createFasthttpClient(timeout int) *fasthttp.Client {
return &fasthttp.Client{
Name: "graphql_proxy",
@@ -33,6 +34,7 @@ func createFasthttpClient(timeout int) *fasthttp.Client {
}
}
// proxyTheRequest handles the request proxying logic.
func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
if !checkAllowedURLs(c) {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -51,7 +53,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
return fmt.Errorf("invalid URL: %v", err)
}
if cfg.LogLevel == "debug" {
if cfg.LogLevel == "DEBUG" {
logDebugRequest(c)
}
@@ -61,7 +63,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
if proxyErr != nil {
return proxyErr
}
if c.Response().StatusCode() != 200 {
if c.Response().StatusCode() != fiber.StatusOK {
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
}
return nil
@@ -94,11 +96,12 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
return fmt.Errorf("failed to proxy request: %v", err)
}
if cfg.LogLevel == "debug" {
if cfg.LogLevel == "DEBUG" {
logDebugResponse(c)
}
if c.Response().Header.Peek("Content-Encoding") != nil && string(c.Response().Header.Peek("Content-Encoding")) == "gzip" {
if bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
// Decompress gzip response
reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body()))
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -122,7 +125,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
c.Response().Header.Del("Content-Encoding")
}
if c.Response().StatusCode() != 200 {
if c.Response().StatusCode() != fiber.StatusOK {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
@@ -133,6 +136,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
return nil
}
// logDebugRequest logs the request details when in debug mode.
func logDebugRequest(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Proxying the request",
@@ -145,6 +149,7 @@ func logDebugRequest(c *fiber.Ctx) {
})
}
// logDebugResponse logs the response details when in debug mode.
func logDebugResponse(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Received proxied response",
+4 -1
View File
@@ -10,6 +10,7 @@ import (
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
// RateLimitConfig holds the rate limit configuration for a role
type RateLimitConfig struct {
RateCounterTicker *goratecounter.RateCounter
Interval time.Duration `json:"interval"`
@@ -21,6 +22,7 @@ var (
rateLimitMu sync.RWMutex
)
// loadRatelimitConfig loads the rate limit configurations from file
func loadRatelimitConfig() error {
paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"}
for _, path := range paths {
@@ -59,7 +61,7 @@ func loadConfigFromPath(path string) error {
Interval: value.Interval,
})
if cfg.LogLevel == "debug" {
if cfg.LogLevel == "DEBUG" {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Setting ratelimit config for role",
Pairs: map[string]interface{}{
@@ -83,6 +85,7 @@ func loadConfigFromPath(path string) error {
return nil
}
// rateLimitedRequest checks if a request should be rate-limited
func rateLimitedRequest(userID, userRole string) bool {
rateLimitMu.RLock()
roleConfig, ok := rateLimits[userRole]
+16 -17
View File
@@ -3,7 +3,6 @@ package main
import (
"fmt"
"strconv"
"sync"
"time"
"github.com/goccy/go-json"
@@ -21,14 +20,7 @@ const (
healthCheckQueryStr = `{ __typename }`
)
var (
ctxPool = sync.Pool{
New: func() interface{} {
return new(fiber.Ctx)
},
}
)
// StartHTTPProxy initializes and starts the HTTP proxy server.
func StartHTTPProxy() {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Starting the HTTP proxy",
@@ -71,15 +63,18 @@ func StartHTTPProxy() {
}
}
// proxyTheRequestToDefault proxies the request to the default GraphQL endpoint.
func proxyTheRequestToDefault(c *fiber.Ctx) error {
return proxyTheRequest(c, cfg.Server.HostGraphQL)
}
// AddRequestUUID adds a unique request UUID to the context.
func AddRequestUUID(c *fiber.Ctx) error {
c.Locals("request_uuid", uuid.NewString())
return c.Next()
}
// checkAllowedURLs checks if the requested URL is allowed.
func checkAllowedURLs(c *fiber.Ctx) bool {
if len(allowedUrls) == 0 {
return true
@@ -89,6 +84,7 @@ func checkAllowedURLs(c *fiber.Ctx) bool {
return ok
}
// healthCheck performs a health check on the GraphQL server.
func healthCheck(c *fiber.Ctx) error {
if len(cfg.Server.HealthcheckGraphQL) > 0 {
cfg.Logger.Debug(&libpack_logger.LogMessage{
@@ -103,16 +99,17 @@ func healthCheck(c *fiber.Ctx) error {
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query")
return c.Status(fiber.StatusInternalServerError).SendString("Can't reach the GraphQL server with {__typename} query")
}
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Health check returning OK",
})
return c.Status(200).SendString("Health check OK")
return c.Status(fiber.StatusOK).SendString("Health check OK")
}
// processGraphQLRequest handles the incoming GraphQL requests.
func processGraphQLRequest(c *fiber.Ctx) error {
startTime := time.Now()
@@ -124,7 +121,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
}
if checkIfUserIsBanned(c, extractedUserID) {
return c.Status(403).SendString("User is banned")
return c.Status(fiber.StatusForbidden).SendString("User is banned")
}
if cfg.Client.RoleFromHeader != "" {
@@ -139,13 +136,13 @@ func processGraphQLRequest(c *fiber.Ctx) error {
Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName},
})
if !rateLimitedRequest(extractedUserID, extractedRoleName) {
return c.Status(429).SendString("Rate limit exceeded, try again later")
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
}
}
parsedResult := parseGraphQLQuery(c)
parsedResult := parseGraphQLQuery(c) // Ensure this function is defined elsewhere
if parsedResult.shouldBlock {
return c.Status(403).SendString("Request blocked")
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
}
if parsedResult.shouldIgnore {
@@ -208,7 +205,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return c.Status(500).SendString("Can't proxy the request - try again later")
return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
}
@@ -217,6 +214,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
return nil
}
// proxyAndCacheTheRequest proxies and caches the request if needed.
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error {
if err := proxyTheRequest(c, currentEndpoint); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -224,7 +222,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int,
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return c.Status(500).SendString("Can't proxy the request - try again later")
return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
libpack_cache.CacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
@@ -232,6 +230,7 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int,
return c.Send(c.Response().Body())
}
// logAndMonitorRequest logs and monitors the request processing.
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
labels := map[string]string{
"op_type": opType,