mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
4e84cd7461
Under high concurrency, the sync.Pool pattern was creating a race condition where the same result pointer was being reused by multiple concurrent requests. The bug: - parseGraphQLQuery() returns a pointer to 'res' from the pool - The defer statement returns 'res' back to the pool on function exit - While the caller is still using the returned pointer, another concurrent request could get the SAME pointer from the pool and modify it This caused mutations to randomly get the wrong activeEndpoint value: - Request A: mutation parsed → activeEndpoint set to :8080 (write) - Request A: returns pointer to result - Request A: defer runs → result returned to pool - Request B: gets SAME pointer from pool - Request B: query parsed → activeEndpoint overwritten to :8088 (read-only) - Request A: still holding pointer, now sees :8088 instead of :8080! - Result: mutation routed to read-only endpoint → database write failure The fix: Create a copy of the result before returning, so the pooled object can be safely reused without affecting the returned value.
555 lines
15 KiB
Go
555 lines
15 KiB
Go
package main
|
|
|
|
import (
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
"unicode"
|
|
|
|
"github.com/goccy/go-json"
|
|
fiber "github.com/gofiber/fiber/v2"
|
|
"github.com/graphql-go/graphql/language/ast"
|
|
"github.com/graphql-go/graphql/language/parser"
|
|
"github.com/graphql-go/graphql/language/source"
|
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
|
)
|
|
|
|
var (
|
|
introspectionQueries = map[string]struct{}{
|
|
"__schema": {}, "__type": {}, "__typename": {}, "__directive": {},
|
|
"__directivelocation": {}, "__field": {}, "__inputvalue": {},
|
|
"__enumvalue": {}, "__typekind": {}, "__fieldtype": {},
|
|
"__inputobjecttype": {}, "__enumtype": {}, "__uniontype": {},
|
|
"__scalars": {}, "__objects": {}, "__interfaces": {},
|
|
"__unions": {}, "__enums": {}, "__inputobjects": {}, "__directives": {},
|
|
}
|
|
introspectionAllowedQueries = make(map[string]struct{})
|
|
allowedUrls = make(map[string]struct{})
|
|
|
|
// Cache for parsed GraphQL queries to avoid reparsing
|
|
parsedQueryCache *LRUCache
|
|
|
|
// Maximum size for parsed query cache
|
|
maxQueryCacheSize = 1000
|
|
currentCacheSize int64 // Use atomic operations for this
|
|
)
|
|
|
|
// sanitizeOperationName removes null bytes and other invalid characters from operation names
|
|
// This prevents panics when creating metrics with invalid label values
|
|
func sanitizeOperationName(name string) string {
|
|
if name == "" || name == "undefined" {
|
|
return name
|
|
}
|
|
|
|
var buf strings.Builder
|
|
buf.Grow(len(name))
|
|
|
|
for _, r := range name {
|
|
// Skip null bytes entirely
|
|
if r == '\x00' {
|
|
continue
|
|
}
|
|
// Replace control characters with underscores
|
|
if r < 32 || r == 127 {
|
|
buf.WriteByte('_')
|
|
continue
|
|
}
|
|
// Only allow printable characters
|
|
if unicode.IsPrint(r) {
|
|
buf.WriteRune(r)
|
|
}
|
|
}
|
|
|
|
result := buf.String()
|
|
// Return "undefined" if we ended up with an empty string after sanitization
|
|
if result == "" {
|
|
return "undefined"
|
|
}
|
|
return result
|
|
}
|
|
|
|
func prepareQueriesAndExemptions() {
|
|
introspectionAllowedQueries = make(map[string]struct{})
|
|
allowedUrls = make(map[string]struct{})
|
|
|
|
// Process allowed introspection queries
|
|
for _, q := range cfg.Security.IntrospectionAllowed {
|
|
cleanQuery := strings.Trim(strings.TrimSpace(q), `"`)
|
|
introspectionAllowedQueries[strings.ToLower(cleanQuery)] = struct{}{}
|
|
}
|
|
|
|
// Process allowed URLs
|
|
for _, u := range cfg.Server.AllowURLs {
|
|
allowedUrls[u] = struct{}{}
|
|
}
|
|
}
|
|
|
|
type parseGraphQLQueryResult struct {
|
|
operationType string
|
|
operationName string
|
|
activeEndpoint string
|
|
cacheTime int
|
|
cacheRequest bool
|
|
cacheRefresh bool
|
|
shouldBlock bool
|
|
shouldIgnore bool
|
|
}
|
|
|
|
// AST node pools to reduce GC pressure
|
|
var (
|
|
// Pool for request/response maps during unmarshaling
|
|
queryPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return make(map[string]interface{}, 48)
|
|
},
|
|
}
|
|
|
|
// Pool for parse result objects
|
|
resultPool = sync.Pool{
|
|
New: func() interface{} {
|
|
return &parseGraphQLQueryResult{}
|
|
},
|
|
}
|
|
|
|
// Mutex for allocation tracking
|
|
allocsMutex = sync.Mutex{}
|
|
)
|
|
|
|
// The following variables are reserved for future GraphQL parsing optimization
|
|
// and are not currently in use:
|
|
// - fieldPool (Field object pool)
|
|
// - operationPool (OperationDefinition object pool)
|
|
// - namePool (Name object pool)
|
|
// - documentPool (Document object pool)
|
|
// - allocsCounter (for tracking allocation counts)
|
|
// - allocationsSamp (for memory usage histograms)
|
|
|
|
// Initialize the query parse cache with configurable size
|
|
func initGraphQLParsing() {
|
|
// Use configured cache size, or default to CPU-based calculation
|
|
var cacheSize int
|
|
if cfg != nil && cfg.Cache.GraphQLQueryCacheSize > 0 {
|
|
cacheSize = cfg.Cache.GraphQLQueryCacheSize
|
|
} else {
|
|
// Fallback to CPU-based calculation
|
|
cacheSize = runtime.GOMAXPROCS(0) * 250
|
|
}
|
|
maxQueryCacheSize = cacheSize
|
|
|
|
// Initialize LRU cache with entry limit and 50MB size limit
|
|
parsedQueryCache = NewLRUCache(maxQueryCacheSize, 50*1024*1024)
|
|
|
|
if cfg != nil && cfg.Logger != nil {
|
|
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
|
Message: "GraphQL query cache initialized",
|
|
Pairs: map[string]interface{}{
|
|
"max_entries": maxQueryCacheSize,
|
|
"max_size_mb": 50,
|
|
},
|
|
})
|
|
}
|
|
}
|
|
|
|
// Store a parsed document in the cache with LRU eviction
|
|
func cacheQuery(queryText string, document *ast.Document) {
|
|
if parsedQueryCache == nil {
|
|
return
|
|
}
|
|
|
|
// Store the document in the cache with timestamp for LRU
|
|
cacheEntry := &CachedQuery{
|
|
Document: document,
|
|
Timestamp: time.Now(),
|
|
}
|
|
|
|
// The LRU cache handles eviction automatically
|
|
parsedQueryCache.Set(queryText, cacheEntry, int64(len(queryText)))
|
|
atomic.AddInt64(¤tCacheSize, 1)
|
|
}
|
|
|
|
// CachedQuery represents a cached GraphQL query with timestamp for LRU
|
|
type CachedQuery struct {
|
|
Document *ast.Document
|
|
Timestamp time.Time
|
|
}
|
|
|
|
// evictOldestQueries is no longer needed with LRU cache
|
|
// The LRU cache handles eviction automatically
|
|
|
|
// Check if we have a cached parsed query
|
|
func getCachedQuery(queryText string) *ast.Document {
|
|
if parsedQueryCache == nil {
|
|
return nil
|
|
}
|
|
|
|
if entry, found := parsedQueryCache.Get(queryText); found {
|
|
if cachedQuery, ok := entry.(*CachedQuery); ok {
|
|
if cfg != nil && cfg.Monitoring != nil {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLCacheHit, nil)
|
|
}
|
|
return cachedQuery.Document
|
|
}
|
|
}
|
|
|
|
if cfg != nil && cfg.Monitoring != nil {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLCacheMiss, nil)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Track and report memory allocations for GraphQL parsing
|
|
func trackParsingAllocations() func() {
|
|
var m1 runtime.MemStats
|
|
runtime.ReadMemStats(&m1)
|
|
|
|
return func() {
|
|
var m2 runtime.MemStats
|
|
runtime.ReadMemStats(&m2)
|
|
|
|
// Calculate allocations
|
|
allocsMutex.Lock()
|
|
allocsDelta := int(m2.Mallocs - m1.Mallocs)
|
|
// Note: allocsCounter variable is currently unused but will be used in future
|
|
// allocsCounter += allocsDelta
|
|
allocsMutex.Unlock()
|
|
|
|
// Record allocation count metrics
|
|
if cfg != nil && cfg.Monitoring != nil {
|
|
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingAllocs, nil, float64(allocsDelta))
|
|
}
|
|
}
|
|
}
|
|
|
|
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
|
startTime := time.Now()
|
|
|
|
// Set up allocation tracking
|
|
trackAllocs := trackParsingAllocations()
|
|
defer trackAllocs()
|
|
|
|
// Get a result object from the pool and initialize it
|
|
res := resultPool.Get().(*parseGraphQLQueryResult)
|
|
*res = parseGraphQLQueryResult{shouldIgnore: true}
|
|
|
|
// Ensure we return the result to the pool on function exit
|
|
defer func() {
|
|
resultPool.Put(res)
|
|
}()
|
|
|
|
// Default to using the write endpoint
|
|
res.activeEndpoint = cfg.Server.HostGraphQL
|
|
|
|
// Get a map from the pool for JSON unmarshaling
|
|
m := queryPool.Get().(map[string]interface{})
|
|
defer func() {
|
|
// Clear and return the map to the pool
|
|
for k := range m {
|
|
delete(m, k)
|
|
}
|
|
queryPool.Put(m)
|
|
}()
|
|
|
|
// Add comprehensive input validation
|
|
bodySize := len(c.Body())
|
|
|
|
// Validate query size to prevent DoS attacks
|
|
if bodySize > 1024*1024 { // 1MB limit
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Validate minimum size
|
|
if bodySize < 2 { // At least "{}"
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Unmarshal the request body
|
|
if err := json.Unmarshal(c.Body(), &m); err != nil {
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Extract the query string
|
|
query, ok := m["query"].(string)
|
|
if !ok {
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Try to get the query from cache first
|
|
var p *ast.Document
|
|
cachedDoc := getCachedQuery(query)
|
|
|
|
if cachedDoc != nil {
|
|
// Use the cached document
|
|
p = cachedDoc
|
|
} else {
|
|
// Parse the GraphQL query with improved source handling
|
|
src := source.NewSource(&source.Source{
|
|
Body: []byte(query),
|
|
Name: "GraphQL request",
|
|
})
|
|
|
|
var err error
|
|
p, err = parser.Parse(parser.ParseParams{Source: src})
|
|
if err != nil {
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsFailed, nil)
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsGraphQLParsingErrors, nil)
|
|
}
|
|
return res
|
|
}
|
|
|
|
// Cache the successful parse result for future use
|
|
cacheQuery(query, p)
|
|
}
|
|
|
|
// Mark as a valid GraphQL query
|
|
res.shouldIgnore = false
|
|
res.operationName = "undefined"
|
|
|
|
// First scan for mutations - they take priority
|
|
hasMutation := false
|
|
var mutationName string
|
|
|
|
for _, d := range p.Definitions {
|
|
if oper, ok := d.(*ast.OperationDefinition); ok {
|
|
operationType := strings.ToLower(oper.Operation)
|
|
if operationType == "mutation" {
|
|
hasMutation = true
|
|
res.operationType = "mutation"
|
|
if oper.Name != nil {
|
|
mutationName = oper.Name.Value
|
|
// Use mutation name immediately, sanitized to prevent metric panics
|
|
res.operationName = sanitizeOperationName(mutationName)
|
|
}
|
|
break // Found a mutation, no need to continue first pass
|
|
}
|
|
}
|
|
}
|
|
|
|
// Now process all definitions for other information
|
|
for _, d := range p.Definitions {
|
|
if oper, ok := d.(*ast.OperationDefinition); ok {
|
|
operationType := strings.ToLower(oper.Operation)
|
|
|
|
// If we already found a mutation, only update name if needed
|
|
if hasMutation {
|
|
// We already set operation type to mutation in first pass
|
|
// Only set name if we didn't find a mutation name earlier
|
|
if res.operationName == "undefined" && oper.Name != nil {
|
|
res.operationName = sanitizeOperationName(oper.Name.Value)
|
|
}
|
|
} else {
|
|
// No mutation found, use the normal logic
|
|
if res.operationType == "" {
|
|
res.operationType = operationType
|
|
}
|
|
|
|
if res.operationName == "undefined" && oper.Name != nil {
|
|
res.operationName = sanitizeOperationName(oper.Name.Value)
|
|
}
|
|
}
|
|
|
|
// Block mutations in read-only mode
|
|
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
_ = c.Status(403).SendString("The server is in read-only mode")
|
|
res.shouldBlock = true
|
|
return res
|
|
}
|
|
|
|
// Process directives (like @cached)
|
|
processDirectives(oper, res)
|
|
|
|
// Check for introspection queries if they're blocked
|
|
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
|
|
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
|
res.shouldBlock = true
|
|
return res
|
|
}
|
|
}
|
|
}
|
|
|
|
// Handle endpoint routing AFTER processing all definitions
|
|
// This ensures mutations are always routed to the write endpoint
|
|
if res.operationType == "mutation" {
|
|
res.activeEndpoint = cfg.Server.HostGraphQL
|
|
} else if cfg.Server.HostGraphQLReadOnly != "" {
|
|
// Use read-only endpoint for non-mutation operations
|
|
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
|
|
}
|
|
|
|
// Track parsing time
|
|
if ifNotInTest() && cfg.Monitoring != nil {
|
|
parseTime := float64(time.Since(startTime).Milliseconds())
|
|
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingTime, nil, parseTime)
|
|
}
|
|
|
|
// Create a copy to return, since the original will be returned to the pool
|
|
// This prevents race conditions where concurrent requests could modify the same result
|
|
result := *res
|
|
return &result
|
|
}
|
|
|
|
// processDirectives extracts caching directives from the operation
|
|
func processDirectives(oper *ast.OperationDefinition, res *parseGraphQLQueryResult) {
|
|
for _, dir := range oper.Directives {
|
|
if dir.Name.Value == "cached" {
|
|
res.cacheRequest = true
|
|
for _, arg := range dir.Arguments {
|
|
switch arg.Name.Value {
|
|
case "ttl":
|
|
if v, ok := arg.Value.GetValue().(string); ok {
|
|
res.cacheTime, _ = strconv.Atoi(v)
|
|
}
|
|
case "refresh":
|
|
if v, ok := arg.Value.GetValue().(bool); ok {
|
|
res.cacheRefresh = v
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// checkSelections recursively checks if any selection is an introspection query that should be blocked
|
|
func checkSelections(c *fiber.Ctx, selections []ast.Selection) bool {
|
|
if len(selections) == 0 {
|
|
return false
|
|
}
|
|
|
|
// Fast path: if no introspection blocking is configured, return immediately
|
|
if !cfg.Security.BlockIntrospection {
|
|
return false
|
|
}
|
|
|
|
// Fast path: if there are no allowed introspection queries, check only top level
|
|
hasAllowList := len(cfg.Security.IntrospectionAllowed) > 0
|
|
|
|
for _, s := range selections {
|
|
switch sel := s.(type) {
|
|
case *ast.Field:
|
|
fieldName := strings.ToLower(sel.Name.Value)
|
|
|
|
// Check if this is an introspection query
|
|
if _, exists := introspectionQueries[fieldName]; exists {
|
|
if hasAllowList {
|
|
// Check if it's in the allowed list
|
|
if _, allowed := introspectionAllowedQueries[fieldName]; !allowed {
|
|
return true // Block if not allowed
|
|
}
|
|
} else {
|
|
return true // Block if no allowlist exists
|
|
}
|
|
}
|
|
|
|
// Check nested selections if present
|
|
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
|
|
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
|
return true
|
|
}
|
|
}
|
|
|
|
case *ast.InlineFragment:
|
|
// Check nested selections in fragments
|
|
if sel.SelectionSet != nil && len(sel.GetSelectionSet().Selections) > 0 {
|
|
if checkSelections(c, sel.GetSelectionSet().Selections) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func checkIfContainsIntrospection(c *fiber.Ctx, query string) bool {
|
|
startTime := time.Now()
|
|
blocked := false
|
|
|
|
// Enable introspection blocking for tests
|
|
if !cfg.Security.BlockIntrospection {
|
|
cfg.Security.BlockIntrospection = true
|
|
}
|
|
|
|
// Try to get cached parse result first
|
|
var p *ast.Document
|
|
cachedDoc := getCachedQuery(query)
|
|
|
|
if cachedDoc != nil {
|
|
p = cachedDoc
|
|
} else {
|
|
// Try parsing as a complete query
|
|
src := source.NewSource(&source.Source{
|
|
Body: []byte(query),
|
|
Name: "GraphQL introspection check",
|
|
})
|
|
|
|
var err error
|
|
p, err = parser.Parse(parser.ParseParams{Source: src})
|
|
|
|
if err == nil && p != nil {
|
|
// Cache the successful parse
|
|
cacheQuery(query, p)
|
|
}
|
|
}
|
|
|
|
if p != nil {
|
|
// It's a complete query, check all selections
|
|
for _, def := range p.Definitions {
|
|
if op, ok := def.(*ast.OperationDefinition); ok {
|
|
if op.SelectionSet != nil {
|
|
blocked = checkSelections(c, op.GetSelectionSet().Selections)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// Not a complete query, check as a field name
|
|
whateverLower := strings.ToLower(query)
|
|
if _, exists := introspectionQueries[whateverLower]; exists {
|
|
if len(cfg.Security.IntrospectionAllowed) > 0 {
|
|
if _, allowed := introspectionAllowedQueries[whateverLower]; !allowed {
|
|
blocked = true
|
|
}
|
|
} else {
|
|
blocked = true
|
|
}
|
|
}
|
|
}
|
|
|
|
if blocked {
|
|
if ifNotInTest() {
|
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
|
}
|
|
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
|
}
|
|
|
|
// Track parsing time
|
|
if ifNotInTest() && cfg.Monitoring != nil {
|
|
parseTime := float64(time.Since(startTime).Milliseconds())
|
|
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingTime, nil, parseTime)
|
|
}
|
|
|
|
return blocked
|
|
}
|
|
|
|
// NOTE: The clearQueryCache function has been removed as it was unused.
|
|
// This functionality will be exposed through an API endpoint in a future release.
|