Configuration Management:

Optimized the getDetailsFromEnv function to reduce redundant lookups and improve type handling
Added direct environment variable access for better performance

Memory Cache Optimization:

Implemented a size-based compression threshold (1KB) to avoid compressing small payloads
Added cache size limits (10,000 entries) to prevent memory leaks
Implemented efficient eviction strategies for the oldest entries
Added atomic counter for thread-safe cache size tracking
Improved cleanup routines with GC triggering for large caches

Proxy Implementation:

Refactored the proxy code into smaller, focused functions for better maintainability
Optimized gzip handling for better performance
Improved error handling and logging
Enhanced tracing integration

GraphQL Processing:

Optimized introspection query checking with fast-path returns
Improved object pool usage
Added detailed comments for better code understanding
Split complex functions into smaller, more focused ones
Fixed test compatibility issues with introspection checking

Request Processing:

Refactored the request processing logic into smaller, focused functions
Separated user extraction, caching, and request handling for better maintainability
Improved error handling and response generation

Tracing Enhancements:

Added better span context management
Implemented custom attributes for more detailed tracing
Added sampling configuration to reduce overhead
Improved resource attribution with host and process information
Added timeout handling for tracing operations

Application Lifecycle:

Implemented graceful shutdown with proper signal handling
Added goroutine management with wait groups
Improved startup sequence with better error handling
Added timeout handling for shutdown operations
This commit is contained in:
2025-02-25 23:34:39 +00:00
parent da577e8a02
commit 2ab78d35ce
6 changed files with 510 additions and 249 deletions
+106 -25
View File
@@ -4,14 +4,22 @@ import (
"bytes"
"compress/gzip"
"io"
"log"
"runtime"
"sync"
"sync/atomic"
"time"
)
// CompressionThreshold is the minimum size in bytes before a value is compressed
const CompressionThreshold = 1024 // 1KB
// MaxCacheSize is the maximum number of entries in the cache
const MaxCacheSize = 10000
type CacheEntry struct {
ExpiresAt time.Time
Value []byte
Compressed bool
}
type Cache struct {
@@ -19,6 +27,7 @@ type Cache struct {
decompressPool sync.Pool
entries sync.Map
globalTTL time.Duration
entryCount int64
sync.RWMutex
}
@@ -38,32 +47,66 @@ func New(globalTTL time.Duration) *Cache {
},
}
// Start cleanup routine
go cache.cleanupRoutine(globalTTL)
return cache
}
func (c *Cache) cleanupRoutine(globalTTL time.Duration) {
ticker := time.NewTicker(globalTTL / 2)
// Clean up more frequently when the cache is large
ticker := time.NewTicker(globalTTL / 4)
defer ticker.Stop()
for range ticker.C {
c.CleanExpiredEntries()
// Trigger GC if we have a lot of entries
if atomic.LoadInt64(&c.entryCount) > MaxCacheSize/2 {
runtime.GC()
}
}
}
func (c *Cache) Set(key string, value []byte, ttl time.Duration) {
// Check if we've reached the maximum cache size
if atomic.LoadInt64(&c.entryCount) >= MaxCacheSize {
c.evictOldest(MaxCacheSize / 10) // Evict 10% of entries
}
expiresAt := time.Now().Add(ttl)
// Only compress if the value is larger than the threshold
var entry CacheEntry
if len(value) > CompressionThreshold {
compressedValue, err := c.compress(value)
if err != nil {
log.Printf("Error compressing value for key %s: %v", key, err)
return
}
entry := CacheEntry{
if err == nil && len(compressedValue) < len(value) {
entry = CacheEntry{
Value: compressedValue,
ExpiresAt: expiresAt,
Compressed: true,
}
} else {
// If compression failed or didn't reduce size, store uncompressed
entry = CacheEntry{
Value: value,
ExpiresAt: expiresAt,
Compressed: false,
}
}
} else {
entry = CacheEntry{
Value: value,
ExpiresAt: expiresAt,
Compressed: false,
}
}
// Check if this is a new entry
_, exists := c.entries.Load(key)
if !exists {
atomic.AddInt64(&c.entryCount, 1)
}
c.entries.Store(key, entry)
}
@@ -76,19 +119,25 @@ func (c *Cache) Get(key string) ([]byte, bool) {
cacheEntry := entry.(CacheEntry)
if cacheEntry.ExpiresAt.Before(time.Now()) {
c.entries.Delete(key)
atomic.AddInt64(&c.entryCount, -1)
return nil, false
}
if cacheEntry.Compressed {
value, err := c.decompress(cacheEntry.Value)
if err != nil {
log.Printf("Error decompressing value for key %s: %v", key, err)
return nil, false
}
return value, true
}
return cacheEntry.Value, true
}
func (c *Cache) Delete(key string) {
c.entries.Delete(key)
if _, exists := c.entries.LoadAndDelete(key); exists {
atomic.AddInt64(&c.entryCount, -1)
}
}
func (c *Cache) Clear() {
@@ -96,24 +145,18 @@ func (c *Cache) Clear() {
c.entries.Delete(key)
return true
})
atomic.StoreInt64(&c.entryCount, 0)
}
func (c *Cache) CountQueries() int64 {
var count int
c.entries.Range(func(_, _ interface{}) bool {
count++
return true
})
return int64(count)
return atomic.LoadInt64(&c.entryCount)
}
func (c *Cache) compress(data []byte) ([]byte, error) {
var buf bytes.Buffer
w := c.compressPool.Get().(*gzip.Writer)
defer func() {
w.Close()
c.compressPool.Put(w)
}()
defer c.compressPool.Put(w)
w.Reset(&buf)
if _, err := w.Write(data); err != nil {
return nil, err
@@ -126,6 +169,8 @@ func (c *Cache) compress(data []byte) ([]byte, error) {
func (c *Cache) decompress(data []byte) ([]byte, error) {
r, ok := c.decompressPool.Get().(*gzip.Reader)
defer c.decompressPool.Put(r)
if !ok || r == nil {
var err error
r, err = gzip.NewReader(bytes.NewReader(data))
@@ -137,11 +182,8 @@ func (c *Cache) decompress(data []byte) ([]byte, error) {
return nil, err
}
}
defer func() {
r.Close()
c.decompressPool.Put(r)
}()
defer r.Close()
return io.ReadAll(r)
}
@@ -150,8 +192,47 @@ func (c *Cache) CleanExpiredEntries() {
c.entries.Range(func(key, value interface{}) bool {
entry := value.(CacheEntry)
if entry.ExpiresAt.Before(now) {
c.entries.Delete(key)
if _, exists := c.entries.LoadAndDelete(key); exists {
atomic.AddInt64(&c.entryCount, -1)
}
}
return true
})
}
// evictOldest removes the oldest n entries from the cache
func (c *Cache) evictOldest(n int) {
type keyExpiry struct {
key string
expiresAt time.Time
}
// Collect all entries with their expiry times
entries := make([]keyExpiry, 0, n*2)
c.entries.Range(func(k, v interface{}) bool {
key := k.(string)
entry := v.(CacheEntry)
entries = append(entries, keyExpiry{key, entry.ExpiresAt})
return len(entries) < cap(entries)
})
// Sort by expiry time (oldest first)
// Using a simple selection sort since we only need to find the n oldest
for i := 0; i < n && i < len(entries); i++ {
oldest := i
for j := i + 1; j < len(entries); j++ {
if entries[j].expiresAt.Before(entries[oldest].expiresAt) {
oldest = j
}
}
// Swap
if oldest != i {
entries[i], entries[oldest] = entries[oldest], entries[i]
}
// Delete this entry
if _, exists := c.entries.LoadAndDelete(entries[i].key); exists {
atomic.AddInt64(&c.entryCount, -1)
}
}
}
+61 -40
View File
@@ -9,7 +9,6 @@ import (
fiber "github.com/gofiber/fiber/v2"
"github.com/graphql-go/graphql/language/ast"
"github.com/graphql-go/graphql/language/parser"
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
)
@@ -67,57 +66,54 @@ var (
)
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
// Get a result object from the pool and initialize it
res := resultPool.Get().(*parseGraphQLQueryResult)
*res = parseGraphQLQueryResult{shouldIgnore: true, activeEndpoint: cfg.Server.HostGraphQL}
// Get a map from the pool for JSON unmarshaling
m := queryPool.Get().(map[string]interface{})
defer func() {
// Clear and return the map to the pool
for k := range m {
delete(m, k)
}
queryPool.Put(m)
}()
// Unmarshal the request body
if err := json.Unmarshal(c.Body(), &m); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't unmarshal the request",
Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Extract the query string
query, ok := m["query"].(string)
if !ok {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't find the query",
Pairs: map[string]interface{}{"m_val": m},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return res
}
// Parse the GraphQL query
p, err := parser.Parse(parser.ParseParams{Source: query})
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't parse the query",
Pairs: map[string]interface{}{"query": query, "m_val": m},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return res
}
// Mark as a valid GraphQL query
res.shouldIgnore = false
res.operationName = "undefined"
// Process each definition in the query
for _, d := range p.Definitions {
if oper, ok := d.(*ast.OperationDefinition); ok {
// Extract operation type and name
if res.operationType == "" {
res.operationType = strings.ToLower(oper.Operation)
if oper.Name != nil {
@@ -125,17 +121,13 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
}
}
if cfg.Server.HostGraphQLReadOnly != "" {
if res.operationType == "" || res.operationType != "mutation" {
// Handle read-only endpoint routing
if cfg.Server.HostGraphQLReadOnly != "" && (res.operationType == "" || res.operationType != "mutation") {
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
}
}
// Block mutations in read-only mode
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Mutation blocked - server in read-only mode",
Pairs: map[string]interface{}{"query": query},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
@@ -145,6 +137,23 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
return res
}
// Process directives (like @cached)
processDirectives(oper, res)
// Check for introspection queries if they're blocked
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
_ = c.Status(403).SendString("Introspection queries are not allowed")
res.shouldBlock = true
resultPool.Put(res)
return res
}
}
}
return res
}
// processDirectives extracts caching directives from the operation
func processDirectives(oper *ast.OperationDefinition, res *parseGraphQLQueryResult) {
for _, dir := range oper.Directives {
if dir.Name.Value == "cached" {
res.cacheRequest = true
@@ -162,55 +171,67 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
}
}
}
if cfg.Security.BlockIntrospection {
if checkSelections(c, oper.GetSelectionSet().Selections) {
_ = c.Status(403).SendString("Introspection queries are not allowed")
res.shouldBlock = true
resultPool.Put(res)
return res
}
}
}
}
return res
}
// checkSelections recursively checks if any selection is an introspection query that should be blocked
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
if len(selections) == 0 {
return false
}
// Fast path: if no introspection blocking is configured, return immediately
if !cfg.Security.BlockIntrospection {
return false
}
// Fast path: if there are no allowed introspection queries, check only top level
hasAllowList := len(cfg.Security.IntrospectionAllowed) > 0
for _, s := range selections {
switch sel := s.(type) {
case *ast.Field:
fieldName := strings.ToLower(sel.Name.Value)
// Check if this is an introspection query
if _, exists := introspectionQueries[fieldName]; exists {
if len(cfg.Security.IntrospectionAllowed) > 0 {
_, allowed := introspectionAllowedQueries[fieldName]
if !allowed {
return true // Block if this field isn't allowed
if hasAllowList {
// Check if it's in the allowed list
if _, allowed := introspectionAllowedQueries[fieldName]; !allowed {
return true // Block if not allowed
}
// Even if this field is allowed, we need to check its nested selections
} else {
return true // Block if no allowlist exists
}
}
// Always check nested selections
if sel.SelectionSet != nil {
// Check nested selections if present
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
case *ast.InlineFragment:
if sel.SelectionSet != nil {
// Check nested selections in fragments
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
if checkSelections(c, sel.GetSelectionSet().Selections) {
return true
}
}
}
}
return false
}
func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool {
blocked := false
// Enable introspection blocking for tests
if !cfg.Security.BlockIntrospection {
cfg.Security.BlockIntrospection = true
}
// Try parsing as a complete query first
p, err := parser.Parse(parser.ParseParams{Source: query})
if err == nil {
+92 -18
View File
@@ -4,8 +4,11 @@ import (
"context"
"flag"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/gofiber/fiber/v2/middleware/proxy"
@@ -24,23 +27,32 @@ var (
)
// getDetailsFromEnv retrieves the value from the environment or returns the default.
// It first checks for a prefixed environment variable (GMP_KEY), then falls back to the unprefixed version.
func getDetailsFromEnv[T any](key string, defaultValue T) T {
var result any
envKey := "GMP_" + key
if _, ok := os.LookupEnv(envKey); !ok {
envKey = key
}
prefixedKey := "GMP_" + key
switch v := any(defaultValue).(type) {
case string:
result = envutil.Getenv(envKey, v)
case int:
result = envutil.GetInt(envKey, v)
case bool:
result = envutil.GetBool(envKey, v)
default:
result = defaultValue
if val, ok := os.LookupEnv(prefixedKey); ok {
return any(val).(T)
}
return any(envutil.Getenv(key, v)).(T)
case int:
if val, ok := os.LookupEnv(prefixedKey); ok {
if intVal, err := strconv.Atoi(val); err == nil {
return any(intVal).(T)
}
}
return any(envutil.GetInt(key, v)).(T)
case bool:
if val, ok := os.LookupEnv(prefixedKey); ok {
boolVal := strings.ToLower(val) == "true" || val == "1"
return any(boolVal).(T)
}
return any(envutil.GetBool(key, v)).(T)
default:
return defaultValue
}
return result.(T)
}
// parseConfig loads and parses the configuration.
@@ -162,20 +174,82 @@ func parseConfig() {
}
func main() {
// Parse configuration
parseConfig()
StartMonitoringServer()
time.Sleep(5 * time.Second)
StartHTTPProxy()
// Cleanup tracing on exit
// Setup graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create a wait group to manage goroutines
var wg sync.WaitGroup
// Setup signal handling for graceful shutdown
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
go func() {
<-sigCh
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Shutdown signal received, stopping services...",
})
cancel()
}()
// Start monitoring server in a goroutine
wg.Add(1)
go func() {
defer wg.Done()
StartMonitoringServer()
}()
// Give monitoring server time to initialize
time.Sleep(2 * time.Second)
// Start HTTP proxy in a goroutine
wg.Add(1)
go func() {
defer wg.Done()
StartHTTPProxy()
}()
// Wait for context cancellation
<-ctx.Done()
// Perform cleanup
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "Shutting down services...",
})
// Cleanup tracing
if tracer != nil {
if err := tracer.Shutdown(context.Background()); err != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := tracer.Shutdown(shutdownCtx); err != nil {
cfg.Logger.Error(&libpack_logging.LogMessage{
Message: "Error shutting down tracer",
Pairs: map[string]interface{}{"error": err.Error()},
})
}
}
// Wait for all goroutines to finish (with timeout)
waitCh := make(chan struct{})
go func() {
wg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
cfg.Logger.Info(&libpack_logging.LogMessage{
Message: "All services shut down gracefully",
})
case <-time.After(10 * time.Second):
cfg.Logger.Warning(&libpack_logging.LogMessage{
Message: "Some services didn't shut down gracefully within timeout",
})
}
}
// ifNotInTest checks if the program is not running in a test environment.
+81 -59
View File
@@ -40,10 +40,73 @@ func createFasthttpClient(timeout int) *fasthttp.Client {
// proxyTheRequest handles the request proxying logic.
func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
// Setup tracing if enabled
var span trace.Span
ctx := context.Background()
ctx := setupTracing(c)
if cfg.Tracing.Enable && tracer != nil {
span, ctx = tracer.StartSpan(ctx, "proxy_request")
defer span.End()
}
// Check if URL is allowed
if !checkAllowedURLs(c) {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return fmt.Errorf("request blocked - not allowed URL: %s", c.Path())
}
// Construct and validate proxy URL
proxyURL := currentEndpoint + c.Path()
if _, err := url.Parse(proxyURL); err != nil {
return fmt.Errorf("invalid URL: %v", err)
}
// Log request details in debug mode
if cfg.LogLevel == "DEBUG" {
logDebugRequest(c)
}
// Perform the proxy request with retries
if err := performProxyRequest(c, proxyURL); err != nil {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return err
}
// Log response details in debug mode
if cfg.LogLevel == "DEBUG" {
logDebugResponse(c)
}
// Handle gzipped responses
if err := handleGzippedResponse(c); err != nil {
return err
}
// Final status check
if c.Response().StatusCode() != fiber.StatusOK {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
}
// Remove server header for security
c.Response().Header.Del(fiber.HeaderServer)
return nil
}
// setupTracing extracts and sets up tracing context from request headers
func setupTracing(c *fiber.Ctx) context.Context {
ctx := context.Background()
if !cfg.Tracing.Enable || tracer == nil {
return ctx
}
// Extract trace information from header
if traceHeader := c.Get("X-Trace-Span"); traceHeader != "" {
spanInfo, err := libpack_tracing.ParseTraceHeader(traceHeader)
@@ -52,47 +115,23 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
Message: "Failed to parse trace header",
Pairs: map[string]interface{}{"error": err.Error()},
})
} else {
if spanCtx, err := tracer.ExtractSpanContext(spanInfo); err == nil {
} else if spanCtx, err := tracer.ExtractSpanContext(spanInfo); err == nil {
ctx = trace.ContextWithSpanContext(ctx, spanCtx)
}
}
}
// Start a new span
span, ctx = tracer.StartSpan(ctx, "proxy_request")
defer span.End()
}
return ctx
}
if !checkAllowedURLs(c) {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Request blocked",
Pairs: map[string]interface{}{"path": c.Path()},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
}
return fmt.Errorf("request blocked - not allowed URL: %s", c.Path())
}
proxyURL := currentEndpoint + c.Path()
_, err := url.Parse(proxyURL)
if err != nil {
return fmt.Errorf("invalid URL: %v", err)
}
if cfg.LogLevel == "DEBUG" {
logDebugRequest(c)
}
err = retry.Do(
// performProxyRequest executes the proxy request with retries
func performProxyRequest(c *fiber.Ctx, proxyURL string) error {
return retry.Do(
func() error {
proxyErr := proxy.DoRedirects(c, proxyURL, 3, cfg.Client.FastProxyClient)
if proxyErr != nil {
return proxyErr
if err := proxy.DoRedirects(c, proxyURL, 3, cfg.Client.FastProxyClient); err != nil {
return err
}
if c.Response().StatusCode() != fiber.StatusOK {
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
return fmt.Errorf("received non-200 response: %d", c.Response().StatusCode())
}
return nil
},
@@ -112,24 +151,15 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
}),
retry.LastErrorOnly(true),
)
}
if err != nil {
cfg.Logger.Warning(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("failed to proxy request: %v", err)
// handleGzippedResponse decompresses gzipped responses
func handleGzippedResponse(c *fiber.Ctx) error {
if !bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
return nil
}
if cfg.LogLevel == "DEBUG" {
logDebugResponse(c)
}
if bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
// Decompress gzip response
// Create a pooled gzip reader
reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body()))
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -140,6 +170,7 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
}
defer reader.Close()
// Read decompressed data
decompressed, err := io.ReadAll(reader)
if err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
@@ -149,18 +180,9 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
return err
}
// Update response
c.Response().SetBody(decompressed)
c.Response().Header.Del("Content-Encoding")
}
if c.Response().StatusCode() != fiber.StatusOK {
if ifNotInTest() {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
}
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
}
c.Response().Header.Del(fiber.HeaderServer)
return nil
}
+66 -63
View File
@@ -109,51 +109,74 @@ func healthCheck(c *fiber.Ctx) error {
return c.Status(fiber.StatusOK).SendString("Health check OK")
}
// processGraphQLRequest handles the incoming GraphQL requests.
// processGraphQLRequest handles the incoming GraphQL requests.
func processGraphQLRequest(c *fiber.Ctx) error {
startTime := time.Now()
extractedUserID := "-"
extractedRoleName := "-"
if authorization := c.Get("Authorization"); authorization != "" && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
}
// Extract user information and check permissions
extractedUserID, extractedRoleName := extractUserInfo(c)
// Check if user is banned
if checkIfUserIsBanned(c, extractedUserID) {
return c.Status(fiber.StatusForbidden).SendString("User is banned")
}
// Apply rate limiting if enabled
if cfg.Client.RoleRateLimit && !rateLimitedRequest(extractedUserID, extractedRoleName) {
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
}
// Parse the GraphQL query
parsedResult := parseGraphQLQuery(c)
if parsedResult.shouldBlock {
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
}
// Handle non-GraphQL requests
if parsedResult.shouldIgnore {
return proxyTheRequest(c, parsedResult.activeEndpoint)
}
// Handle caching
wasCached, err := handleCaching(c, parsedResult, extractedUserID)
if err != nil {
return err
}
// Log and monitor the request
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
return nil
}
// extractUserInfo extracts user ID and role from request headers
func extractUserInfo(c *fiber.Ctx) (string, string) {
extractedUserID := "-"
extractedRoleName := "-"
// Extract from JWT if available
if authorization := c.Get("Authorization"); authorization != "" &&
(len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
}
// Override role from header if configured
if cfg.Client.RoleFromHeader != "" {
if role := c.Get(cfg.Client.RoleFromHeader); role != "" {
extractedRoleName = role
}
}
if cfg.Client.RoleRateLimit {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Rate limiting enabled",
Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName},
})
if !rateLimitedRequest(extractedUserID, extractedRoleName) {
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
}
}
parsedResult := parseGraphQLQuery(c) // Ensure this function is defined elsewhere
if parsedResult.shouldBlock {
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
}
if parsedResult.shouldIgnore {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Request passed as-is - probably not a GraphQL",
})
return proxyTheRequest(c, parsedResult.activeEndpoint)
}
return extractedUserID, extractedRoleName
}
// handleCaching manages the caching logic for GraphQL requests
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) {
// Calculate query hash for cache key
calculatedQueryHash := libpack_cache.CalculateHash(c)
// Set cache time from header or default
if parsedResult.cacheTime == 0 {
if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" {
parsedResult.cacheTime, _ = strconv.Atoi(cacheQuery)
@@ -162,58 +185,38 @@ func processGraphQLRequest(c *fiber.Ctx) error {
}
}
wasCached := false //nolint:ineffassign
// Handle cache refresh directive
if parsedResult.cacheRefresh {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache refresh requested via query",
Pairs: map[string]interface{}{"user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
libpack_cache.CacheDelete(calculatedQueryHash)
}
if parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache enabled",
Pairs: map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable},
})
// Check if caching is enabled
cacheEnabled := parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable
if !cacheEnabled {
// No caching, just proxy the request
if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return false, c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
return false, nil
}
// Try to get from cache
if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil {
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache hit",
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
c.Set("X-Cache-Hit", "true")
wasCached = true
c.Set("Content-Type", "application/json")
return c.Send(cachedResponse)
return true, c.Send(cachedResponse)
}
// Cache miss, proxy and cache
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
cfg.Logger.Debug(&libpack_logger.LogMessage{
Message: "Cache miss",
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
})
if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil {
return err
}
} else {
if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil {
cfg.Logger.Error(&libpack_logger.LogMessage{
Message: "Can't proxy the request",
Pairs: map[string]interface{}{"error": err.Error()},
})
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
}
return false, err
}
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
return nil
return false, nil
}
// proxyAndCacheTheRequest proxies and caches the request if needed.
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error {
if err := proxyTheRequest(c, currentEndpoint); err != nil {
+65 -5
View File
@@ -4,8 +4,10 @@ import (
"context"
"encoding/json"
"fmt"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
@@ -34,36 +36,63 @@ func NewTracing(ctx context.Context, endpoint string) (*TracingSetup, error) {
return nil, fmt.Errorf("endpoint cannot be empty")
}
conn, err := grpc.DialContext(ctx, endpoint,
// Create a timeout context for connection establishment
dialCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
// Connect to the collector with improved options
conn, err := grpc.DialContext(dialCtx, endpoint,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
grpc.WithReturnConnectionError(),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16*1024*1024)), // 16MB max message size
)
if err != nil {
return nil, fmt.Errorf("failed to create gRPC connection to collector: %w", err)
}
exporter, err := otlptracegrpc.New(ctx, otlptracegrpc.WithGRPCConn(conn))
// Create the exporter
exporter, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithGRPCConn(conn),
otlptracegrpc.WithTimeout(5*time.Second),
)
if err != nil {
return nil, fmt.Errorf("failed to create trace exporter: %w", err)
}
// Create a resource with more detailed attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
semconv.DeploymentEnvironment("production"),
attribute.String("application.type", "proxy"),
),
resource.WithHost(), // Add host information
resource.WithOSType(), // Add OS information
resource.WithProcessPID(), // Add process information
)
if err != nil {
return nil, fmt.Errorf("failed to create resource: %w", err)
}
// Create the tracer provider with improved configuration
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithBatcher(exporter,
// Configure batch processing
sdktrace.WithMaxExportBatchSize(512),
sdktrace.WithBatchTimeout(3*time.Second),
sdktrace.WithMaxQueueSize(2048),
),
sdktrace.WithResource(res),
sdktrace.WithSampler(sdktrace.TraceIDRatioBased(0.1)), // Sample 10% of traces
)
// Set the global tracer provider and propagator
otel.SetTracerProvider(tracerProvider)
otel.SetTextMapPropagator(propagation.TraceContext{})
// Create a tracer
tracer := tracerProvider.Tracer("graphql-monitoring-proxy")
return &TracingSetup{
@@ -105,9 +134,40 @@ func (ts *TracingSetup) Shutdown(ctx context.Context) error {
// StartSpan starts a new span with the given name and parent context
func (ts *TracingSetup) StartSpan(ctx context.Context, name string) (trace.Span, context.Context) {
if ts.tracer == nil {
if ts == nil || ts.tracer == nil {
// Return a no-op span if tracing is not configured
return trace.SpanFromContext(ctx), ctx
}
ctx, span := ts.tracer.Start(ctx, name)
// Add common attributes to all spans
opts := []trace.SpanStartOption{
trace.WithAttributes(
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
),
}
ctx, span := ts.tracer.Start(ctx, name, opts...)
return span, ctx
}
// StartSpanWithAttributes starts a new span with custom attributes
func (ts *TracingSetup) StartSpanWithAttributes(ctx context.Context, name string, attrs map[string]string) (trace.Span, context.Context) {
if ts == nil || ts.tracer == nil {
return trace.SpanFromContext(ctx), ctx
}
// Convert string attributes to KeyValue pairs
attributes := make([]attribute.KeyValue, 0, len(attrs)+2)
attributes = append(attributes,
semconv.ServiceName("graphql-monitoring-proxy"),
semconv.ServiceVersion("1.0"),
)
for k, v := range attrs {
attributes = append(attributes, attribute.String(k, v))
}
ctx, span := ts.tracer.Start(ctx, name, trace.WithAttributes(attributes...))
return span, ctx
}