Compare commits

...

24 Commits

Author SHA1 Message Date
lukaszraczylo 162c4acd7c fixup! fixup! fixup! fixup! fixup! fixup! Fix redis cache benchmark. 2024-06-28 18:05:17 +01:00
lukaszraczylo fde78a4ece fixup! fixup! fixup! fixup! fixup! Fix redis cache benchmark. 2024-06-28 17:58:46 +01:00
lukaszraczylo b1ffffd545 Create static.yml 2024-06-28 17:49:21 +01:00
lukaszraczylo 977554dd49 fixup! fixup! fixup! fixup! Fix redis cache benchmark. 2024-06-28 14:12:25 +01:00
lukaszraczylo 4ca8ce5751 fixup! fixup! fixup! Fix redis cache benchmark. 2024-06-28 13:57:43 +01:00
lukaszraczylo de55444012 fixup! fixup! Fix redis cache benchmark. 2024-06-28 13:50:58 +01:00
lukaszraczylo 3ec1c37f23 fixup! Fix redis cache benchmark. 2024-06-28 13:40:43 +01:00
lukaszraczylo eb9821dc3f Fix redis cache benchmark. 2024-06-28 13:37:31 +01:00
lukaszraczylo 3467cc5be0 Fix the cleanup routine. 2024-06-28 13:26:18 +01:00
lukaszraczylo b10a28bf52 General code optimisations. (#16)
* General code optimisations.
2024-06-28 12:31:01 +01:00
lukaszraczylo 1b1656c4b5 Update go.mod and go.sum
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-28 03:01:32 +00:00
lukaszraczylo b29733e435 fixup! fixup! Update go.mod and go.sum 2024-06-28 00:46:34 +01:00
lukaszraczylo f8a7b8ad83 fixup! Update go.mod and go.sum 2024-06-28 00:40:58 +01:00
lukaszraczylo 43c62d85dd Update go.mod and go.sum
Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2024-06-27 12:52:22 +00:00
lukaszraczylo 43b7ab7a77 fixup! fixup! fixup! fixup! fixup! fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 13:48:36 +01:00
lukaszraczylo d0c883a418 fixup! fixup! fixup! fixup! fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 13:42:43 +01:00
lukaszraczylo 33fc370ff5 fixup! fixup! fixup! fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 13:40:13 +01:00
lukaszraczylo 0a1fb50906 fixup! fixup! fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 13:31:44 +01:00
lukaszraczylo f348c07b60 fixup! fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 08:44:46 +01:00
lukaszraczylo 60b2f217d0 fixup! fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-27 08:44:14 +01:00
lukaszraczylo f7babe93d9 fixup! Disable caller as it's not necessary and generates slight delay. 2024-06-20 08:41:33 +01:00
lukaszraczylo 16844e325e Disable caller as it's not necessary and generates slight delay. 2024-06-19 23:40:44 +01:00
lukaszraczylo 61d7a45d00 Update cache library, use miniredis for testing, add additional benchmarks. (#14)
Update cache library,
Update logging library,
use miniredis for testing, add additional benchmarks.
2024-06-19 23:10:36 +01:00
Chris Clayton 12e4237997 divide long functions, replace strings.builder with bytes.buffer. (#13)
Co-authored-by: Chris Clayton <chris.clayton@contino.io>
2024-06-17 10:23:41 +01:00
41 changed files with 2589 additions and 1095 deletions
+73
View File
@@ -0,0 +1,73 @@
name: Autoupdate go.mod and go.sum
on:
workflow_dispatch:
schedule:
- cron: "0 3 * * *"
env:
GO_VERSION: ">=1.21"
jobs:
# This job is responsible for preparation of the build
# environment variables.
prepare:
name: Preparing build context
runs-on: ubuntu-latest
steps:
- name: Checkout repo
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
id: cache
with:
go-version: ${{env.GO_VERSION}}
cache-dependency-path: "**/*.sum"
- name: Go get dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
go get ./...
# This job is responsible for running tests and linting the codebase
test:
name: "Unit testing"
runs-on: ubuntu-latest
container: golang:1
needs: [prepare]
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Ensure full history is checked out
token: ${{ secrets.GHCR_TOKEN }}
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: ${{env.GO_VERSION}}
cache-dependency-path: "**/*.sum"
- name: Install dependencies
run: |
apt-get update
apt-get install ca-certificates make -y
update-ca-certificates
go mod tidy
go get -u -v ./...
go mod tidy -v
- name: Run unit tests
run: |
CI_RUN=${CI} make test
git config --global --add safe.directory /__w/graphql-monitoring-proxy/graphql-monitoring-proxy
- name: Commit changes
uses: stefanzweifel/git-auto-commit-action@v5
with:
commit_message: "Update go.mod and go.sum"
commit_options: "--no-verify --signoff"
file_pattern: "go.mod go.sum"
+39 -19
View File
@@ -15,6 +15,12 @@ on:
env:
GO_VERSION: ">=1.21"
permissions:
# deployments permission to deploy GitHub pages website
deployments: write
# contents permission to update benchmark contents in gh-pages branch
contents: write
jobs:
# This job is responsible for preparation of the build
# environment variables.
@@ -47,20 +53,20 @@ jobs:
# container: github/super-linter:v4
needs: [prepare]
services:
# Label used to access the service container
redis:
# Docker Hub image
image: redis
# Set health checks to wait until redis has started
options: >-
--health-cmd "redis-cli ping"
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
# Maps the container port to the host machine
- 6379:6379
# services:
# # Label used to access the service container
# redis:
# # Docker Hub image
# image: redis
# # Set health checks to wait until redis has started
# options: >-
# --health-cmd "redis-cli ping"
# --health-interval 10s
# --health-timeout 5s
# --health-retries 5
# ports:
# # Maps the container port to the host machine
# - 6379:6379
steps:
- name: Checkout repository
@@ -78,12 +84,26 @@ jobs:
apt-get install ca-certificates make -y
update-ca-certificates
go mod tidy
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Run unit tests
env:
REDIS_HOST: redis
REDIS_PORT: 6379
REDIS_SERVER: "redis:6379"
run: |
export REDIS_SERVER="$REDIS_HOST:$REDIS_PORT"
CI_RUN=${CI} make test
- name: Run benchmark
run: |
go test -bench=. -benchmem ./... -run=^# | tee output.txt
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
tool: "go"
output-file-path: output.txt
fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: false
# auto-push only if it's on main branch
auto-push: false
gh-pages-branch: "gh-pages"
benchmark-data-dir-path: "docs"
+56 -4
View File
@@ -4,11 +4,20 @@ on:
workflow_dispatch:
push:
paths-ignore:
- '**/**.md'
- '**/**.yaml'
- 'static/**'
- "**/**.md"
- "**/**.yaml"
- "static/**"
branches:
- 'main'
- "main"
env:
GO_VERSION: ">=1.21"
permissions:
# deployments permission to deploy GitHub pages website
deployments: write
# contents permission to update benchmark contents in gh-pages branch
contents: write
jobs:
shared:
@@ -18,3 +27,46 @@ jobs:
should-deploy: false
secrets:
ghcr-token: ${{ secrets.GHCR_TOKEN }}
test:
name: "Benchmarking the results"
needs: [shared]
runs-on: ubuntu-latest
container: golang:1
# container: github/super-linter:v4
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: ${{env.GO_VERSION}}
cache-dependency-path: "**/*.sum"
- name: Install dependencies
run: |
apt-get update
apt-get install ca-certificates make -y
update-ca-certificates
go mod tidy
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Run benchmark
run: |
go test -bench=. -benchmem ./... -run=^# | tee output.txt
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
tool: "go"
output-file-path: output.txt
fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: false
# auto-push only if it's on main branch
auto-push: true
gh-pages-branch: "gh-pages"
benchmark-data-dir-path: "docs"
+1
View File
@@ -1,2 +1,3 @@
graphql-proxy
test.sh
banned.json*
+5 -5
View File
@@ -1,9 +1,9 @@
CI_RUN?=false
ADDITIONAL_BUILD_FLAGS=""
# ADDITIONAL_BUILD_FLAGS=""
ifeq ($(CI_RUN), true)
ADDITIONAL_BUILD_FLAGS="-test.short"
endif
# ifeq ($(CI_RUN), true)
# ADDITIONAL_BUILD_FLAGS="-test.short"
# endif
.PHONY: help
help: ## display this help
@@ -19,7 +19,7 @@ build: ## build the binary
.PHONY: test
test: ## run tests on library
@LOG_LEVEL=debug go test $(ADDITIONAL_BUILD_FLAGS) -v -cover ./... -race
@LOG_LEVEL=info go test -v -cover -race ./...
.PHONY: test-packages
test-packages: ## run tests on packages
+1 -1
View File
@@ -2,7 +2,7 @@
Creates a passthrough proxy to a graphql endpoint(s), allowing you to analyse the queries and responses, producing the Prometheus metrics at a fraction of the cost - because, as we know - $0 is a fair price.
This project is in active use by [telegram-bot.app](https://telegram-bot.app), and was tested with 30k queries per second on a single instance, consuming 10 MB of RAM and 0.1% CPU.
This project is in active use by [telegram-bot.app](https://telegram-bot.app), and was tested with 30k queries per second on a single instance, consuming 10 MB of RAM and 0.1% CPU. [Benchmarks](https://lukaszraczylo.github.io/graphql-monitoring-proxy/dev/bench/) are available.
![Example of monitoring dashboard](static/monitoring-at-glance.png?raw=true)
+178 -80
View File
@@ -3,72 +3,94 @@ package main
import (
"fmt"
"os"
"sync"
"time"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
"github.com/gofrs/flock"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
var bannedUsersIDs map[string]string = make(map[string]string)
var (
bannedUsersIDs = make(map[string]string)
bannedUsersIDsMutex sync.RWMutex
)
func enableApi() {
if cfg.Server.EnableApi {
apiserver := fiber.New(fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
if !cfg.Server.EnableApi {
return
}
apiserver := fiber.New(fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
})
api := apiserver.Group("/api")
api.Post("/user-ban", apiBanUser)
api.Post("/user-unban", apiUnbanUser)
api.Post("/cache-clear", apiClearCache)
api.Get("/cache-stats", apiCacheStats)
go periodicallyReloadBannedUsers()
if err := apiserver.Listen(fmt.Sprintf(":%d", cfg.Server.ApiPort)); err != nil {
cfg.Logger.Critical(&libpack_logger.LogMessage{
Message: "Can't start the service",
Pairs: map[string]interface{}{"port": cfg.Server.ApiPort},
})
api := apiserver.Group("/api")
api.Post("/user-ban", apiBanUser)
api.Post("/user-unban", apiUnbanUser)
api.Post("/cache-clear", apiClearCache)
api.Get("/cache-stats", apiCacheStats)
go periodicallyReloadBannedUsers()
err := apiserver.Listen(fmt.Sprintf(":%d", cfg.Server.ApiPort))
if err != nil {
cfg.Logger.Critical("Can't start the service", map[string]interface{}{"error": err.Error()})
}
}
}
func periodicallyReloadBannedUsers() {
for {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
loadBannedUsers()
cfg.Logger.Debug("Banned users reloaded", map[string]interface{}{"users": bannedUsersIDs})
<-time.After(10 * time.Second)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Banned users reloaded",
Pairs: map[string]interface{}{"users": bannedUsersIDs},
})
}
}
func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool {
bannedUsersIDsMutex.RLock()
_, found := bannedUsersIDs[userID]
cfg.Logger.Debug("Checking if user is banned", map[string]interface{}{"user_id": userID, "found": found})
bannedUsersIDsMutex.RUnlock()
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Checking if user is banned",
Pairs: map[string]interface{}{"user_id": userID, "banned": found},
})
if found {
cfg.Logger.Info("User is banned", map[string]interface{}{"user_id": userID})
c.Status(403).SendString("User is banned")
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "User is banned",
Pairs: map[string]interface{}{"user_id": userID},
})
c.Status(fiber.StatusForbidden).SendString("User is banned")
}
return found
}
func apiClearCache(c *fiber.Ctx) error {
cfg.Logger.Debug("Clearing cache via API", nil)
cacheClear()
cfg.Logger.Info("Cache cleared via API", nil)
c.Status(200).SendString("OK: cache cleared")
return nil
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Clearing cache via API",
})
libpack_cache.CacheClear()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Cache cleared via API",
})
return c.SendString("OK: cache cleared")
}
func apiCacheStats(c *fiber.Ctx) error {
stats := getCacheStats()
cfg.Logger.Debug("Getting cache stats via API", map[string]interface{}{"stats": stats})
err := c.JSON(stats)
if err != nil {
cfg.Logger.Error("Can't marshal cache stats", map[string]interface{}{"error": err.Error()})
return err
}
return nil
return c.JSON(libpack_cache.GetCacheStats())
}
type apiBanUserRequest struct {
@@ -78,84 +100,160 @@ type apiBanUserRequest struct {
func apiBanUser(c *fiber.Ctx) error {
var req apiBanUserRequest
err := c.BodyParser(&req)
if err != nil {
cfg.Logger.Error("Can't parse the ban user request", map[string]interface{}{"error": err.Error()})
return err
if err := c.BodyParser(&req); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the ban user request",
Pairs: map[string]interface{}{"error": err.Error()},
})
return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload")
}
if req.UserID == "" || req.Reason == "" {
return c.Status(fiber.StatusBadRequest).SendString("user_id and reason are required")
}
bannedUsersIDsMutex.Lock()
bannedUsersIDs[req.UserID] = req.Reason
cfg.Logger.Info("Banned user", map[string]interface{}{"user_id": req.UserID, "reason": req.Reason})
storeBannedUsers()
c.Status(200).SendString("OK: user banned")
return nil
bannedUsersIDsMutex.Unlock()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Banned user",
Pairs: map[string]interface{}{"user_id": req.UserID, "reason": req.Reason},
})
if err := storeBannedUsers(); err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users")
}
return c.SendString("OK: user banned")
}
func apiUnbanUser(c *fiber.Ctx) error {
var req apiBanUserRequest
err := c.BodyParser(&req)
if err != nil {
cfg.Logger.Error("Can't parse the unban user request", map[string]interface{}{"error": err.Error()})
return err
if err := c.BodyParser(&req); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the unban user request",
Pairs: map[string]interface{}{"error": err.Error()},
})
return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload")
}
if req.UserID == "" {
return c.Status(fiber.StatusBadRequest).SendString("user_id is required")
}
bannedUsersIDsMutex.Lock()
delete(bannedUsersIDs, req.UserID)
cfg.Logger.Info("Unbanned user", map[string]interface{}{"user_id": req.UserID})
storeBannedUsers()
c.Status(200).SendString("OK: user unbanned")
return nil
bannedUsersIDsMutex.Unlock()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Unbanned user",
Pairs: map[string]interface{}{"user_id": req.UserID},
})
if err := storeBannedUsers(); err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users")
}
return c.SendString("OK: user unbanned")
}
func storeBannedUsers() {
func storeBannedUsers() error {
fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
err := fileLock.Lock()
if err != nil {
cfg.Logger.Error("Can't lock the file", map[string]interface{}{"error": err.Error()})
return
if err := lockFile(fileLock); err != nil {
return err
}
defer fileLock.Unlock()
bannedUsersIDsMutex.RLock()
data, err := json.Marshal(bannedUsersIDs)
bannedUsersIDsMutex.RUnlock()
if err != nil {
cfg.Logger.Error("Can't marshal banned users", map[string]interface{}{"error": err.Error()})
return
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't marshal banned users",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
err = os.WriteFile(cfg.Api.BannedUsersFile, data, 0644)
if err != nil {
cfg.Logger.Error("Can't write banned users to file", map[string]interface{}{"error": err.Error()})
return
if err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0644); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't write banned users to file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
}
func loadBannedUsers() {
if _, err := os.Stat(cfg.Api.BannedUsersFile); os.IsNotExist(err) {
cfg.Logger.Info("Banned users file doesn't exist - creating it", map[string]interface{}{"file": cfg.Api.BannedUsersFile})
_, err := os.Create(cfg.Api.BannedUsersFile)
if err != nil {
cfg.Logger.Error("Can't create the file", map[string]interface{}{"error": err.Error()})
return
}
// write empty json to the file
err = os.WriteFile(cfg.Api.BannedUsersFile, []byte("{}"), 0644)
if err != nil {
cfg.Logger.Error("Can't write to the file", map[string]interface{}{"error": err.Error()})
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Banned users file doesn't exist - creating it",
Pairs: map[string]interface{}{"file": cfg.Api.BannedUsersFile},
})
if err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{}"), 0644); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't create and write to the file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
}
fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
err := fileLock.RLock() // Use RLock for read lock
if err != nil {
cfg.Logger.Error("Can't lock the file [load]", map[string]interface{}{"error": err.Error()})
if err := lockFileRead(fileLock); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file [load]",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
defer fileLock.Unlock()
data, err := os.ReadFile(cfg.Api.BannedUsersFile)
if err != nil {
cfg.Logger.Error("Can't read banned users from file", map[string]interface{}{"error": err.Error()})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't read banned users from file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
err = json.Unmarshal(data, &bannedUsersIDs)
if err != nil {
cfg.Logger.Error("Can't unmarshal banned users", map[string]interface{}{"error": err.Error()})
var newBannedUsers map[string]string
if err := json.Unmarshal(data, &newBannedUsers); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't unmarshal banned users",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
bannedUsersIDsMutex.Lock()
bannedUsersIDs = newBannedUsers
bannedUsersIDsMutex.Unlock()
}
func lockFile(fileLock *flock.Flock) error {
if err := fileLock.Lock(); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
}
func lockFileRead(fileLock *flock.Flock) error {
if err := fileLock.RLock(); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file for reading",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
}
-95
View File
@@ -1,95 +0,0 @@
package main
import (
"time"
fiber "github.com/gofiber/fiber/v2"
"github.com/gookit/goutil/strutil"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
)
type CacheStats struct {
CachedQueries int `json:"cached_queries"`
CacheHits int `json:"cache_hits"`
CacheMisses int `json:"cache_misses"`
}
type CacheClient interface {
Set(key string, value []byte, ttl time.Duration)
Get(key string) ([]byte, bool)
Delete(key string)
Clear()
CountQueries() int
}
var (
cacheStats *CacheStats
)
func calculateHash(c *fiber.Ctx) string {
return strutil.Md5(c.Body())
}
func enableCache() {
cacheStats = &CacheStats{}
if shouldUseRedisCache() {
cfg.Logger.Info("Using Redis cache", nil)
cfg.Cache.Client = libpack_redis.NewClient(&libpack_redis.RedisClientConfig{
RedisDB: cfg.Cache.CacheRedisDB,
RedisServer: cfg.Cache.CacheRedisURL,
RedisPassword: cfg.Cache.CacheRedisPassword,
})
} else {
cfg.Logger.Info("Using in-memory cache", nil)
cfg.Cache.Client = libpack_cache.New(time.Duration(cfg.Cache.CacheTTL) * time.Second)
}
}
func cacheLookup(hash string) []byte {
obj, found := cfg.Cache.Client.Get(hash)
if found {
cacheStats.CacheHits++
return obj
}
cacheStats.CacheMisses++
return nil
}
func cacheDelete(hash string) {
cfg.Logger.Debug("Deleting data from cache", map[string]interface{}{"hash": hash})
cacheStats.CachedQueries--
cfg.Cache.Client.Delete(hash)
}
func cacheStore(hash string, data []byte) {
cfg.Logger.Debug("Storing data in cache", map[string]interface{}{"hash": hash})
cacheStats.CachedQueries++
cfg.Cache.Client.Set(hash, data, time.Duration(cfg.Cache.CacheTTL)*time.Second)
}
func cacheStoreWithTTL(hash string, data []byte, ttl time.Duration) {
cfg.Logger.Debug("Storing data in cache with TTL", map[string]interface{}{"hash": hash, "ttl": ttl})
cacheStats.CachedQueries++
cfg.Cache.Client.Set(hash, data, ttl)
}
func cacheGetQueries() int {
cfg.Logger.Debug("Counting cache queries", nil)
return cfg.Cache.Client.CountQueries()
}
func cacheClear() {
cfg.Cache.Client.Clear()
cacheStats = &CacheStats{}
}
func getCacheStats() *CacheStats {
cfg.Logger.Debug("Getting cache stats", nil)
cacheStats.CachedQueries = cacheGetQueries()
return cacheStats
}
func shouldUseRedisCache() bool {
return cfg.Cache.CacheRedisEnable
}
+134
View File
@@ -0,0 +1,134 @@
package libpack_cache
import (
"sync/atomic"
"time"
fiber "github.com/gofiber/fiber/v2"
"github.com/gookit/goutil/strutil"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
type CacheConfig struct {
Logger *libpack_logger.Logger
Client CacheClient
Redis struct {
URL string `json:"url"`
Password string `json:"password"`
DB int `json:"db"`
Enable bool `json:"enable"`
}
TTL int `json:"ttl"`
}
type CacheStats struct {
CachedQueries int64 `json:"cached_queries"`
CacheHits int64 `json:"cache_hits"`
CacheMisses int64 `json:"cache_misses"`
}
type CacheClient interface {
Set(key string, value []byte, ttl time.Duration)
Get(key string) ([]byte, bool)
Delete(key string)
Clear()
CountQueries() int64
}
var (
cacheStats *CacheStats
config *CacheConfig
)
func CalculateHash(c *fiber.Ctx) string {
return strutil.Md5(c.Body())
}
func EnableCache(cfg *CacheConfig) {
if cfg.Logger == nil {
cfg.Logger = libpack_logger.New()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Initializing in-module logger",
})
}
cacheStats = &CacheStats{}
if ShouldUseRedisCache(cfg) {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Using Redis cache",
})
cfg.Client = libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisDB: cfg.Redis.DB,
RedisServer: cfg.Redis.URL,
RedisPassword: cfg.Redis.Password,
})
} else {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Using in-memory cache",
})
cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second)
}
config = cfg
}
func CacheLookup(hash string) []byte {
obj, found := config.Client.Get(hash)
if found {
atomic.AddInt64(&cacheStats.CacheHits, 1)
return obj
}
atomic.AddInt64(&cacheStats.CacheMisses, 1)
return nil
}
func CacheDelete(hash string) {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Deleting data from cache",
Pairs: map[string]interface{}{"hash": hash},
})
atomic.AddInt64(&cacheStats.CachedQueries, -1)
config.Client.Delete(hash)
}
func CacheStore(hash string, data []byte) {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Storing data in cache",
Pairs: map[string]interface{}{"hash": hash},
})
atomic.AddInt64(&cacheStats.CachedQueries, 1)
config.Client.Set(hash, data, time.Duration(config.TTL)*time.Second)
}
func CacheStoreWithTTL(hash string, data []byte, ttl time.Duration) {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Storing data in cache with TTL",
Pairs: map[string]interface{}{"hash": hash, "ttl": ttl},
})
atomic.AddInt64(&cacheStats.CachedQueries, 1)
config.Client.Set(hash, data, ttl)
}
func CacheGetQueries() int64 {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Counting cache queries",
})
return config.Client.CountQueries()
}
func CacheClear() {
config.Client.Clear()
cacheStats = &CacheStats{}
}
func GetCacheStats() *CacheStats {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Getting cache stats",
})
cacheStats.CachedQueries = CacheGetQueries()
return cacheStats
}
func ShouldUseRedisCache(cfg *CacheConfig) bool {
return cfg.Redis.Enable
}
+116
View File
@@ -0,0 +1,116 @@
package libpack_cache
import (
"testing"
"time"
"github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
const (
Parallelism = 4
RequestPerSec = 10000
)
func BenchmarkCacheLookupInMemory(b *testing.B) {
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_memory.New(5 * time.Minute),
TTL: 5,
}
EnableCache(config)
hash := "00000000000000000000000000000000001337"
data := []byte("it's fine.")
CacheStore(hash, data)
b.SetParallelism(Parallelism)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
CacheLookup(hash)
}
})
}
func BenchmarkCacheLookupRedis(b *testing.B) {
redis_server, _ := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redis_server.Addr(),
RedisDB: 0,
})
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
config.Redis.Enable = true
EnableCache(config)
hash := "00000000000000000000000000000000001337"
data := []byte("it's fine.")
CacheStore(hash, data)
b.SetParallelism(Parallelism)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
CacheLookup(hash)
}
})
}
func BenchmarkCacheStoreInMemory(b *testing.B) {
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_memory.New(5 * time.Minute),
TTL: 5,
}
EnableCache(config)
hash := "00000000000000000000000000000000001337"
data := []byte("it's fine.")
b.SetParallelism(Parallelism)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
CacheStore(hash, data)
}
})
}
func BenchmarkCacheStoreRedis(b *testing.B) {
redis_server, _ := miniredis.Run()
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redis_server.Addr(),
RedisDB: 0,
})
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
config.Redis.Enable = true
EnableCache(config)
hash := "00000000000000000000000000000000001337"
data := []byte("it's fine.")
b.SetParallelism(Parallelism)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
CacheStore(hash, data)
}
})
}
+34
View File
@@ -0,0 +1,34 @@
package libpack_cache
import (
"testing"
"github.com/alicebob/miniredis/v2"
assertions "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type Tests struct {
suite.Suite
}
var (
assert *assertions.Assertions
redisMockServer, _ = miniredis.Run()
)
func (suite *Tests) BeforeTest(suiteName, testName string) {
}
func (suite *Tests) SetupTest() {
cacheStats = &CacheStats{}
assert = assertions.New(suite.T())
}
// TearDownTest is run after each test to clean up
func (suite *Tests) TearDownTest() {
}
func TestSuite(t *testing.T) {
suite.Run(t, new(Tests))
}
+215
View File
@@ -0,0 +1,215 @@
package libpack_cache
import (
"fmt"
"sync"
"time"
"github.com/alicebob/miniredis/v2"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
func (suite *Tests) Test_cacheLookupInmemory() {
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_memory.New(5 * time.Minute),
TTL: 5,
}
type args struct {
hash string
}
tests := []struct {
name string
args args
want []byte
addCache struct {
data []byte
}
}{
{
name: "test_non_existent",
args: args{
hash: "00000000000000000000000000000000000000",
},
want: nil,
},
{
name: "test_existent",
args: args{
hash: "00000000000000000000000000000000001337",
},
want: []byte("it's fine."),
addCache: struct {
data []byte
}{
data: []byte("it's fine."),
},
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
if tt.addCache.data != nil {
CacheStore(tt.args.hash, tt.addCache.data)
}
got := CacheLookup(tt.args.hash)
assert.Equal(tt.want, got, "Unexpected cache lookup result")
})
}
}
func (suite *Tests) Test_cacheLookupRedis() {
// redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379")
// config.Client = libpack_cache_redis.NewClient(&libpack_cache_redis.RedisClientConfig{
// RedisServer: redis_server,
// RedisPassword: "",
// RedisDB: 0,
// })
mockedCache := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: redisMockServer.Addr(),
RedisDB: 0,
})
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: mockedCache,
TTL: 5,
}
type args struct {
hash string
}
tests := []struct {
name string
args args
want []byte
addCache struct {
data []byte
}
}{
{
name: "test_non_existent",
args: args{
hash: "00000000000000000000000000000000000000",
},
want: nil,
},
{
name: "test_existent",
args: args{
hash: "00000000000000000000000000000000001337",
},
want: []byte("it's fine."),
addCache: struct {
data []byte
}{
data: []byte("it's fine."),
},
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
if tt.addCache.data != nil {
CacheStore(tt.args.hash, tt.addCache.data)
}
got := CacheLookup(tt.args.hash)
assert.Equal(tt.want, got, "Unexpected cache lookup result")
})
}
}
func (suite *Tests) Test_cacheConcurrency() {
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_memory.New(5 * time.Second),
TTL: 5,
}
const numGoroutines = 10
const numOperations = 1000
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
key := fmt.Sprintf("key-%d-%d", id, j)
value := []byte(fmt.Sprintf("value-%d-%d", id, j))
CacheStore(key, value)
retrieved := CacheLookup(key)
assert.Equal(string(value), string(retrieved), "Concurrent cache operation failed")
}
}(i)
}
wg.Wait()
}
// func (suite *Tests) Test_cacheEviction() {
// config = &CacheConfig{
// Logger: libpack_logger.New(),
// Client: libpack_cache_memory.New(3 * time.Second), // 3 seconds TTL
// TTL: 3,
// }
// // Fill the cache
// for i := 0; i < 20; i++ {
// key := fmt.Sprintf("key-%d", i)
// value := []byte(fmt.Sprintf("value-%d", i))
// CacheStore(key, value)
// time.Sleep(100 * time.Millisecond) // Ensure different creation times
// }
// // Wait for the TTL to expire for the first half of the items
// time.Sleep(3100 * time.Millisecond)
// // Check that the oldest items have been evicted
// for i := 0; i < 10; i++ {
// key := fmt.Sprintf("key-%d", i)
// retrieved := CacheLookup(key)
// assert.Nil(retrieved, fmt.Sprintf("Old item %s should have been evicted", key))
// }
// // Check that the newer items are still in the cache
// for i := 10; i < 20; i++ {
// key := fmt.Sprintf("key-%d", i)
// expected := []byte(fmt.Sprintf("value-%d", i))
// retrieved := CacheLookup(key)
// assert.Equal(expected, retrieved, fmt.Sprintf("Recent item %s should be in cache", key))
// }
// }
func (suite *Tests) Test_cacheRedisFailure() {
mr, err := miniredis.Run()
if err != nil {
suite.T().Fatal(err)
}
defer mr.Close()
config = &CacheConfig{
Logger: libpack_logger.New(),
Client: libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisServer: mr.Addr(),
RedisDB: 0,
}),
TTL: 5,
}
// Test normal operation
CacheStore("test-key", []byte("test-value"))
retrieved := CacheLookup("test-key")
assert.Equal([]byte("test-value"), retrieved)
// Simulate Redis failure
mr.Close()
// Operations should not panic, but should return errors or nil values
CacheStore("another-key", []byte("another-value"))
retrieved = CacheLookup("another-key")
assert.Nil(retrieved, "Lookup should return nil when Redis is down")
}
+19 -48
View File
@@ -1,4 +1,4 @@
package libpack_cache
package libpack_cache_memory
import (
"bytes"
@@ -19,7 +19,7 @@ type Cache struct {
decompressPool sync.Pool
entries sync.Map
globalTTL time.Duration
mu sync.RWMutex // Added sync.RWMutex field for locking
sync.RWMutex
}
func New(globalTTL time.Duration) *Cache {
@@ -52,9 +52,6 @@ func (c *Cache) cleanupRoutine(globalTTL time.Duration) {
}
func (c *Cache) Set(key string, value []byte, ttl time.Duration) {
c.lock()
defer c.unlock()
expiresAt := time.Now().Add(ttl)
compressedValue, err := c.compress(value)
@@ -71,15 +68,18 @@ func (c *Cache) Set(key string, value []byte, ttl time.Duration) {
}
func (c *Cache) Get(key string) ([]byte, bool) {
c.rlock()
defer c.runlock()
entry, ok := c.entries.Load(key)
if !ok || entry.(CacheEntry).ExpiresAt.Before(time.Now()) {
if !ok {
return nil, false
}
compressedValue := entry.(CacheEntry).Value
value, err := c.decompress(compressedValue)
cacheEntry := entry.(CacheEntry)
if cacheEntry.ExpiresAt.Before(time.Now()) {
c.entries.Delete(key)
return nil, false
}
value, err := c.decompress(cacheEntry.Value)
if err != nil {
log.Printf("Error decompressing value for key %s: %v", key, err)
return nil, false
@@ -88,39 +88,32 @@ func (c *Cache) Get(key string) ([]byte, bool) {
}
func (c *Cache) Delete(key string) {
c.lock()
defer c.unlock()
c.entries.Delete(key)
}
func (c *Cache) Clear() {
c.lock()
defer c.unlock()
c.entries.Range(func(key, value interface{}) bool {
c.entries.Delete(key)
return true
})
}
func (c *Cache) CountQueries() int {
c.rlock()
defer c.runlock()
func (c *Cache) CountQueries() int64 {
var count int
c.entries.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
return int64(count)
}
func (c *Cache) compress(data []byte) ([]byte, error) {
w := c.compressPool.Get().(*gzip.Writer)
defer c.compressPool.Put(w)
var buf bytes.Buffer
w := c.compressPool.Get().(*gzip.Writer)
defer func() {
w.Close()
c.compressPool.Put(w)
}()
w.Reset(&buf)
if _, err := w.Write(data); err != nil {
return nil, err
@@ -149,11 +142,7 @@ func (c *Cache) decompress(data []byte) ([]byte, error) {
c.decompressPool.Put(r)
}()
decompressedData, err := io.ReadAll(r)
if err != nil {
return nil, err
}
return decompressedData, nil
return io.ReadAll(r)
}
func (c *Cache) CleanExpiredEntries() {
@@ -166,21 +155,3 @@ func (c *Cache) CleanExpiredEntries() {
return true
})
}
// Private methods to handle locking
func (c *Cache) lock() {
c.mu.Lock()
}
func (c *Cache) unlock() {
c.mu.Unlock()
}
func (c *Cache) rlock() {
c.mu.RLock()
}
func (c *Cache) runlock() {
c.mu.RUnlock()
}
@@ -1,13 +1,14 @@
package libpack_cache
package libpack_cache_memory
import (
"fmt"
"testing"
"time"
)
// Assume that New function initializes the cache and it is defined somewhere in the libpack_cache package.
func BenchmarkCacheSet(b *testing.B) {
func BenchmarkMemCacheSet(b *testing.B) {
cache := New(30 * time.Second) // Initializing the cache with a TTL of 30 seconds
key := "benchmark-key"
value := []byte("benchmark-value")
@@ -19,7 +20,7 @@ func BenchmarkCacheSet(b *testing.B) {
}
}
func BenchmarkCacheGet(b *testing.B) {
func BenchmarkMemCacheGet(b *testing.B) {
cache := New(30 * time.Second) // Initializing the cache
key := "benchmark-key"
value := []byte("benchmark-value")
@@ -32,7 +33,7 @@ func BenchmarkCacheGet(b *testing.B) {
}
}
func BenchmarkCacheExpire(b *testing.B) {
func BenchmarkMemCacheExpire(b *testing.B) {
key := "benchmark-expire-key"
value := []byte("benchmark-value")
ttl := 5 * time.Millisecond // Setting a short TTL for quick expiration
@@ -45,10 +46,37 @@ func BenchmarkCacheExpire(b *testing.B) {
}
}
func BenchmarkCacheStats(b *testing.B) {
func BenchmarkMemCacheStats(b *testing.B) {
cache := New(30 * time.Second) // Initializing the cache
key := "benchmark-key"
value := []byte("benchmark-value")
cache.Set(key, value, 5*time.Second) // Pre-set a value to retrieve
cache.Get(key)
}
func BenchmarkCacheSet(b *testing.B) {
cache := New(5 * time.Second)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Set(fmt.Sprintf("key-%d", i), []byte("value"), 5*time.Second)
}
}
func BenchmarkCacheGet(b *testing.B) {
cache := New(5 * time.Second)
cache.Set("test-key", []byte("test-value"), 5*time.Second)
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache.Get("test-key")
}
}
func BenchmarkCacheDelete(b *testing.B) {
cache := New(5 * time.Second)
b.ResetTimer()
for i := 0; i < b.N; i++ {
key := fmt.Sprintf("key-%d", i)
cache.Set(key, []byte("value"), 5*time.Second)
cache.Delete(key)
}
}
+64 -8
View File
@@ -1,31 +1,33 @@
package libpack_cache
package libpack_cache_memory
import (
"fmt"
"sync"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type CacheTestSuite struct {
type MemoryTestSuite struct {
suite.Suite
}
func (suite *CacheTestSuite) SetupTest() {
func (suite *MemoryTestSuite) SetupTest() {
}
func TestCachingTestSuite(t *testing.T) {
suite.Run(t, new(CacheTestSuite))
suite.Run(t, new(MemoryTestSuite))
}
func (suite *CacheTestSuite) Test_New() {
func (suite *MemoryTestSuite) Test_New() {
suite.T().Run("should return a new cache", func(t *testing.T) {
cache := New(2 * time.Second)
suite.NotNil(cache)
})
}
func (suite *CacheTestSuite) Test_CacheUse() {
func (suite *MemoryTestSuite) Test_CacheUse() {
cache := New(30 * time.Second)
tests := []struct {
name string
@@ -50,7 +52,7 @@ func (suite *CacheTestSuite) Test_CacheUse() {
}
}
func (suite *CacheTestSuite) Test_CacheDelete() {
func (suite *MemoryTestSuite) Test_CacheDelete() {
cache := New(30 * time.Second)
tests := []struct {
name string
@@ -79,7 +81,7 @@ func (suite *CacheTestSuite) Test_CacheDelete() {
}
}
func (suite *CacheTestSuite) Test_CacheExpire() {
func (suite *MemoryTestSuite) Test_CacheExpire() {
cache := New(30 * time.Second)
tests := []struct {
name string
@@ -110,3 +112,57 @@ func (suite *CacheTestSuite) Test_CacheExpire() {
})
}
}
func (suite *MemoryTestSuite) Test_ConcurrentReadWrite() {
cache := New(5 * time.Second)
const numGoroutines = 100
const numOperations = 1000
var wg sync.WaitGroup
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
key := fmt.Sprintf("key-%d-%d", id, j)
value := []byte(fmt.Sprintf("value-%d-%d", id, j))
if j%2 == 0 {
cache.Set(key, value, 5*time.Second)
} else {
_, _ = cache.Get(key)
}
}
}(i)
}
wg.Wait()
}
func (suite *MemoryTestSuite) Test_LargeItems() {
cache := New(5 * time.Second)
largeValue := make([]byte, 10*1024*1024) // 10MB
cache.Set("large-key", largeValue, 5*time.Second)
retrieved, found := cache.Get("large-key")
suite.Assert().True(found)
suite.Assert().Equal(largeValue, retrieved)
}
func (suite *MemoryTestSuite) Test_ZeroTTL() {
cache := New(5 * time.Second)
cache.Set("zero-ttl", []byte("value"), 0)
_, found := cache.Get("zero-ttl")
suite.Assert().False(found, "Item with zero TTL should not be stored")
}
func (suite *MemoryTestSuite) Test_LongTTL() {
cache := New(5 * time.Second)
cache.Set("long-ttl", []byte("value"), 24*365*time.Hour) // 1 year
retrieved, found := cache.Get("long-ttl")
suite.Assert().True(found)
suite.Assert().Equal([]byte("value"), retrieved)
}
-77
View File
@@ -1,77 +0,0 @@
package libpack_redis
import (
"context"
"time"
redis "github.com/redis/go-redis/v9"
)
var ()
type RedisConfig struct {
client *redis.Client
ctx context.Context
}
func prependKeyName(key string) string {
return "gmp_cache:" + key
}
type RedisClientConfig struct {
RedisServer string
RedisPassword string
RedisDB int
}
func NewClient(redisClientConfig *RedisClientConfig) *RedisConfig {
c := &RedisConfig{
client: redis.NewClient(&redis.Options{
Addr: redisClientConfig.RedisServer,
Password: redisClientConfig.RedisPassword,
DB: redisClientConfig.RedisDB,
}),
ctx: context.Background(),
}
_, err := c.client.Ping(c.ctx).Result()
if err != nil {
panic(err)
}
return c
}
func (c *RedisConfig) Set(key string, value []byte, ttl time.Duration) {
c.client.Set(c.ctx, prependKeyName(key), value, ttl)
}
func (c *RedisConfig) Get(key string) ([]byte, bool) {
val, err := c.client.Get(c.ctx, prependKeyName(key)).Result()
if err == redis.Nil || err != nil {
return nil, false
}
return []byte(val), true
}
func (c *RedisConfig) Delete(key string) {
c.client.Del(c.ctx, prependKeyName(key))
}
func (c *RedisConfig) Clear() {
c.client.FlushDB(c.ctx)
}
func (c *RedisConfig) CountQueries() int {
keys, err := c.client.Keys(c.ctx, prependKeyName("*")).Result()
if err != nil {
return 0
}
return len(keys)
}
func (c *RedisConfig) CountQueriesWithPattern(pattern string) int {
keys, err := c.client.Keys(c.ctx, prependKeyName(pattern)).Result()
if err != nil {
return 0
}
return len(keys)
}
+96
View File
@@ -0,0 +1,96 @@
package libpack_cache_redis
import (
"context"
"strings"
"time"
"sync"
redis "github.com/redis/go-redis/v9"
)
type RedisConfig struct {
ctx context.Context
client *redis.Client
builderPool *sync.Pool
prefix string
}
func (c *RedisConfig) prependKeyName(key string) string {
builder := c.builderPool.Get().(*strings.Builder)
defer c.builderPool.Put(builder)
builder.Reset()
builder.WriteString(c.prefix)
builder.WriteString(key)
return builder.String()
}
type RedisClientConfig struct {
RedisServer string
RedisPassword string
Prefix string
RedisDB int
}
func New(redisClientConfig *RedisClientConfig) *RedisConfig {
c := &RedisConfig{
client: redis.NewClient(&redis.Options{
Addr: redisClientConfig.RedisServer,
Password: redisClientConfig.RedisPassword,
DB: redisClientConfig.RedisDB,
}),
ctx: context.Background(),
prefix: redisClientConfig.Prefix,
builderPool: &sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
},
}
_, err := c.client.Ping(c.ctx).Result()
if err != nil {
panic(err)
}
return c
}
func (c *RedisConfig) Set(key string, value []byte, ttl time.Duration) {
c.client.Set(c.ctx, c.prependKeyName(key), value, ttl)
}
func (c *RedisConfig) Get(key string) ([]byte, bool) {
val, err := c.client.Get(c.ctx, c.prependKeyName(key)).Result()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return []byte(val), true
}
func (c *RedisConfig) Delete(key string) {
c.client.Del(c.ctx, c.prependKeyName(key))
}
func (c *RedisConfig) Clear() {
c.client.FlushDB(c.ctx)
}
func (c *RedisConfig) CountQueries() int64 {
keys, err := c.client.Keys(c.ctx, c.prependKeyName("*")).Result()
if err != nil {
return 0
}
return int64(len(keys))
}
func (c *RedisConfig) CountQueriesWithPattern(pattern string) int {
keys, err := c.client.Keys(c.ctx, c.prependKeyName(pattern)).Result()
if err != nil {
return 0
}
return len(keys)
}
+22 -17
View File
@@ -1,23 +1,24 @@
package libpack_redis
package libpack_cache_redis
import (
"testing"
"time"
"github.com/gookit/goutil/envutil"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type RedisConfigSuite struct {
suite.Suite
redisConfig *RedisConfig
redisConfig *RedisConfig
redis_server *miniredis.Miniredis
}
func (suite *RedisConfigSuite) SetupTest() {
redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379")
suite.redisConfig = NewClient(&RedisClientConfig{
RedisServer: redis_server,
suite.redis_server, _ = miniredis.Run()
suite.redisConfig = New(&RedisClientConfig{
RedisServer: suite.redis_server.Addr(),
RedisPassword: "",
RedisDB: 0,
})
@@ -29,7 +30,7 @@ func TestRedisConfigSuite(t *testing.T) {
}
func (suite *RedisConfigSuite) TestSet() {
key := "testkey"
key := "testkeyset"
value := []byte("testvalue")
suite.redisConfig.Delete(key) // Ensure the key is deleted before the test
@@ -50,9 +51,9 @@ func (suite *RedisConfigSuite) TestSet() {
}
func (suite *RedisConfigSuite) TestSetWithExpiry() {
key := "testkey"
value := []byte("testvalue")
expiry := 1 * time.Second
key := "testkey_with_expiry"
value := []byte("testvaluewithexpiry")
expiry := 2 * time.Second
suite.redisConfig.Delete(key) // Ensure the key is deleted before the test
// Test writing a new key-value pair
@@ -60,17 +61,19 @@ func (suite *RedisConfigSuite) TestSetWithExpiry() {
storedValue, found := suite.redisConfig.Get(key)
assert.True(suite.T(), found)
assert.Equal(suite.T(), value, storedValue)
_, found = suite.redisConfig.Get(key)
assert.True(suite.T(), found, "Key should exist")
// Test that key expires after the specified time
time.Sleep(2 * time.Second)
suite.redis_server.FastForward(3 * time.Second)
_, found = suite.redisConfig.Get(key)
assert.False(suite.T(), found)
assert.False(suite.T(), found, "Key should have expired after 2 seconds")
suite.redisConfig.Delete(key) // Clean up after the test
}
func (suite *RedisConfigSuite) TestGet() {
key := "testkey"
key := "testkeyget"
value := []byte("testvalue")
suite.redisConfig.Set(key, value, 0) // Set the key-value pair
storedValue, found := suite.redisConfig.Get(key)
@@ -79,7 +82,7 @@ func (suite *RedisConfigSuite) TestGet() {
}
func (suite *RedisConfigSuite) TestDeleteKey() {
key := "testkey"
key := "testkeydelete"
value := []byte("testvalue")
suite.redisConfig.Set(key, value, 0) // Set the key-value pair
suite.redisConfig.Delete(key)
@@ -89,7 +92,7 @@ func (suite *RedisConfigSuite) TestDeleteKey() {
func (suite *RedisConfigSuite) TestCheckIfKeyExists() {
ttl := time.Duration(10) * time.Second
key := "testkey"
key := "testkeyifexists"
value := []byte("testvalue")
suite.redisConfig.Set(key, value, ttl) // Set the key-value pair
_, found := suite.redisConfig.Get(key)
@@ -106,8 +109,8 @@ func (suite *RedisConfigSuite) TestGetKeys() {
suite.redisConfig.Set("testkey2", []byte("testvalue2"), ttl)
suite.redisConfig.Set("otherkey", []byte("othervalue"), ttl)
keys, _ := suite.redisConfig.client.Keys(suite.redisConfig.ctx, prependKeyName("testkey*")).Result()
expectedKeys := []string{prependKeyName("testkey1"), prependKeyName("testkey2")}
keys, _ := suite.redisConfig.client.Keys(suite.redisConfig.ctx, "testkey*").Result()
expectedKeys := []string{"testkey1", "testkey2"}
assert.ElementsMatch(suite.T(), expectedKeys, keys)
suite.redisConfig.client.Del(suite.redisConfig.ctx, "testkey1", "testkey2", "otherkey")
@@ -120,6 +123,8 @@ func (suite *RedisConfigSuite) TestGetKeysCount() {
suite.redisConfig.Set("otherkey", []byte("othervalue"), ttl)
assert.Equal(suite.T(), 2, suite.redisConfig.CountQueriesWithPattern("testkey*"))
assert.Equal(suite.T(), 1, suite.redisConfig.CountQueriesWithPattern("otherkey*"))
assert.Equal(suite.T(), int64(3), suite.redisConfig.CountQueries())
suite.redisConfig.client.Del(suite.redisConfig.ctx, "testkey1", "testkey2", "otherkey")
}
-99
View File
@@ -1,99 +0,0 @@
package main
import (
"github.com/gookit/goutil/envutil"
libpack_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
)
func (suite *Tests) Test_cacheLookupInmemory() {
type args struct {
hash string
}
tests := []struct {
name string
args args
want []byte
addCache struct {
data []byte
}
}{
{
name: "test_non_existent",
args: args{
hash: "00000000000000000000000000000000000000",
},
want: nil,
},
{
name: "test_existent",
args: args{
hash: "00000000000000000000000000000000001337",
},
want: []byte("it's fine."),
addCache: struct {
data []byte
}{
data: []byte("it's fine."),
},
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
if tt.addCache.data != nil {
cacheStore(tt.args.hash, tt.addCache.data)
}
got := cacheLookup(tt.args.hash)
assert.Equal(tt.want, got, "Unexpected cache lookup result")
})
}
}
func (suite *Tests) Test_cacheLookupRedis() {
redis_server := envutil.Getenv("REDIS_SERVER", "localhost:6379")
cfg.Cache.Client = libpack_redis.NewClient(&libpack_redis.RedisClientConfig{
RedisServer: redis_server,
RedisPassword: "",
RedisDB: 0,
})
type args struct {
hash string
}
tests := []struct {
name string
args args
want []byte
addCache struct {
data []byte
}
}{
{
name: "test_non_existent",
args: args{
hash: "00000000000000000000000000000000000000",
},
want: nil,
},
{
name: "test_existent",
args: args{
hash: "00000000000000000000000000000000001337",
},
want: []byte("it's fine."),
addCache: struct {
data []byte
}{
data: []byte("it's fine."),
},
},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
if tt.addCache.data != nil {
cacheStore(tt.args.hash, tt.addCache.data)
}
got := cacheLookup(tt.args.hash)
assert.Equal(tt.want, got, "Unexpected cache lookup result")
})
}
}
+31 -19
View File
@@ -7,18 +7,18 @@ import (
"github.com/goccy/go-json"
"github.com/lukaszraczylo/ask"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
func extractClaimsFromJWTHeader(authorization string) (usr string, role string) {
usr, role = "-", "-"
const defaultValue = "-"
handleError := func(msg string, details map[string]interface{}) {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
cfg.Logger.Error(msg, details)
}
var emptyMetrics = map[string]string{}
tokenParts := strings.Split(authorization, ".")
func extractClaimsFromJWTHeader(authorization string) (usr, role string) {
usr, role = defaultValue, defaultValue
tokenParts := strings.SplitN(authorization, ".", 3)
if len(tokenParts) != 3 {
handleError("Can't split the token", map[string]interface{}{"token": authorization})
return
@@ -36,18 +36,30 @@ func extractClaimsFromJWTHeader(authorization string) (usr string, role string)
return
}
extractClaim := func(claimPath string, target *string, name string) {
if len(claimPath) > 0 {
var ok bool
*target, ok = ask.For(claimMap, claimPath).String("-")
if !ok {
handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath})
}
}
}
extractClaim(cfg.Client.JWTUserClaimPath, &usr, "user id")
extractClaim(cfg.Client.JWTRoleClaimPath, &role, "role")
usr = extractClaim(claimMap, cfg.Client.JWTUserClaimPath, "user id")
role = extractClaim(claimMap, cfg.Client.JWTRoleClaimPath, "role")
return
}
func extractClaim(claimMap map[string]interface{}, claimPath, name string) string {
if claimPath == "" {
return defaultValue
}
value, ok := ask.For(claimMap, claimPath).String(defaultValue)
if !ok {
handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": claimMap, "path": claimPath})
return defaultValue
}
return value
}
func handleError(msg string, details map[string]interface{}) {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, emptyMetrics)
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: msg,
Pairs: details,
})
}
+67 -37
View File
@@ -5,54 +5,84 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
func enableHasuraEventCleaner() {
if cfg.HasuraEventCleaner.Enable {
if cfg.HasuraEventCleaner.EventMetadataDb == "" {
cfg.Logger.Warning("Event metadata db URL not specified, event cleaner not active", nil)
return
}
const (
initialDelay = 60 * time.Second
cleanupInterval = 1 * time.Hour
)
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
cfg.Logger.Info("Event cleaner enabled", map[string]interface{}{"interval_in_days": cfg.HasuraEventCleaner.ClearOlderThan})
time.Sleep(60 * time.Second) // wait for everything to start and settle down
cfg.Logger.Info("Initial cleanup of old events", nil)
cleanEvents()
for {
select {
case <-ticker.C:
cfg.Logger.Info("Cleaning up old events", nil)
cleanEvents()
}
}
}
var delQueries = [...]string{
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - interval '%d days';",
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - interval '%d days';",
"DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL '%d days';",
"DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';",
"DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';",
}
func cleanEvents() {
conn, err := pgx.Connect(context.Background(), cfg.HasuraEventCleaner.EventMetadataDb)
if err != nil {
cfg.Logger.Error("Failed to connect to event metadata db", map[string]interface{}{"error": err})
func enableHasuraEventCleaner() {
if !cfg.HasuraEventCleaner.Enable {
return
}
defer conn.Close(context.Background())
delQueries := []string{
fmt.Sprintf("DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < now() - interval '%d days';", cfg.HasuraEventCleaner.ClearOlderThan),
fmt.Sprintf("DELETE FROM hdb_catalog.event_log WHERE created_at < now() - interval '%d days';", cfg.HasuraEventCleaner.ClearOlderThan),
fmt.Sprintf("DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan),
fmt.Sprintf("DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan),
fmt.Sprintf("DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';", cfg.HasuraEventCleaner.ClearOlderThan),
if cfg.HasuraEventCleaner.EventMetadataDb == "" {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Event metadata db URL not specified, event cleaner not active",
})
return
}
for _, query := range delQueries {
_, err := conn.Exec(context.Background(), query)
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Event cleaner enabled",
Pairs: map[string]interface{}{"interval_in_days": cfg.HasuraEventCleaner.ClearOlderThan},
})
go func() {
pool, err := pgxpool.New(context.Background(), cfg.HasuraEventCleaner.EventMetadataDb)
if err != nil {
cfg.Logger.Debug("Failed to execute query", map[string]interface{}{"query": query, "error": err})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to create connection pool",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
defer pool.Close()
time.Sleep(initialDelay)
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Initial cleanup of old events",
})
cleanEvents(pool)
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for range ticker.C {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Cleaning up old events",
})
cleanEvents(pool)
}
}()
}
func cleanEvents(pool *pgxpool.Pool) {
ctx := context.Background()
for _, query := range delQueries {
_, err := pool.Exec(ctx, fmt.Sprintf(query, cfg.HasuraEventCleaner.ClearOlderThan))
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to execute query",
Pairs: map[string]interface{}{"query": query, "error": err.Error()},
})
} else {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Successfully executed query",
Pairs: map[string]interface{}{"query": query},
})
}
}
}
+11 -10
View File
@@ -1,27 +1,28 @@
module github.com/lukaszraczylo/graphql-monitoring-proxy
go 1.21
go 1.22.4
require (
github.com/VictoriaMetrics/metrics v1.33.1
github.com/VictoriaMetrics/metrics v1.34.0
github.com/alicebob/miniredis/v2 v2.33.0
github.com/avast/retry-go/v4 v4.6.0
github.com/goccy/go-json v0.10.3
github.com/gofiber/fiber/v2 v2.52.4
github.com/gofrs/flock v0.8.1
github.com/gofrs/flock v0.9.0
github.com/google/uuid v1.6.0
github.com/gookit/goutil v0.6.15
github.com/graphql-go/graphql v0.8.1
github.com/jackc/pgx/v5 v5.6.0
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415
github.com/lukaszraczylo/go-ratecounter v0.1.8
github.com/lukaszraczylo/go-simple-graphql v1.2.14
github.com/lukaszraczylo/go-ratecounter v0.1.12
github.com/lukaszraczylo/go-simple-graphql v1.2.17
github.com/redis/go-redis/v9 v9.5.3
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.9.0
github.com/valyala/fasthttp v1.54.0
github.com/valyala/fasthttp v1.55.0
)
require (
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
@@ -29,21 +30,21 @@ require (
github.com/gookit/color v1.5.4 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/klauspost/compress v1.17.9 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/rs/zerolog v1.33.0 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fastrand v1.1.0 // indirect
github.com/valyala/histogram v1.2.0 // indirect
github.com/valyala/tcplisten v1.0.0 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect
golang.org/x/crypto v0.24.0 // indirect
golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
+19 -16
View File
@@ -1,5 +1,9 @@
github.com/VictoriaMetrics/metrics v1.33.1 h1:CNV3tfm2Kpv7Y9W3ohmvqgFWPR55tV2c7M2U6OIo+UM=
github.com/VictoriaMetrics/metrics v1.33.1/go.mod h1:r7hveu6xMdUACXvB8TYdAj8WEsKzWB0EkpJN+RDtOf8=
github.com/VictoriaMetrics/metrics v1.34.0 h1:0i8k/gdOJdSoZB4Z9pikVnVQXfhcIvnG7M7h2WaQW2w=
github.com/VictoriaMetrics/metrics v1.34.0/go.mod h1:r7hveu6xMdUACXvB8TYdAj8WEsKzWB0EkpJN+RDtOf8=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk=
github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA=
github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0=
github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M=
github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY=
github.com/avast/retry-go/v4 v4.6.0 h1:K9xNA+KeB8HHc2aWFuLb25Offp+0iVRXEvFx8IinRJA=
@@ -11,7 +15,6 @@ github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -22,8 +25,8 @@ github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PU
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofiber/fiber/v2 v2.52.4 h1:P+T+4iK7VaqUsq2PALYEfBBo6bJZ4q3FP8cZ84EggTM=
github.com/gofiber/fiber/v2 v2.52.4/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gofrs/flock v0.9.0 h1:QqEH0zKHPdEyY4YbJLleD9Il4ft7h6hn3gECO6Ss4rQ=
github.com/gofrs/flock v0.9.0/go.mod h1:O+L78Axre/Bc0Ya3RlNiGP+Rt0tFHWjtHTQ+B2uPZw8=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0=
@@ -48,10 +51,10 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415 h1:lvI8Wlbg4PxkRcg2f10wgoaRpfN19v+YdRek3+dLtlM=
github.com/lukaszraczylo/ask v0.0.0-20230927103145-2ff1123b4415/go.mod h1:M+UVdyqZs++xtEPrascaVmZdOMhCnxjZ2SgH+xHpR0c=
github.com/lukaszraczylo/go-ratecounter v0.1.8 h1:ZYm6Wkn58ZAlFWRmC7PaD4oAYHWcu8/0MUDWGe3PnJQ=
github.com/lukaszraczylo/go-ratecounter v0.1.8/go.mod h1:TqXEOCtFJStk1i0tkipprv1kiDHGon1MVUisjSTBSKM=
github.com/lukaszraczylo/go-simple-graphql v1.2.14 h1:Dth+yZ+1ialCpnslSb6UgHbXszExjDUu/I95QZbnWVU=
github.com/lukaszraczylo/go-simple-graphql v1.2.14/go.mod h1:pSKmm9OLGoS9pjmIvhBB/fo0+LganRrL29CN3fdkRPw=
github.com/lukaszraczylo/go-ratecounter v0.1.12 h1:VO6hHYGw/Jy9JUizXf/bS0AI2QX1ueWWAWckMFVJ/w4=
github.com/lukaszraczylo/go-ratecounter v0.1.12/go.mod h1:TqXEOCtFJStk1i0tkipprv1kiDHGon1MVUisjSTBSKM=
github.com/lukaszraczylo/go-simple-graphql v1.2.17 h1:XxUUgxcCIZSVLzI4UfhBDXoFoMlygcXHfAJwXxawr1s=
github.com/lukaszraczylo/go-simple-graphql v1.2.17/go.mod h1:pSKmm9OLGoS9pjmIvhBB/fo0+LganRrL29CN3fdkRPw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
@@ -60,7 +63,6 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U=
github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@@ -69,9 +71,8 @@ github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
@@ -82,8 +83,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.54.0 h1:cCL+ZZR3z3HPLMVfEYVUMtJqVaui0+gu7Lx63unHwS0=
github.com/valyala/fasthttp v1.54.0/go.mod h1:6dt4/8olwq9QARP/TDuPmWyWcl4byhpvTJ4AAtcz+QM=
github.com/valyala/fasthttp v1.55.0 h1:Zkefzgt6a7+bVKHnu/YaYSOPfNYNisSVBo/unVCf8k8=
github.com/valyala/fasthttp v1.55.0/go.mod h1:NkY9JtkrpPKmgwV3HTaS2HWaJss9RSIsRVfcxxoHiOM=
github.com/valyala/fastrand v1.1.0 h1:f+5HkLW4rsgzdNoleUOB69hyT9IlD2ZQh9GyDMfb5G8=
github.com/valyala/fastrand v1.1.0/go.mod h1:HWqCzkrkg6QXT8V2EXWvXCoow7vLwOFN002oeRzjapQ=
github.com/valyala/histogram v1.2.0 h1:wyYGAZZt3CpwUiIb9AU/Zbllg1llXyrtApRS815OLoQ=
@@ -92,10 +93,12 @@ github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVS
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E=
golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
+108 -99
View File
@@ -3,56 +3,40 @@ package main
import (
"strconv"
"strings"
"sync"
"unsafe"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
"github.com/graphql-go/graphql/language/ast"
"github.com/graphql-go/graphql/language/parser"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
var introspection_queries = []string{
"__schema",
"__type",
"__typename",
"__directive",
"__directivelocation",
"__field",
"__inputvalue",
"__enumvalue",
"__typekind",
"__fieldtype",
"__inputobjecttype",
"__enumtype",
"__uniontype",
"__scalars",
"__objects",
"__interfaces",
"__unions",
"__enums",
"__inputobjects",
"__directives",
}
// Saving the introspection queries as a map O(1) operation instead of O(n) for a slice.
var introspectionQuerySet = map[string]struct{}{}
var introspectionAllowedQueries = map[string]struct{}{}
var allowedUrls = map[string]struct{}{}
// Utility function to convert a slice of strings to a map for O(1) lookups.
func sliceToMap(slice []string) map[string]struct{} {
resultMap := make(map[string]struct{}, len(slice))
for _, item := range slice {
resultMap[strings.ToLower(item)] = struct{}{}
var (
introspectionQueries = map[string]struct{}{
"__schema": {}, "__type": {}, "__typename": {}, "__directive": {},
"__directivelocation": {}, "__field": {}, "__inputvalue": {},
"__enumvalue": {}, "__typekind": {}, "__fieldtype": {},
"__inputobjecttype": {}, "__enumtype": {}, "__uniontype": {},
"__scalars": {}, "__objects": {}, "__interfaces": {},
"__unions": {}, "__enums": {}, "__inputobjects": {}, "__directives": {},
}
return resultMap
}
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
mu sync.RWMutex
)
func prepareQueriesAndExemptions() {
introspectionQuerySet = sliceToMap(introspection_queries)
introspectionAllowedQueries = sliceToMap(cfg.Security.IntrospectionAllowed)
allowedUrls = sliceToMap(cfg.Server.AllowURLs)
mu.Lock()
defer mu.Unlock()
for _, q := range cfg.Security.IntrospectionAllowed {
introspectionAllowedQueries[strings.ToLower(q)] = struct{}{}
}
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
}
}
type parseGraphQLQueryResult struct {
@@ -66,34 +50,63 @@ type parseGraphQLQueryResult struct {
shouldIgnore bool
}
func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
res = &parseGraphQLQueryResult{shouldIgnore: true}
m := make(map[string]interface{})
err := json.Unmarshal(c.Body(), &m)
if err != nil {
cfg.Logger.Error("Can't unmarshal the request", map[string]interface{}{"error": err.Error(), "body": string(c.Body())})
var (
queryPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{}, 4)
},
}
resultPool = sync.Pool{
New: func() interface{} {
return &parseGraphQLQueryResult{}
},
}
)
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
res := resultPool.Get().(*parseGraphQLQueryResult)
defer resultPool.Put(res)
*res = parseGraphQLQueryResult{shouldIgnore: true}
m := queryPool.Get().(map[string]interface{})
defer queryPool.Put(m)
for k := range m {
delete(m, k)
}
if err := json.Unmarshal(c.Body(), &m); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't unmarshal the request",
Pairs: map[string]interface{}{"error": err.Error(), "body": unsafeString(c.Body())},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
return res
}
// get the query
query, ok := m["query"].(string)
if !ok {
cfg.Logger.Error("Can't find the query", map[string]interface{}{"query": query, "m_val": m})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't find the query",
Pairs: map[string]interface{}{"m_val": m},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return
return res
}
p, err := parser.Parse(parser.ParseParams{Source: query})
if err != nil {
cfg.Logger.Error("Can't parse the query", map[string]interface{}{"query": query, "m_val": m})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the query",
Pairs: map[string]interface{}{"query": query, "m_val": m},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
return res
}
res.shouldIgnore = false
@@ -102,44 +115,44 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
res.operationName = oper.Name.Value
// If we haven't set an operation type yet, use this one
if res.operationType == "" {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
res.operationName = oper.Name.Value
}
}
// If the query is a mutation then direct it to the RW endpoint,
// otherwise direct it to the RO endpoint if it's set.
if cfg.Server.HostGraphQLReadOnly != "" && res.operationType != "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning("Mutation blocked", m)
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Mutation blocked",
Pairs: map[string]interface{}{"query": query},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("The server is in read-only mode")
_ = c.Status(403).SendString("The server is in read-only mode")
res.shouldBlock = true
return
return res
}
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
res.cacheRequest = true
for _, arg := range dir.Arguments {
if arg.Name.Value == "ttl" {
res.cacheTime, err = strconv.Atoi(arg.Value.GetValue().(string))
if err != nil {
cfg.Logger.Error("Can't parse the ttl, using global", map[string]interface{}{"bad_ttl": arg.Value.GetValue().(string)})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return
switch arg.Name.Value {
case "ttl":
if v, ok := arg.Value.GetValue().(string); ok {
res.cacheTime, _ = strconv.Atoi(v)
}
case "refresh":
if v, ok := arg.Value.GetValue().(bool); ok {
res.cacheRefresh = v
}
}
if arg.Name.Value == "refresh" {
res.cacheRefresh = arg.Value.GetValue().(bool)
}
}
}
@@ -148,26 +161,25 @@ func parseGraphQLQuery(c *fiber.Ctx) (res *parseGraphQLQueryResult) {
if cfg.Security.BlockIntrospection {
res.shouldBlock = checkSelections(c, oper.GetSelectionSet().Selections)
if res.shouldBlock {
return
return res
}
}
}
}
return
return res
}
func unsafeString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
for _, s := range selections {
field, ok := s.(*ast.Field)
if !ok {
continue // or handle the case where the type assertion fails
}
shouldBlock := checkIfContainsIntrospection(c, field.Name.Value)
if shouldBlock {
return true
}
if field.SelectionSet != nil {
if checkSelections(c, field.GetSelectionSet().Selections) {
if field, ok := s.(*ast.Field); ok {
if checkIfContainsIntrospection(c, field.Name.Value) {
return true
}
if field.SelectionSet != nil && checkSelections(c, field.GetSelectionSet().Selections) {
return true
}
}
@@ -175,29 +187,26 @@ func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) (shouldBlock bool) {
func checkIfContainsIntrospection(c *fiber.Ctx, whatever string) bool {
whateverLower := strings.ToLower(whatever)
got_exemption := false
mu.RLock()
defer mu.RUnlock()
// If the query is an introspection query, we need to check if it's allowed.
if _, exists := introspectionQuerySet[whateverLower]; exists {
if _, exists := introspectionQueries[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed_exists := introspectionAllowedQueries[whateverLower]; allowed_exists {
cfg.Logger.Debug("Introspection query allowed, passing through", map[string]interface{}{"query": whatever})
got_exemption = true
shouldBlock = false
if _, allowed := introspectionAllowedQueries[whateverLower]; allowed {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Introspection query allowed, passing through",
Pairs: map[string]interface{}{"query": whatever},
})
return false
}
}
if !got_exemption {
shouldBlock = true
}
}
if shouldBlock {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Introspection queries are not allowed")
_ = c.Status(403).SendString("Introspection queries are not allowed")
return true
}
return
return false
}
+112
View File
@@ -1,6 +1,10 @@
package main
import (
"fmt"
"strings"
fiber "github.com/gofiber/fiber/v2"
"github.com/valyala/fasthttp"
)
@@ -318,3 +322,111 @@ func (suite *Tests) Test_parseGraphQLQuery() {
})
}
}
func (suite *Tests) Test_parseGraphQLQuery_complex() {
// ... existing tests ...
// Add these new test cases
suite.Run("test complex query with multiple operations", func() {
query := `
query GetUser($id: ID!) {
user(id: $id) {
name
email
}
}
mutation UpdateUser($id: ID!, $name: String!) {
updateUser(id: $id, name: $name) {
id
name
}
}
`
body := fmt.Sprintf(`{"query": %q}`, query)
ctx := createTestContext(body)
result := parseGraphQLQuery(ctx)
assert.Equal("query", result.operationType)
assert.Equal("GetUser", result.operationName)
assert.False(result.shouldBlock)
})
suite.Run("test query with custom directives", func() {
query := `
query GetUser($id: ID!) @custom(directive: "value") {
user(id: $id) {
name
email
}
}
`
body := fmt.Sprintf(`{"query": %q}`, query)
ctx := createTestContext(body)
result := parseGraphQLQuery(ctx)
assert.Equal("query", result.operationType)
assert.Equal("GetUser", result.operationName)
assert.False(result.shouldBlock)
assert.False(result.shouldBlock)
})
}
func (suite *Tests) Test_checkAllowedURLs() {
tests := []struct {
name string
path string
allowed []string
expected bool
}{
{"allowed path", "/v1/graphql", []string{"/v1/graphql"}, true},
{"disallowed path", "/v2/graphql", []string{"/v1/graphql"}, false},
{"empty allowed list", "/v1/graphql", []string{}, true},
{"multiple allowed paths", "/v2/graphql", []string{"/v1/graphql", "/v2/graphql"}, true},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
allowedUrls = make(map[string]struct{})
for _, url := range tt.allowed {
allowedUrls[url] = struct{}{}
}
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().SetRequestURI(tt.path)
ctx.Request().URI().SetPath(tt.path)
result := checkAllowedURLs(ctx)
assert.Equal(tt.expected, result)
})
}
}
func (suite *Tests) Test_checkIfContainsIntrospection() {
tests := []struct {
name string
query string
allowed []string
expected bool
}{
{"allowed introspection", "__schema", []string{"__schema"}, false},
{"disallowed introspection", "__type", []string{"__schema"}, true},
{"non-introspection query", "normalQuery", []string{}, false},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
cfg.Security.IntrospectionAllowed = tt.allowed
introspectionAllowedQueries = make(map[string]struct{})
for _, q := range tt.allowed {
introspectionAllowedQueries[strings.ToLower(q)] = struct{}{}
}
ctx := createTestContext("")
result := checkIfContainsIntrospection(ctx, tt.query)
assert.Equal(tt.expected, result)
})
}
}
func createTestContext(body string) *fiber.Ctx {
app := fiber.New()
ctx := app.AcquireCtx(&fasthttp.RequestCtx{})
ctx.Request().SetBody([]byte(body))
return ctx
}
+204
View File
@@ -0,0 +1,204 @@
package libpack_logger
import (
"bytes"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/goccy/go-json"
)
const (
_ = iota
LEVEL_DEBUG
LEVEL_INFO
LEVEL_WARN
LEVEL_ERROR
LEVEL_FATAL
)
var LevelNames = [...]string{
"none",
"debug",
"info",
"warn",
"error",
"fatal",
}
const (
defaultFormat = time.RFC3339
defaultMinLevel = LEVEL_INFO
defaultShowCaller = false
)
var defaultOutput = os.Stdout
type Logger struct {
output io.Writer
format string
minLogLevel int
showCaller bool
}
type LogMessage struct {
output io.Writer
Pairs map[string]any
Message string
}
func (m *LogMessage) String() string {
return m.Message
}
var fieldNames = map[string]string{
"timestamp": "timestamp",
"level": "level",
"message": "message",
}
func New() *Logger {
return &Logger{
format: defaultFormat,
minLogLevel: defaultMinLevel,
output: defaultOutput,
showCaller: defaultShowCaller,
}
}
func (l *Logger) SetOutput(output io.Writer) *Logger {
l.output = output
return l
}
var bufferPool = sync.Pool{
New: func() any {
return new(bytes.Buffer)
},
}
var defaultPairs = make(map[string]any)
func GetLogLevel(level string) int {
for i, name := range LevelNames {
if name == strings.ToLower(level) {
return i
}
}
return defaultMinLevel
}
func (l *Logger) log(level int, m *LogMessage) {
if m.Pairs == nil {
m.Pairs = defaultPairs
}
m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.format)
m.Pairs[fieldNames["level"]] = LevelNames[level]
m.Pairs[fieldNames["message"]] = m.Message
if l.showCaller {
m.Pairs["caller"] = getCaller()
}
buffer := bufferPool.Get().(*bytes.Buffer)
defer bufferPool.Put(buffer)
buffer.Reset()
var encoder = json.NewEncoder(buffer)
err := encoder.Encode(m.Pairs)
if err != nil {
fmt.Println("Error marshalling log message:", err)
return
}
// if not running in test - use stderr and stdout, otherwise - use logger's output setting
if flag.Lookup("test.v") != nil {
m.output = os.Stdout
if level >= LEVEL_ERROR {
m.output = os.Stderr
}
}
// Use logger's output setting instead of os.Stdout or os.Stderr
l.output.Write(buffer.Bytes())
}
func (l *Logger) Debug(m *LogMessage) {
if l.shouldLog(LEVEL_DEBUG) {
l.log(LEVEL_DEBUG, m)
}
}
func (l *Logger) Info(m *LogMessage) {
if l.shouldLog(LEVEL_INFO) {
l.log(LEVEL_INFO, m)
}
}
func (l *Logger) Warn(m *LogMessage) {
if l.shouldLog(LEVEL_WARN) {
l.log(LEVEL_WARN, m)
}
}
func (l *Logger) Warning(m *LogMessage) {
l.Warn(m)
}
func (l *Logger) Error(m *LogMessage) {
if l.shouldLog(LEVEL_ERROR) {
l.log(LEVEL_ERROR, m)
}
}
func (l *Logger) Fatal(m *LogMessage) {
if l.shouldLog(LEVEL_FATAL) {
l.log(LEVEL_FATAL, m)
}
}
func (l *Logger) Critical(m *LogMessage) {
l.Fatal(m)
os.Exit(1)
}
func (l *Logger) shouldLog(level int) bool {
return level >= l.minLogLevel
}
func (l *Logger) SetFormat(format string) *Logger {
l.format = format
return l
}
func (l *Logger) SetMinLogLevel(level int) *Logger {
l.minLogLevel = level
return l
}
func (l *Logger) SetFieldName(field, name string) *Logger {
fieldNames[field] = name
return l
}
func (l *Logger) SetShowCaller(show bool) *Logger {
l.showCaller = show
return l
}
func getCaller() string {
_, file, line, ok := runtime.Caller(3)
if !ok {
return "unknown:0"
}
file = filepath.Base(file)
return fmt.Sprintf("%s:%d", file, line)
}
+140
View File
@@ -0,0 +1,140 @@
package libpack_logger
import (
"bytes"
"testing"
"time"
)
func Benchmark_NewLogger(b *testing.B) {
type triggers struct {
ModFormat struct {
Format string
}
ModLevel struct {
Level int
}
}
tests := []struct {
name string
triggers triggers
}{
{
name: "BenchmarkNew",
},
{
name: "BenchmarkNewChangeTimeFormat",
triggers: triggers{
ModFormat: struct{ Format string }{
Format: time.RFC3339Nano,
},
},
},
{
name: "BenchmarkNewChangeLogLevel",
triggers: triggers{
ModLevel: struct{ Level int }{
Level: LEVEL_DEBUG,
},
},
},
{
name: "BenchmarkNewChangeTimeFormatAndLogLevel",
triggers: triggers{
ModFormat: struct{ Format string }{
Format: time.RFC3339Nano,
},
ModLevel: struct{ Level int }{
Level: LEVEL_DEBUG,
},
},
},
}
for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
got := New()
if tt.triggers.ModFormat.Format != "" {
got = got.SetFormat(tt.triggers.ModFormat.Format)
}
if tt.triggers.ModLevel.Level != 0 {
got = got.SetMinLogLevel(tt.triggers.ModLevel.Level)
}
}
})
}
}
func Benchmark_Log_Debug(b *testing.B) {
output := &bytes.Buffer{}
logger := New().SetMinLogLevel(LEVEL_DEBUG).SetOutput(output)
msg := &LogMessage{
Message: "debug message",
Pairs: make(map[string]any),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Debug(msg)
}
}
func Benchmark_Log_Info(b *testing.B) {
output := &bytes.Buffer{}
logger := New().SetMinLogLevel(LEVEL_INFO).SetOutput(output)
msg := &LogMessage{
Message: "info message",
Pairs: make(map[string]any),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Info(msg)
}
}
func Benchmark_Log_Warn(b *testing.B) {
output := &bytes.Buffer{}
logger := New().SetMinLogLevel(LEVEL_WARN).SetOutput(output)
msg := &LogMessage{
Message: "warn message",
Pairs: make(map[string]any),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Warn(msg)
}
}
func Benchmark_Log_Error(b *testing.B) {
output := &bytes.Buffer{}
logger := New().SetMinLogLevel(LEVEL_ERROR).SetOutput(output)
msg := &LogMessage{
Message: "error message",
Pairs: map[string]any{"key": "value"},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Error(msg)
}
}
func Benchmark_Log_Fatal(b *testing.B) {
output := &bytes.Buffer{}
logger := New().SetMinLogLevel(LEVEL_FATAL).SetOutput(output)
msg := &LogMessage{
Message: "fatal message",
Pairs: make(map[string]any),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Fatal(msg)
}
}
+31
View File
@@ -0,0 +1,31 @@
package libpack_logger
import (
"testing"
assertions "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type LoggerTestSuite struct {
suite.Suite
}
var (
assert *assertions.Assertions
)
func (suite *LoggerTestSuite) BeforeTest(suiteName, testName string) {
}
func (suite *LoggerTestSuite) SetupTest() {
assert = assertions.New(suite.T())
}
// TearDownTest is run after each test to clean up
func (suite *LoggerTestSuite) TearDownTest() {
}
func TestSuite(t *testing.T) {
suite.Run(t, new(LoggerTestSuite))
}
+182
View File
@@ -0,0 +1,182 @@
package libpack_logger
import (
"bytes"
"fmt"
"os"
"reflect"
"testing"
"time"
"github.com/goccy/go-json"
)
func captureStderr(f func()) string {
originalStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
f()
w.Close()
var buf bytes.Buffer
buf.ReadFrom(r)
os.Stderr = originalStderr
return buf.String()
}
func captureStdOut(f func()) string {
originalStdout := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
f()
w.Close()
var buf bytes.Buffer
buf.ReadFrom(r)
os.Stdout = originalStdout
return buf.String()
}
func (suite *LoggerTestSuite) Test_LogMessageString() {
msg := &LogMessage{
Message: "test message",
}
assert.Equal("test message", msg.String())
}
func callLoggerMethod(logger *Logger, methodName string, message *LogMessage) {
// Get the method by name using reflection
method := reflect.ValueOf(logger).MethodByName(methodName)
if method.IsValid() {
// Call the method with the message as an argument
method.Call([]reflect.Value{reflect.ValueOf(message)})
} else {
fmt.Printf("Method %s does not exist on Logger\n", methodName)
}
}
func (suite *LoggerTestSuite) Test_LogsLevelsPrint() {
output := &bytes.Buffer{}
logger := New().SetOutput(output)
tests := []struct {
pairs map[string]any
name string
method string
message string
loggerMinLevel int
messageLogLevel int
wantOutput bool
}{
{
name: "Log: Debug, Level: Debug - no pairs",
method: "Debug",
loggerMinLevel: LEVEL_DEBUG,
messageLogLevel: LEVEL_DEBUG,
message: "debug message",
wantOutput: true,
},
{
name: "Log: Info, Level: Info - one pair",
method: "Info",
loggerMinLevel: LEVEL_INFO,
messageLogLevel: LEVEL_INFO,
message: "info message",
pairs: map[string]any{
"key": "value",
},
wantOutput: true,
},
{
name: "Log: Info, Level: Warn - with pairs",
method: "Info",
loggerMinLevel: LEVEL_WARN,
messageLogLevel: LEVEL_INFO,
message: "warn message",
pairs: map[string]any{
"key1": "value1",
"key2": "value2",
},
wantOutput: false,
},
{
name: "Log: Warn, Level: Info - with 500 pairs",
method: "Warn",
loggerMinLevel: LEVEL_INFO,
messageLogLevel: LEVEL_WARN,
message: "warn message with 500 pairs",
pairs: func() map[string]any {
pairs := make(map[string]any)
for i := 0; i < 500; i++ {
pairs[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
}
return pairs
}(),
wantOutput: true,
},
}
for _, tt := range tests {
suite.T().Run(tt.name, func(t *testing.T) {
msg := &LogMessage{
Message: tt.message,
Pairs: tt.pairs,
}
output.Reset()
// Set logger's minimum log level
logger.SetMinLogLevel(tt.loggerMinLevel)
fmt.Println("Logger min log level:", LevelNames[logger.minLogLevel])
// Call the logging method
callLoggerMethod(logger, tt.method, msg)
logOutput := output.String()
fmt.Println("Output:", logOutput)
if tt.wantOutput {
var loggedMessage map[string]any
err := json.Unmarshal([]byte(logOutput), &loggedMessage)
if err != nil {
t.Fatalf("Error unmarshalling log message: %v\nLog output: %s", err, logOutput)
}
if !containsLogMessage(logOutput, tt.message) {
t.Errorf("Expected log message %q, but got %q", tt.message, logOutput)
}
assert.Equal(LevelNames[tt.messageLogLevel], loggedMessage["level"])
if tt.pairs != nil {
for k, v := range tt.pairs {
assert.Equal(v, loggedMessage[k])
}
}
} else {
assert.Equal("", logOutput)
}
})
}
}
func containsLogMessage(logOutput, expectedMessage string) bool {
return bytes.Contains([]byte(logOutput), []byte(expectedMessage))
}
func (suite *LoggerTestSuite) Test_SetFormat() {
logger := New().SetFormat(time.RFC3339Nano)
assert.Equal(time.RFC3339Nano, logger.format)
}
func (suite *LoggerTestSuite) Test_SetMinLogLevel() {
logger := New().SetMinLogLevel(LEVEL_DEBUG)
assert.Equal(LEVEL_DEBUG, logger.minLogLevel)
}
func (suite *LoggerTestSuite) Test_ShouldLog() {
logger := New().SetMinLogLevel(LEVEL_WARN)
assert.True(logger.shouldLog(LEVEL_WARN))
assert.True(logger.shouldLog(LEVEL_ERROR))
assert.False(logger.shouldLog(LEVEL_INFO))
assert.False(logger.shouldLog(LEVEL_DEBUG))
}
-123
View File
@@ -1,123 +0,0 @@
package libpack_logging
import (
"io"
"os"
"sync"
"time"
"github.com/gookit/goutil/envutil"
"github.com/rs/zerolog"
)
type LogConfig struct {
logger zerolog.Logger
}
var (
baseLogger zerolog.Logger
eventPool = sync.Pool{
New: func() interface{} {
return new(zerolog.Event)
},
}
fieldMapPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{})
},
}
)
func init() {
zerolog.TimeFieldFormat = time.RFC3339
zerolog.MessageFieldName = "short_message"
zerolog.TimestampFieldName = "timestamp"
zerolog.LevelFieldName = "level"
zerolog.LevelFatalValue = "critical"
baseLogger = zerolog.New(os.Stdout).With().Timestamp().Logger()
switch logLevel := envutil.Getenv("LOG_LEVEL", "info"); logLevel {
case "debug":
baseLogger = baseLogger.Level(zerolog.DebugLevel)
case "warn":
baseLogger = baseLogger.Level(zerolog.WarnLevel)
case "error":
baseLogger = baseLogger.Level(zerolog.ErrorLevel)
default:
baseLogger = baseLogger.Level(zerolog.InfoLevel)
}
}
func NewLogger() *LogConfig {
return &LogConfig{logger: baseLogger}
}
func (lw *LogConfig) log(w io.Writer, level zerolog.Level, message string, fields map[string]interface{}) {
logger := lw.logger.Output(w)
event := logger.WithLevel(level).CallerSkipFrame(3)
for k, val := range fields {
switch v := val.(type) {
case string:
event = event.Str(k, v)
case int:
event = event.Int(k, v)
case float64:
event = event.Float64(k, v)
default:
event = event.Interface(k, val)
}
}
event.Msg(message)
}
func (lw *LogConfig) logWithLevel(level zerolog.Level, message string, fields map[string]interface{}) {
if lw.logger.GetLevel() > level {
return
}
if lw.logger.GetLevel() <= level {
w := os.Stdout
if level >= zerolog.ErrorLevel {
w = os.Stderr
}
lw.log(w, level, message, fields)
}
}
func (lw *LogConfig) Debug(message string, fields map[string]interface{}) {
lw.logWithLevel(zerolog.DebugLevel, message, fields)
}
func (lw *LogConfig) Info(message string, fields map[string]interface{}) {
lw.logWithLevel(zerolog.InfoLevel, message, fields)
}
func (lw *LogConfig) Warning(message string, fields map[string]interface{}) {
lw.logWithLevel(zerolog.WarnLevel, message, fields)
}
func (lw *LogConfig) Error(message string, fields map[string]interface{}) {
lw.logWithLevel(zerolog.ErrorLevel, message, fields)
}
func (lw *LogConfig) Critical(message string, fields map[string]interface{}) {
lw.logWithLevel(zerolog.FatalLevel, message, fields)
os.Exit(1)
}
// Helper function to get a new fields map from the pool
func getFieldsMap() map[string]interface{} {
return fieldMapPool.Get().(map[string]interface{})
}
// Helper function to put a used fields map back into the pool
func putFieldsMap(fields map[string]interface{}) {
for k := range fields {
delete(fields, k)
}
fieldMapPool.Put(fields)
}
-32
View File
@@ -1,32 +0,0 @@
package libpack_logging
import (
"os"
"testing"
)
func BenchmarkNewLogger(b *testing.B) {
for i := 0; i < b.N; i++ {
NewLogger()
}
}
func BenchmarkInfoLog(b *testing.B) {
oldEnv := os.Getenv("LOG_LEVEL")
os.Setenv("LOG_LEVEL", "info")
oldStdout := os.Stdout
oldStderr := os.Stderr
os.Stdout, _ = os.Open(os.DevNull)
os.Stderr, _ = os.Open(os.DevNull)
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
os.Setenv("LOG_LEVEL", oldEnv)
}()
testsLogger := NewLogger()
b.ResetTimer()
for i := 0; i < b.N; i++ {
testsLogger.Info("test", map[string]interface{}{"test": "test"})
}
}
+16 -2
View File
@@ -9,6 +9,7 @@ import (
"github.com/gofiber/fiber/v2/middleware/proxy"
"github.com/gookit/goutil/envutil"
graphql "github.com/lukaszraczylo/go-simple-graphql"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
@@ -62,7 +63,8 @@ func parseConfig() {
}
return strings.Split(urls, ",")
}()
c.Logger = libpack_logging.NewLogger()
c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info"))
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "")
c.Client.GQLClient = graphql.NewConnection()
c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL)
@@ -88,7 +90,19 @@ func parseConfig() {
c.HasuraEventCleaner.EventMetadataDb = getDetailsFromEnv("HASURA_EVENT_METADATA_DB", "")
cfg = &c
enableCache() // takes close to no resources, but can be used with dynamic query cache
if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
cacheConfig := &libpack_cache.CacheConfig{
Logger: cfg.Logger,
TTL: cfg.Cache.CacheTTL,
}
if cfg.Cache.CacheRedisEnable {
cacheConfig.Redis.Enable = true
cacheConfig.Redis.URL = cfg.Cache.CacheRedisURL
cacheConfig.Redis.Password = cfg.Cache.CacheRedisPassword
cacheConfig.Redis.DB = cfg.Cache.CacheRedisDB
}
libpack_cache.EnableCache(cacheConfig)
}
loadRatelimitConfig()
once.Do(func() {
go enableApi()
+28 -3
View File
@@ -34,14 +34,13 @@ func (suite *Tests) SetupTest() {
JSONDecoder: json.Unmarshal,
},
)
cacheStats = &CacheStats{}
// Initialize a simple in-memory cache client for testing purposes
cfg.Cache.Client = libpack_cache.New(5 * time.Minute)
libpack_cache.New(5 * time.Minute)
parseConfig()
enableApi()
StartMonitoringServer()
cfg.Logger = libpack_logging.NewLogger()
cfg.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(getDetailsFromEnv("LOG_LEVEL", "info")))
// Setup environment variables here if needed
os.Setenv("GMP_TEST_STRING", "testValue")
os.Setenv("GMP_TEST_INT", "123")
@@ -113,3 +112,29 @@ func (suite *Tests) Test_envVariableSetting() {
})
}
}
func (suite *Tests) Test_getDetailsFromEnv() {
tests := []struct {
name string
key string
defaultValue interface{}
envValue string
expected interface{}
}{
{"string value", "TEST_STRING", "default", "envValue", "envValue"},
{"int value", "TEST_INT", 0, "123", 123},
{"bool value", "TEST_BOOL", false, "true", true},
{"default value", "NON_EXISTENT", "default", "", "default"},
}
for _, tt := range tests {
suite.Run(tt.name, func() {
if tt.envValue != "" {
os.Setenv("GMP_"+tt.key, tt.envValue)
defer os.Unsetenv("GMP_" + tt.key)
}
result := getDetailsFromEnv(tt.key, tt.defaultValue)
assert.Equal(tt.expected, result)
})
}
}
+133 -89
View File
@@ -1,129 +1,173 @@
package libpack_monitoring
import (
"bytes"
"fmt"
"os"
"sort"
"strings"
"sync"
"unicode"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
)
func (ms *MetricsSetup) get_metrics_name(name string, labels map[string]string) (complete_name string) {
var sortedLabelKeysCache = struct {
m sync.Map
}{}
func (ms *MetricsSetup) get_metrics_name(name string, labels map[string]string) string {
const unknownPodName = "unknown"
var sb strings.Builder
var buf bytes.Buffer
// Prepare default labels without initializing a new map
podName := unknownPodName
if hn, err := os.Hostname(); err == nil {
podName = hn
}
podName := getPodName()
if labels == nil {
labels = map[string]string{
"microservice": libpack_config.PKG_NAME,
"pod": podName,
}
labels = defaultLabels(podName)
} else {
if _, exists := labels["microservice"]; !exists {
labels["microservice"] = libpack_config.PKG_NAME
}
if _, exists := labels["pod"]; !exists {
labels["pod"] = podName
}
ensureDefaultLabels(&labels, podName)
}
// Prefix handling
if ms.metrics_prefix != "" {
sb.WriteString(ms.metrics_prefix)
sb.WriteString("_")
buf.WriteString(ms.metrics_prefix)
buf.WriteByte('_')
}
sb.WriteString(name)
buf.WriteString(name)
// Append labels if any
if len(labels) > 0 {
sb.WriteString("{")
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(k)
sb.WriteString("=\"")
sb.WriteString(labels[k])
sb.WriteString("\"")
}
sb.WriteString("}")
buf.WriteByte('{')
appendSortedLabels(&buf, labels)
buf.WriteByte('}')
}
return sb.String()
return buf.String()
}
// validate_metrics_name validates the name of the metric to adhere to the Prometheus naming conventions
// https://prometheus.io/docs/practices/naming/
func validate_metrics_name(name string) error {
var sb strings.Builder // Use strings.Builder for efficient string concatenation
// Track if the last character was an underscore to avoid duplicate underscores
lastWasUnderscore := false
for _, r := range name {
// Convert spaces to underscores and skip non-alphanumeric characters except underscores
if r == ' ' || (unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_') {
if r == ' ' || r == '_' {
if lastWasUnderscore {
continue // Skip if the previous character was also an underscore
}
r = '_' // Convert spaces to underscores
lastWasUnderscore = true
} else {
lastWasUnderscore = false
}
sb.WriteRune(r) // Add valid characters to the builder
}
func getPodName() string {
const unknownPodName = "unknown"
if hn, err := os.Hostname(); err == nil {
return hn
}
// Trim leading and trailing underscores
name_new := strings.Trim(sb.String(), "_")
// Check if the processed name matches the original input
if name_new != name {
return fmt.Errorf("Invalid metric name: %s, expected %s", name, name_new)
}
return nil
return unknownPodName
}
func compile_metrics_with_labels(name string, labels map[string]string) string {
var totalLength int
totalLength += len(name)
for k, v := range labels {
totalLength += len(k) + len(v) + 2
func defaultLabels(podName string) map[string]string {
return map[string]string{
"microservice": libpack_config.PKG_NAME,
"pod": podName,
}
}
func ensureDefaultLabels(labels *map[string]string, podName string) {
if *labels == nil {
*labels = make(map[string]string)
}
if _, exists := (*labels)["microservice"]; !exists {
(*labels)["microservice"] = libpack_config.PKG_NAME
}
if _, exists := (*labels)["pod"]; !exists {
(*labels)["pod"] = podName
}
}
func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
keys := getSortedKeys(labels)
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(k)
buf.WriteString(`="`)
buf.WriteString(labels[k])
buf.WriteByte('"')
}
}
func getSortedKeys(labels map[string]string) []string {
labelsKey := labelsToString(labels)
if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok {
return keys.([]string)
}
var sb strings.Builder
sb.Grow(totalLength + 1)
sb.WriteString(name)
// Collect keys and sort them
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
// Append sorted key-value pairs to the builder
for _, k := range keys {
sb.WriteString("_")
sb.WriteString(k)
sb.WriteString("_")
sb.WriteString(labels[k])
}
sortedLabelKeysCache.m.Store(labelsKey, keys)
return keys
}
func labelsToString(labels map[string]string) string {
var sb strings.Builder
for k, v := range labels {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(v)
sb.WriteByte(';')
}
return sb.String()
}
func validate_metrics_name(name string) error {
cleanedName := clean_metric_name(name)
finalName := strings.Trim(cleanedName, "_")
if finalName != name {
return fmt.Errorf("invalid metric name: %s, expected %s", name, finalName)
}
return nil
}
func clean_metric_name(name string) string {
var buf bytes.Buffer
lastWasUnderscore := false
for _, r := range name {
if is_allowed_rune(r) {
if is_special_rune(r) {
if lastWasUnderscore {
continue
}
r = '_'
lastWasUnderscore = true
} else {
lastWasUnderscore = false
}
buf.WriteRune(r)
} else if !lastWasUnderscore {
buf.WriteByte('_')
lastWasUnderscore = true
}
}
return strings.Trim(buf.String(), "_")
}
func is_allowed_rune(r rune) bool {
return unicode.IsLetter(r) || unicode.IsDigit(r) || r == ' ' || r == '_'
}
func is_special_rune(r rune) bool {
return r == ' ' || r == '_'
}
func compile_metrics_with_labels(name string, labels map[string]string) string {
var buf bytes.Buffer
buf.WriteString(name)
keys := getSortedKeys(labels)
for _, k := range keys {
buf.WriteByte('_')
buf.WriteString(k)
buf.WriteByte('_')
buf.WriteString(labels[k])
}
return buf.String()
}
+91 -6
View File
@@ -1,7 +1,6 @@
package libpack_monitoring
import (
"os"
"testing"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
@@ -134,10 +133,96 @@ func TestValidateMetricsName(t *testing.T) {
}
}
func getPodName() string {
podName, err := os.Hostname()
if err != nil {
return "unknown"
func TestCleanMetricName(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"valid metric name", "valid_metric_name"},
{"valid@metric#name!", "valid_metric_name"},
{"__valid__metric__name__", "valid_metric_name"},
{" valid metric name ", "valid_metric_name"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
assert.Equal(t, tt.expected, clean_metric_name(tt.input))
})
}
}
func TestDefaultLabels(t *testing.T) {
podName := "test-pod"
libpack_config.PKG_NAME = "example_microservice"
expected := map[string]string{
"microservice": "example_microservice",
"pod": podName,
}
assert.Equal(t, expected, defaultLabels(podName))
}
func TestEnsureDefaultLabels(t *testing.T) {
podName := "test-pod"
libpack_config.PKG_NAME = "example_microservice"
tests := []struct {
inputLabels map[string]string
expectedLabels map[string]string
name string
}{
{
name: "Nil labels",
inputLabels: nil,
expectedLabels: map[string]string{"microservice": "example_microservice", "pod": podName},
},
{
name: "Empty labels",
inputLabels: map[string]string{},
expectedLabels: map[string]string{"microservice": "example_microservice", "pod": podName},
},
{
name: "Partial labels",
inputLabels: map[string]string{"microservice": "test_service"},
expectedLabels: map[string]string{"microservice": "test_service", "pod": podName},
},
{
name: "Complete labels",
inputLabels: map[string]string{"microservice": "test_service", "pod": "custom_pod"},
expectedLabels: map[string]string{"microservice": "test_service", "pod": "custom_pod"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ensureDefaultLabels(&tt.inputLabels, podName)
assert.Equal(t, tt.expectedLabels, tt.inputLabels)
})
}
}
func TestLabelsToString(t *testing.T) {
tests := []struct {
labels map[string]string
expected string
}{
{
labels: map[string]string{"key1": "value1", "key2": "value2"},
expected: "key1=value1;key2=value2;",
},
{
labels: map[string]string{"a": "1", "b": "2"},
expected: "a=1;b=2;",
},
{
labels: map[string]string{},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
assert.Equal(t, tt.expected, labelsToString(tt.labels))
})
}
return podName
}
+37 -26
View File
@@ -1,6 +1,3 @@
// Package `libpack_monitoring` provides and easy way to add prometheus metrics to your application.
// It also provides a way to add custom metrics to the already started prometheus registry.
package libpack_monitoring
import (
@@ -12,7 +9,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gookit/goutil/envutil"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
type MetricsSetup struct {
@@ -22,9 +19,7 @@ type MetricsSetup struct {
metrics_prefix string
}
var (
log *logging.LogConfig
)
var log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO)
type InitConfig struct {
PurgeOnCrawl bool
@@ -32,11 +27,11 @@ type InitConfig struct {
}
func NewMonitoring(ic *InitConfig) *MetricsSetup {
log = logging.NewLogger()
ms := &MetricsSetup{ic: ic}
ms.metrics_set = metrics.NewSet()
ms.metrics_set_custom = metrics.NewSet()
// if not testing, start the prometheus endpoint
ms := &MetricsSetup{
ic: ic,
metrics_set: metrics.NewSet(),
metrics_set_custom: metrics.NewSet(),
}
if flag.Lookup("test.v") == nil {
go ms.startPrometheusEndpoint()
@@ -60,9 +55,11 @@ func (ms *MetricsSetup) startPrometheusEndpoint() {
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
})
app.Get("/metrics", ms.metricsEndpoint)
err := app.Listen(fmt.Sprintf(":%d", envutil.GetInt("MONITORING_PORT", 9393)))
if err != nil {
fmt.Println("Can't start the service: ", err)
if err := app.Listen(fmt.Sprintf(":%d", envutil.GetInt("MONITORING_PORT", 9393))); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "Can't start the service",
Pairs: map[string]interface{}{"error": err},
})
}
}
@@ -85,19 +82,24 @@ func (ms *MetricsSetup) ListActiveMetrics() []string {
}
func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[string]string, val float64) *metrics.Gauge {
if validate_metrics_name(metric_name) != nil {
log.Critical("RegisterMetricsGauge() error", map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name})
if err := validate_metrics_name(metric_name); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "RegisterMetricsGauge() error",
Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name},
})
return nil
}
return ms.metrics_set_custom.GetOrCreateGauge(ms.get_metrics_name(metric_name, labels), func() float64 {
// get current value of the gauge and add val to it
return val
})
}
func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[string]string) *metrics.Counter {
if validate_metrics_name(metric_name) != nil {
log.Critical("RegisterMetricsCounter() error", map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name})
if err := validate_metrics_name(metric_name); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "RegisterMetricsCounter() error",
Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name},
})
return nil
}
if metric_name == MetricsSucceeded || metric_name == MetricsFailed || metric_name == MetricsSkipped {
@@ -107,24 +109,33 @@ func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[st
}
func (ms *MetricsSetup) RegisterFloatCounter(metric_name string, labels map[string]string) *metrics.FloatCounter {
if validate_metrics_name(metric_name) != nil {
log.Critical("RegisterFloatCounter() error", map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name})
if err := validate_metrics_name(metric_name); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "RegisterFloatCounter() error",
Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name},
})
return nil
}
return ms.metrics_set_custom.GetOrCreateFloatCounter(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) RegisterMetricsSummary(metric_name string, labels map[string]string) *metrics.Summary {
if validate_metrics_name(metric_name) != nil {
log.Critical("RegisterMetricsSummary() error", map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name})
if err := validate_metrics_name(metric_name); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "RegisterMetricsSummary() error",
Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name},
})
return nil
}
return ms.metrics_set_custom.GetOrCreateSummary(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) RegisterMetricsHistogram(metric_name string, labels map[string]string) *metrics.Histogram {
if validate_metrics_name(metric_name) != nil {
log.Critical("RegisterMetricsHistogram() error", map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name})
if err := validate_metrics_name(metric_name); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "RegisterMetricsHistogram() error",
Pairs: map[string]interface{}{"_error": "Invalid metric name", "_metric_name": metric_name},
})
return nil
}
return ms.metrics_set_custom.GetOrCreateHistogram(ms.get_metrics_name(metric_name, labels))
+77 -31
View File
@@ -3,15 +3,25 @@ package main
import (
"crypto/tls"
"fmt"
"net/url"
"time"
"github.com/avast/retry-go/v4"
fiber "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/proxy"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
"github.com/valyala/fasthttp"
)
var (
httpClient *fasthttp.Client
)
func init() {
httpClient = createFasthttpClient(30) // Assuming a default timeout of 30 seconds
}
func createFasthttpClient(timeout int) *fasthttp.Client {
return &fasthttp.Client{
Name: "graphql_proxy",
@@ -20,65 +30,101 @@ func createFasthttpClient(timeout int) *fasthttp.Client {
InsecureSkipVerify: true,
},
MaxConnsPerHost: 2048,
ReadTimeout: time.Second * time.Duration(timeout),
WriteTimeout: time.Second * time.Duration(timeout),
MaxIdleConnDuration: time.Second * time.Duration(timeout),
MaxConnDuration: time.Second * time.Duration(timeout),
ReadTimeout: time.Duration(timeout) * time.Second,
WriteTimeout: time.Duration(timeout) * time.Second,
MaxIdleConnDuration: time.Duration(timeout) * time.Second,
MaxConnDuration: time.Duration(timeout) * time.Second,
DisableHeaderNamesNormalizing: true,
}
}
func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
if !checkAllowedURLs(c) {
cfg.Logger.Error("Request blocked", map[string]interface{}{"path": c.Path()})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Request blocked",
Pairs: map[string]interface{}{"path": c.Path()},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
c.Status(403).SendString("Request blocked - not allowed URL")
return nil
return fmt.Errorf("request blocked - not allowed URL: %s", c.Path())
}
c.Request().Header.DisableNormalizing()
c.Request().Header.Add("X-Real-IP", c.IP())
c.Request().Header.Add(fiber.HeaderXForwardedFor, string(c.Request().Header.Peek("X-Forwarded-For")))
c.Request().Header.Del(fiber.HeaderAcceptEncoding)
cfg.Logger.Debug("Proxying the request", map[string]interface{}{"path": c.Path(), "body": string(c.Request().Body()), "headers": c.GetReqHeaders(), "request_uuid": c.Locals("request_uuid")})
proxyURL := currentEndpoint + c.Path()
_, err := url.Parse(proxyURL)
if err != nil {
return fmt.Errorf("invalid URL: %v", err)
}
err := retry.Do(
if cfg.LogLevel == "debug" {
logDebugRequest(c)
}
err = retry.Do(
func() error {
errInt := proxy.DoRedirects(c, currentEndpoint+c.Path(), 3, cfg.Client.FastProxyClient)
if errInt != nil {
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": errInt.Error()})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return errInt
}
return nil
return proxy.DoRedirects(c, proxyURL, 3, httpClient)
},
retry.OnRetry(func(n uint, err error) {
cfg.Logger.Warning("Retrying the request", map[string]interface{}{"path": c.Path(), "error": err.Error()})
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Retrying the request",
Pairs: map[string]interface{}{
"path": c.Path(),
"error": err.Error(),
},
})
}),
retry.Attempts(uint(3)),
retry.Attempts(3),
retry.DelayType(retry.BackOffDelay),
retry.Delay(time.Duration(250*time.Millisecond)),
retry.Delay(250*time.Millisecond),
retry.LastErrorOnly(true),
)
if err != nil {
cfg.Logger.Warning("Can't proxy the request", map[string]interface{}{"error": err.Error()})
return err
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("failed to proxy request: %v", err)
}
cfg.Logger.Debug("Received proxied response", map[string]interface{}{"path": c.Path(), "response_body": string(c.Response().Body()), "response_code": c.Response().StatusCode(), "headers": c.GetRespHeaders(), "request_uuid": c.Locals("request_uuid")})
if cfg.LogLevel == "debug" {
logDebugResponse(c)
}
if c.Response().StatusCode() != 200 {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("Received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
}
c.Response().Header.Del(fiber.HeaderServer)
return nil
}
func logDebugRequest(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Proxying the request",
Pairs: map[string]interface{}{
"path": c.Path(),
"body": string(c.Body()),
"headers": c.GetReqHeaders(),
"request_uuid": c.Locals("request_uuid"),
},
})
}
func logDebugResponse(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Received proxied response",
Pairs: map[string]interface{}{
"path": c.Path(),
"response_body": string(c.Response().Body()),
"response_code": c.Response().StatusCode(),
"headers": c.GetRespHeaders(),
"request_uuid": c.Locals("request_uuid"),
},
})
}
+30
View File
@@ -1,6 +1,8 @@
package main
import (
"strings"
"github.com/valyala/fasthttp"
)
@@ -95,3 +97,31 @@ func (suite *Tests) Test_proxyTheRequest() {
})
}
}
func (suite *Tests) Test_proxyTheRequestWithPayloads() {
allowedUrls = make(map[string]struct{})
allowedUrls["/"] = struct{}{}
suite.Run("Test with invalid URL", func() {
cfg.Server.HostGraphQL = "://invalid-url"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.NotNil(err)
})
suite.Run("Test with network error", func() {
cfg.Server.HostGraphQL = "http://non-existent-host.invalid"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.NotNil(err)
})
suite.Run("Test with large payload", func() {
cfg.Server.HostGraphQL = "https://telegram-bot.app/"
ctx := suite.app.AcquireCtx(&fasthttp.RequestCtx{})
largePayload := strings.Repeat("a", 10*1024*1024) // 10MB payload
ctx.Context().Request.SetBody([]byte(largePayload))
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
assert.Nil(err)
})
}
+59 -50
View File
@@ -2,94 +2,97 @@ package main
import (
"os"
"sync"
"time"
"github.com/goccy/go-json"
goratecounter "github.com/lukaszraczylo/go-ratecounter"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
type RateLimitConfig struct {
RateCounterTicker *goratecounter.RateCounter
Interval string `json:"interval"`
Req int `json:"req"`
Interval time.Duration `json:"interval"`
Req int `json:"req"`
}
var rateLimits map[string]RateLimitConfig
var ratelimit_intervals = map[string]time.Duration{
"milli": time.Millisecond,
"micro": time.Microsecond,
"nano": time.Nanosecond,
"second": time.Second,
"minute": time.Minute,
"hour": time.Hour,
"day": time.Hour * 24,
}
var (
rateLimits = make(map[string]RateLimitConfig)
rateLimitMu sync.RWMutex
)
func loadRatelimitConfig() error {
paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"}
for _, path := range paths {
err := loadConfigFromPath(path)
if err == nil {
if err := loadConfigFromPath(path); err == nil {
return nil
}
cfg.Logger.Debug("Failed to load config", map[string]interface{}{"path": path, "error": err})
}
cfg.Logger.Error("Rate limit config not found", map[string]interface{}{"paths": paths})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Rate limit config not found",
Pairs: map[string]interface{}{"paths": paths},
})
return os.ErrNotExist
}
func loadConfigFromPath(path string) error {
file, err := os.Open(path)
file, err := os.ReadFile(path)
if err != nil {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Failed to load config",
Pairs: map[string]interface{}{"path": path, "error": err},
})
return err
}
defer file.Close()
config := struct {
var config struct {
RateLimit map[string]RateLimitConfig `json:"ratelimit"`
}{}
}
decoder := json.NewDecoder(file)
if err := decoder.Decode(&config); err != nil {
if err := json.Unmarshal(file, &config); err != nil {
return err
}
newRateLimits := make(map[string]RateLimitConfig, len(config.RateLimit))
for key, value := range config.RateLimit {
value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{
Interval: time.Duration(value.Req) * ratelimit_intervals[value.Interval],
Interval: value.Interval,
})
cfg.Logger.Debug("Setting ratelimit config for role", map[string]interface{}{
"role": key,
"interval_provided": value.Interval,
"interval_used": ratelimit_intervals[value.Interval],
"ratelimit": value.Req,
})
config.RateLimit[key] = value
if cfg.LogLevel == "debug" {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Setting ratelimit config for role",
Pairs: map[string]interface{}{
"role": key,
"interval_used": value.Interval,
"ratelimit": value.Req,
},
})
}
newRateLimits[key] = value
}
rateLimits = config.RateLimit
cfg.Logger.Debug("Rate limit config loaded", map[string]interface{}{"ratelimit": rateLimits})
rateLimitMu.Lock()
rateLimits = newRateLimits
rateLimitMu.Unlock()
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit config loaded",
Pairs: map[string]interface{}{"ratelimit": rateLimits},
})
return nil
}
func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) {
if rateLimits == nil {
cfg.Logger.Debug("Rate limit config not found", map[string]interface{}{"user_role": userRole})
return true
}
// Fetch role config once to avoid multiple map lookups
func rateLimitedRequest(userID, userRole string) bool {
rateLimitMu.RLock()
roleConfig, ok := rateLimits[userRole]
if !ok {
cfg.Logger.Warning("Rate limit role not found", map[string]interface{}{"user_role": userRole})
return true
}
rateLimitMu.RUnlock()
if roleConfig.RateCounterTicker == nil {
cfg.Logger.Warning("Rate limit ticker not found", map[string]interface{}{"user_role": userRole})
if !ok || roleConfig.RateCounterTicker == nil {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit role not found or ticker not initialized",
Pairs: map[string]interface{}{"user_role": userRole},
})
return true
}
@@ -104,10 +107,16 @@ func rateLimitedRequest(userID string, userRole string) (shouldAllow bool) {
"interval": roleConfig.Interval,
}
cfg.Logger.Debug("Rate limit ticker", logDetails)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit ticker",
Pairs: map[string]interface{}{"log_details": logDetails},
})
if tickerRate > float64(roleConfig.Req) {
cfg.Logger.Debug("Rate limit exceeded", logDetails)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit exceeded",
Pairs: map[string]interface{}{"log_details": logDetails},
})
return false
}
+125 -92
View File
@@ -3,6 +3,7 @@ package main
import (
"fmt"
"strconv"
"sync"
"time"
"github.com/goccy/go-json"
@@ -10,14 +11,30 @@ import (
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/google/uuid"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
// StartHTTPProxy starts the HTTP and points it to the GraphQL server.
const (
healthCheckQueryStr = `{ __typename }`
)
var (
ctxPool = sync.Pool{
New: func() interface{} {
return new(fiber.Ctx)
},
}
)
func StartHTTPProxy() {
cfg.Logger.Debug("Starting the HTTP proxy", nil)
server := fiber.New(fiber.Config{
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Starting the HTTP proxy",
})
serverConfig := fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
IdleTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second * 2,
@@ -25,13 +42,14 @@ func StartHTTPProxy() {
WriteTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second * 2,
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
})
}
server := fiber.New(serverConfig)
server.Use(cors.New(cors.Config{
AllowOrigins: "*",
}))
// add middleware to check if the request is a GraphQL query
server.Use(AddRequestUUID)
server.Get("/healthz", healthCheck)
@@ -40,10 +58,16 @@ func StartHTTPProxy() {
server.Post("/*", processGraphQLRequest)
server.Get("/*", proxyTheRequestToDefault)
cfg.Logger.Info("GraphQL query proxy started", map[string]interface{}{"port": cfg.Server.PortGraphQL})
err := server.Listen(fmt.Sprintf(":%d", cfg.Server.PortGraphQL))
if err != nil {
cfg.Logger.Critical("Can't start the service", map[string]interface{}{"error": err.Error()})
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "GraphQL proxy started",
Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL},
})
if err := server.Listen(fmt.Sprintf(":%d", cfg.Server.PortGraphQL)); err != nil {
cfg.Logger.Critical(&libpack_logger.LogMessage{
Message: "Can't start the service",
Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL, "error": err.Error()},
})
}
}
@@ -60,82 +84,82 @@ func checkAllowedURLs(c *fiber.Ctx) bool {
if len(allowedUrls) == 0 {
return true
}
_, ok := allowedUrls[c.Path()]
path := c.OriginalURL()
_, ok := allowedUrls[path]
return ok
}
func healthCheck(c *fiber.Ctx) error {
if len(cfg.Server.HealthcheckGraphQL) > 0 {
cfg.Logger.Debug("Health check enabled", map[string]interface{}{"url": cfg.Server.HealthcheckGraphQL})
query := `{ __typename }`
_, err := cfg.Client.GQLClient.Query(query, nil, nil)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Health check enabled",
Pairs: map[string]interface{}{"url": cfg.Server.HealthcheckGraphQL},
})
_, err := cfg.Client.GQLClient.Query(healthCheckQueryStr, nil, nil)
if err != nil {
cfg.Logger.Error("Can't reach the GraphQL server", map[string]interface{}{"error": err.Error()})
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't reach the GraphQL server",
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query")
return err
return c.Status(500).SendString("Can't reach the GraphQL server with {__typename} query")
}
}
cfg.Logger.Debug("Health check returning OK", nil)
c.Status(200).SendString("Health check OK")
return nil
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Health check returning OK",
})
return c.Status(200).SendString("Health check OK")
}
func processGraphQLRequest(c *fiber.Ctx) error {
startTime := time.Now()
// Initialize variables with default values
extractedUserID := "-"
extractedRoleName := "-"
var queryCacheHash string
authorization := c.Request().Header.Peek("Authorization")
if authorization != nil && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(string(authorization))
if authorization := c.Get("Authorization"); authorization != "" && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
}
if checkIfUserIsBanned(c, extractedUserID) {
c.Status(403).SendString("User is banned")
return nil
return c.Status(403).SendString("User is banned")
}
if len(cfg.Client.RoleFromHeader) > 0 {
extractedRoleName = string(c.Request().Header.Peek(cfg.Client.RoleFromHeader))
if extractedRoleName == "" {
extractedRoleName = "-"
if cfg.Client.RoleFromHeader != "" {
if role := c.Get(cfg.Client.RoleFromHeader); role != "" {
extractedRoleName = role
}
}
// Implementing rate limiting if enabled
if cfg.Client.RoleRateLimit {
cfg.Logger.Debug("Rate limiting enabled", map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName})
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limiting enabled",
Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName},
})
if !rateLimitedRequest(extractedUserID, extractedRoleName) {
c.Status(429).SendString("Rate limit exceeded, try again later")
return nil
return c.Status(429).SendString("Rate limit exceeded, try again later")
}
}
parsedResult := parseGraphQLQuery(c)
if parsedResult.shouldBlock {
c.Status(403).SendString("Request blocked")
return nil
return c.Status(403).SendString("Request blocked")
}
if parsedResult.shouldIgnore {
cfg.Logger.Debug("Request passed as-is - probably not a GraphQL", nil)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Request passed as-is - probably not a GraphQL",
})
return proxyTheRequest(c, parsedResult.activeEndpoint)
}
calculatedQueryHash := calculateHash(c)
calculatedQueryHash := libpack_cache.CalculateHash(c)
if parsedResult.cacheTime > 0 {
cfg.Logger.Debug("Cache time set via query", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
} else {
// If not set via query, try setting via header
cacheQuery := c.Request().Header.Peek("X-Cache-Graphql-Query")
if cacheQuery != nil {
parsedResult.cacheTime, _ = strconv.Atoi(string(cacheQuery))
cfg.Logger.Debug("Cache time set via header", map[string]interface{}{"cacheTime": parsedResult.cacheTime})
if parsedResult.cacheTime == 0 {
if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" {
parsedResult.cacheTime, _ = strconv.Atoi(cacheQuery)
} else {
parsedResult.cacheTime = cfg.Cache.CacheTTL
}
@@ -144,81 +168,90 @@ func processGraphQLRequest(c *fiber.Ctx) error {
wasCached := false
if parsedResult.cacheRefresh {
cfg.Logger.Debug("Cache refresh requested via query", map[string]interface{}{"user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
cacheDelete(calculatedQueryHash)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache refresh requested via query",
Pairs: map[string]interface{}{"user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
libpack_cache.CacheDelete(calculatedQueryHash)
}
// Handling Cache Logic
if parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
cfg.Logger.Debug("Cache enabled", map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable})
queryCacheHash = calculatedQueryHash
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache enabled",
Pairs: map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable},
})
if cachedResponse := cacheLookup(queryCacheHash); cachedResponse != nil {
if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
cfg.Logger.Debug("Cache hit", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
c.Request().Header.Add("X-Cache-Hit", "true")
err := c.Send(cachedResponse)
if err != nil {
cfg.Logger.Error("Can't send the cached response", map[string]interface{}{"error": err.Error()})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
c.Status(500).SendString("Can't send the cached response - try again later")
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache hit",
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
c.Set("X-Cache-Hit", "true")
wasCached = true
} else {
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
cfg.Logger.Debug("Cache miss", map[string]interface{}{"hash": queryCacheHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")})
proxyAndCacheTheRequest(c, queryCacheHash, parsedResult.cacheTime, parsedResult.activeEndpoint)
return c.Send(cachedResponse)
}
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache miss",
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil {
return err
}
} else {
err := proxyTheRequest(c, parsedResult.activeEndpoint)
if err != nil {
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()})
if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
c.Status(500).SendString("Can't proxy the request - try again later")
return nil
return c.Status(500).SendString("Can't proxy the request - try again later")
}
}
timeTaken := time.Since(startTime)
// Logging & Monitoring
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, timeTaken, startTime)
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
return nil
}
// Additional helper function to avoid code repetition
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) {
err := proxyTheRequest(c, currentEndpoint)
if err != nil {
cfg.Logger.Error("Can't proxy the request", map[string]interface{}{"error": err.Error()})
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error {
if err := proxyTheRequest(c, currentEndpoint); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
c.Status(500).SendString("Can't proxy the request - try again later")
return
return c.Status(500).SendString("Can't proxy the request - try again later")
}
cacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
libpack_cache.CacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
cfg.Monitoring.Increment(libpack_monitoring.MetricsQueriesCached, nil)
c.Send(c.Response().Body())
return c.Send(c.Response().Body())
}
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
labels := map[string]string{
"op_type": opType,
"op_name": opName,
"cached": fmt.Sprintf("%t", wasCached),
"cached": strconv.FormatBool(wasCached),
"user_id": userID,
}
if cfg.Server.AccessLog {
cfg.Logger.Info("Request processed", map[string]interface{}{
"ip": c.IP(),
"fwd-ip": string(c.Request().Header.Peek("X-Forwarded-For")),
"user_id": userID,
"op_type": opType,
"op_name": opName,
"time": duration,
"cache": wasCached,
"request_uuid": c.Locals("request_uuid"),
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Request processed",
Pairs: map[string]interface{}{
"ip": c.IP(),
"fwd-ip": c.Get("X-Forwarded-For"),
"user_id": userID,
"op_type": opType,
"op_name": opName,
"time": duration,
"cache": wasCached,
"request_uuid": c.Locals("request_uuid"),
},
})
}
+2 -2
View File
@@ -9,7 +9,8 @@ import (
// config is a struct that holds the configuration of the application.
type config struct {
Logger *libpack_logging.LogConfig
Logger *libpack_logging.Logger
LogLevel string
Monitoring *libpack_monitoring.MetricsSetup
Api struct{ BannedUsersFile string }
Client struct {
@@ -32,7 +33,6 @@ type config struct {
Enable bool
}
Cache struct {
Client CacheClient
CacheRedisURL string
CacheRedisPassword string
CacheTTL int