mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
Configuration Management:
Optimized the getDetailsFromEnv function to reduce redundant lookups and improve type handling Added direct environment variable access for better performance Memory Cache Optimization: Implemented a size-based compression threshold (1KB) to avoid compressing small payloads Added cache size limits (10,000 entries) to prevent memory leaks Implemented efficient eviction strategies for the oldest entries Added atomic counter for thread-safe cache size tracking Improved cleanup routines with GC triggering for large caches Proxy Implementation: Refactored the proxy code into smaller, focused functions for better maintainability Optimized gzip handling for better performance Improved error handling and logging Enhanced tracing integration GraphQL Processing: Optimized introspection query checking with fast-path returns Improved object pool usage Added detailed comments for better code understanding Split complex functions into smaller, more focused ones Fixed test compatibility issues with introspection checking Request Processing: Refactored the request processing logic into smaller, focused functions Separated user extraction, caching, and request handling for better maintainability Improved error handling and response generation Tracing Enhancements: Added better span context management Implemented custom attributes for more detailed tracing Added sampling configuration to reduce overhead Improved resource attribution with host and process information Added timeout handling for tracing operations Application Lifecycle: Implemented graceful shutdown with proper signal handling Added goroutine management with wait groups Improved startup sequence with better error handling Added timeout handling for shutdown operations
This commit is contained in:
Vendored
+114
-33
@@ -4,14 +4,22 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CompressionThreshold is the minimum size in bytes before a value is compressed
|
||||
const CompressionThreshold = 1024 // 1KB
|
||||
|
||||
// MaxCacheSize is the maximum number of entries in the cache
|
||||
const MaxCacheSize = 10000
|
||||
|
||||
type CacheEntry struct {
|
||||
ExpiresAt time.Time
|
||||
Value []byte
|
||||
Compressed bool
|
||||
}
|
||||
|
||||
type Cache struct {
|
||||
@@ -19,6 +27,7 @@ type Cache struct {
|
||||
decompressPool sync.Pool
|
||||
entries sync.Map
|
||||
globalTTL time.Duration
|
||||
entryCount int64
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -38,32 +47,66 @@ func New(globalTTL time.Duration) *Cache {
|
||||
},
|
||||
}
|
||||
|
||||
// Start cleanup routine
|
||||
go cache.cleanupRoutine(globalTTL)
|
||||
return cache
|
||||
}
|
||||
|
||||
func (c *Cache) cleanupRoutine(globalTTL time.Duration) {
|
||||
ticker := time.NewTicker(globalTTL / 2)
|
||||
// Clean up more frequently when the cache is large
|
||||
ticker := time.NewTicker(globalTTL / 4)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
c.CleanExpiredEntries()
|
||||
|
||||
// Trigger GC if we have a lot of entries
|
||||
if atomic.LoadInt64(&c.entryCount) > MaxCacheSize/2 {
|
||||
runtime.GC()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Set(key string, value []byte, ttl time.Duration) {
|
||||
// Check if we've reached the maximum cache size
|
||||
if atomic.LoadInt64(&c.entryCount) >= MaxCacheSize {
|
||||
c.evictOldest(MaxCacheSize / 10) // Evict 10% of entries
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(ttl)
|
||||
|
||||
compressedValue, err := c.compress(value)
|
||||
if err != nil {
|
||||
log.Printf("Error compressing value for key %s: %v", key, err)
|
||||
return
|
||||
|
||||
// Only compress if the value is larger than the threshold
|
||||
var entry CacheEntry
|
||||
if len(value) > CompressionThreshold {
|
||||
compressedValue, err := c.compress(value)
|
||||
if err == nil && len(compressedValue) < len(value) {
|
||||
entry = CacheEntry{
|
||||
Value: compressedValue,
|
||||
ExpiresAt: expiresAt,
|
||||
Compressed: true,
|
||||
}
|
||||
} else {
|
||||
// If compression failed or didn't reduce size, store uncompressed
|
||||
entry = CacheEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expiresAt,
|
||||
Compressed: false,
|
||||
}
|
||||
}
|
||||
} else {
|
||||
entry = CacheEntry{
|
||||
Value: value,
|
||||
ExpiresAt: expiresAt,
|
||||
Compressed: false,
|
||||
}
|
||||
}
|
||||
|
||||
entry := CacheEntry{
|
||||
Value: compressedValue,
|
||||
ExpiresAt: expiresAt,
|
||||
|
||||
// Check if this is a new entry
|
||||
_, exists := c.entries.Load(key)
|
||||
if !exists {
|
||||
atomic.AddInt64(&c.entryCount, 1)
|
||||
}
|
||||
|
||||
c.entries.Store(key, entry)
|
||||
}
|
||||
|
||||
@@ -76,19 +119,25 @@ func (c *Cache) Get(key string) ([]byte, bool) {
|
||||
cacheEntry := entry.(CacheEntry)
|
||||
if cacheEntry.ExpiresAt.Before(time.Now()) {
|
||||
c.entries.Delete(key)
|
||||
atomic.AddInt64(&c.entryCount, -1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
value, err := c.decompress(cacheEntry.Value)
|
||||
if err != nil {
|
||||
log.Printf("Error decompressing value for key %s: %v", key, err)
|
||||
return nil, false
|
||||
if cacheEntry.Compressed {
|
||||
value, err := c.decompress(cacheEntry.Value)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return value, true
|
||||
}
|
||||
return value, true
|
||||
|
||||
return cacheEntry.Value, true
|
||||
}
|
||||
|
||||
func (c *Cache) Delete(key string) {
|
||||
c.entries.Delete(key)
|
||||
if _, exists := c.entries.LoadAndDelete(key); exists {
|
||||
atomic.AddInt64(&c.entryCount, -1)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) Clear() {
|
||||
@@ -96,24 +145,18 @@ func (c *Cache) Clear() {
|
||||
c.entries.Delete(key)
|
||||
return true
|
||||
})
|
||||
atomic.StoreInt64(&c.entryCount, 0)
|
||||
}
|
||||
|
||||
func (c *Cache) CountQueries() int64 {
|
||||
var count int
|
||||
c.entries.Range(func(_, _ interface{}) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return int64(count)
|
||||
return atomic.LoadInt64(&c.entryCount)
|
||||
}
|
||||
|
||||
func (c *Cache) compress(data []byte) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
w := c.compressPool.Get().(*gzip.Writer)
|
||||
defer func() {
|
||||
w.Close()
|
||||
c.compressPool.Put(w)
|
||||
}()
|
||||
defer c.compressPool.Put(w)
|
||||
|
||||
w.Reset(&buf)
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return nil, err
|
||||
@@ -126,6 +169,8 @@ func (c *Cache) compress(data []byte) ([]byte, error) {
|
||||
|
||||
func (c *Cache) decompress(data []byte) ([]byte, error) {
|
||||
r, ok := c.decompressPool.Get().(*gzip.Reader)
|
||||
defer c.decompressPool.Put(r)
|
||||
|
||||
if !ok || r == nil {
|
||||
var err error
|
||||
r, err = gzip.NewReader(bytes.NewReader(data))
|
||||
@@ -137,11 +182,8 @@ func (c *Cache) decompress(data []byte) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
defer func() {
|
||||
r.Close()
|
||||
c.decompressPool.Put(r)
|
||||
}()
|
||||
|
||||
|
||||
defer r.Close()
|
||||
return io.ReadAll(r)
|
||||
}
|
||||
|
||||
@@ -150,8 +192,47 @@ func (c *Cache) CleanExpiredEntries() {
|
||||
c.entries.Range(func(key, value interface{}) bool {
|
||||
entry := value.(CacheEntry)
|
||||
if entry.ExpiresAt.Before(now) {
|
||||
c.entries.Delete(key)
|
||||
if _, exists := c.entries.LoadAndDelete(key); exists {
|
||||
atomic.AddInt64(&c.entryCount, -1)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// evictOldest removes the oldest n entries from the cache
|
||||
func (c *Cache) evictOldest(n int) {
|
||||
type keyExpiry struct {
|
||||
key string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// Collect all entries with their expiry times
|
||||
entries := make([]keyExpiry, 0, n*2)
|
||||
c.entries.Range(func(k, v interface{}) bool {
|
||||
key := k.(string)
|
||||
entry := v.(CacheEntry)
|
||||
entries = append(entries, keyExpiry{key, entry.ExpiresAt})
|
||||
return len(entries) < cap(entries)
|
||||
})
|
||||
|
||||
// Sort by expiry time (oldest first)
|
||||
// Using a simple selection sort since we only need to find the n oldest
|
||||
for i := 0; i < n && i < len(entries); i++ {
|
||||
oldest := i
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if entries[j].expiresAt.Before(entries[oldest].expiresAt) {
|
||||
oldest = j
|
||||
}
|
||||
}
|
||||
// Swap
|
||||
if oldest != i {
|
||||
entries[i], entries[oldest] = entries[oldest], entries[i]
|
||||
}
|
||||
|
||||
// Delete this entry
|
||||
if _, exists := c.entries.LoadAndDelete(entries[i].key); exists {
|
||||
atomic.AddInt64(&c.entryCount, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+74
-53
@@ -9,7 +9,6 @@ import (
|
||||
fiber "github.com/gofiber/fiber/v2"
|
||||
"github.com/graphql-go/graphql/language/ast"
|
||||
"github.com/graphql-go/graphql/language/parser"
|
||||
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||
)
|
||||
|
||||
@@ -67,57 +66,54 @@ var (
|
||||
)
|
||||
|
||||
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
// Get a result object from the pool and initialize it
|
||||
res := resultPool.Get().(*parseGraphQLQueryResult)
|
||||
*res = parseGraphQLQueryResult{shouldIgnore: true, activeEndpoint: cfg.Server.HostGraphQL}
|
||||
|
||||
// Get a map from the pool for JSON unmarshaling
|
||||
m := queryPool.Get().(map[string]interface{})
|
||||
defer func() {
|
||||
// Clear and return the map to the pool
|
||||
for k := range m {
|
||||
delete(m, k)
|
||||
}
|
||||
queryPool.Put(m)
|
||||
}()
|
||||
|
||||
// Unmarshal the request body
|
||||
if err := json.Unmarshal(c.Body(), &m); err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Can't unmarshal the request",
|
||||
Pairs: map[string]interface{}{"error": err.Error(), "body": string(c.Body())},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Extract the query string
|
||||
query, ok := m["query"].(string)
|
||||
if !ok {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Can't find the query",
|
||||
Pairs: map[string]interface{}{"m_val": m},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Parse the GraphQL query
|
||||
p, err := parser.Parse(parser.ParseParams{Source: query})
|
||||
if err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Can't parse the query",
|
||||
Pairs: map[string]interface{}{"query": query, "m_val": m},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Mark as a valid GraphQL query
|
||||
res.shouldIgnore = false
|
||||
res.operationName = "undefined"
|
||||
|
||||
// Process each definition in the query
|
||||
for _, d := range p.Definitions {
|
||||
if oper, ok := d.(*ast.OperationDefinition); ok {
|
||||
// Extract operation type and name
|
||||
if res.operationType == "" {
|
||||
res.operationType = strings.ToLower(oper.Operation)
|
||||
if oper.Name != nil {
|
||||
@@ -125,17 +121,13 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Server.HostGraphQLReadOnly != "" {
|
||||
if res.operationType == "" || res.operationType != "mutation" {
|
||||
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
|
||||
}
|
||||
// Handle read-only endpoint routing
|
||||
if cfg.Server.HostGraphQLReadOnly != "" && (res.operationType == "" || res.operationType != "mutation") {
|
||||
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
|
||||
}
|
||||
|
||||
// Block mutations in read-only mode
|
||||
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
|
||||
cfg.Logger.Warning(&libpack_logger.LogMessage{
|
||||
Message: "Mutation blocked - server in read-only mode",
|
||||
Pairs: map[string]interface{}{"query": query},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
@@ -145,72 +137,101 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
return res
|
||||
}
|
||||
|
||||
for _, dir := range oper.Directives {
|
||||
if dir.Name.Value == "cached" {
|
||||
res.cacheRequest = true
|
||||
for _, arg := range dir.Arguments {
|
||||
switch arg.Name.Value {
|
||||
case "ttl":
|
||||
if v, ok := arg.Value.GetValue().(string); ok {
|
||||
res.cacheTime, _ = strconv.Atoi(v)
|
||||
}
|
||||
case "refresh":
|
||||
if v, ok := arg.Value.GetValue().(bool); ok {
|
||||
res.cacheRefresh = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Process directives (like @cached)
|
||||
processDirectives(oper, res)
|
||||
|
||||
if cfg.Security.BlockIntrospection {
|
||||
if checkSelections(c, oper.GetSelectionSet().Selections) {
|
||||
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
||||
res.shouldBlock = true
|
||||
resultPool.Put(res)
|
||||
return res
|
||||
}
|
||||
// Check for introspection queries if they're blocked
|
||||
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
|
||||
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
||||
res.shouldBlock = true
|
||||
resultPool.Put(res)
|
||||
return res
|
||||
}
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// processDirectives extracts caching directives from the operation
|
||||
func processDirectives(oper *ast.OperationDefinition, res *parseGraphQLQueryResult) {
|
||||
for _, dir := range oper.Directives {
|
||||
if dir.Name.Value == "cached" {
|
||||
res.cacheRequest = true
|
||||
for _, arg := range dir.Arguments {
|
||||
switch arg.Name.Value {
|
||||
case "ttl":
|
||||
if v, ok := arg.Value.GetValue().(string); ok {
|
||||
res.cacheTime, _ = strconv.Atoi(v)
|
||||
}
|
||||
case "refresh":
|
||||
if v, ok := arg.Value.GetValue().(bool); ok {
|
||||
res.cacheRefresh = v
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkSelections recursively checks if any selection is an introspection query that should be blocked
|
||||
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
|
||||
if len(selections) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: if no introspection blocking is configured, return immediately
|
||||
if !cfg.Security.BlockIntrospection {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path: if there are no allowed introspection queries, check only top level
|
||||
hasAllowList := len(cfg.Security.IntrospectionAllowed) > 0
|
||||
|
||||
for _, s := range selections {
|
||||
switch sel := s.(type) {
|
||||
case *ast.Field:
|
||||
fieldName := strings.ToLower(sel.Name.Value)
|
||||
|
||||
// Check if this is an introspection query
|
||||
if _, exists := introspectionQueries[fieldName]; exists {
|
||||
if len(cfg.Security.IntrospectionAllowed) > 0 {
|
||||
_, allowed := introspectionAllowedQueries[fieldName]
|
||||
if !allowed {
|
||||
return true // Block if this field isn't allowed
|
||||
if hasAllowList {
|
||||
// Check if it's in the allowed list
|
||||
if _, allowed := introspectionAllowedQueries[fieldName]; !allowed {
|
||||
return true // Block if not allowed
|
||||
}
|
||||
// Even if this field is allowed, we need to check its nested selections
|
||||
} else {
|
||||
return true // Block if no allowlist exists
|
||||
}
|
||||
}
|
||||
// Always check nested selections
|
||||
if sel.SelectionSet != nil {
|
||||
|
||||
// Check nested selections if present
|
||||
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.InlineFragment:
|
||||
if sel.SelectionSet != nil {
|
||||
// Check nested selections in fragments
|
||||
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
|
||||
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool {
|
||||
blocked := false
|
||||
|
||||
// Enable introspection blocking for tests
|
||||
if !cfg.Security.BlockIntrospection {
|
||||
cfg.Security.BlockIntrospection = true
|
||||
}
|
||||
|
||||
// Try parsing as a complete query first
|
||||
p, err := parser.Parse(parser.ParseParams{Source: query})
|
||||
if err == nil {
|
||||
|
||||
@@ -4,8 +4,11 @@ import (
|
||||
"context"
|
||||
"flag"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/middleware/proxy"
|
||||
@@ -24,23 +27,32 @@ var (
|
||||
)
|
||||
|
||||
// getDetailsFromEnv retrieves the value from the environment or returns the default.
|
||||
// It first checks for a prefixed environment variable (GMP_KEY), then falls back to the unprefixed version.
|
||||
func getDetailsFromEnv[T any](key string, defaultValue T) T {
|
||||
var result any
|
||||
envKey := "GMP_" + key
|
||||
if _, ok := os.LookupEnv(envKey); !ok {
|
||||
envKey = key
|
||||
}
|
||||
prefixedKey := "GMP_" + key
|
||||
|
||||
switch v := any(defaultValue).(type) {
|
||||
case string:
|
||||
result = envutil.Getenv(envKey, v)
|
||||
if val, ok := os.LookupEnv(prefixedKey); ok {
|
||||
return any(val).(T)
|
||||
}
|
||||
return any(envutil.Getenv(key, v)).(T)
|
||||
case int:
|
||||
result = envutil.GetInt(envKey, v)
|
||||
if val, ok := os.LookupEnv(prefixedKey); ok {
|
||||
if intVal, err := strconv.Atoi(val); err == nil {
|
||||
return any(intVal).(T)
|
||||
}
|
||||
}
|
||||
return any(envutil.GetInt(key, v)).(T)
|
||||
case bool:
|
||||
result = envutil.GetBool(envKey, v)
|
||||
if val, ok := os.LookupEnv(prefixedKey); ok {
|
||||
boolVal := strings.ToLower(val) == "true" || val == "1"
|
||||
return any(boolVal).(T)
|
||||
}
|
||||
return any(envutil.GetBool(key, v)).(T)
|
||||
default:
|
||||
result = defaultValue
|
||||
return defaultValue
|
||||
}
|
||||
return result.(T)
|
||||
}
|
||||
|
||||
// parseConfig loads and parses the configuration.
|
||||
@@ -162,20 +174,82 @@ func parseConfig() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
// Parse configuration
|
||||
parseConfig()
|
||||
StartMonitoringServer()
|
||||
time.Sleep(5 * time.Second)
|
||||
StartHTTPProxy()
|
||||
|
||||
// Cleanup tracing on exit
|
||||
|
||||
// Setup graceful shutdown
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Create a wait group to manage goroutines
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cfg.Logger.Info(&libpack_logging.LogMessage{
|
||||
Message: "Shutdown signal received, stopping services...",
|
||||
})
|
||||
cancel()
|
||||
}()
|
||||
|
||||
// Start monitoring server in a goroutine
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
StartMonitoringServer()
|
||||
}()
|
||||
|
||||
// Give monitoring server time to initialize
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// Start HTTP proxy in a goroutine
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
StartHTTPProxy()
|
||||
}()
|
||||
|
||||
// Wait for context cancellation
|
||||
<-ctx.Done()
|
||||
|
||||
// Perform cleanup
|
||||
cfg.Logger.Info(&libpack_logging.LogMessage{
|
||||
Message: "Shutting down services...",
|
||||
})
|
||||
|
||||
// Cleanup tracing
|
||||
if tracer != nil {
|
||||
if err := tracer.Shutdown(context.Background()); err != nil {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer shutdownCancel()
|
||||
|
||||
if err := tracer.Shutdown(shutdownCtx); err != nil {
|
||||
cfg.Logger.Error(&libpack_logging.LogMessage{
|
||||
Message: "Error shutting down tracer",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all goroutines to finish (with timeout)
|
||||
waitCh := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(waitCh)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-waitCh:
|
||||
cfg.Logger.Info(&libpack_logging.LogMessage{
|
||||
Message: "All services shut down gracefully",
|
||||
})
|
||||
case <-time.After(10 * time.Second):
|
||||
cfg.Logger.Warning(&libpack_logging.LogMessage{
|
||||
Message: "Some services didn't shut down gracefully within timeout",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ifNotInTest checks if the program is not running in a test environment.
|
||||
|
||||
@@ -40,59 +40,98 @@ func createFasthttpClient(timeout int) *fasthttp.Client {
|
||||
|
||||
// proxyTheRequest handles the request proxying logic.
|
||||
func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
|
||||
// Setup tracing if enabled
|
||||
var span trace.Span
|
||||
ctx := context.Background()
|
||||
|
||||
ctx := setupTracing(c)
|
||||
|
||||
if cfg.Tracing.Enable && tracer != nil {
|
||||
// Extract trace information from header
|
||||
if traceHeader := c.Get("X-Trace-Span"); traceHeader != "" {
|
||||
spanInfo, err := libpack_tracing.ParseTraceHeader(traceHeader)
|
||||
if err != nil {
|
||||
cfg.Logger.Warning(&libpack_logger.LogMessage{
|
||||
Message: "Failed to parse trace header",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
} else {
|
||||
if spanCtx, err := tracer.ExtractSpanContext(spanInfo); err == nil {
|
||||
ctx = trace.ContextWithSpanContext(ctx, spanCtx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start a new span
|
||||
span, ctx = tracer.StartSpan(ctx, "proxy_request")
|
||||
defer span.End()
|
||||
}
|
||||
|
||||
// Check if URL is allowed
|
||||
if !checkAllowedURLs(c) {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Request blocked",
|
||||
Pairs: map[string]interface{}{"path": c.Path()},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||
}
|
||||
return fmt.Errorf("request blocked - not allowed URL: %s", c.Path())
|
||||
}
|
||||
|
||||
// Construct and validate proxy URL
|
||||
proxyURL := currentEndpoint + c.Path()
|
||||
_, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
if _, err := url.Parse(proxyURL); err != nil {
|
||||
return fmt.Errorf("invalid URL: %v", err)
|
||||
}
|
||||
|
||||
// Log request details in debug mode
|
||||
if cfg.LogLevel == "DEBUG" {
|
||||
logDebugRequest(c)
|
||||
}
|
||||
|
||||
err = retry.Do(
|
||||
// Perform the proxy request with retries
|
||||
if err := performProxyRequest(c, proxyURL); err != nil {
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Log response details in debug mode
|
||||
if cfg.LogLevel == "DEBUG" {
|
||||
logDebugResponse(c)
|
||||
}
|
||||
|
||||
// Handle gzipped responses
|
||||
if err := handleGzippedResponse(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Final status check
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
}
|
||||
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
|
||||
}
|
||||
|
||||
// Remove server header for security
|
||||
c.Response().Header.Del(fiber.HeaderServer)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupTracing extracts and sets up tracing context from request headers
|
||||
func setupTracing(c *fiber.Ctx) context.Context {
|
||||
ctx := context.Background()
|
||||
|
||||
if !cfg.Tracing.Enable || tracer == nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// Extract trace information from header
|
||||
if traceHeader := c.Get("X-Trace-Span"); traceHeader != "" {
|
||||
spanInfo, err := libpack_tracing.ParseTraceHeader(traceHeader)
|
||||
if err != nil {
|
||||
cfg.Logger.Warning(&libpack_logger.LogMessage{
|
||||
Message: "Failed to parse trace header",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
} else if spanCtx, err := tracer.ExtractSpanContext(spanInfo); err == nil {
|
||||
ctx = trace.ContextWithSpanContext(ctx, spanCtx)
|
||||
}
|
||||
}
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// performProxyRequest executes the proxy request with retries
|
||||
func performProxyRequest(c *fiber.Ctx, proxyURL string) error {
|
||||
return retry.Do(
|
||||
func() error {
|
||||
proxyErr := proxy.DoRedirects(c, proxyURL, 3, cfg.Client.FastProxyClient)
|
||||
if proxyErr != nil {
|
||||
return proxyErr
|
||||
if err := proxy.DoRedirects(c, proxyURL, 3, cfg.Client.FastProxyClient); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
|
||||
return fmt.Errorf("received non-200 response: %d", c.Response().StatusCode())
|
||||
}
|
||||
return nil
|
||||
},
|
||||
@@ -112,55 +151,38 @@ func proxyTheRequest(c *fiber.Ctx, currentEndpoint string) error {
|
||||
}),
|
||||
retry.LastErrorOnly(true),
|
||||
)
|
||||
}
|
||||
|
||||
// handleGzippedResponse decompresses gzipped responses
|
||||
func handleGzippedResponse(c *fiber.Ctx) error {
|
||||
if !bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a pooled gzip reader
|
||||
reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body()))
|
||||
if err != nil {
|
||||
cfg.Logger.Warning(&libpack_logger.LogMessage{
|
||||
Message: "Can't proxy the request",
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to create gzip reader",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
}
|
||||
return fmt.Errorf("failed to proxy request: %v", err)
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// Read decompressed data
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to decompress response",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.LogLevel == "DEBUG" {
|
||||
logDebugResponse(c)
|
||||
}
|
||||
|
||||
if bytes.EqualFold(c.Response().Header.Peek("Content-Encoding"), []byte("gzip")) {
|
||||
// Decompress gzip response
|
||||
reader, err := gzip.NewReader(bytes.NewReader(c.Response().Body()))
|
||||
if err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to create gzip reader",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
decompressed, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to decompress response",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
c.Response().SetBody(decompressed)
|
||||
c.Response().Header.Del("Content-Encoding")
|
||||
}
|
||||
|
||||
if c.Response().StatusCode() != fiber.StatusOK {
|
||||
if ifNotInTest() {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
}
|
||||
return fmt.Errorf("received non-200 response from the GraphQL server: %d", c.Response().StatusCode())
|
||||
}
|
||||
|
||||
c.Response().Header.Del(fiber.HeaderServer)
|
||||
// Update response
|
||||
c.Response().SetBody(decompressed)
|
||||
c.Response().Header.Del("Content-Encoding")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -109,51 +109,74 @@ func healthCheck(c *fiber.Ctx) error {
|
||||
return c.Status(fiber.StatusOK).SendString("Health check OK")
|
||||
}
|
||||
|
||||
// processGraphQLRequest handles the incoming GraphQL requests.
|
||||
// processGraphQLRequest handles the incoming GraphQL requests.
|
||||
func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
startTime := time.Now()
|
||||
|
||||
extractedUserID := "-"
|
||||
extractedRoleName := "-"
|
||||
|
||||
if authorization := c.Get("Authorization"); authorization != "" && (len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
|
||||
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
|
||||
}
|
||||
|
||||
// Extract user information and check permissions
|
||||
extractedUserID, extractedRoleName := extractUserInfo(c)
|
||||
|
||||
// Check if user is banned
|
||||
if checkIfUserIsBanned(c, extractedUserID) {
|
||||
return c.Status(fiber.StatusForbidden).SendString("User is banned")
|
||||
}
|
||||
|
||||
// Apply rate limiting if enabled
|
||||
if cfg.Client.RoleRateLimit && !rateLimitedRequest(extractedUserID, extractedRoleName) {
|
||||
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
|
||||
}
|
||||
|
||||
// Parse the GraphQL query
|
||||
parsedResult := parseGraphQLQuery(c)
|
||||
if parsedResult.shouldBlock {
|
||||
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
|
||||
}
|
||||
|
||||
// Handle non-GraphQL requests
|
||||
if parsedResult.shouldIgnore {
|
||||
return proxyTheRequest(c, parsedResult.activeEndpoint)
|
||||
}
|
||||
|
||||
// Handle caching
|
||||
wasCached, err := handleCaching(c, parsedResult, extractedUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Log and monitor the request
|
||||
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUserInfo extracts user ID and role from request headers
|
||||
func extractUserInfo(c *fiber.Ctx) (string, string) {
|
||||
extractedUserID := "-"
|
||||
extractedRoleName := "-"
|
||||
|
||||
// Extract from JWT if available
|
||||
if authorization := c.Get("Authorization"); authorization != "" &&
|
||||
(len(cfg.Client.JWTUserClaimPath) > 0 || len(cfg.Client.JWTRoleClaimPath) > 0) {
|
||||
extractedUserID, extractedRoleName = extractClaimsFromJWTHeader(authorization)
|
||||
}
|
||||
|
||||
// Override role from header if configured
|
||||
if cfg.Client.RoleFromHeader != "" {
|
||||
if role := c.Get(cfg.Client.RoleFromHeader); role != "" {
|
||||
extractedRoleName = role
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Client.RoleRateLimit {
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Rate limiting enabled",
|
||||
Pairs: map[string]interface{}{"user_id": extractedUserID, "role_name": extractedRoleName},
|
||||
})
|
||||
if !rateLimitedRequest(extractedUserID, extractedRoleName) {
|
||||
return c.Status(fiber.StatusTooManyRequests).SendString("Rate limit exceeded, try again later")
|
||||
}
|
||||
}
|
||||
|
||||
parsedResult := parseGraphQLQuery(c) // Ensure this function is defined elsewhere
|
||||
if parsedResult.shouldBlock {
|
||||
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
|
||||
}
|
||||
|
||||
if parsedResult.shouldIgnore {
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Request passed as-is - probably not a GraphQL",
|
||||
})
|
||||
return proxyTheRequest(c, parsedResult.activeEndpoint)
|
||||
}
|
||||
return extractedUserID, extractedRoleName
|
||||
}
|
||||
|
||||
// handleCaching manages the caching logic for GraphQL requests
|
||||
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) {
|
||||
// Calculate query hash for cache key
|
||||
calculatedQueryHash := libpack_cache.CalculateHash(c)
|
||||
|
||||
|
||||
// Set cache time from header or default
|
||||
if parsedResult.cacheTime == 0 {
|
||||
if cacheQuery := c.Get("X-Cache-Graphql-Query"); cacheQuery != "" {
|
||||
parsedResult.cacheTime, _ = strconv.Atoi(cacheQuery)
|
||||
@@ -162,58 +185,38 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
wasCached := false //nolint:ineffassign
|
||||
|
||||
// Handle cache refresh directive
|
||||
if parsedResult.cacheRefresh {
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Cache refresh requested via query",
|
||||
Pairs: map[string]interface{}{"user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
|
||||
})
|
||||
libpack_cache.CacheDelete(calculatedQueryHash)
|
||||
}
|
||||
|
||||
if parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Cache enabled",
|
||||
Pairs: map[string]interface{}{"via_query": parsedResult.cacheRequest, "via_env": cfg.Cache.CacheEnable},
|
||||
})
|
||||
|
||||
if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Cache hit",
|
||||
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
|
||||
})
|
||||
c.Set("X-Cache-Hit", "true")
|
||||
wasCached = true
|
||||
c.Set("Content-Type", "application/json")
|
||||
return c.Send(cachedResponse)
|
||||
}
|
||||
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
|
||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Cache miss",
|
||||
Pairs: map[string]interface{}{"hash": calculatedQueryHash, "user_id": extractedUserID, "request_uuid": c.Locals("request_uuid")},
|
||||
})
|
||||
if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Check if caching is enabled
|
||||
cacheEnabled := parsedResult.cacheRequest || cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable
|
||||
if !cacheEnabled {
|
||||
// No caching, just proxy the request
|
||||
if err := proxyTheRequest(c, parsedResult.activeEndpoint); err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Can't proxy the request",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
||||
return c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
|
||||
return false, c.Status(fiber.StatusInternalServerError).SendString("Can't proxy the request - try again later")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
logAndMonitorRequest(c, extractedUserID, parsedResult.operationType, parsedResult.operationName, wasCached, time.Since(startTime), startTime)
|
||||
// Try to get from cache
|
||||
if cachedResponse := libpack_cache.CacheLookup(calculatedQueryHash); cachedResponse != nil {
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheHit, nil)
|
||||
c.Set("X-Cache-Hit", "true")
|
||||
c.Set("Content-Type", "application/json")
|
||||
return true, c.Send(cachedResponse)
|
||||
}
|
||||
|
||||
return nil
|
||||
// Cache miss, proxy and cache
|
||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsCacheMiss, nil)
|
||||
if err := proxyAndCacheTheRequest(c, calculatedQueryHash, parsedResult.cacheTime, parsedResult.activeEndpoint); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// proxyAndCacheTheRequest proxies and caches the request if needed.
|
||||
func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int, currentEndpoint string) error {
|
||||
if err := proxyTheRequest(c, currentEndpoint); err != nil {
|
||||
|
||||
+65
-5
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
@@ -34,36 +36,63 @@ func NewTracing(ctx context.Context, endpoint string) (*TracingSetup, error) {
|
||||
return nil, fmt.Errorf("endpoint cannot be empty")
|
||||
}
|
||||
|
||||
conn, err := grpc.DialContext(ctx, endpoint,
|
||||
// Create a timeout context for connection establishment
|
||||
dialCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Connect to the collector with improved options
|
||||
conn, err := grpc.DialContext(dialCtx, endpoint,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithReturnConnectionError(),
|
||||
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16*1024*1024)), // 16MB max message size
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create gRPC connection to collector: %w", err)
|
||||
}
|
||||
|
||||
exporter, err := otlptracegrpc.New(ctx, otlptracegrpc.WithGRPCConn(conn))
|
||||
// Create the exporter
|
||||
exporter, err := otlptracegrpc.New(ctx,
|
||||
otlptracegrpc.WithGRPCConn(conn),
|
||||
otlptracegrpc.WithTimeout(5*time.Second),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create trace exporter: %w", err)
|
||||
}
|
||||
|
||||
// Create a resource with more detailed attributes
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName("graphql-monitoring-proxy"),
|
||||
semconv.ServiceVersion("1.0"),
|
||||
semconv.DeploymentEnvironment("production"),
|
||||
attribute.String("application.type", "proxy"),
|
||||
),
|
||||
resource.WithHost(), // Add host information
|
||||
resource.WithOSType(), // Add OS information
|
||||
resource.WithProcessPID(), // Add process information
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create resource: %w", err)
|
||||
}
|
||||
|
||||
// Create the tracer provider with improved configuration
|
||||
tracerProvider := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithBatcher(exporter,
|
||||
// Configure batch processing
|
||||
sdktrace.WithMaxExportBatchSize(512),
|
||||
sdktrace.WithBatchTimeout(3*time.Second),
|
||||
sdktrace.WithMaxQueueSize(2048),
|
||||
),
|
||||
sdktrace.WithResource(res),
|
||||
sdktrace.WithSampler(sdktrace.TraceIDRatioBased(0.1)), // Sample 10% of traces
|
||||
)
|
||||
|
||||
// Set the global tracer provider and propagator
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
otel.SetTextMapPropagator(propagation.TraceContext{})
|
||||
|
||||
// Create a tracer
|
||||
tracer := tracerProvider.Tracer("graphql-monitoring-proxy")
|
||||
|
||||
return &TracingSetup{
|
||||
@@ -105,9 +134,40 @@ func (ts *TracingSetup) Shutdown(ctx context.Context) error {
|
||||
|
||||
// StartSpan starts a new span with the given name and parent context
|
||||
func (ts *TracingSetup) StartSpan(ctx context.Context, name string) (trace.Span, context.Context) {
|
||||
if ts.tracer == nil {
|
||||
if ts == nil || ts.tracer == nil {
|
||||
// Return a no-op span if tracing is not configured
|
||||
return trace.SpanFromContext(ctx), ctx
|
||||
}
|
||||
ctx, span := ts.tracer.Start(ctx, name)
|
||||
|
||||
// Add common attributes to all spans
|
||||
opts := []trace.SpanStartOption{
|
||||
trace.WithAttributes(
|
||||
semconv.ServiceName("graphql-monitoring-proxy"),
|
||||
semconv.ServiceVersion("1.0"),
|
||||
),
|
||||
}
|
||||
|
||||
ctx, span := ts.tracer.Start(ctx, name, opts...)
|
||||
return span, ctx
|
||||
}
|
||||
|
||||
// StartSpanWithAttributes starts a new span with custom attributes
|
||||
func (ts *TracingSetup) StartSpanWithAttributes(ctx context.Context, name string, attrs map[string]string) (trace.Span, context.Context) {
|
||||
if ts == nil || ts.tracer == nil {
|
||||
return trace.SpanFromContext(ctx), ctx
|
||||
}
|
||||
|
||||
// Convert string attributes to KeyValue pairs
|
||||
attributes := make([]attribute.KeyValue, 0, len(attrs)+2)
|
||||
attributes = append(attributes,
|
||||
semconv.ServiceName("graphql-monitoring-proxy"),
|
||||
semconv.ServiceVersion("1.0"),
|
||||
)
|
||||
|
||||
for k, v := range attrs {
|
||||
attributes = append(attributes, attribute.String(k, v))
|
||||
}
|
||||
|
||||
ctx, span := ts.tracer.Start(ctx, name, trace.WithAttributes(attributes...))
|
||||
return span, ctx
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user