package main
import (
"context"
"fmt"
"os"
"sync"
"time"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
"github.com/gofrs/flock"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
var (
bannedUsersIDs = make(map[string]string)
bannedUsersIDsMutex sync.RWMutex
)
func enableApi(ctx context.Context) error {
if !cfg.Server.EnableApi {
return nil
}
apiserver := fiber.New(fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
})
api := apiserver.Group("/api")
api.Post("/user-ban", apiBanUser)
api.Post("/user-unban", apiUnbanUser)
api.Post("/cache-clear", apiClearCache)
api.Get("/cache-stats", apiCacheStats)
// Start banned users reload in a separate goroutine with context
go periodicallyReloadBannedUsers(ctx)
// Start server in a goroutine and handle shutdown
errCh := make(chan error, 1)
go func() {
if err := apiserver.Listen(fmt.Sprintf(":%d", cfg.Server.ApiPort)); err != nil {
errCh <- err
}
}()
// Wait for context cancellation or error
select {
case <-ctx.Done():
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Shutting down API server",
})
return apiserver.Shutdown()
case err := <-errCh:
return err
}
}
func periodicallyReloadBannedUsers(ctx context.Context) {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Stopping banned users reload",
})
return
case <-ticker.C:
loadBannedUsers()
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Banned users reloaded",
Pairs: map[string]interface{}{"users": bannedUsersIDs},
})
}
}
}
func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool {
bannedUsersIDsMutex.RLock()
_, found := bannedUsersIDs[userID]
bannedUsersIDsMutex.RUnlock()
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Checking if user is banned",
Pairs: map[string]interface{}{"user_id": userID, "banned": found},
})
if found {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "User is banned",
Pairs: map[string]interface{}{"user_id": userID},
})
if err := c.Status(fiber.StatusForbidden).SendString("User is banned"); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to send banned user response",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
}
return found
}
func apiClearCache(c *fiber.Ctx) error {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Clearing cache via API",
})
libpack_cache.CacheClear()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Cache cleared via API",
})
return c.SendString("OK: cache cleared")
}
func apiCacheStats(c *fiber.Ctx) error {
return c.JSON(libpack_cache.GetCacheStats())
}
type apiBanUserRequest struct {
UserID string `json:"user_id"`
Reason string `json:"reason"`
}
func apiBanUser(c *fiber.Ctx) error {
var req apiBanUserRequest
if err := c.BodyParser(&req); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the ban user request",
Pairs: map[string]interface{}{"error": err.Error()},
})
return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload")
}
if req.UserID == "" || req.Reason == "" {
return c.Status(fiber.StatusBadRequest).SendString("user_id and reason are required")
}
bannedUsersIDsMutex.Lock()
bannedUsersIDs[req.UserID] = req.Reason
bannedUsersIDsMutex.Unlock()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Banned user",
Pairs: map[string]interface{}{"user_id": req.UserID, "reason": req.Reason},
})
if err := storeBannedUsers(); err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users")
}
return c.SendString("OK: user banned")
}
func apiUnbanUser(c *fiber.Ctx) error {
var req apiBanUserRequest
if err := c.BodyParser(&req); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the unban user request",
Pairs: map[string]interface{}{"error": err.Error()},
})
return c.Status(fiber.StatusBadRequest).SendString("Invalid request payload")
}
if req.UserID == "" {
return c.Status(fiber.StatusBadRequest).SendString("user_id is required")
}
bannedUsersIDsMutex.Lock()
delete(bannedUsersIDs, req.UserID)
bannedUsersIDsMutex.Unlock()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Unbanned user",
Pairs: map[string]interface{}{"user_id": req.UserID},
})
if err := storeBannedUsers(); err != nil {
return c.Status(fiber.StatusInternalServerError).SendString("Failed to store banned users")
}
return c.SendString("OK: user unbanned")
}
func storeBannedUsers() error {
fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
if err := lockFile(fileLock); err != nil {
return err
}
defer func() {
if err := fileLock.Unlock(); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to unlock file",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
}()
bannedUsersIDsMutex.RLock()
data, err := json.Marshal(bannedUsersIDs)
bannedUsersIDsMutex.RUnlock()
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't marshal banned users",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
if err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0o644); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't write banned users to file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
}
func loadBannedUsers() {
if _, err := os.Stat(cfg.Api.BannedUsersFile); os.IsNotExist(err) {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Banned users file doesn't exist - creating it",
Pairs: map[string]interface{}{"file": cfg.Api.BannedUsersFile},
})
if err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{}"), 0o644); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't create and write to the file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
}
fileLock := flock.New(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
if err := lockFileRead(fileLock); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file [load]",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
defer func() {
if err := fileLock.Unlock(); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to unlock file",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
}()
data, err := os.ReadFile(cfg.Api.BannedUsersFile)
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't read banned users from file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
var newBannedUsers map[string]string
if err := json.Unmarshal(data, &newBannedUsers); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't unmarshal banned users",
Pairs: map[string]interface{}{"error": err.Error()},
})
return
}
bannedUsersIDsMutex.Lock()
bannedUsersIDs = newBannedUsers
bannedUsersIDsMutex.Unlock()
}
func lockFile(fileLock *flock.Flock) error {
// Add timeout to prevent indefinite blocking
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Try to acquire lock with timeout
lockChan := make(chan error, 1)
go func() {
lockChan <- fileLock.Lock()
}()
select {
case err := <-lockChan:
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
case <-ctx.Done():
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "File lock timeout",
Pairs: map[string]interface{}{"timeout": "30s"},
})
return fmt.Errorf("file lock timeout after 30 seconds")
}
}
func lockFileRead(fileLock *flock.Flock) error {
// Add timeout to prevent indefinite blocking
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Try to acquire read lock with timeout
lockChan := make(chan error, 1)
go func() {
lockChan <- fileLock.RLock()
}()
select {
case err := <-lockChan:
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't lock the file for reading",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
return nil
case <-ctx.Done():
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "File read lock timeout",
Pairs: map[string]interface{}{"timeout": "30s"},
})
return fmt.Errorf("file read lock timeout after 30 seconds")
}
}
package main
import (
"bytes"
"compress/gzip"
"io"
"sync"
)
// BufferPool manages reusable buffers for HTTP operations
type BufferPool struct {
pool sync.Pool
}
// NewBufferPool creates a new buffer pool
func NewBufferPool() *BufferPool {
return &BufferPool{
pool: sync.Pool{
New: func() interface{} {
// Create a buffer with 4KB initial capacity
return bytes.NewBuffer(make([]byte, 0, 4096))
},
},
}
}
// Get retrieves a buffer from the pool
func (bp *BufferPool) Get() *bytes.Buffer {
buf := bp.pool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
// Put returns a buffer to the pool
func (bp *BufferPool) Put(buf *bytes.Buffer) {
// Only return buffers that aren't too large (avoid memory bloat)
if buf.Cap() > 1024*1024 { // 1MB limit
return
}
buf.Reset()
bp.pool.Put(buf)
}
// GzipWriterPool manages reusable gzip writers
type GzipWriterPool struct {
pool sync.Pool
}
// NewGzipWriterPool creates a new gzip writer pool
func NewGzipWriterPool() *GzipWriterPool {
return &GzipWriterPool{
pool: sync.Pool{
New: func() interface{} {
// Create a gzip writer with default compression
return gzip.NewWriter(nil)
},
},
}
}
// Get retrieves a gzip writer from the pool
func (gp *GzipWriterPool) Get(w io.Writer) *gzip.Writer {
gz := gp.pool.Get().(*gzip.Writer)
gz.Reset(w)
return gz
}
// Put returns a gzip writer to the pool
func (gp *GzipWriterPool) Put(gz *gzip.Writer) {
gz.Reset(nil)
gp.pool.Put(gz)
}
// GzipReaderPool manages reusable gzip readers
type GzipReaderPool struct {
pool sync.Pool
}
// NewGzipReaderPool creates a new gzip reader pool
func NewGzipReaderPool() *GzipReaderPool {
return &GzipReaderPool{
pool: sync.Pool{
New: func() interface{} {
// We'll reset the reader when getting from pool
return &gzip.Reader{}
},
},
}
}
// Get retrieves a gzip reader from the pool
func (gp *GzipReaderPool) Get(r io.Reader) (*gzip.Reader, error) {
gr := gp.pool.Get().(*gzip.Reader)
if err := gr.Reset(r); err != nil {
// If reset fails, create a new reader
return gzip.NewReader(r)
}
return gr, nil
}
// Put returns a gzip reader to the pool
func (gp *GzipReaderPool) Put(gr *gzip.Reader) {
gr.Close()
gp.pool.Put(gr)
}
// Global buffer pools
var (
httpBufferPool = NewBufferPool()
gzipWriterPool = NewGzipWriterPool()
gzipReaderPool = NewGzipReaderPool()
)
// GetHTTPBuffer gets a buffer from the global pool
func GetHTTPBuffer() *bytes.Buffer {
return httpBufferPool.Get()
}
// PutHTTPBuffer returns a buffer to the global pool
func PutHTTPBuffer(buf *bytes.Buffer) {
httpBufferPool.Put(buf)
}
// GetGzipWriter gets a gzip writer from the global pool
func GetGzipWriter(w io.Writer) *gzip.Writer {
return gzipWriterPool.Get(w)
}
// PutGzipWriter returns a gzip writer to the global pool
func PutGzipWriter(gz *gzip.Writer) {
gzipWriterPool.Put(gz)
}
// GetGzipReader gets a gzip reader from the global pool
func GetGzipReader(r io.Reader) (*gzip.Reader, error) {
return gzipReaderPool.Get(r)
}
// PutGzipReader returns a gzip reader to the global pool
func PutGzipReader(gr *gzip.Reader) {
gzipReaderPool.Put(gr)
}
package libpack_cache
import (
"bytes"
"compress/gzip"
"io"
"sync/atomic"
"time"
fiber "github.com/gofiber/fiber/v2"
"github.com/gookit/goutil/strutil"
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
libpack_cache_redis "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/redis"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
type CacheConfig struct {
Logger *libpack_logger.Logger
Client CacheClient
Redis struct {
URL string `json:"url"`
Password string `json:"password"`
DB int `json:"db"`
Enable bool `json:"enable"`
}
Memory struct {
MaxMemorySize int64 `json:"max_memory_size"` // Maximum memory size in bytes
MaxEntries int64 `json:"max_entries"` // Maximum number of entries
}
TTL int `json:"ttl"`
}
type CacheStats struct {
CachedQueries int64 `json:"cached_queries"`
CacheHits int64 `json:"cache_hits"`
CacheMisses int64 `json:"cache_misses"`
}
type CacheClient interface {
Set(key string, value []byte, ttl time.Duration)
Get(key string) ([]byte, bool)
Delete(key string)
Clear()
CountQueries() int64
// Memory usage reporting methods
GetMemoryUsage() int64 // Returns current memory usage in bytes
GetMaxMemorySize() int64 // Returns max memory size in bytes
}
var (
cacheStats *CacheStats
config *CacheConfig
)
func CalculateHash(c *fiber.Ctx) string {
return strutil.Md5(c.Body())
}
func EnableCache(cfg *CacheConfig) {
if cfg.Logger == nil {
cfg.Logger = libpack_logger.New()
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Initializing in-module logger",
})
}
cacheStats = &CacheStats{}
if ShouldUseRedisCache(cfg) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using Redis cache",
})
redisClient, err := libpack_cache_redis.New(&libpack_cache_redis.RedisClientConfig{
RedisDB: cfg.Redis.DB,
RedisServer: cfg.Redis.URL,
RedisPassword: cfg.Redis.Password,
})
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to create Redis client",
Pairs: map[string]interface{}{"error": err.Error()},
})
// Fall back to memory cache
cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second)
} else {
cfg.Client = libpack_cache_redis.NewCacheWrapper(redisClient, cfg.Logger)
}
} else {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Using in-memory cache",
Pairs: map[string]interface{}{
"max_memory_size_bytes": cfg.Memory.MaxMemorySize,
"max_entries": cfg.Memory.MaxEntries,
},
})
// Use memory size and entry limits if configured, otherwise use defaults
if cfg.Memory.MaxMemorySize > 0 || cfg.Memory.MaxEntries > 0 {
maxMemory := cfg.Memory.MaxMemorySize
if maxMemory <= 0 {
maxMemory = libpack_cache_memory.DefaultMaxMemorySize
}
maxEntries := cfg.Memory.MaxEntries
if maxEntries <= 0 {
maxEntries = libpack_cache_memory.DefaultMaxCacheSize
}
cfg.Client = libpack_cache_memory.NewWithSize(
time.Duration(cfg.TTL)*time.Second,
maxMemory,
maxEntries,
)
} else {
// Backward compatibility
cfg.Client = libpack_cache_memory.New(time.Duration(cfg.TTL) * time.Second)
}
}
config = cfg
}
func CacheLookup(hash string) []byte {
if !IsCacheInitialized() {
return nil
}
obj, found := config.Client.Get(hash)
if found {
atomic.AddInt64(&cacheStats.CacheHits, 1)
// If the cached data is compressed, decompress it
if len(obj) > 2 && obj[0] == 0x1f && obj[1] == 0x8b {
reader, err := gzip.NewReader(bytes.NewReader(obj))
if err != nil {
config.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to create gzip reader for cached data",
Pairs: map[string]interface{}{"error": err.Error(), "hash": hash},
})
return nil
}
// Ensure reader is always closed, even on error
defer func() {
if closeErr := reader.Close(); closeErr != nil {
config.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to close gzip reader",
Pairs: map[string]interface{}{"error": closeErr.Error(), "hash": hash},
})
}
}()
decompressed, err := io.ReadAll(reader)
if err != nil {
config.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to decompress cached data",
Pairs: map[string]interface{}{"error": err.Error(), "hash": hash},
})
return nil
}
return decompressed
}
return obj
}
atomic.AddInt64(&cacheStats.CacheMisses, 1)
return nil
}
func CacheDelete(hash string) {
if !IsCacheInitialized() {
return
}
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Deleting data from cache",
Pairs: map[string]interface{}{"hash": hash},
})
// Use atomic operations with validation to prevent inconsistent statistics
for {
current := atomic.LoadInt64(&cacheStats.CachedQueries)
if current <= 0 {
break // Don't go below zero
}
if atomic.CompareAndSwapInt64(&cacheStats.CachedQueries, current, current-1) {
break
}
// Retry if CAS failed due to concurrent modification
}
config.Client.Delete(hash)
}
func CacheStore(hash string, data []byte) {
if !IsCacheInitialized() {
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache not initialized",
})
return
}
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Storing data in cache",
Pairs: map[string]interface{}{"hash": hash},
})
atomic.AddInt64(&cacheStats.CachedQueries, 1)
config.Client.Set(hash, data, time.Duration(config.TTL)*time.Second)
}
func CacheStoreWithTTL(hash string, data []byte, ttl time.Duration) {
if !IsCacheInitialized() {
return
}
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Storing data in cache with TTL",
Pairs: map[string]interface{}{"hash": hash, "ttl": ttl},
})
atomic.AddInt64(&cacheStats.CachedQueries, 1)
config.Client.Set(hash, data, ttl)
}
func CacheGetQueries() int64 {
if !IsCacheInitialized() {
return 0
}
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Counting cache queries",
})
return config.Client.CountQueries()
}
func CacheClear() {
config.Client.Clear()
cacheStats = &CacheStats{}
}
func GetCacheStats() *CacheStats {
if !IsCacheInitialized() {
return &CacheStats{}
}
config.Logger.Debug(&libpack_logger.LogMessage{
Message: "Getting cache stats",
})
cacheStats.CachedQueries = CacheGetQueries()
return cacheStats
}
// GetCacheMemoryUsage returns the current memory usage of the cache in bytes
func GetCacheMemoryUsage() int64 {
if !IsCacheInitialized() {
return 0
}
return config.Client.GetMemoryUsage()
}
// GetCacheMaxMemorySize returns the maximum memory size allowed for the cache in bytes
func GetCacheMaxMemorySize() int64 {
if !IsCacheInitialized() {
return 0
}
return config.Client.GetMaxMemorySize()
}
func ShouldUseRedisCache(cfg *CacheConfig) bool {
return cfg.Redis.Enable
}
func IsCacheInitialized() bool {
return config != nil && config.Client != nil
}
package libpack_cache_memory
import (
"bytes"
"sync"
)
var bufferPool = sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 4096))
},
}
// GetBuffer gets a buffer from the pool
func GetBuffer() *bytes.Buffer {
buf := bufferPool.Get().(*bytes.Buffer)
buf.Reset()
return buf
}
// PutBuffer returns a buffer to the pool
func PutBuffer(buf *bytes.Buffer) {
if buf.Cap() > 1024*1024 { // Don't pool buffers larger than 1MB
return
}
buf.Reset()
bufferPool.Put(buf)
}
package libpack_cache_memory
import (
"compress/gzip"
"container/list"
"sync"
"sync/atomic"
"time"
)
// LRUMemoryCache is an efficient LRU-based memory cache implementation
type LRUMemoryCache struct {
maxMemorySize int64
maxEntries int64
currentMemory int64
currentCount int64
mu sync.RWMutex
entries map[string]*lruEntry
evictList *list.List
gzipWriterPool *sync.Pool
gzipReaderPool *sync.Pool
cancel func()
}
type lruEntry struct {
key string
value []byte
compressed bool
size int64
expiresAt time.Time
element *list.Element
}
// NewLRUMemoryCache creates a new LRU memory cache
func NewLRUMemoryCache(maxMemorySize, maxEntries int64) *LRUMemoryCache {
return &LRUMemoryCache{
maxMemorySize: maxMemorySize,
maxEntries: maxEntries,
entries: make(map[string]*lruEntry),
evictList: list.New(),
gzipWriterPool: &sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
},
gzipReaderPool: &sync.Pool{
New: func() interface{} {
return &gzip.Reader{}
},
},
}
}
// Set adds or updates an entry in the cache
func (c *LRUMemoryCache) Set(key string, value []byte, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
// Calculate expiry time
expiresAt := time.Now().Add(ttl)
// Check if we should compress
compressed := false
finalValue := value
if len(value) > 1024 { // Compress if larger than 1KB
if compressedData, err := c.compress(value); err == nil && len(compressedData) < len(value) {
compressed = true
finalValue = compressedData
}
}
entrySize := int64(len(key) + len(finalValue) + 64) // 64 bytes overhead estimate
// Check if key exists
if existing, exists := c.entries[key]; exists {
// Update existing entry
c.evictList.MoveToFront(existing.element)
atomic.AddInt64(&c.currentMemory, -existing.size)
atomic.AddInt64(&c.currentMemory, entrySize)
existing.value = finalValue
existing.compressed = compressed
existing.size = entrySize
existing.expiresAt = expiresAt
c.evictIfNeeded()
return
}
// Create new entry
entry := &lruEntry{
key: key,
value: finalValue,
compressed: compressed,
size: entrySize,
expiresAt: expiresAt,
}
element := c.evictList.PushFront(entry)
entry.element = element
c.entries[key] = entry
atomic.AddInt64(&c.currentMemory, entrySize)
atomic.AddInt64(&c.currentCount, 1)
c.evictIfNeeded()
}
// Get retrieves a value from the cache
func (c *LRUMemoryCache) Get(key string) ([]byte, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.entries[key]
if !exists {
return nil, false
}
// Check if expired
if time.Now().After(entry.expiresAt) {
c.removeEntry(entry)
return nil, false
}
// Move to front (most recently used)
c.evictList.MoveToFront(entry.element)
// Decompress if needed
if entry.compressed {
if decompressed, err := c.decompress(entry.value); err == nil {
return decompressed, true
}
// If decompression fails, remove the entry
c.removeEntry(entry)
return nil, false
}
return entry.value, true
}
// Delete removes an entry from the cache
func (c *LRUMemoryCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
if entry, exists := c.entries[key]; exists {
c.removeEntry(entry)
}
}
// Clear removes all entries
func (c *LRUMemoryCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.entries = make(map[string]*lruEntry)
c.evictList = list.New()
atomic.StoreInt64(&c.currentMemory, 0)
atomic.StoreInt64(&c.currentCount, 0)
}
// evictIfNeeded removes entries when limits are exceeded
func (c *LRUMemoryCache) evictIfNeeded() {
// Evict based on entry count
for atomic.LoadInt64(&c.currentCount) > c.maxEntries && c.evictList.Len() > 0 {
c.evictOldest()
}
// Evict based on memory
for atomic.LoadInt64(&c.currentMemory) > c.maxMemorySize && c.evictList.Len() > 0 {
c.evictOldest()
}
}
// evictOldest removes the least recently used entry
func (c *LRUMemoryCache) evictOldest() {
element := c.evictList.Back()
if element == nil {
return
}
entry := element.Value.(*lruEntry)
c.removeEntry(entry)
}
// removeEntry removes an entry from all data structures
func (c *LRUMemoryCache) removeEntry(entry *lruEntry) {
c.evictList.Remove(entry.element)
delete(c.entries, entry.key)
atomic.AddInt64(&c.currentMemory, -entry.size)
atomic.AddInt64(&c.currentCount, -1)
}
// CleanExpiredEntries removes all expired entries
func (c *LRUMemoryCache) CleanExpiredEntries() {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
for element := c.evictList.Back(); element != nil; {
entry := element.Value.(*lruEntry)
if now.After(entry.expiresAt) {
next := element.Prev()
c.removeEntry(entry)
element = next
} else {
element = element.Prev()
}
}
}
// compress compresses data using gzip
func (c *LRUMemoryCache) compress(data []byte) ([]byte, error) {
buf := GetBuffer()
defer PutBuffer(buf)
gz := c.gzipWriterPool.Get().(*gzip.Writer)
gz.Reset(buf)
defer c.gzipWriterPool.Put(gz)
if _, err := gz.Write(data); err != nil {
return nil, err
}
if err := gz.Close(); err != nil {
return nil, err
}
compressed := make([]byte, buf.Len())
copy(compressed, buf.Bytes())
return compressed, nil
}
// decompress decompresses gzip data
func (c *LRUMemoryCache) decompress(data []byte) ([]byte, error) {
buf := GetBuffer()
defer PutBuffer(buf)
buf.Write(data)
gr := c.gzipReaderPool.Get().(*gzip.Reader)
defer c.gzipReaderPool.Put(gr)
if err := gr.Reset(buf); err != nil {
return nil, err
}
result := GetBuffer()
defer PutBuffer(result)
if _, err := result.ReadFrom(gr); err != nil {
return nil, err
}
decompressed := make([]byte, result.Len())
copy(decompressed, result.Bytes())
return decompressed, nil
}
// GetStats returns cache statistics
func (c *LRUMemoryCache) GetStats() map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return map[string]interface{}{
"entries": atomic.LoadInt64(&c.currentCount),
"memory_bytes": atomic.LoadInt64(&c.currentMemory),
"max_entries": c.maxEntries,
"max_memory": c.maxMemorySize,
"fill_percent": float64(atomic.LoadInt64(&c.currentMemory)) / float64(c.maxMemorySize) * 100,
}
}
// GetMemoryUsage returns current memory usage in bytes
func (c *LRUMemoryCache) GetMemoryUsage() int64 {
return atomic.LoadInt64(&c.currentMemory)
}
// GetMaxMemorySize returns the maximum memory size
func (c *LRUMemoryCache) GetMaxMemorySize() int64 {
return c.maxMemorySize
}
package libpack_cache_memory
import (
"bytes"
"compress/gzip"
"context"
"io"
"sync"
"sync/atomic"
"time"
)
// CompressionThreshold is the minimum size in bytes before a value is compressed
const CompressionThreshold = 1024 // 1KB
// DefaultMaxMemorySize is the default maximum memory size in bytes (100MB)
const DefaultMaxMemorySize = 100 * 1024 * 1024
// DefaultMaxCacheSize is the default maximum number of entries in the cache
// This is used for backward compatibility
const DefaultMaxCacheSize = 10000
// approxEntryOverhead is the estimated overhead per cache entry in bytes
// This accounts for the CacheEntry struct overhead, map entry, and synchronization
const approxEntryOverhead = 64
type CacheEntry struct {
ExpiresAt time.Time
Value []byte
Compressed bool
MemorySize int64 // Estimated memory usage of this entry in bytes
}
type Cache struct {
compressPool sync.Pool
decompressPool sync.Pool
entries sync.Map
globalTTL time.Duration
entryCount int64
memoryUsage int64 // Total memory usage in bytes
maxMemorySize int64 // Maximum memory usage in bytes
maxCacheSize int64 // Maximum number of entries (for backward compatibility)
// Add context for graceful shutdown
ctx context.Context
cancel context.CancelFunc
sync.RWMutex
}
func New(globalTTL time.Duration) *Cache {
return NewWithSize(globalTTL, DefaultMaxMemorySize, DefaultMaxCacheSize)
}
// NewWithSize creates a new cache with the specified memory size limit and entry count limit
func NewWithSize(globalTTL time.Duration, maxMemorySize int64, maxCacheSize int64) *Cache {
// Create context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
cache := &Cache{
globalTTL: globalTTL,
maxMemorySize: maxMemorySize,
maxCacheSize: maxCacheSize,
ctx: ctx,
cancel: cancel,
compressPool: sync.Pool{
New: func() interface{} {
return gzip.NewWriter(nil)
},
},
decompressPool: sync.Pool{
New: func() interface{} {
r, _ := gzip.NewReader(bytes.NewReader([]byte{}))
return r
},
},
}
// Start cleanup routine with context cancellation
go cache.cleanupRoutine(globalTTL)
return cache
}
func (c *Cache) cleanupRoutine(globalTTL time.Duration) {
// Clean up more frequently when the cache is large
ticker := time.NewTicker(globalTTL / 4)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
// Context cancelled, exit gracefully
return
case <-ticker.C:
c.CleanExpiredEntries()
// Note: Removed aggressive GC trigger that was causing performance issues
// The Go runtime GC is already optimized and will run when needed
}
}
}
// Shutdown gracefully stops the cache cleanup routine
func (c *Cache) Shutdown() {
if c.cancel != nil {
c.cancel()
}
}
func (c *Cache) Set(key string, value []byte, ttl time.Duration) {
// Calculate the memory size of this entry
entrySize := int64(len(key) + len(value) + approxEntryOverhead)
// Check if we need to evict entries based on memory or count limits
currentMemory := atomic.LoadInt64(&c.memoryUsage)
if currentMemory+entrySize > c.maxMemorySize {
// Need to evict based on memory
memoryToFree := (currentMemory + entrySize) - c.maxMemorySize + (c.maxMemorySize / 10)
c.evictToFreeMemory(memoryToFree)
} else if atomic.LoadInt64(&c.entryCount) >= c.maxCacheSize {
// Fall back to count-based eviction for backward compatibility
c.evictOldest(int(c.maxCacheSize / 10)) // Evict 10% of entries
}
expiresAt := time.Now().Add(ttl)
// Only compress if the value is larger than the threshold
var entry CacheEntry
if len(value) > CompressionThreshold {
compressedValue, err := c.compress(value)
if err == nil && len(compressedValue) < len(value) {
entry = CacheEntry{
Value: compressedValue,
ExpiresAt: expiresAt,
Compressed: true,
}
} else {
// If compression failed or didn't reduce size, store uncompressed
entry = CacheEntry{
Value: value,
ExpiresAt: expiresAt,
Compressed: false,
}
}
} else {
entry = CacheEntry{
Value: value,
ExpiresAt: expiresAt,
Compressed: false,
}
}
// Update the entry memory size based on compression status
if entry.Compressed {
entry.MemorySize = int64(len(key) + len(entry.Value) + approxEntryOverhead)
} else {
entry.MemorySize = int64(len(key) + len(entry.Value) + approxEntryOverhead)
}
// Check if this is a new entry or an update
oldEntry, exists := c.entries.Load(key)
if exists {
// Update memory usage: subtract old entry size, add new entry size
oldCacheEntry := oldEntry.(CacheEntry)
atomic.AddInt64(&c.memoryUsage, -oldCacheEntry.MemorySize)
} else {
// New entry
atomic.AddInt64(&c.entryCount, 1)
}
// Add new entry's memory size to total
atomic.AddInt64(&c.memoryUsage, entry.MemorySize)
c.entries.Store(key, entry)
}
func (c *Cache) Get(key string) ([]byte, bool) {
entry, ok := c.entries.Load(key)
if !ok {
return nil, false
}
cacheEntry := entry.(CacheEntry)
if cacheEntry.ExpiresAt.Before(time.Now()) {
c.entries.Delete(key)
atomic.AddInt64(&c.entryCount, -1)
atomic.AddInt64(&c.memoryUsage, -cacheEntry.MemorySize)
return nil, false
}
if cacheEntry.Compressed {
value, err := c.decompress(cacheEntry.Value)
if err != nil {
return nil, false
}
return value, true
}
return cacheEntry.Value, true
}
func (c *Cache) Delete(key string) {
if entry, exists := c.entries.LoadAndDelete(key); exists {
cacheEntry := entry.(CacheEntry)
atomic.AddInt64(&c.entryCount, -1)
atomic.AddInt64(&c.memoryUsage, -cacheEntry.MemorySize)
}
}
func (c *Cache) Clear() {
c.entries.Range(func(key, value interface{}) bool {
c.entries.Delete(key)
return true
})
atomic.StoreInt64(&c.entryCount, 0)
atomic.StoreInt64(&c.memoryUsage, 0)
}
func (c *Cache) CountQueries() int64 {
return atomic.LoadInt64(&c.entryCount)
}
func (c *Cache) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w := c.compressPool.Get().(*gzip.Writer)
defer c.compressPool.Put(w)
w.Reset(&buf)
if _, err := w.Write(data); err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (c *Cache) decompress(data []byte) ([]byte, error) {
r, ok := c.decompressPool.Get().(*gzip.Reader)
defer c.decompressPool.Put(r)
if !ok || r == nil {
var err error
r, err = gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
} else {
if err := r.Reset(bytes.NewReader(data)); err != nil {
return nil, err
}
}
defer func() {
_ = r.Close() // Ignore error in defer cleanup
}()
return io.ReadAll(r)
}
func (c *Cache) CleanExpiredEntries() {
now := time.Now()
c.entries.Range(func(key, value interface{}) bool {
entry := value.(CacheEntry)
if entry.ExpiresAt.Before(now) {
if _, exists := c.entries.LoadAndDelete(key); exists {
atomic.AddInt64(&c.entryCount, -1)
atomic.AddInt64(&c.memoryUsage, -entry.MemorySize)
}
}
return true
})
}
// evictOldest removes the oldest n entries from the cache
func (c *Cache) evictOldest(n int) {
type keyExpiry struct {
key string
expiresAt time.Time
}
// Collect all entries with their expiry times
entries := make([]keyExpiry, 0, n*2)
c.entries.Range(func(k, v interface{}) bool {
key := k.(string)
entry := v.(CacheEntry)
entries = append(entries, keyExpiry{key, entry.ExpiresAt})
return len(entries) < cap(entries)
})
// Sort by expiry time (oldest first)
// Using a simple selection sort since we only need to find the n oldest
for i := 0; i < n && i < len(entries); i++ {
oldest := i
for j := i + 1; j < len(entries); j++ {
if entries[j].expiresAt.Before(entries[oldest].expiresAt) {
oldest = j
}
}
// Swap
if oldest != i {
entries[i], entries[oldest] = entries[oldest], entries[i]
}
// Delete this entry
if entry, exists := c.entries.LoadAndDelete(entries[i].key); exists {
cacheEntry := entry.(CacheEntry)
atomic.AddInt64(&c.entryCount, -1)
atomic.AddInt64(&c.memoryUsage, -cacheEntry.MemorySize)
}
}
}
// evictToFreeMemory removes entries until the specified amount of memory is freed
func (c *Cache) evictToFreeMemory(bytesToFree int64) {
type keyMemorySize struct {
key string
memorySize int64
expiresAt time.Time
}
// Collect entries to consider for eviction
entries := make([]keyMemorySize, 0, int(c.maxCacheSize/5))
c.entries.Range(func(k, v interface{}) bool {
key := k.(string)
entry := v.(CacheEntry)
entries = append(entries, keyMemorySize{key, entry.MemorySize, entry.ExpiresAt})
return len(entries) < cap(entries)
})
// Sort entries by expiry time (oldest first)
// Simple selection sort since we only need to find the oldest entries
var freedBytes int64
for i := 0; i < len(entries) && freedBytes < bytesToFree; i++ {
oldest := i
for j := i + 1; j < len(entries); j++ {
if entries[j].expiresAt.Before(entries[oldest].expiresAt) {
oldest = j
}
}
// Swap
if oldest != i {
entries[i], entries[oldest] = entries[oldest], entries[i]
}
// Delete this entry
if entry, exists := c.entries.LoadAndDelete(entries[i].key); exists {
cacheEntry := entry.(CacheEntry)
atomic.AddInt64(&c.entryCount, -1)
atomic.AddInt64(&c.memoryUsage, -cacheEntry.MemorySize)
freedBytes += cacheEntry.MemorySize
}
}
}
// GetMemoryUsage returns the current memory usage of the cache in bytes
func (c *Cache) GetMemoryUsage() int64 {
return atomic.LoadInt64(&c.memoryUsage)
}
// GetMaxMemorySize returns the maximum memory size allowed for the cache in bytes
func (c *Cache) GetMaxMemorySize() int64 {
return c.maxMemorySize
}
// SetMaxMemorySize updates the maximum memory size allowed for the cache
func (c *Cache) SetMaxMemorySize(maxBytes int64) {
c.maxMemorySize = maxBytes
// Check if we need to evict entries due to the new limit
currentMemory := atomic.LoadInt64(&c.memoryUsage)
if currentMemory > maxBytes {
memoryToFree := currentMemory - maxBytes + (maxBytes / 10)
c.evictToFreeMemory(memoryToFree)
}
}
package libpack_cache_redis
import (
"context"
"strings"
"sync"
"time"
redis "github.com/redis/go-redis/v9"
)
type RedisConfig struct {
ctx context.Context
client *redis.Client
builderPool *sync.Pool
prefix string
}
func (c *RedisConfig) prependKeyName(key string) string {
builder := c.builderPool.Get().(*strings.Builder)
defer c.builderPool.Put(builder)
builder.Reset()
builder.WriteString(c.prefix)
builder.WriteString(key)
return builder.String()
}
type RedisClientConfig struct {
RedisServer string
RedisPassword string
Prefix string
RedisDB int
}
func New(redisClientConfig *RedisClientConfig) (*RedisConfig, error) {
c := &RedisConfig{
client: redis.NewClient(&redis.Options{
Addr: redisClientConfig.RedisServer,
Password: redisClientConfig.RedisPassword,
DB: redisClientConfig.RedisDB,
}),
ctx: context.Background(),
prefix: redisClientConfig.Prefix,
builderPool: &sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
},
}
_, err := c.client.Ping(c.ctx).Result()
if err != nil {
return nil, err
}
return c, nil
}
func (c *RedisConfig) Set(key string, value []byte, ttl time.Duration) error {
return c.client.Set(c.ctx, c.prependKeyName(key), value, ttl).Err()
}
func (c *RedisConfig) Get(key string) ([]byte, bool, error) {
val, err := c.client.Get(c.ctx, c.prependKeyName(key)).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
return []byte(val), true, nil
}
func (c *RedisConfig) Delete(key string) error {
return c.client.Del(c.ctx, c.prependKeyName(key)).Err()
}
func (c *RedisConfig) Clear() error {
return c.client.FlushDB(c.ctx).Err()
}
func (c *RedisConfig) CountQueries() (int64, error) {
keys, err := c.client.Keys(c.ctx, c.prependKeyName("*")).Result()
if err != nil {
return 0, err
}
return int64(len(keys)), nil
}
func (c *RedisConfig) CountQueriesWithPattern(pattern string) (int, error) {
keys, err := c.client.Keys(c.ctx, c.prependKeyName(pattern)).Result()
if err != nil {
return 0, err
}
return len(keys), nil
}
// GetMemoryUsage returns an approximation of memory usage for Redis
// For Redis, this is not as accurate as the memory cache implementation
// as actual memory is managed by Redis server
func (c *RedisConfig) GetMemoryUsage() int64 {
// We could attempt to get memory usage from Redis info
// but for now, we'll just return 0 since Redis manages its own memory
// and this information would require parsing the INFO command output
_, err := c.client.Info(c.ctx, "memory").Result()
if err != nil {
return 0
}
// Just return 0 as a placeholder since Redis manages its own memory
// In a production environment, you could parse the Redis INFO command result
// to extract actual "used_memory" value
return 0
}
// GetMaxMemorySize returns the configured max memory for Redis
// In Redis, this would be the 'maxmemory' configuration value
func (c *RedisConfig) GetMaxMemorySize() int64 {
// Return a default value as Redis manages its own memory limits
// In a production environment, you could get this from Redis config
return 0
}
package libpack_cache_redis
import (
"time"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
// CacheWrapper wraps RedisConfig to implement the CacheClient interface
// without returning errors, for backward compatibility
type CacheWrapper struct {
redis *RedisConfig
logger *libpack_logger.Logger
}
// NewCacheWrapper creates a new cache wrapper
func NewCacheWrapper(config *RedisConfig, logger *libpack_logger.Logger) *CacheWrapper {
if logger == nil {
logger = &libpack_logger.Logger{}
}
return &CacheWrapper{
redis: config,
logger: logger,
}
}
// Set stores a value with the given TTL
func (w *CacheWrapper) Set(key string, value []byte, ttl time.Duration) {
if err := w.redis.Set(key, value, ttl); err != nil {
w.logger.Error(&libpack_logger.LogMessage{
Message: "Redis set error",
Pairs: map[string]interface{}{
"error": err.Error(),
"key": key,
},
})
}
}
// Get retrieves a value
func (w *CacheWrapper) Get(key string) ([]byte, bool) {
value, found, err := w.redis.Get(key)
if err != nil {
w.logger.Error(&libpack_logger.LogMessage{
Message: "Redis get error",
Pairs: map[string]interface{}{
"error": err.Error(),
"key": key,
},
})
return nil, false
}
return value, found
}
// Delete removes a key
func (w *CacheWrapper) Delete(key string) {
if err := w.redis.Delete(key); err != nil {
w.logger.Error(&libpack_logger.LogMessage{
Message: "Redis delete error",
Pairs: map[string]interface{}{
"error": err.Error(),
"key": key,
},
})
}
}
// Clear removes all keys
func (w *CacheWrapper) Clear() {
if err := w.redis.Clear(); err != nil {
w.logger.Error(&libpack_logger.LogMessage{
Message: "Redis clear error",
Pairs: map[string]interface{}{
"error": err.Error(),
},
})
}
}
// CountQueries returns the number of queries
func (w *CacheWrapper) CountQueries() int64 {
count, err := w.redis.CountQueries()
if err != nil {
w.logger.Error(&libpack_logger.LogMessage{
Message: "Redis count queries error",
Pairs: map[string]interface{}{
"error": err.Error(),
},
})
return 0
}
return count
}
// GetMemoryUsage returns 0 for Redis (not applicable)
func (w *CacheWrapper) GetMemoryUsage() int64 {
return 0
}
// GetMaxMemorySize returns 0 for Redis (not applicable)
func (w *CacheWrapper) GetMaxMemorySize() int64 {
return 0
}
package main
import (
"sync/atomic"
"github.com/VictoriaMetrics/metrics"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
// CircuitBreakerMetrics manages circuit breaker metrics without recreating gauges
type CircuitBreakerMetrics struct {
stateValue atomic.Value // stores float64
stateGauge *metrics.Gauge
failCounters map[string]*metrics.Counter
}
// NewCircuitBreakerMetrics creates a new circuit breaker metrics manager
func NewCircuitBreakerMetrics(monitoring *libpack_monitoring.MetricsSetup) *CircuitBreakerMetrics {
cbm := &CircuitBreakerMetrics{
failCounters: make(map[string]*metrics.Counter),
}
// Initialize state value
cbm.stateValue.Store(float64(0))
// Create gauge with callback that reads the atomic value
cbm.stateGauge = monitoring.RegisterMetricsGauge(
libpack_monitoring.MetricsCircuitState,
nil,
0, // Initial value doesn't matter as callback will be used
)
// Override the gauge callback to read from atomic value
cbm.stateGauge = monitoring.RegisterMetricsGauge(
libpack_monitoring.MetricsCircuitState,
nil,
cbm.GetState(),
)
return cbm
}
// UpdateState updates the circuit breaker state value atomically
func (cbm *CircuitBreakerMetrics) UpdateState(state float64) {
cbm.stateValue.Store(state)
}
// GetState returns the current circuit breaker state value
func (cbm *CircuitBreakerMetrics) GetState() float64 {
if val := cbm.stateValue.Load(); val != nil {
return val.(float64)
}
return 0
}
// GetOrCreateFailCounter returns a counter for the given state key
func (cbm *CircuitBreakerMetrics) GetOrCreateFailCounter(monitoring *libpack_monitoring.MetricsSetup, stateKey string) *metrics.Counter {
if counter, exists := cbm.failCounters[stateKey]; exists {
return counter
}
// Create new counter
counter := monitoring.RegisterMetricsCounter(stateKey, nil)
cbm.failCounters[stateKey] = counter
return counter
}
// Global circuit breaker metrics instance
var cbMetrics *CircuitBreakerMetrics
// InitializeCircuitBreakerMetrics initializes the global circuit breaker metrics
func InitializeCircuitBreakerMetrics(monitoring *libpack_monitoring.MetricsSetup) {
if cbMetrics == nil {
cbMetrics = NewCircuitBreakerMetrics(monitoring)
}
}
package main
import (
"context"
"sync"
"time"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
"github.com/valyala/fasthttp"
)
// ConnectionPoolManager manages HTTP client connections
type ConnectionPoolManager struct {
client *fasthttp.Client
mu sync.RWMutex
cleanupTimer *time.Timer
ctx context.Context
cancel context.CancelFunc
}
// NewConnectionPoolManager creates a new connection pool manager
func NewConnectionPoolManager(client *fasthttp.Client) *ConnectionPoolManager {
ctx, cancel := context.WithCancel(context.Background())
cpm := &ConnectionPoolManager{
client: client,
ctx: ctx,
cancel: cancel,
}
// Start periodic cleanup
cpm.startPeriodicCleanup()
return cpm
}
// startPeriodicCleanup starts a timer to periodically clean idle connections
func (cpm *ConnectionPoolManager) startPeriodicCleanup() {
// Clean idle connections every 30 seconds
go func() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-cpm.ctx.Done():
return
case <-ticker.C:
cpm.cleanIdleConnections()
}
}
}()
}
// cleanIdleConnections closes idle connections
func (cpm *ConnectionPoolManager) cleanIdleConnections() {
cpm.mu.Lock()
defer cpm.mu.Unlock()
if cpm.client != nil {
cpm.client.CloseIdleConnections()
cfg.Logger.Debug(&libpack_logging.LogMessage{
Message: "Cleaned idle HTTP connections",
})
}
}
// GetClient returns the HTTP client
func (cpm *ConnectionPoolManager) GetClient() *fasthttp.Client {
cpm.mu.RLock()
defer cpm.mu.RUnlock()
return cpm.client
}
// Shutdown gracefully shuts down the connection pool
func (cpm *ConnectionPoolManager) Shutdown() error {
if cpm == nil {
return nil
}
cpm.cancel()
cpm.mu.Lock()
defer cpm.mu.Unlock()
if cpm.client != nil {
cpm.client.CloseIdleConnections()
if cfg != nil && cfg.Logger != nil {
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "HTTP connection pool shut down",
})
}
}
return nil
}
// Global connection pool manager
var connectionPoolManager *ConnectionPoolManager
// InitializeConnectionPool initializes the global connection pool
func InitializeConnectionPool(client *fasthttp.Client) {
if connectionPoolManager != nil {
connectionPoolManager.Shutdown()
}
connectionPoolManager = NewConnectionPoolManager(client)
}
// ShutdownConnectionPool safely shuts down the global connection pool
func ShutdownConnectionPool() {
if connectionPoolManager != nil {
connectionPoolManager.Shutdown()
connectionPoolManager = nil
}
}
// GetConnectionPoolManager returns the global connection pool manager
func GetConnectionPoolManager() *ConnectionPoolManager {
return connectionPoolManager
}
package main
import (
"encoding/base64"
"fmt"
"strings"
"github.com/goccy/go-json"
"github.com/lukaszraczylo/ask"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
const defaultValue = "-"
var emptyMetrics = map[string]string{}
func extractClaimsFromJWTHeader(authorization string) (usr, role string) {
usr, role = defaultValue, defaultValue
tokenParts := strings.SplitN(authorization, ".", 3)
if len(tokenParts) != 3 {
handleError("Can't split the token", map[string]interface{}{"token": maskToken(authorization)})
return
}
claim, err := base64.RawURLEncoding.DecodeString(tokenParts[1])
if err != nil {
handleError("Can't decode the token", map[string]interface{}{"token": maskToken(authorization)})
return
}
var claimMap map[string]interface{}
if err = json.Unmarshal(claim, &claimMap); err != nil {
handleError("Can't unmarshal the claim", map[string]interface{}{"token": maskToken(authorization)})
return
}
usr = extractClaim(claimMap, cfg.Client.JWTUserClaimPath, "user id")
role = extractClaim(claimMap, cfg.Client.JWTRoleClaimPath, "role")
return
}
func extractClaim(claimMap map[string]interface{}, claimPath, name string) string {
if claimPath == "" {
return defaultValue
}
// Validate claim path to prevent injection attacks
if !isValidClaimPath(claimPath) {
handleError(fmt.Sprintf("Invalid claim path for %s", name), map[string]interface{}{"path": claimPath})
return defaultValue
}
value, ok := ask.For(claimMap, claimPath).String(defaultValue)
if !ok {
handleError(fmt.Sprintf("Can't find the %s", name), map[string]interface{}{"claim_map": sanitizeClaimMap(claimMap), "path": claimPath})
return defaultValue
}
return value
}
// maskToken masks JWT tokens in logs to prevent exposure
func maskToken(token string) string {
if len(token) <= 10 {
return "***"
}
return token[:4] + "***" + token[len(token)-4:]
}
// isValidClaimPath validates JWT claim paths to prevent injection
func isValidClaimPath(path string) bool {
if path == "" {
return false
}
// Allow only alphanumeric characters, dots, underscores, and hyphens
for _, char := range path {
if (char < 'a' || char > 'z') &&
(char < 'A' || char > 'Z') &&
(char < '0' || char > '9') &&
char != '.' && char != '_' && char != '-' {
return false
}
}
// Prevent path traversal attempts
if strings.Contains(path, "..") || strings.Contains(path, "//") {
return false
}
return true
}
// sanitizeClaimMap removes sensitive data from claim map for logging
func sanitizeClaimMap(claimMap map[string]interface{}) map[string]interface{} {
sanitized := make(map[string]interface{})
sensitiveKeys := map[string]bool{
"password": true, "secret": true, "token": true, "key": true,
"auth": true, "credential": true, "private": true,
}
for k, v := range claimMap {
lowerKey := strings.ToLower(k)
if sensitiveKeys[lowerKey] {
sanitized[k] = "***"
} else {
sanitized[k] = v
}
}
return sanitized
}
func handleError(msg string, details map[string]interface{}) {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, emptyMetrics)
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: msg,
Pairs: details,
})
}
package main
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5/pgxpool"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
const (
initialDelay = 60 * time.Second
cleanupInterval = 1 * time.Hour
)
var delQueries = [...]string{
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - interval '%d days';",
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - interval '%d days';",
"DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL '%d days';",
"DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';",
"DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL '%d days';",
}
func enableHasuraEventCleaner(ctx context.Context) error {
cfgMutex.RLock()
if !cfg.HasuraEventCleaner.Enable {
cfgMutex.RUnlock()
return nil
}
eventMetadataDb := cfg.HasuraEventCleaner.EventMetadataDb
if eventMetadataDb == "" {
logger := cfg.Logger
cfgMutex.RUnlock()
logger.Warning(&libpack_logger.LogMessage{
Message: "Event metadata db URL not specified, event cleaner not active",
})
return nil
}
clearOlderThan := cfg.HasuraEventCleaner.ClearOlderThan
logger := cfg.Logger
cfgMutex.RUnlock()
logger.Info(&libpack_logger.LogMessage{
Message: "Event cleaner enabled",
Pairs: map[string]interface{}{"interval_in_days": clearOlderThan},
})
// Parse pool configuration
poolConfig, err := pgxpool.ParseConfig(eventMetadataDb)
if err != nil {
return err
}
// Set connection pool limits
poolConfig.MaxConns = 10
poolConfig.MinConns = 2
poolConfig.MaxConnLifetime = time.Hour
poolConfig.MaxConnIdleTime = 30 * time.Minute
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
logger.Error(&libpack_logger.LogMessage{
Message: "Failed to create connection pool",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
go func() {
defer pool.Close()
// Wait for initial delay or context cancellation
select {
case <-ctx.Done():
return
case <-time.After(initialDelay):
}
logger.Info(&libpack_logger.LogMessage{
Message: "Initial cleanup of old events",
})
cleanEvents(ctx, pool, clearOlderThan, logger)
ticker := time.NewTicker(cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
logger.Info(&libpack_logger.LogMessage{
Message: "Stopping event cleaner",
})
return
case <-ticker.C:
logger.Info(&libpack_logger.LogMessage{
Message: "Cleaning up old events",
})
cleanEvents(ctx, pool, clearOlderThan, logger)
}
}
}()
return nil
}
func cleanEvents(ctx context.Context, pool *pgxpool.Pool, clearOlderThan int, logger *libpack_logger.Logger) {
var errors []error
var failedQueries []string
for _, query := range delQueries {
_, err := pool.Exec(ctx, fmt.Sprintf(query, clearOlderThan))
if err != nil {
errors = append(errors, err)
failedQueries = append(failedQueries, query)
} else {
logger.Debug(&libpack_logger.LogMessage{
Message: "Successfully executed query",
Pairs: map[string]interface{}{"query": query},
})
}
}
if len(errors) > 0 {
var errMsgs []string
for _, err := range errors {
errMsgs = append(errMsgs, err.Error())
}
logger.Error(&libpack_logger.LogMessage{
Message: "Failed to execute some queries",
Pairs: map[string]interface{}{
"failed_queries": failedQueries,
"errors": errMsgs,
},
})
}
}
package main
import (
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
"github.com/graphql-go/graphql/language/ast"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/source"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
var (
introspectionQueries = map[string]struct{}{
"__schema": {}, "__type": {}, "__typename": {}, "__directive": {},
"__directivelocation": {}, "__field": {}, "__inputvalue": {},
"__enumvalue": {}, "__typekind": {}, "__fieldtype": {},
"__inputobjecttype": {}, "__enumtype": {}, "__uniontype": {},
"__scalars": {}, "__objects": {}, "__interfaces": {},
"__unions": {}, "__enums": {}, "__inputobjects": {}, "__directives": {},
}
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
// Cache for parsed GraphQL queries to avoid reparsing
parsedQueryCache *LRUCache
// Maximum size for parsed query cache
maxQueryCacheSize = 1000
currentCacheSize int64 // Use atomic operations for this
)
func prepareQueriesAndExemptions() {
introspectionAllowedQueries = make(map[string]struct{})
allowedUrls = make(map[string]struct{})
// Process allowed introspection queries
for _, q := range cfg.Security.IntrospectionAllowed {
cleanQuery := strings.Trim(strings.TrimSpace(q), `"`)
introspectionAllowedQueries[strings.ToLower(cleanQuery)] = struct{}{}
}
// Process allowed URLs
for _, u := range cfg.Server.AllowURLs {
allowedUrls[u] = struct{}{}
}
}
type parseGraphQLQueryResult struct {
operationType string
operationName string
activeEndpoint string
cacheTime int
cacheRequest bool
cacheRefresh bool
shouldBlock bool
shouldIgnore bool
}
// AST node pools to reduce GC pressure
var (
// Pool for request/response maps during unmarshaling
queryPool = sync.Pool{
New: func() interface{} {
return make(map[string]interface{}, 48)
},
}
// Pool for parse result objects
resultPool = sync.Pool{
New: func() interface{} {
return &parseGraphQLQueryResult{}
},
}
// Mutex for allocation tracking
allocsMutex = sync.Mutex{}
)
// The following variables are reserved for future GraphQL parsing optimization
// and are not currently in use:
// - fieldPool (Field object pool)
// - operationPool (OperationDefinition object pool)
// - namePool (Name object pool)
// - documentPool (Document object pool)
// - allocsCounter (for tracking allocation counts)
// - allocationsSamp (for memory usage histograms)
// Initialize the query parse cache with a fixed size
func initGraphQLParsing() {
// Set cache size based on available memory
maxQueryCacheSize = runtime.GOMAXPROCS(0) * 250
// Initialize LRU cache with entry limit and 50MB size limit
parsedQueryCache = NewLRUCache(maxQueryCacheSize, 50*1024*1024)
}
// Store a parsed document in the cache with LRU eviction
func cacheQuery(queryText string, document *ast.Document) {
if parsedQueryCache == nil {
return
}
// Store the document in the cache with timestamp for LRU
cacheEntry := &CachedQuery{
Document: document,
Timestamp: time.Now(),
}
// The LRU cache handles eviction automatically
parsedQueryCache.Set(queryText, cacheEntry, int64(len(queryText)))
atomic.AddInt64(¤tCacheSize, 1)
}
// CachedQuery represents a cached GraphQL query with timestamp for LRU
type CachedQuery struct {
Document *ast.Document
Timestamp time.Time
}
// evictOldestQueries is no longer needed with LRU cache
// The LRU cache handles eviction automatically
// Check if we have a cached parsed query
func getCachedQuery(queryText string) *ast.Document {
if parsedQueryCache == nil {
return nil
}
if entry, found := parsedQueryCache.Get(queryText); found {
if cachedQuery, ok := entry.(*CachedQuery); ok {
if cfg != nil && cfg.Monitoring != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLCacheHit, nil)
}
return cachedQuery.Document
}
}
if cfg != nil && cfg.Monitoring != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLCacheMiss, nil)
}
return nil
}
// Track and report memory allocations for GraphQL parsing
func trackParsingAllocations() func() {
var m1 runtime.MemStats
runtime.ReadMemStats(&m1)
return func() {
var m2 runtime.MemStats
runtime.ReadMemStats(&m2)
// Calculate allocations
allocsMutex.Lock()
allocsDelta := int(m2.Mallocs - m1.Mallocs)
// Note: allocsCounter variable is currently unused but will be used in future
// allocsCounter += allocsDelta
allocsMutex.Unlock()
// Record allocation count metrics
if cfg != nil && cfg.Monitoring != nil {
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingAllocs, nil, float64(allocsDelta))
}
}
}
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
startTime := time.Now()
// Set up allocation tracking
trackAllocs := trackParsingAllocations()
defer trackAllocs()
// Get a result object from the pool and initialize it
res := resultPool.Get().(*parseGraphQLQueryResult)
*res = parseGraphQLQueryResult{shouldIgnore: true}
// Ensure we return the result to the pool on function exit
defer func() {
resultPool.Put(res)
}()
// Default to using the write endpoint
res.activeEndpoint = cfg.Server.HostGraphQL
// Get a map from the pool for JSON unmarshaling
m := queryPool.Get().(map[string]interface{})
defer func() {
// Clear and return the map to the pool
for k := range m {
delete(m, k)
}
queryPool.Put(m)
}()
// Add comprehensive input validation
bodySize := len(c.Body())
// Validate query size to prevent DoS attacks
if bodySize > 1024*1024 { // 1MB limit
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Validate minimum size
if bodySize < 2 { // At least "{}"
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Unmarshal the request body
if err := json.Unmarshal(c.Body(), &m); err != nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Extract the query string
query, ok := m["query"].(string)
if !ok {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Try to get the query from cache first
var p *ast.Document
cachedDoc := getCachedQuery(query)
if cachedDoc != nil {
// Use the cached document
p = cachedDoc
} else {
// Parse the GraphQL query with improved source handling
src := source.NewSource(&source.Source{
Body: []byte(query),
Name: "GraphQL request",
})
var err error
p, err = parser.Parse(parser.ParseParams{Source: src})
if err != nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLParsingErrors, nil)
}
return res
}
// Cache the successful parse result for future use
cacheQuery(query, p)
}
// Mark as a valid GraphQL query
res.shouldIgnore = false
res.operationName = "undefined"
// First scan for mutations - they take priority
hasMutation := false
var mutationName string
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
operationType := strings.ToLower(oper.Operation)
if operationType == "mutation" {
hasMutation = true
res.operationType = "mutation"
if oper.Name != nil {
mutationName = oper.Name.Value
// Use mutation name immediately
res.operationName = mutationName
}
break // Found a mutation, no need to continue first pass
}
}
}
// Now process all definitions for other information
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
operationType := strings.ToLower(oper.Operation)
// If we already found a mutation, only update name if needed
if hasMutation {
// We already set operation type to mutation in first pass
// Only set name if we didn't find a mutation name earlier
if res.operationName == "undefined" && oper.Name != nil {
res.operationName = oper.Name.Value
}
} else {
// No mutation found, use the normal logic
if res.operationType == "" {
res.operationType = operationType
}
if res.operationName == "undefined" && oper.Name != nil {
res.operationName = oper.Name.Value
}
}
// Handle endpoint routing - always use write endpoint for mutations
if res.operationType == "mutation" {
res.activeEndpoint = cfg.Server.HostGraphQL
} else if cfg.Server.HostGraphQLReadOnly != "" {
// Use read-only endpoint for non-mutation operations
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
// Block mutations in read-only mode
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
_ = c.Status(403).SendString("The server is in read-only mode")
res.shouldBlock = true
return res
}
// Process directives (like @cached)
processDirectives(oper, res)
// Check for introspection queries if they're blocked
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
_ = c.Status(403).SendString("Introspection queries are not allowed")
res.shouldBlock = true
return res
}
}
}
// Track parsing time
if ifNotInTest() && cfg.Monitoring != nil {
parseTime := float64(time.Since(startTime).Milliseconds())
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingTime, nil, parseTime)
}
return res
}
// processDirectives extracts caching directives from the operation
func processDirectives(oper *ast.OperationDefinition, res *parseGraphQLQueryResult) {
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
res.cacheRequest = true
for _, arg := range dir.Arguments {
switch arg.Name.Value {
case "ttl":
if v, ok := arg.Value.GetValue().(string); ok {
res.cacheTime, _ = strconv.Atoi(v)
}
case "refresh":
if v, ok := arg.Value.GetValue().(bool); ok {
res.cacheRefresh = v
}
}
}
}
}
}
// checkSelections recursively checks if any selection is an introspection query that should be blocked
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
if len(selections) == 0 {
return false
}
// Fast path: if no introspection blocking is configured, return immediately
if !cfg.Security.BlockIntrospection {
return false
}
// Fast path: if there are no allowed introspection queries, check only top level
hasAllowList := len(cfg.Security.IntrospectionAllowed) > 0
for _, s := range selections {
switch sel := s.(type) {
case *ast.Field:
fieldName := strings.ToLower(sel.Name.Value)
// Check if this is an introspection query
if _, exists := introspectionQueries[fieldName]; exists {
if hasAllowList {
// Check if it's in the allowed list
if _, allowed := introspectionAllowedQueries[fieldName]; !allowed {
return true // Block if not allowed
}
} else {
return true // Block if no allowlist exists
}
}
// Check nested selections if present
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
// Check nested selections in fragments
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
}
}
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool {
startTime := time.Now()
blocked := false
// Enable introspection blocking for tests
if !cfg.Security.BlockIntrospection {
cfg.Security.BlockIntrospection = true
}
// Try to get cached parse result first
var p *ast.Document
cachedDoc := getCachedQuery(query)
if cachedDoc != nil {
p = cachedDoc
} else {
// Try parsing as a complete query
src := source.NewSource(&source.Source{
Body: []byte(query),
Name: "GraphQL introspection check",
})
var err error
p, err = parser.Parse(parser.ParseParams{Source: src})
if err == nil && p != nil {
// Cache the successful parse
cacheQuery(query, p)
}
}
if p != nil {
// It's a complete query, check all selections
for _, def := range p.Definitions {
if op, ok := def.(*ast.OperationDefinition); ok {
if op.SelectionSet != nil {
blocked = checkSelections(c, op.GetSelectionSet().Selections)
break
}
}
}
} else {
// Not a complete query, check as a field name
whateverLower := strings.ToLower(query)
if _, exists := introspectionQueries[whateverLower]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
if _, allowed := introspectionAllowedQueries[whateverLower]; !allowed {
blocked = true
}
} else {
blocked = true
}
}
}
if blocked {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
_ = c.Status(403).SendString("Introspection queries are not allowed")
}
// Track parsing time
if ifNotInTest() && cfg.Monitoring != nil {
parseTime := float64(time.Since(startTime).Milliseconds())
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingTime, nil, parseTime)
}
return blocked
}
// NOTE: The clearQueryCache function has been removed as it was unused.
// This functionality will be exposed through an API endpoint in a future release.
package libpack_logger
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/goccy/go-json"
)
const (
LEVEL_DEBUG = iota
LEVEL_INFO
LEVEL_WARN
LEVEL_ERROR
LEVEL_FATAL
)
var levelNames = []string{
"debug",
"info",
"warn",
"error",
"fatal",
}
const (
defaultTimeFormat = time.RFC3339
defaultMinLevel = LEVEL_INFO
defaultShowCaller = false
)
// Logger represents the logging object with configurations.
type Logger struct {
output io.Writer
timeFormat string
minLogLevel int
showCaller bool
mu sync.Mutex // Mutex to protect concurrent access to output
}
// LogMessage represents a log message with optional pairs.
type LogMessage struct {
Pairs map[string]interface{}
Message string
}
// bufferPool is used to reuse bytes.Buffer for efficiency.
var bufferPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// fieldNames allows customization of output field names.
var fieldNames = map[string]string{
"timestamp": "timestamp",
"level": "level",
"message": "message",
}
// osExit is a variable to allow mocking os.Exit in tests
var osExit = os.Exit
// exitMutex ensures thread-safe access to osExit
var exitMutex sync.RWMutex
// New creates a new Logger with default settings.
func New() *Logger {
return &Logger{
timeFormat: defaultTimeFormat,
minLogLevel: defaultMinLevel,
output: os.Stdout,
showCaller: defaultShowCaller,
}
}
// SetOutput sets the output destination for the logger.
func (l *Logger) SetOutput(output io.Writer) *Logger {
l.mu.Lock()
l.output = output
l.mu.Unlock()
return l
}
// GetLogLevel returns the log level integer corresponding to the given level name.
func GetLogLevel(level string) int {
level = strings.ToLower(level)
for i, name := range levelNames {
if name == level {
return i
}
}
return defaultMinLevel
}
// SetTimeFormat sets the time format for the logger's timestamp field.
func (l *Logger) SetTimeFormat(format string) *Logger {
l.timeFormat = format
return l
}
// SetMinLogLevel sets the minimum log level for the logger.
func (l *Logger) SetMinLogLevel(level int) *Logger {
l.minLogLevel = level
return l
}
// SetFieldName allows customizing the field names in log output.
func (l *Logger) SetFieldName(field, name string) *Logger {
fieldNames[field] = name
return l
}
// SetShowCaller enables or disables including the caller information in log output.
func (l *Logger) SetShowCaller(show bool) *Logger {
l.showCaller = show
return l
}
// shouldLog determines if the message should be logged based on the logger's minimum log level.
func (l *Logger) shouldLog(level int) bool {
return level >= l.minLogLevel
}
// log writes the log message with the given level.
func (l *Logger) log(level int, m *LogMessage) {
if m.Pairs == nil {
m.Pairs = make(map[string]interface{})
}
m.Pairs[fieldNames["timestamp"]] = time.Now().Format(l.timeFormat)
m.Pairs[fieldNames["level"]] = levelNames[level]
m.Pairs[fieldNames["message"]] = m.Message
if l.showCaller {
m.Pairs["caller"] = getCaller()
}
buffer := bufferPool.Get().(*bytes.Buffer)
buffer.Reset()
defer bufferPool.Put(buffer)
encoder := json.NewEncoder(buffer)
err := encoder.Encode(m.Pairs)
if err != nil {
fmt.Fprintln(os.Stderr, "Error marshalling log message:", err)
return
}
// Lock the mutex before writing to the output to prevent race conditions
l.mu.Lock()
_, err = l.output.Write(buffer.Bytes())
l.mu.Unlock()
if err != nil {
fmt.Fprintln(os.Stderr, "Error writing log message:", err)
}
}
// Debug logs a debug-level message.
func (l *Logger) Debug(m *LogMessage) {
if l.shouldLog(LEVEL_DEBUG) {
l.log(LEVEL_DEBUG, m)
}
}
// Info logs an info-level message.
func (l *Logger) Info(m *LogMessage) {
if l.shouldLog(LEVEL_INFO) {
l.log(LEVEL_INFO, m)
}
}
// Warn logs a warning-level message.
func (l *Logger) Warn(m *LogMessage) {
if l.shouldLog(LEVEL_WARN) {
l.log(LEVEL_WARN, m)
}
}
// Warning is an alias for Warn.
func (l *Logger) Warning(m *LogMessage) {
l.Warn(m)
}
// Error logs an error-level message.
func (l *Logger) Error(m *LogMessage) {
if l.shouldLog(LEVEL_ERROR) {
l.log(LEVEL_ERROR, m)
}
}
// Fatal logs a fatal-level message.
func (l *Logger) Fatal(m *LogMessage) {
if l.shouldLog(LEVEL_FATAL) {
l.log(LEVEL_FATAL, m)
}
}
// Critical logs a critical-level message and exits the application.
func (l *Logger) Critical(m *LogMessage) {
l.Fatal(m)
exitMutex.RLock()
defer exitMutex.RUnlock()
osExit(1)
}
// getCaller retrieves the file and line number of the caller.
func getCaller() string {
// Skip 3 stack frames: getCaller -> log -> [Debug|Info|...]
const depth = 3
_, file, line, ok := runtime.Caller(depth)
if !ok {
return "unknown:0"
}
file = filepath.Base(file)
return fmt.Sprintf("%s:%d", file, line)
}
package main
import (
"container/list"
"sync"
"time"
)
// LRUCacheEntry represents a cache entry with metadata
type LRUCacheEntry struct {
key string
value interface{}
size int64
timestamp time.Time
element *list.Element
}
// LRUCache implements a thread-safe LRU cache with O(1) operations
type LRUCache struct {
mu sync.RWMutex
maxEntries int
maxSize int64
currentSize int64
entries map[string]*LRUCacheEntry
evictList *list.List
}
// NewLRUCache creates a new LRU cache
func NewLRUCache(maxEntries int, maxSize int64) *LRUCache {
return &LRUCache{
maxEntries: maxEntries,
maxSize: maxSize,
entries: make(map[string]*LRUCacheEntry),
evictList: list.New(),
}
}
// Get retrieves a value from the cache
func (c *LRUCache) Get(key string) (interface{}, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.entries[key]
if !exists {
return nil, false
}
// Move to front (most recently used)
c.evictList.MoveToFront(entry.element)
entry.timestamp = time.Now()
return entry.value, true
}
// Set adds or updates a value in the cache
func (c *LRUCache) Set(key string, value interface{}, size int64) {
c.mu.Lock()
defer c.mu.Unlock()
// Check if key already exists
if entry, exists := c.entries[key]; exists {
// Update existing entry
c.currentSize -= entry.size
c.currentSize += size
entry.value = value
entry.size = size
entry.timestamp = time.Now()
c.evictList.MoveToFront(entry.element)
// Check if we need to evict due to size
c.evictIfNeeded()
return
}
// Create new entry
entry := &LRUCacheEntry{
key: key,
value: value,
size: size,
timestamp: time.Now(),
}
// Add to front of list
element := c.evictList.PushFront(entry)
entry.element = element
c.entries[key] = entry
c.currentSize += size
// Evict if necessary
c.evictIfNeeded()
}
// evictIfNeeded removes entries when cache limits are exceeded
func (c *LRUCache) evictIfNeeded() {
// Evict based on entry count
for c.evictList.Len() > c.maxEntries {
c.evictOldest()
}
// Evict based on size
for c.currentSize > c.maxSize && c.evictList.Len() > 0 {
c.evictOldest()
}
}
// evictOldest removes the least recently used entry
func (c *LRUCache) evictOldest() {
element := c.evictList.Back()
if element == nil {
return
}
entry := element.Value.(*LRUCacheEntry)
c.removeEntry(entry)
}
// removeEntry removes an entry from the cache
func (c *LRUCache) removeEntry(entry *LRUCacheEntry) {
c.evictList.Remove(entry.element)
delete(c.entries, entry.key)
c.currentSize -= entry.size
}
// Delete removes a key from the cache
func (c *LRUCache) Delete(key string) {
c.mu.Lock()
defer c.mu.Unlock()
entry, exists := c.entries[key]
if !exists {
return
}
c.removeEntry(entry)
}
// Clear removes all entries from the cache
func (c *LRUCache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.entries = make(map[string]*LRUCacheEntry)
c.evictList = list.New()
c.currentSize = 0
}
// Len returns the number of entries in the cache
func (c *LRUCache) Len() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.evictList.Len()
}
// Size returns the current size of the cache in bytes
func (c *LRUCache) Size() int64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.currentSize
}
// CleanupExpired removes entries older than the given duration
func (c *LRUCache) CleanupExpired(maxAge time.Duration) int {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
removed := 0
// Iterate from back (oldest) to front (newest)
for element := c.evictList.Back(); element != nil; {
entry := element.Value.(*LRUCacheEntry)
// If entry is not expired, we can stop (entries are ordered by access time)
if now.Sub(entry.timestamp) <= maxAge {
break
}
// Remove expired entry
next := element.Prev()
c.removeEntry(entry)
removed++
element = next
}
return removed
}
// GetStats returns cache statistics
func (c *LRUCache) GetStats() map[string]interface{} {
c.mu.RLock()
defer c.mu.RUnlock()
return map[string]interface{}{
"entries": c.evictList.Len(),
"size_bytes": c.currentSize,
"max_entries": c.maxEntries,
"max_size": c.maxSize,
"fill_percent": float64(c.currentSize) / float64(c.maxSize) * 100,
}
}
package main
import (
"context"
"flag"
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/gofiber/fiber/v2/middleware/proxy"
"github.com/gookit/goutil/envutil"
graphql "github.com/lukaszraczylo/go-simple-graphql"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
libpack_tracing "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing"
)
var (
cfg *config
cfgMutex sync.RWMutex
once sync.Once
tracer *libpack_tracing.TracingSetup
shutdownManager *ShutdownManager
)
// getDetailsFromEnv retrieves the value from the environment or returns the default.
// It first checks for a prefixed environment variable (GMP_KEY), then falls back to the unprefixed version.
func getDetailsFromEnv[T any](key string, defaultValue T) T {
prefixedKey := "GMP_" + key
switch v := any(defaultValue).(type) {
case string:
if val, ok := os.LookupEnv(prefixedKey); ok {
return any(val).(T)
}
return any(envutil.Getenv(key, v)).(T)
case int:
if val, ok := os.LookupEnv(prefixedKey); ok {
if intVal, err := strconv.Atoi(val); err == nil {
return any(intVal).(T)
}
}
return any(envutil.GetInt(key, v)).(T)
case bool:
if val, ok := os.LookupEnv(prefixedKey); ok {
boolVal := strings.ToLower(val) == "true" || val == "1"
return any(boolVal).(T)
}
return any(envutil.GetBool(key, v)).(T)
default:
return defaultValue
}
}
// parseConfig loads and parses the configuration.
func parseConfig() {
libpack_config.PKG_NAME = "graphql_proxy"
c := config{}
// Server configurations
c.Server.PortGraphQL = getDetailsFromEnv("PORT_GRAPHQL", 8080)
c.Server.PortMonitoring = getDetailsFromEnv("MONITORING_PORT", 9393)
c.Server.HostGraphQL = getDetailsFromEnv("HOST_GRAPHQL", "http://localhost/")
c.Server.HostGraphQLReadOnly = getDetailsFromEnv("HOST_GRAPHQL_READONLY", "")
// Client configurations
c.Client.JWTUserClaimPath = getDetailsFromEnv("JWT_USER_CLAIM_PATH", "")
c.Client.JWTRoleClaimPath = getDetailsFromEnv("JWT_ROLE_CLAIM_PATH", "")
c.Client.RoleFromHeader = getDetailsFromEnv("ROLE_FROM_HEADER", "")
c.Client.RoleRateLimit = getDetailsFromEnv("ROLE_RATE_LIMIT", false)
// In-memory cache
c.Cache.CacheEnable = getDetailsFromEnv("ENABLE_GLOBAL_CACHE", false)
c.Cache.CacheTTL = getDetailsFromEnv("CACHE_TTL", 60)
c.Cache.CacheMaxMemorySize = getDetailsFromEnv("CACHE_MAX_MEMORY_SIZE", 100) // Default 100MB
c.Cache.CacheMaxEntries = getDetailsFromEnv("CACHE_MAX_ENTRIES", 10000) // Default 10000 entries
// Redis cache
c.Cache.CacheRedisEnable = getDetailsFromEnv("ENABLE_REDIS_CACHE", false)
c.Cache.CacheRedisURL = getDetailsFromEnv("CACHE_REDIS_URL", "localhost:6379")
c.Cache.CacheRedisPassword = getDetailsFromEnv("CACHE_REDIS_PASSWORD", "")
c.Cache.CacheRedisDB = getDetailsFromEnv("CACHE_REDIS_DB", 0)
// Security configurations
c.Security.BlockIntrospection = getDetailsFromEnv("BLOCK_SCHEMA_INTROSPECTION", false)
c.Security.IntrospectionAllowed = func() []string {
urls := getDetailsFromEnv("ALLOWED_INTROSPECTION", "")
if urls == "" {
return nil
}
return strings.Split(urls, ",")
}()
c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info"))
// Logger setup
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).
SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
// Health check
c.Server.HealthcheckGraphQL = getDetailsFromEnv("HEALTHCHECK_GRAPHQL_URL", "")
c.Client.GQLClient = graphql.NewConnection()
c.Client.GQLClient.SetEndpoint(c.Server.HealthcheckGraphQL)
// Server modes
c.Server.AccessLog = getDetailsFromEnv("ENABLE_ACCESS_LOG", false)
c.Server.ReadOnlyMode = getDetailsFromEnv("READ_ONLY_MODE", false)
c.Server.AllowURLs = func() []string {
urls := getDetailsFromEnv("ALLOWED_URLS", "")
if urls == "" {
return nil
}
return strings.Split(urls, ",")
}()
// Client timeout and connection configurations with bounds checking
clientTimeout := getDetailsFromEnv("PROXIED_CLIENT_TIMEOUT", 120)
if clientTimeout < 1 || clientTimeout > 3600 { // 1 second to 1 hour max
c.Logger.Warning(&libpack_logging.LogMessage{
Message: "Invalid client timeout, using default",
Pairs: map[string]interface{}{"requested": clientTimeout, "default": 120},
})
clientTimeout = 120
}
c.Client.ClientTimeout = clientTimeout
// Configure HTTP connection pool and timeouts with sensible defaults
// MaxConnsPerHost limits parallel connections to prevent overwhelming backends
maxConns := getDetailsFromEnv("MAX_CONNS_PER_HOST", 1024)
if maxConns < 1 || maxConns > 10000 { // Reasonable bounds
c.Logger.Warning(&libpack_logging.LogMessage{
Message: "Invalid max connections per host, using default",
Pairs: map[string]interface{}{"requested": maxConns, "default": 1024},
})
maxConns = 1024
}
c.Client.MaxConnsPerHost = maxConns
// Configure distinct timeout values for more granular control with bounds checking
readTimeout := getDetailsFromEnv("CLIENT_READ_TIMEOUT", c.Client.ClientTimeout)
if readTimeout < 1 || readTimeout > 3600 {
readTimeout = c.Client.ClientTimeout
}
c.Client.ReadTimeout = readTimeout
writeTimeout := getDetailsFromEnv("CLIENT_WRITE_TIMEOUT", c.Client.ClientTimeout)
if writeTimeout < 1 || writeTimeout > 3600 {
writeTimeout = c.Client.ClientTimeout
}
c.Client.WriteTimeout = writeTimeout
// MaxIdleConnDuration controls how long connections stay in the pool
idleDuration := getDetailsFromEnv("CLIENT_MAX_IDLE_CONN_DURATION", 300)
if idleDuration < 1 || idleDuration > 7200 { // 1 second to 2 hours max
idleDuration = 300
}
c.Client.MaxIdleConnDuration = idleDuration
// Secure by default: TLS verification is enabled unless explicitly disabled
c.Client.DisableTLSVerify = getDetailsFromEnv("CLIENT_DISABLE_TLS_VERIFY", false)
// Create HTTP client with the optimized parameters
c.Client.FastProxyClient = createFasthttpClient(&c)
proxy.WithClient(c.Client.FastProxyClient) // Setting the global proxy client
// API configurations
c.Server.EnableApi = getDetailsFromEnv("ENABLE_API", false)
c.Server.ApiPort = getDetailsFromEnv("API_PORT", 9090)
// Validate and sanitize banned users file path to prevent path traversal
bannedUsersFile := getDetailsFromEnv("BANNED_USERS_FILE", "/go/src/app/banned_users.json")
if validatedPath, err := validateFilePath(bannedUsersFile); err != nil {
c.Logger.Error(&libpack_logging.LogMessage{
Message: "Invalid banned users file path, using default",
Pairs: map[string]interface{}{"requested": bannedUsersFile, "error": err.Error()},
})
c.Api.BannedUsersFile = "/go/src/app/banned_users.json"
} else {
c.Api.BannedUsersFile = validatedPath
}
c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false)
c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0)
// Hasura event cleaner
c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false)
c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1)
c.HasuraEventCleaner.EventMetadataDb = getDetailsFromEnv("HASURA_EVENT_METADATA_DB", "")
// Tracing configuration
c.Tracing.Enable = getDetailsFromEnv("ENABLE_TRACE", false)
c.Tracing.Endpoint = getDetailsFromEnv("TRACE_ENDPOINT", "localhost:4317")
// Circuit Breaker configuration
c.CircuitBreaker.Enable = getDetailsFromEnv("ENABLE_CIRCUIT_BREAKER", false)
c.CircuitBreaker.MaxFailures = getDetailsFromEnv("CIRCUIT_MAX_FAILURES", 5)
c.CircuitBreaker.Timeout = getDetailsFromEnv("CIRCUIT_TIMEOUT_SECONDS", 30)
c.CircuitBreaker.MaxRequestsInHalfOpen = getDetailsFromEnv("CIRCUIT_MAX_HALF_OPEN_REQUESTS", 2)
c.CircuitBreaker.ReturnCachedOnOpen = getDetailsFromEnv("CIRCUIT_RETURN_CACHED_ON_OPEN", true)
c.CircuitBreaker.TripOnTimeouts = getDetailsFromEnv("CIRCUIT_TRIP_ON_TIMEOUTS", true)
c.CircuitBreaker.TripOn5xx = getDetailsFromEnv("CIRCUIT_TRIP_ON_5XX", true)
cfgMutex.Lock()
cfg = &c
cfgMutex.Unlock()
// Initialize tracing if enabled
if cfg.Tracing.Enable {
if cfg.Tracing.Endpoint == "" {
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Tracing endpoint not configured, using default localhost:4317",
})
cfg.Tracing.Endpoint = "localhost:4317"
}
var err error
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
tracer, err = libpack_tracing.NewTracing(ctx, cfg.Tracing.Endpoint)
if err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Failed to initialize tracing",
Pairs: map[string]interface{}{"error": err.Error()},
})
} else {
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Tracing initialized",
Pairs: map[string]interface{}{"endpoint": cfg.Tracing.Endpoint},
})
}
}
// Initialize cache if enabled
if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
cacheConfig := &libpack_cache.CacheConfig{
Logger: cfg.Logger,
TTL: cfg.Cache.CacheTTL,
}
// Redis cache configurations
if cfg.Cache.CacheRedisEnable {
cacheConfig.Redis.Enable = true
cacheConfig.Redis.URL = cfg.Cache.CacheRedisURL
cacheConfig.Redis.Password = cfg.Cache.CacheRedisPassword
cacheConfig.Redis.DB = cfg.Cache.CacheRedisDB
} else {
// Memory cache configurations
cacheConfig.Memory.MaxMemorySize = int64(cfg.Cache.CacheMaxMemorySize) * 1024 * 1024 // Convert MB to bytes
cacheConfig.Memory.MaxEntries = int64(cfg.Cache.CacheMaxEntries)
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Configuring memory cache with limits",
Pairs: map[string]interface{}{
"max_memory_mb": cfg.Cache.CacheMaxMemorySize,
"max_entries": cfg.Cache.CacheMaxEntries,
},
})
}
libpack_cache.EnableCache(cacheConfig)
// Start memory monitoring for in-memory cache if it's not Redis
// Will be started with context in main()
}
// Initialize circuit breaker if enabled
if cfg.CircuitBreaker.Enable {
initCircuitBreaker(cfg)
}
// Load rate limit configuration with improved error handling
if err := loadRatelimitConfig(); err != nil {
// Log the error with clear guidance
detailedError := err.Error()
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Failed to start service due to rate limit configuration error",
Pairs: map[string]interface{}{
"error": detailedError,
},
})
// If we're not in a test environment, print to stderr and exit if config error
if ifNotInTest() {
fmt.Fprintln(os.Stderr, "⚠️ CRITICAL ERROR: Rate limit configuration problem detected")
fmt.Fprintln(os.Stderr, detailedError)
os.Exit(1)
}
}
// API and event cleaner will be started with context in main()
prepareQueriesAndExemptions()
// Initialize GraphQL parsing optimizations
initGraphQLParsing()
}
func main() {
// Parse configuration
parseConfig()
// Setup graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Initialize shutdown manager
shutdownManager = NewShutdownManager(ctx)
// Create a wait group to manage goroutines
var wg sync.WaitGroup
// Setup signal handling for graceful shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Shutdown signal received, stopping services...",
})
cancel()
}()
// Start background services with context
once.Do(func() {
// Start API server
shutdownManager.RunGoroutine("api-server", func(ctx context.Context) {
if err := enableApi(ctx); err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "API server error",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
})
// Start event cleaner
shutdownManager.RunGoroutine("event-cleaner", func(ctx context.Context) {
if err := enableHasuraEventCleaner(ctx); err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Event cleaner error",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
})
// Start cache memory monitoring if not using Redis
if cfg.Cache.CacheEnable && !cfg.Cache.CacheRedisEnable {
shutdownManager.RunGoroutine("cache-memory-monitoring", startCacheMemoryMonitoring)
}
})
// Register connection pool for cleanup
shutdownManager.RegisterComponent("http-connection-pool", func(ctx context.Context) error {
if connectionPoolManager != nil {
return connectionPoolManager.Shutdown()
}
return nil
})
// Cache shutdown is handled internally by the cache implementation
// Start monitoring server
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Starting monitoring server...",
Pairs: map[string]interface{}{"port": cfg.Server.PortMonitoring},
})
// Start monitoring server in a goroutine
wg.Add(1)
monitoringErrCh := make(chan error, 1)
go func() {
defer wg.Done()
if err := StartMonitoringServer(); err != nil {
monitoringErrCh <- err
}
}()
// Give monitoring server time to initialize
select {
case err := <-monitoringErrCh:
cfg.Logger.Critical(&libpack_logging.LogMessage{
Message: "Failed to start monitoring server",
Pairs: map[string]interface{}{
"error": err.Error(),
"port": cfg.Server.PortMonitoring,
},
})
os.Exit(1)
case <-time.After(2 * time.Second):
// Continue if no error received within timeout
}
// Start HTTP proxy
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Starting HTTP proxy server...",
Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL},
})
// Start HTTP proxy in a goroutine
wg.Add(1)
proxyErrCh := make(chan error, 1)
go func() {
defer wg.Done()
if err := StartHTTPProxy(); err != nil {
proxyErrCh <- err
}
}()
// Block for a moment to check for immediate startup errors
select {
case err := <-proxyErrCh:
cfg.Logger.Critical(&libpack_logging.LogMessage{
Message: "Failed to start HTTP proxy server",
Pairs: map[string]interface{}{
"error": err.Error(),
"port": cfg.Server.PortGraphQL,
},
})
os.Exit(1)
case <-time.After(1 * time.Second):
// Continue if no error received within timeout
}
// Wait for context cancellation
<-ctx.Done()
// Perform cleanup
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Shutting down services...",
})
// Register tracer shutdown
if tracer != nil {
shutdownManager.RegisterComponent("tracer", func(ctx context.Context) error {
return tracer.Shutdown(ctx)
})
}
// Perform graceful shutdown of all components
if err := shutdownManager.Shutdown(30 * time.Second); err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Error during shutdown",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
// Wait for all goroutines to finish (with timeout)
waitCh := make(chan struct{})
go func() {
wg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "All services shut down gracefully",
})
case <-time.After(10 * time.Second):
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Some services didn't shut down gracefully within timeout",
})
}
}
// startCacheMemoryMonitoring polls memory cache usage and updates metrics
func startCacheMemoryMonitoring(ctx context.Context) {
// Check every few seconds (more frequent than cleanup routine)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Starting memory cache monitoring",
})
// Use mutex to protect concurrent access to metrics registration
var metricsMutex sync.Mutex
// Create initial metrics with proper synchronization
metricsMutex.Lock()
cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryLimit, nil,
float64(libpack_cache.GetCacheMaxMemorySize()))
metricsMutex.Unlock()
for {
select {
case <-ctx.Done():
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Stopping cache memory monitoring",
})
return
case <-ticker.C:
// Skip if monitoring not initialized or cache not initialized
if cfg.Monitoring == nil || !libpack_cache.IsCacheInitialized() {
continue
}
// Get current memory usage atomically
memoryUsage := libpack_cache.GetCacheMemoryUsage()
memoryLimit := libpack_cache.GetCacheMaxMemorySize()
// Update metrics with proper synchronization
metricsMutex.Lock()
cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryUsage, nil,
float64(memoryUsage))
cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryLimit, nil,
float64(memoryLimit))
// Calculate percentage (protect against division by zero)
var percentUsed float64
if memoryLimit > 0 {
percentUsed = float64(memoryUsage) / float64(memoryLimit) * 100.0
}
cfg.Monitoring.RegisterMetricsGauge(libpack_monitoring.MetricsCacheMemoryPercent, nil,
percentUsed)
metricsMutex.Unlock()
// Log if memory usage is high (over 80%)
if percentUsed > 80.0 {
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Memory cache usage is high",
Pairs: map[string]interface{}{
"memory_usage_bytes": memoryUsage,
"memory_limit_bytes": memoryLimit,
"percent_used": percentUsed,
},
})
}
}
}
}
// validateFilePath validates and sanitizes file paths to prevent path traversal attacks
func validateFilePath(path string) (string, error) {
if path == "" {
return "", fmt.Errorf("empty file path")
}
// Check for path traversal attempts
if strings.Contains(path, "..") {
return "", fmt.Errorf("path traversal detected")
}
// Check for null bytes
if strings.Contains(path, "\x00") {
return "", fmt.Errorf("null byte in path")
}
// Ensure path is absolute or within allowed directories
allowedPrefixes := []string{
"/go/src/app/",
"./",
"/tmp/",
"/var/tmp/",
}
isAllowed := false
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(path, prefix) {
isAllowed = true
break
}
}
if !isAllowed {
return "", fmt.Errorf("path not in allowed directories")
}
return path, nil
}
// ifNotInTest checks if the program is not running in a test environment.
func ifNotInTest() bool {
return flag.Lookup("test.v") == nil
}
package main
import (
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
// StartMonitoringServer initializes and starts the monitoring server.
func StartMonitoringServer() error {
cfg.Monitoring = libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{
PurgeOnCrawl: cfg.Server.PurgeOnCrawl,
PurgeEvery: cfg.Server.PurgeEvery,
})
cfg.Monitoring.AddMetricsPrefix("graphql_proxy")
cfg.Monitoring.RegisterDefaultMetrics()
// Currently, the monitoring server initialization doesn't throw errors,
// but we return nil to maintain the interface contract
return nil
}
package libpack_monitoring
func (ms *MetricsSetup) RegisterDefaultMetrics() {
ms.RegisterMetricsCounter(MetricsSucceeded, nil)
ms.RegisterMetricsCounter(MetricsFailed, nil)
ms.RegisterMetricsCounter(MetricsSkipped, nil)
ms.RegisterMetricsHistogram(MetricsDuration, nil)
ms.RegisterMetricsCounter(MetricsCacheHit, nil)
ms.RegisterMetricsCounter(MetricsCacheMiss, nil)
ms.RegisterMetricsCounter(MetricsQueriesCached, nil)
}
func (ms *MetricsSetup) RegisterGoMetrics() {
// TODO: metrics.WriteProcessMetrics(ms.metrics_set)
}
package libpack_monitoring
import (
"bytes"
"fmt"
"os"
"sort"
"strings"
"sync"
"unicode"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
)
var sortedLabelKeysCache = struct {
m sync.Map
}{}
func (ms *MetricsSetup) get_metrics_name(name string, labels map[string]string) string {
var buf bytes.Buffer
podName := getPodName()
if labels == nil {
labels = defaultLabels(podName)
} else {
ensureDefaultLabels(&labels, podName)
}
if ms.metrics_prefix != "" {
buf.WriteString(ms.metrics_prefix)
buf.WriteByte('_')
}
buf.WriteString(name)
if len(labels) > 0 {
buf.WriteByte('{')
appendSortedLabels(&buf, labels)
buf.WriteByte('}')
}
return buf.String()
}
func getPodName() string {
const unknownPodName = "unknown"
if hn, err := os.Hostname(); err == nil {
return hn
}
return unknownPodName
}
func defaultLabels(podName string) map[string]string {
return map[string]string{
"microservice": libpack_config.PKG_NAME,
"pod": podName,
}
}
func ensureDefaultLabels(labels *map[string]string, podName string) {
if *labels == nil {
*labels = make(map[string]string)
}
if _, exists := (*labels)["microservice"]; !exists {
(*labels)["microservice"] = libpack_config.PKG_NAME
}
if _, exists := (*labels)["pod"]; !exists {
(*labels)["pod"] = podName
}
}
func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
keys := getSortedKeys(labels)
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(k)
buf.WriteString(`="`)
buf.WriteString(labels[k])
buf.WriteByte('"')
}
}
func getSortedKeys(labels map[string]string) []string {
labelsKey := labelsToString(labels)
// Check if the sorted keys are already cached
if keys, ok := sortedLabelKeysCache.m.Load(labelsKey); ok {
return keys.([]string)
}
// Compute the sorted keys
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
// Store the sorted keys in the cache
sortedLabelKeysCache.m.Store(labelsKey, keys)
return keys
}
func labelsToString(labels map[string]string) string {
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(labels[k])
sb.WriteByte(';')
}
return sb.String()
}
func validate_metrics_name(name string) error {
cleanedName := clean_metric_name(name)
finalName := strings.Trim(cleanedName, "_")
if finalName != name {
return fmt.Errorf("invalid metric name: %s, expected %s", name, finalName)
}
return nil
}
func clean_metric_name(name string) string {
var buf bytes.Buffer
lastWasUnderscore := false
for _, r := range name {
if is_allowed_rune(r) {
if is_special_rune(r) {
if lastWasUnderscore {
continue
}
r = '_'
lastWasUnderscore = true
} else {
lastWasUnderscore = false
}
buf.WriteRune(r)
} else if !lastWasUnderscore {
buf.WriteByte('_')
lastWasUnderscore = true
}
}
return strings.Trim(buf.String(), "_")
}
func is_allowed_rune(r rune) bool {
return unicode.IsLetter(r) || unicode.IsDigit(r) || r == ' ' || r == '_'
}
func is_special_rune(r rune) bool {
return r == ' ' || r == '_'
}
func compile_metrics_with_labels(name string, labels map[string]string) string {
var buf bytes.Buffer
buf.WriteString(name)
keys := getSortedKeys(labels)
for _, k := range keys {
buf.WriteByte('_')
buf.WriteString(k)
buf.WriteByte('_')
buf.WriteString(labels[k])
}
return buf.String()
}
package libpack_monitoring
import (
"flag"
"fmt"
"time"
"github.com/VictoriaMetrics/metrics"
"github.com/gofiber/fiber/v2"
"github.com/gookit/goutil/envutil"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
type MetricsSetup struct {
metrics_set *metrics.Set
metrics_set_custom *metrics.Set
ic *InitConfig
metrics_prefix string
}
var log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO)
type InitConfig struct {
PurgeOnCrawl bool
PurgeEvery int
}
func NewMonitoring(ic *InitConfig) *MetricsSetup {
ms := &MetricsSetup{
ic: ic,
metrics_set: metrics.NewSet(),
metrics_set_custom: metrics.NewSet(),
}
if flag.Lookup("test.v") == nil {
go ms.startPrometheusEndpoint()
if ic.PurgeEvery > 0 {
ticker := time.NewTicker(time.Duration(ic.PurgeEvery) * time.Second)
go func() {
for range ticker.C {
ms.PurgeMetrics()
}
}()
}
}
return ms
}
func (ms *MetricsSetup) startPrometheusEndpoint() {
app := fiber.New(fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
})
app.Get("/metrics", ms.metricsEndpoint)
if err := app.Listen(fmt.Sprintf(":%d", envutil.GetInt("MONITORING_PORT", 9393))); err != nil {
log.Critical(&libpack_logger.LogMessage{
Message: "Can't start the MONITORING service",
Pairs: map[string]interface{}{"error": err},
})
}
}
func (ms *MetricsSetup) metricsEndpoint(c *fiber.Ctx) error {
ms.metrics_set.WritePrometheus(c.Response().BodyWriter())
ms.metrics_set_custom.WritePrometheus(c.Response().BodyWriter())
if ms.ic.PurgeOnCrawl && ms.ic.PurgeEvery == 0 {
ms.PurgeMetrics()
}
return nil
}
func (ms *MetricsSetup) AddMetricsPrefix(prefix string) {
ms.metrics_prefix = prefix
}
func (ms *MetricsSetup) ListActiveMetrics() []string {
return ms.metrics_set.ListMetricNames()
}
func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[string]string, val float64) *metrics.Gauge {
if err := validate_metrics_name(metric_name); err != nil {
log.Error(&libpack_logger.LogMessage{
Message: "RegisterMetricsGauge() error - invalid metric name",
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
})
// Return a dummy gauge instead of nil to prevent panics
return &metrics.Gauge{}
}
return ms.metrics_set_custom.GetOrCreateGauge(ms.get_metrics_name(metric_name, labels), func() float64 {
return val
})
}
func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[string]string) *metrics.Counter {
if err := validate_metrics_name(metric_name); err != nil {
log.Error(&libpack_logger.LogMessage{
Message: "RegisterMetricsCounter() error - invalid metric name",
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
})
// Return a dummy counter instead of nil to prevent panics
return &metrics.Counter{}
}
if metric_name == MetricsSucceeded || metric_name == MetricsFailed || metric_name == MetricsSkipped {
return ms.metrics_set.GetOrCreateCounter(ms.get_metrics_name(metric_name, labels))
}
return ms.metrics_set_custom.GetOrCreateCounter(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) RegisterFloatCounter(metric_name string, labels map[string]string) *metrics.FloatCounter {
if err := validate_metrics_name(metric_name); err != nil {
log.Error(&libpack_logger.LogMessage{
Message: "RegisterFloatCounter() error - invalid metric name",
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
})
// Return a dummy float counter instead of nil to prevent panics
return &metrics.FloatCounter{}
}
return ms.metrics_set_custom.GetOrCreateFloatCounter(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) RegisterMetricsSummary(metric_name string, labels map[string]string) *metrics.Summary {
if err := validate_metrics_name(metric_name); err != nil {
log.Error(&libpack_logger.LogMessage{
Message: "RegisterMetricsSummary() error - invalid metric name",
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
})
// Return a dummy summary instead of nil to prevent panics
return &metrics.Summary{}
}
return ms.metrics_set_custom.GetOrCreateSummary(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) RegisterMetricsHistogram(metric_name string, labels map[string]string) *metrics.Histogram {
if err := validate_metrics_name(metric_name); err != nil {
log.Error(&libpack_logger.LogMessage{
Message: "RegisterMetricsHistogram() error - invalid metric name",
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
})
// Return a dummy histogram instead of nil to prevent panics
return &metrics.Histogram{}
}
return ms.metrics_set_custom.GetOrCreateHistogram(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) Increment(metric_name string, labels map[string]string) {
ms.RegisterMetricsCounter(metric_name, labels).Inc()
}
func (ms *MetricsSetup) IncrementFloat(metric_name string, labels map[string]string, value float64) {
ms.RegisterFloatCounter(metric_name, labels).Add(value)
}
func (ms *MetricsSetup) Set(metric_name string, labels map[string]string, value uint64) {
ms.RegisterMetricsCounter(metric_name, labels).Set(value)
}
func (ms *MetricsSetup) Update(metric_name string, labels map[string]string, value float64) {
ms.RegisterMetricsHistogram(metric_name, labels).Update(value)
}
func (ms *MetricsSetup) UpdateDuration(metric_name string, labels map[string]string, value time.Time) {
ms.RegisterMetricsHistogram(metric_name, labels).UpdateDuration(value)
}
func (ms *MetricsSetup) UpdateSummary(metric_name string, labels map[string]string, value float64) {
ms.RegisterMetricsSummary(metric_name, labels).Update(value)
}
func (ms *MetricsSetup) RemoveMetrics(metric_name string, labels map[string]string) {
ms.metrics_set_custom.UnregisterMetric(ms.get_metrics_name(metric_name, labels))
}
func (ms *MetricsSetup) PurgeMetrics() {
ms.metrics_set_custom.UnregisterAllMetrics()
}
package main
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"math"
"net"
"net/url"
"strings"
"sync"
"time"
"go.opentelemetry.io/otel/trace"
"github.com/avast/retry-go/v4"
"github.com/gofiber/fiber/v2"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
libpack_tracing "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing"
"github.com/sony/gobreaker"
"github.com/valyala/fasthttp"
)
// Errors related to circuit breaker
var (
ErrCircuitOpen = errors.New("circuit breaker is open")
)
// Default values for circuit breaker
const (
defaultMaxRequestsInHalfOpen = 10 // Default maximum requests in half-open state
)
// Global circuit breaker
var (
cb *gobreaker.CircuitBreaker
cbMutex sync.RWMutex
)
// safeUint32 converts an int to uint32 safely, handling negative values and values exceeding uint32 max
func safeUint32(value int) uint32 {
// Handle negative values
if value < 0 {
return 0
}
// Handle values exceeding uint32 max
if value > math.MaxUint32 {
return math.MaxUint32
}
return uint32(value)
}
// initCircuitBreaker initializes the circuit breaker with configured settings
func initCircuitBreaker(config *config) {
// Only initialize if enabled
if !config.CircuitBreaker.Enable {
config.Logger.Info(&libpack_logger.LogMessage{
Message: "Circuit breaker is disabled",
})
return
}
cbMutex.Lock()
defer cbMutex.Unlock()
// Initialize circuit breaker metrics
InitializeCircuitBreakerMetrics(config.Monitoring)
// Create circuit breaker settings
cbSettings := gobreaker.Settings{
Name: "graphql-proxy-circuit",
MaxRequests: safeMaxRequests(config.CircuitBreaker.MaxRequestsInHalfOpen),
Interval: 0, // No specific interval for counting failures
Timeout: time.Duration(config.CircuitBreaker.Timeout) * time.Second,
ReadyToTrip: createTripFunc(config),
OnStateChange: createStateChangeFunc(config),
}
// Initialize the circuit breaker
cb = gobreaker.NewCircuitBreaker(cbSettings)
config.Logger.Info(&libpack_logger.LogMessage{
Message: "Circuit breaker initialized",
Pairs: map[string]interface{}{
"max_failures": config.CircuitBreaker.MaxFailures,
"timeout_seconds": config.CircuitBreaker.Timeout,
"max_half_open_reqs": config.CircuitBreaker.MaxRequestsInHalfOpen,
},
})
}
// createTripFunc returns a function that determines when to trip the circuit
func createTripFunc(config *config) func(counts gobreaker.Counts) bool {
return func(counts gobreaker.Counts) bool {
failureRatio := float64(counts.TotalFailures) / float64(counts.Requests)
shouldTrip := counts.ConsecutiveFailures >= safeUint32(config.CircuitBreaker.MaxFailures)
if shouldTrip {
config.Logger.Warning(&libpack_logger.LogMessage{
Message: "Circuit breaker tripped",
Pairs: map[string]interface{}{
"consecutive_failures": counts.ConsecutiveFailures,
"failure_ratio": failureRatio,
"total_failures": counts.TotalFailures,
"total_requests": counts.Requests,
},
})
}
return shouldTrip
}
}
// createStateChangeFunc returns a function that handles circuit state changes
func createStateChangeFunc(config *config) func(name string, from gobreaker.State, to gobreaker.State) {
return func(name string, from gobreaker.State, to gobreaker.State) {
var stateValue float64
var stateName string
switch to {
case gobreaker.StateOpen:
stateValue = float64(libpack_monitoring.CircuitOpen)
stateName = "open"
case gobreaker.StateHalfOpen:
stateValue = float64(libpack_monitoring.CircuitHalfOpen)
stateName = "half-open"
case gobreaker.StateClosed:
stateValue = float64(libpack_monitoring.CircuitClosed)
stateName = "closed"
}
// Update metrics using atomic operations to prevent race conditions
// Use a separate atomic variable to track state instead of recreating gauges
updateCircuitBreakerState(config, stateValue)
// Log state change
config.Logger.Info(&libpack_logger.LogMessage{
Message: "Circuit breaker state changed",
Pairs: map[string]interface{}{
"from": from.String(),
"to": to.String(),
"name": name,
},
})
// Use the new metrics system
if cbMetrics != nil {
// Replace hyphens with underscores to avoid validation errors
safeStateName := strings.ReplaceAll(stateName, "-", "_")
stateKey := fmt.Sprintf("circuit_state_%s", safeStateName)
counter := cbMetrics.GetOrCreateFailCounter(config.Monitoring, stateKey)
counter.Inc()
}
}
}
// createFasthttpClient creates and configures a fasthttp client with optimized settings.
// The client is configured based on the provided configuration settings, with careful
// attention to performance and security considerations.
func createFasthttpClient(clientConfig *config) *fasthttp.Client {
tlsConfig := &tls.Config{
InsecureSkipVerify: clientConfig.Client.DisableTLSVerify,
}
// Calculate timeout values, ensuring they're always positive
clientTimeout := time.Duration(clientConfig.Client.ClientTimeout) * time.Second
if clientTimeout <= 0 {
clientTimeout = 30 * time.Second // Default timeout of 30 seconds
}
// For timeout behavior, use the client timeout for all timeout settings
// to ensure consistent behavior
readTimeout := clientTimeout
writeTimeout := clientTimeout
// Create a custom dialer with timeout
dialer := &fasthttp.TCPDialer{
Concurrency: 1000,
DNSCacheDuration: time.Hour,
}
client := &fasthttp.Client{
Name: "graphql_proxy",
NoDefaultUserAgentHeader: true,
TLSConfig: tlsConfig,
// Control connection pool size to prevent overwhelming backend services
MaxConnsPerHost: clientConfig.Client.MaxConnsPerHost,
// Configure timeouts to handle different network scenarios
// Setting all timeout-related parameters to ensure proper timeout behavior
Dial: func(addr string) (net.Conn, error) {
return dialer.DialTimeout(addr, clientTimeout)
},
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
MaxIdleConnDuration: time.Duration(clientConfig.Client.MaxIdleConnDuration) * time.Second,
MaxConnDuration: clientTimeout,
DisableHeaderNamesNormalizing: false,
// Performance tuning
ReadBufferSize: 4096,
WriteBufferSize: 4096,
MaxResponseBodySize: 1024 * 1024 * 10, // 10MB max response size
DisablePathNormalizing: false,
}
// Initialize connection pool manager
InitializeConnectionPool(client)
return client
}
// proxyTheRequest handles the request proxying logic.
func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
// Setup tracing if enabled
var span trace.Span
var ctx context.Context
if cfg.Tracing.Enable && tracer != nil {
ctx = setupTracing(c)
span, _ = tracer.StartSpan(ctx, "proxy_request")
defer span.End()
}
// Check if URL is allowed
if !checkAllowedURLs(c) {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return fmt.Errorf("request blocked - not allowed URL: %s", c.Path())
}
// Construct and validate proxy URL
proxyURL := currentEndpoint + c.Path()
if _, err := url.Parse(proxyURL); err != nil {
return fmt.Errorf("invalid URL: %v", err)
}
// Log request details in debug mode
if cfg.LogLevel == "DEBUG" {
logDebugRequest(c)
}
// Perform the proxy request with retries
if err := performProxyRequest(c, proxyURL); err != nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return err
}
// Log response details in debug mode
if cfg.LogLevel == "DEBUG" {
logDebugResponse(c)
}
// Handle gzipped responses
if err := handleGzippedResponse(c); err != nil {
return err
}
// Final status check
if c.Response().StatusCode() != fiber.StatusOK {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
}
// Remove server header for security
c.Response().Header.Del(fiber.HeaderServer)
return nil
}
// setupTracing extracts and sets up tracing context from request headers
func setupTracing(c *fiber.Ctx) context.Context {
ctx := context.Background()
if !cfg.Tracing.Enable || tracer == nil {
return ctx
}
// Extract trace information from header
if traceHeader := c.Get("X-Trace-Span"); traceHeader != "" {
spanInfo, err := libpack_tracing.ParseTraceHeader(traceHeader)
if err != nil {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Failed to parse trace header",
Pairs: map[string]interface{}{"error": err.Error()},
})
} else if spanCtx, err := tracer.ExtractSpanContext(spanInfo); err == nil {
ctx = trace.ContextWithSpanContext(ctx, spanCtx)
}
}
return ctx
}
// performProxyRequest executes the proxy request with retries and circuit breaker
func performProxyRequest(c *fiber.Ctx, proxyURL string) error {
// If circuit breaker is not enabled, use the original method
if !cfg.CircuitBreaker.Enable || cb == nil {
return performProxyRequestWithRetries(c, proxyURL)
}
// Calculate cache key for potential fallback
cacheKey := libpack_cache.CalculateHash(c)
// Execute request through circuit breaker
_, err := cb.Execute(func() (interface{}, error) {
// Execute the request with retries
err := performProxyRequestWithRetries(c, proxyURL)
// Check if the error or status code should trip the circuit breaker
if err != nil {
// Log error that could potentially trip the circuit
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Error in circuit-protected request",
Pairs: map[string]interface{}{
"path": c.Path(),
"error": err.Error(),
},
})
return nil, err
}
// Check if non-2xx responses should trip the circuit
statusCode := c.Response().StatusCode()
if cfg.CircuitBreaker.TripOn5xx && statusCode >= 500 && statusCode < 600 {
err := fmt.Errorf("received 5xx status code: %d", statusCode)
cfg.Monitoring.Increment(libpack_monitoring.MetricsCircuitFailed, nil)
return nil, err
}
// Request was successful
cfg.Monitoring.Increment(libpack_monitoring.MetricsCircuitSuccessful, nil)
return nil, nil
})
// If the circuit is open, try to serve from cache if configured
if err == gobreaker.ErrOpenState && cfg.CircuitBreaker.ReturnCachedOnOpen {
cfg.Monitoring.Increment(libpack_monitoring.MetricsCircuitRejected, nil)
// Try to fetch from cache
if cachedResponse := libpack_cache.CacheLookup(cacheKey); cachedResponse != nil {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Circuit open - serving from cache",
Pairs: map[string]interface{}{
"path": c.Path(),
},
})
// Set response from cache
c.Response().SetBody(cachedResponse)
c.Response().SetStatusCode(fiber.StatusOK)
// Mark as cache hit since we're serving from cache
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
cfg.Monitoring.Increment(libpack_monitoring.MetricsCircuitFallbackSuccess, nil)
return nil
}
// No cached response available
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Circuit open - no cached response available",
Pairs: map[string]interface{}{
"path": c.Path(),
},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsCircuitFallbackFailed, nil)
return ErrCircuitOpen
}
return err
}
// performProxyRequestWithRetries executes the proxy request with retries
// This is the original implementation extracted for reuse
func performProxyRequestWithRetries(c *fiber.Ctx, proxyURL string) error {
return retry.Do(
func() error {
if err := doProxyRequestWithTimeout(c, proxyURL, cfg.Client.FastProxyClient); err != nil {
// Check if this is a timeout error - don't retry timeouts
if strings.Contains(strings.ToLower(err.Error()), "timeout") ||
strings.Contains(strings.ToLower(err.Error()), "deadline exceeded") ||
strings.Contains(strings.ToLower(err.Error()), "context deadline exceeded") {
return retry.Unrecoverable(err)
}
return err
}
if c.Response().StatusCode() != fiber.StatusOK {
return fmt.Errorf("received non-200 response: %d", c.Response().StatusCode())
}
return nil
},
retry.Attempts(5),
retry.DelayType(retry.BackOffDelay),
retry.Delay(250*time.Millisecond),
retry.MaxDelay(5*time.Second),
retry.OnRetry(func(n uint, err error) {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Retrying the request",
Pairs: map[string]interface{}{
"path": c.Path(),
"attempt": n + 1,
"error": err.Error(),
"error_type": fmt.Sprintf("%T", err),
"is_timeout": strings.Contains(strings.ToLower(err.Error()), "timeout"),
},
})
}),
retry.LastErrorOnly(true),
)
}
// doProxyRequestWithTimeout performs a proxy request with proper timeout handling
func doProxyRequestWithTimeout(c *fiber.Ctx, proxyURL string, client *fasthttp.Client) error {
// Calculate timeout from client configuration
clientTimeout := time.Duration(cfg.Client.ClientTimeout) * time.Second
if clientTimeout <= 0 {
clientTimeout = 30 * time.Second
}
// Acquire request and response objects
req := fasthttp.AcquireRequest()
resp := fasthttp.AcquireResponse()
defer fasthttp.ReleaseRequest(req)
defer fasthttp.ReleaseResponse(resp)
// Copy the original request
c.Request().CopyTo(req)
req.SetRequestURI(proxyURL)
// Perform the request with timeout
err := client.DoTimeout(req, resp, clientTimeout)
if err != nil {
return err
}
// Copy response back to fiber context
resp.CopyTo(c.Response())
return nil
}
// handleGzippedResponse decompresses gzipped responses
func handleGzippedResponse(c *fiber.Ctx) error {
if !bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
return nil
}
// Use pooled gzip reader
reader, err := GetGzipReader(bytes.NewReader(c.Response().Body()))
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to create gzip reader",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
defer func() {
// Return reader to pool
PutGzipReader(reader)
}()
// Use pooled buffer for reading
buf := GetHTTPBuffer()
defer PutHTTPBuffer(buf)
// Read decompressed data into pooled buffer
_, err = io.Copy(buf, reader)
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to decompress response",
Pairs: map[string]interface{}{"error": err.Error()},
})
return err
}
// Get decompressed data
decompressed := buf.Bytes()
// Update response
c.Response().SetBody(decompressed)
c.Response().Header.Del("Content-Encoding")
return nil
}
// logDebugRequest logs the request details when in debug mode.
func logDebugRequest(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Proxying the request",
Pairs: map[string]interface{}{
"path": c.Path(),
"body": string(c.Body()),
"headers": c.GetReqHeaders(),
"request_uuid": c.Locals("request_uuid"),
},
})
}
// logDebugResponse logs the response details when in debug mode.
func logDebugResponse(c *fiber.Ctx) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Received proxied response",
Pairs: map[string]interface{}{
"path": c.Path(),
"response_body": string(c.Response().Body()),
"response_code": c.Response().StatusCode(),
"headers": c.GetRespHeaders(),
"request_uuid": c.Locals("request_uuid"),
},
})
}
// safeMaxRequests converts MaxRequestsInHalfOpen safely to uint32, providing a fallback value if out of bounds
func safeMaxRequests(maxRequestsInHalfOpen int) uint32 {
// Check if value is invalid (negative or too large)
if maxRequestsInHalfOpen < 0 || maxRequestsInHalfOpen > math.MaxUint32 {
// Log warning and return a default value
if cfg != nil && cfg.Logger != nil {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Invalid MaxRequestsInHalfOpen value, using default",
Pairs: map[string]interface{}{
"requested_value": maxRequestsInHalfOpen,
"default_value": defaultMaxRequestsInHalfOpen,
},
})
}
return uint32(defaultMaxRequestsInHalfOpen)
}
return uint32(maxRequestsInHalfOpen)
}
// updateCircuitBreakerState safely updates the circuit breaker state using atomic operations
func updateCircuitBreakerState(config *config, stateValue float64) {
// Update the state atomically using the new metrics system
if cbMetrics != nil {
cbMetrics.UpdateState(stateValue)
}
}
package main
import (
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"github.com/goccy/go-json"
goratecounter "github.com/lukaszraczylo/go-ratecounter"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
// RateLimitConfig holds the rate limit configuration for a role
type RateLimitConfig struct {
RateCounterTicker *goratecounter.RateCounter
Interval time.Duration `json:"interval"`
Req int `json:"req"`
}
// UnmarshalJSON implements custom JSON unmarshaling for RateLimitConfig
func (r *RateLimitConfig) UnmarshalJSON(data []byte) error {
// Use a temporary struct to unmarshal the JSON data
type RateLimitConfigTemp struct {
Interval interface{} `json:"interval"`
Req int `json:"req"`
}
var temp RateLimitConfigTemp
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
// Set the Req field directly
r.Req = temp.Req
// Handle the Interval field based on its type
switch v := temp.Interval.(type) {
case string:
// Convert string to time.Duration
switch v {
case "second":
r.Interval = time.Second
case "minute":
r.Interval = time.Minute
case "hour":
r.Interval = time.Hour
case "day":
r.Interval = 24 * time.Hour
default:
// Try to parse as a Go duration string (e.g. "1s", "5m")
var err error
r.Interval, err = time.ParseDuration(v)
if err != nil {
return fmt.Errorf("invalid duration format: %s", v)
}
}
case float64:
// Numeric value is assumed to be in seconds
r.Interval = time.Duration(v * float64(time.Second))
default:
return fmt.Errorf("interval must be a string or number, got %T", v)
}
return nil
}
var (
rateLimits = make(map[string]RateLimitConfig)
rateLimitMu sync.RWMutex
// Use atomic.Value for safe concurrent config swapping
rateLimitConfigAtomic atomic.Value
)
// Variable to hold the current load config function - allows for testing
var loadConfigFunc = loadConfigFromPath
// loadRatelimitConfig loads the rate limit configurations from file
func loadRatelimitConfig() error {
paths := []string{"/go/src/app/ratelimit.json", "./ratelimit.json", "./static/app/default-ratelimit.json"}
configError := NewRateLimitConfigError(paths)
// Try each path and collect detailed error information
for _, path := range paths {
if err := loadConfigFunc(path); err == nil {
return nil
} else {
// Store the specific error for this path
configError.PathErrors[path] = err.Error()
}
}
// Log detailed error information
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Failed to load rate limit configuration",
Pairs: map[string]interface{}{
"paths": paths,
"path_errors": configError.PathErrors,
},
})
return configError
}
func loadConfigFromPath(path string) error {
file, err := os.ReadFile(path)
if err != nil {
// Provide more specific error message based on the error type
errMsg := ""
if os.IsNotExist(err) {
errMsg = "File not found"
} else if os.IsPermission(err) {
errMsg = "Permission denied"
} else {
errMsg = "I/O error: " + err.Error()
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Failed to load rate limit config",
Pairs: map[string]interface{}{
"path": path,
"error": errMsg,
"error_details": err.Error(),
},
})
return fmt.Errorf("%s", errMsg)
}
var config struct {
RateLimit map[string]RateLimitConfig `json:"ratelimit"`
}
if err := json.Unmarshal(file, &config); err != nil {
errMsg := fmt.Sprintf("Invalid JSON format: %s", err.Error())
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Failed to parse rate limit config",
Pairs: map[string]interface{}{
"path": path,
"error": errMsg,
},
})
return fmt.Errorf("%s", errMsg)
}
// Validate configuration
if len(config.RateLimit) == 0 {
errMsg := "Empty rate limit configuration"
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Invalid rate limit config",
Pairs: map[string]interface{}{
"path": path,
"error": errMsg,
},
})
return fmt.Errorf("%s", errMsg)
}
newRateLimits := make(map[string]RateLimitConfig, len(config.RateLimit))
for key, value := range config.RateLimit {
value.RateCounterTicker = goratecounter.NewRateCounter().WithConfig(goratecounter.RateCounterConfig{
Interval: value.Interval,
})
if cfg.LogLevel == "DEBUG" {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Setting ratelimit config for role",
Pairs: map[string]interface{}{
"role": key,
"interval_used": value.Interval,
"ratelimit": value.Req,
},
})
}
newRateLimits[key] = value
}
// Use atomic swap for thread-safe configuration updates
rateLimitMu.Lock()
rateLimits = newRateLimits
// Store the new config atomically
rateLimitConfigAtomic.Store(newRateLimits)
rateLimitMu.Unlock()
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit config loaded",
Pairs: map[string]interface{}{"ratelimit": rateLimits},
})
return nil
}
// rateLimitedRequest checks if a request should be rate-limited
func rateLimitedRequest(userID, userRole string) bool {
// Try to get config from atomic value first for better performance
if configInterface := rateLimitConfigAtomic.Load(); configInterface != nil {
if config, ok := configInterface.(map[string]RateLimitConfig); ok {
if roleConfig, exists := config[userRole]; exists && roleConfig.RateCounterTicker != nil {
return checkRateLimit(userID, userRole, roleConfig)
}
}
}
// Fallback to mutex-protected access
rateLimitMu.RLock()
roleConfig, ok := rateLimits[userRole]
rateLimitMu.RUnlock()
if !ok || roleConfig.RateCounterTicker == nil {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit role not found or ticker not initialized - defaulting to deny",
Pairs: map[string]interface{}{"user_role": userRole},
})
// Default to deny when config not found (security fix)
return false
}
return checkRateLimit(userID, userRole, roleConfig)
}
// checkRateLimit performs the actual rate limit check
func checkRateLimit(userID, userRole string, roleConfig RateLimitConfig) bool {
roleConfig.RateCounterTicker.Incr(1)
tickerRate := roleConfig.RateCounterTicker.GetRate()
logDetails := map[string]interface{}{
"user_role": userRole,
"user_id": userID,
"rate": tickerRate,
"config_rate": roleConfig.Req,
"interval": roleConfig.Interval,
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit ticker",
Pairs: map[string]interface{}{"log_details": logDetails},
})
if tickerRate > float64(roleConfig.Req) {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limit exceeded",
Pairs: map[string]interface{}{"log_details": logDetails},
})
return false
}
return true
}
package main
import (
"fmt"
"strings"
)
// RateLimitConfigError represents a detailed error when loading rate limit configuration
type RateLimitConfigError struct {
Paths []string
// Map of path -> error message
PathErrors map[string]string
}
// Error implements the error interface
func (e *RateLimitConfigError) Error() string {
sb := strings.Builder{}
sb.WriteString("Failed to load rate limit configuration. Please ensure a valid configuration file exists at one of these locations:\n")
for _, path := range e.Paths {
errMsg := e.PathErrors[path]
sb.WriteString(fmt.Sprintf(" - %s: %s\n", path, errMsg))
}
sb.WriteString("\nTo resolve this issue:\n")
sb.WriteString("1. Create a valid JSON file using the following template:\n")
sb.WriteString(` {
"ratelimit": {
"admin": {
"req": 100,
"interval": "second"
},
"guest": {
"req": 3,
"interval": "second"
},
"-": {
"req": 10,
"interval": "minute"
}
}
}`)
sb.WriteString("\n\nThe 'interval' field supports the following formats:\n")
sb.WriteString(" - String values: \"second\", \"minute\", \"hour\", \"day\"\n")
sb.WriteString(" - Go duration strings: \"5s\", \"10m\", \"1h\"\n")
sb.WriteString(" - Numeric values (in seconds): 60, 3600\n")
sb.WriteString("\n2. Save it as 'ratelimit.json' in the current directory or in '/go/src/app/' (in Docker)\n")
sb.WriteString("3. Ensure the file has correct permissions and is accessible by the service\n")
return sb.String()
}
// NewRateLimitConfigError creates a new rate limit configuration error
func NewRateLimitConfigError(paths []string) *RateLimitConfigError {
return &RateLimitConfigError{
Paths: paths,
PathErrors: make(map[string]string),
}
}
package main
import (
"fmt"
"strconv"
"time"
"github.com/goccy/go-json"
fiber "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/google/uuid"
graphql "github.com/lukaszraczylo/go-simple-graphql"
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
libpack_config "github.com/lukaszraczylo/graphql-monitoring-proxy/config"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
const (
healthCheckQueryStr = `{ __typename }`
)
// HealthCheckResponse represents the response structure for health check endpoints
type HealthCheckResponse struct {
Status string `json:"status"` // overall status: "healthy" or "unhealthy"
Dependencies map[string]DependencyStatus `json:"dependencies"` // status of each dependency
Timestamp string `json:"timestamp"` // when the health check was performed
}
// DependencyStatus represents the status of a dependency
type DependencyStatus struct {
Status string `json:"status"` // "up" or "down"
ResponseTime int64 `json:"responseTime"` // in milliseconds
Error *string `json:"error,omitempty"` // error message if any
}
// StartHTTPProxy initializes and starts the HTTP proxy server.
func StartHTTPProxy() error {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Starting the HTTP proxy",
})
serverConfig := fiber.Config{
DisableStartupMessage: true,
AppName: fmt.Sprintf("GraphQL Monitoring Proxy - %s v%s", libpack_config.PKG_NAME, libpack_config.PKG_VERSION),
IdleTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second,
ReadTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second,
WriteTimeout: time.Duration(cfg.Client.ClientTimeout) * time.Second,
JSONEncoder: json.Marshal,
JSONDecoder: json.Unmarshal,
}
server := fiber.New(serverConfig)
server.Use(cors.New(cors.Config{
AllowOrigins: "*",
}))
server.Use(AddRequestUUID)
server.Get("/healthz", healthCheck)
server.Get("/livez", healthCheck)
server.Get("/health", healthCheck)
server.Post("/*", processGraphQLRequest)
server.Get("/*", proxyTheRequestToDefault)
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "GraphQL proxy starting",
Pairs: map[string]interface{}{"port": cfg.Server.PortGraphQL},
})
if err := server.Listen(fmt.Sprintf(":%d", cfg.Server.PortGraphQL)); err != nil {
return fmt.Errorf("failed to start HTTP proxy server on port %d: %w",
cfg.Server.PortGraphQL, err)
}
return nil
}
// proxyTheRequestToDefault proxies the request to the default GraphQL endpoint.
func proxyTheRequestToDefault(c *fiber.Ctx) error {
return proxyTheRequest(c, cfg.Server.HostGraphQL)
}
// AddRequestUUID adds a unique request UUID to the context.
func AddRequestUUID(c *fiber.Ctx) error {
c.Locals("request_uuid", uuid.NewString())
return c.Next()
}
// checkAllowedURLs checks if the requested URL is allowed.
func checkAllowedURLs(c *fiber.Ctx) bool {
if len(allowedUrls) == 0 {
return true
}
path := c.OriginalURL()
_, ok := allowedUrls[path]
return ok
}
// healthCheck performs a comprehensive health check on the GraphQL server and its dependencies.
func healthCheck(c *fiber.Ctx) error {
// Prepare the response structure
response := HealthCheckResponse{
Status: "healthy",
Dependencies: make(map[string]DependencyStatus),
Timestamp: time.Now().UTC().Format(time.RFC3339),
}
// Configure checks from query parameters
checkGraphQL := true
checkRedis := cfg.Cache.CacheRedisEnable
// Parse query parameters to enable/disable specific checks
if c.Query("check_graphql") == "false" {
checkGraphQL = false
}
if c.Query("check_redis") == "false" {
checkRedis = false
}
// Check GraphQL backend service
if checkGraphQL {
startTime := time.Now()
graphqlStatus := DependencyStatus{
Status: "up",
}
// Try to connect to main GraphQL endpoint
endpoint := cfg.Server.HostGraphQL
if len(cfg.Server.HealthcheckGraphQL) > 0 {
endpoint = cfg.Server.HealthcheckGraphQL
}
// Create a new GraphQL client for the health check
tempClient := graphql.NewConnection()
tempClient.SetEndpoint(endpoint)
_, err := tempClient.Query(healthCheckQueryStr, nil, nil)
graphqlStatus.ResponseTime = time.Since(startTime).Milliseconds()
if err != nil {
errorMsg := err.Error()
graphqlStatus.Status = "down"
graphqlStatus.Error = &errorMsg
response.Status = "unhealthy"
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Health check: Can't reach the GraphQL server",
Pairs: map[string]interface{}{
"endpoint": endpoint,
"error": errorMsg,
"response_time_ms": graphqlStatus.ResponseTime,
},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
response.Dependencies["graphql"] = graphqlStatus
}
// Check Redis connectivity if enabled
if checkRedis && cfg.Cache.CacheRedisEnable {
startTime := time.Now()
redisStatus := DependencyStatus{
Status: "up",
}
// Implement proper Redis connectivity test
redisAccessible := false
var redisError error
if libpack_cache.IsCacheInitialized() {
// Try a simple Redis operation to test connectivity
testKey := "health_check_test"
testValue := []byte("test")
// Try to set and get a test value
libpack_cache.CacheStore(testKey, testValue)
retrievedValue := libpack_cache.CacheLookup(testKey)
if retrievedValue != nil && string(retrievedValue) == "test" {
redisAccessible = true
// Clean up test key
libpack_cache.CacheDelete(testKey)
} else {
redisError = fmt.Errorf("redis test operation failed")
}
} else {
redisError = fmt.Errorf("cache not initialized")
}
redisStatus.ResponseTime = time.Since(startTime).Milliseconds()
if !redisAccessible {
errorMsg := "Failed to connect to Redis"
if redisError != nil {
errorMsg = redisError.Error()
}
redisStatus.Status = "down"
redisStatus.Error = &errorMsg
response.Status = "unhealthy"
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Health check: Can't connect to Redis",
Pairs: map[string]interface{}{
"server": cfg.Cache.CacheRedisURL,
"error": errorMsg,
"response_time_ms": redisStatus.ResponseTime,
},
})
}
response.Dependencies["redis"] = redisStatus
}
// Determine appropriate HTTP status code
httpStatus := fiber.StatusOK
if response.Status == "unhealthy" {
httpStatus = fiber.StatusServiceUnavailable
}
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Health check completed",
Pairs: map[string]interface{}{
"status": response.Status,
"dependencies": response.Dependencies,
},
})
// Return JSON response
return c.Status(httpStatus).JSON(response)
}
// processGraphQLRequest handles the incoming GraphQL requests.
func processGraphQLRequest(c *fiber.Ctx) error {
startTime := time.Now()
// Extract user information and check permissions
extractedUserID, extractedRoleName := extractUserInfo(c)
// Check if user is banned
if checkIfUserIsBanned(c, extractedUserID) {
return c.Status(fiber.StatusForbidden).SendString("User is banned")
}
// Apply rate limiting if enabled
if cfg.Client.RoleRateLimit && !rateLimitedRequest(extractedUserID, extractedRoleName) {
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
}
// Parse the GraphQL query
parsedResult := parseGraphQLQuery(c)
if parsedResult.shouldBlock {
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
}
// Handle non-GraphQL requests
if parsedResult.shouldIgnore {
return proxyTheRequest(c, parsedResult.activeEndpoint)
}
// Handle caching
wasCached, err := handleCaching(c, parsedResult, extractedUserID)
if err != nil {
return err
}
// Log and monitor the request
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
return nil
}
// extractUserInfo extracts user ID and role from request headers
func extractUserInfo(c *fiber.Ctx) (string, string) {
extractedUserID := "-"
extractedRoleName := "-"
// Extract from JWT if available
if authorization := c.Get("Authorization"); authorization != "" &&
(len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
}
// Override role from header if configured
if cfg.Client.RoleFromHeader != "" {
if role := c.Get(cfg.Client.RoleFromHeader); role != "" {
extractedRoleName = role
}
}
return extractedUserID, extractedRoleName
}
// handleCaching manages the caching logic for GraphQL requests
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) {
// Calculate query hash for cache key
calculatedQueryHash := libpack_cache.CalculateHash(c)
// Set cache time from header or default
if parsedResult.cacheTime == 0 {
if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" {
parsedResult.cacheTime, _ = strconv.Atoi(cacheQuery)
} else {
parsedResult.cacheTime = cfg.Cache.CacheTTL
}
}
// Handle cache refresh directive
if parsedResult.cacheRefresh {
libpack_cache.CacheDelete(calculatedQueryHash)
}
// Check if caching is enabled
cacheEnabled := parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable
if !cacheEnabled {
// No caching, just proxy the request
if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return false, c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
return false, nil
}
// Try to get from cache
if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
c.Set("X-Cache-Hit", "true")
c.Set("Content-Type", "application/json")
return true, c.Send(cachedResponse)
}
// Cache miss, proxy and cache
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil {
return false, err
}
return false, nil
}
// proxyAndCacheTheRequest proxies and caches the request if needed.
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error {
if err := proxyTheRequest(c, currentEndpoint); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
libpack_cache.CacheStoreWithTTL(queryCacheHash, c.Response().Body(), time.Duration(cacheTime)*time.Second)
cfg.Monitoring.Increment(libpack_monitoring.MetricsQueriesCached, nil)
return c.Send(c.Response().Body())
}
// logAndMonitorRequest logs and monitors the request processing.
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
labels := map[string]string{
"op_type": opType,
"op_name": opName,
"cached": strconv.FormatBool(wasCached),
"user_id": userID,
}
if cfg.Server.AccessLog {
cfg.Logger.Info(&libpack_logger.LogMessage{
Message: "Request processed",
Pairs: map[string]interface{}{
"ip": c.IP(),
"fwd-ip": c.Get("X-Forwarded-For"),
"user_id": userID,
"op_type": opType,
"op_name": opName,
"time": duration,
"cache": wasCached,
"request_uuid": c.Locals("request_uuid"),
},
})
}
cfg.Monitoring.Increment(libpack_monitoring.MetricsSucceeded, nil)
cfg.Monitoring.Increment(libpack_monitoring.MetricsExecutedQuery, labels)
if !wasCached {
cfg.Monitoring.UpdateDuration(libpack_monitoring.MetricsTimedQuery, labels, startTime)
cfg.Monitoring.Update(libpack_monitoring.MetricsTimedQuery, labels, float64(duration.Milliseconds()))
}
}
package main
import (
"context"
"sync"
"time"
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
)
// ShutdownManager manages graceful shutdown for all components
type ShutdownManager struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
components []ShutdownComponent
mu sync.Mutex
}
// ShutdownComponent represents a component that needs graceful shutdown
type ShutdownComponent struct {
Name string
Shutdown func(context.Context) error
}
// NewShutdownManager creates a new shutdown manager
func NewShutdownManager(ctx context.Context) *ShutdownManager {
ctx, cancel := context.WithCancel(ctx)
return &ShutdownManager{
ctx: ctx,
cancel: cancel,
}
}
// RegisterComponent registers a component for graceful shutdown
func (sm *ShutdownManager) RegisterComponent(name string, shutdown func(context.Context) error) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.components = append(sm.components, ShutdownComponent{
Name: name,
Shutdown: shutdown,
})
}
// RunGoroutine starts a goroutine that respects the shutdown context
func (sm *ShutdownManager) RunGoroutine(name string, fn func(context.Context)) {
sm.wg.Add(1)
go func() {
defer sm.wg.Done()
cfg.Logger.Debug(&libpack_logging.LogMessage{
Message: "Starting managed goroutine",
Pairs: map[string]interface{}{"name": name},
})
fn(sm.ctx)
cfg.Logger.Debug(&libpack_logging.LogMessage{
Message: "Managed goroutine finished",
Pairs: map[string]interface{}{"name": name},
})
}()
}
// Shutdown initiates graceful shutdown of all components
func (sm *ShutdownManager) Shutdown(timeout time.Duration) error {
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Initiating graceful shutdown",
})
// Cancel the context to signal all goroutines to stop
sm.cancel()
// Create a timeout context for component shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), timeout)
defer shutdownCancel()
// Shutdown all registered components
sm.mu.Lock()
components := make([]ShutdownComponent, len(sm.components))
copy(components, sm.components)
sm.mu.Unlock()
var shutdownWg sync.WaitGroup
for _, comp := range components {
shutdownWg.Add(1)
go func(c ShutdownComponent) {
defer shutdownWg.Done()
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Shutting down component",
Pairs: map[string]interface{}{"component": c.Name},
})
if err := c.Shutdown(shutdownCtx); err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Error shutting down component",
Pairs: map[string]interface{}{
"component": c.Name,
"error": err.Error(),
},
})
}
}(comp)
}
// Wait for all components to shutdown
componentsDone := make(chan struct{})
go func() {
shutdownWg.Wait()
close(componentsDone)
}()
// Wait for goroutines with timeout
goroutinesDone := make(chan struct{})
go func() {
sm.wg.Wait()
close(goroutinesDone)
}()
select {
case <-componentsDone:
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "All components shut down successfully",
})
case <-shutdownCtx.Done():
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Component shutdown timed out",
})
}
select {
case <-goroutinesDone:
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "All goroutines finished",
})
case <-time.After(timeout):
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Some goroutines didn't finish within timeout",
})
}
return nil
}
package tracing
import (
"context"
"encoding/json"
"fmt"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
)
type TracingSetup struct {
tracerProvider *sdktrace.TracerProvider
tracer trace.Tracer
}
type TraceSpanInfo struct {
TraceParent string `json:"traceparent"`
}
// NewTracing creates a new tracing setup with OTLP exporter
func NewTracing(ctx context.Context, endpoint string) (*TracingSetup, error) {
if ctx == nil {
return nil, fmt.Errorf("context cannot be nil")
}
if endpoint == "" {
return nil, fmt.Errorf("endpoint cannot be empty")
}
// Validate endpoint format
// A simple validation to check if the endpoint has a reasonable format
// We're looking for hostname:port where port is a valid port number (0-65535)
var host string
var port int
if n, err := fmt.Sscanf(endpoint, "%s:%d", &host, &port); err != nil || n != 2 {
return nil, fmt.Errorf("invalid endpoint format: must be 'hostname:port'")
}
if port < 0 || port > 65535 {
return nil, fmt.Errorf("invalid port number: must be between 0 and 65535")
}
// Create the exporter directly with the endpoint
exporter, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(endpoint),
otlptracegrpc.WithInsecure(),
otlptracegrpc.WithTimeout(5*time.Second),
otlptracegrpc.WithDialOption(grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16*1024*1024))), // 16MB max message size
)
if err != nil {
return nil, fmt.Errorf("failed to create trace exporter: %w", err)
}
// Create a resource with more detailed attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
semconv.DeploymentEnvironment("production"),
attribute.String("application.type", "proxy"),
),
resource.WithHost(), // Add host information
resource.WithOSType(), // Add OS information
resource.WithProcessPID(), // Add process information
)
if err != nil {
return nil, fmt.Errorf("failed to create resource: %w", err)
}
// Create the tracer provider with improved configuration
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter,
// Configure batch processing
sdktrace.WithMaxExportBatchSize(512),
sdktrace.WithBatchTimeout(3*time.Second),
sdktrace.WithMaxQueueSize(2048),
),
sdktrace.WithResource(res),
sdktrace.WithSampler(sdktrace.TraceIDRatioBased(0.1)), // Sample 10% of traces
)
// Set the global tracer provider and propagator
otel.SetTracerProvider(tracerProvider)
otel.SetTextMapPropagator(propagation.TraceContext{})
// Create a tracer
tracer := tracerProvider.Tracer("graphql-monitoring-proxy")
return &TracingSetup{
tracerProvider: tracerProvider,
tracer: tracer,
}, nil
}
// ExtractSpanContext extracts span context from TraceSpanInfo
func (ts *TracingSetup) ExtractSpanContext(spanInfo *TraceSpanInfo) (trace.SpanContext, error) {
carrier := propagation.MapCarrier{
"traceparent": spanInfo.TraceParent,
}
ctx := context.Background()
ctx = otel.GetTextMapPropagator().Extract(ctx, carrier)
spanCtx := trace.SpanContextFromContext(ctx)
if !spanCtx.IsValid() {
return trace.SpanContext{}, fmt.Errorf("invalid span context")
}
return spanCtx, nil
}
// ParseTraceHeader parses X-Trace-Span header content
func ParseTraceHeader(headerContent string) (*TraceSpanInfo, error) {
var spanInfo TraceSpanInfo
if err := json.Unmarshal([]byte(headerContent), &spanInfo); err != nil {
return nil, fmt.Errorf("failed to parse trace header: %w", err)
}
return &spanInfo, nil
}
// Shutdown cleanly shuts down the tracer provider
func (ts *TracingSetup) Shutdown(ctx context.Context) error {
if ts.tracerProvider == nil {
return nil
}
return ts.tracerProvider.Shutdown(ctx)
}
// StartSpan starts a new span with the given name and parent context
func (ts *TracingSetup) StartSpan(ctx context.Context, name string) (trace.Span, context.Context) {
if ts == nil || ts.tracer == nil {
// Return a no-op span if tracing is not configured
return trace.SpanFromContext(ctx), ctx
}
// Add common attributes to all spans
opts := []trace.SpanStartOption{
trace.WithAttributes(
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
),
}
ctx, span := ts.tracer.Start(ctx, name, opts...)
return span, ctx
}
// StartSpanWithAttributes starts a new span with custom attributes
func (ts *TracingSetup) StartSpanWithAttributes(ctx context.Context, name string, attrs map[string]string) (trace.Span, context.Context) {
if ts == nil || ts.tracer == nil {
return trace.SpanFromContext(ctx), ctx
}
// Convert string attributes to KeyValue pairs
attributes := make([]attribute.KeyValue, 0, len(attrs)+2)
attributes = append(attributes,
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
)
for k, v := range attrs {
attributes = append(attributes, attribute.String(k, v))
}
ctx, span := ts.tracer.Start(ctx, name, trace.WithAttributes(attributes...))
return span, ctx
}