mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-12 00:19:36 +00:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cc35031db9 | |||
| 6a69694ab3 | |||
| b210627fb7 | |||
| edcabe3cf0 | |||
| c99bf2b245 | |||
| 39dc7b49cf | |||
| 28223b40da | |||
| ee5618c699 | |||
| 94c097bc6c | |||
| 4e84cd7461 | |||
| e37a8beaa7 | |||
| 9dd8c11363 | |||
| 9fbee0d9a1 | |||
| 7df651c17a | |||
| 7ada94e4fa | |||
| c510c29a8f | |||
| 370602858a | |||
| 6261be6e53 | |||
| 5ae4ea1e25 |
@@ -3,3 +3,5 @@ test.sh
|
||||
banned.json*
|
||||
dist/
|
||||
coverage.out
|
||||
CLAUDE.md
|
||||
graphql-monitoring-proxy
|
||||
|
||||
@@ -155,6 +155,7 @@ You can still use the non-prefixed environment variables in the spirit of the ba
|
||||
| `CACHE_TTL` | The cache TTL | `60` |
|
||||
| `CACHE_MAX_MEMORY_SIZE` | Maximum memory size for cache in MB | `100` |
|
||||
| `CACHE_MAX_ENTRIES` | Maximum number of entries in cache | `10000` |
|
||||
| `CACHE_PER_USER_DISABLED` | **⚠️ SECURITY**: Disable per-user cache isolation | `false` (**DO NOT** set to `true` in multi-user apps) |
|
||||
| `ENABLE_REDIS_CACHE` | Enable distributed Redis cache | `false` |
|
||||
| `CACHE_REDIS_URL` | URL to redis server / cluster endpoint | `localhost:6379` |
|
||||
| `CACHE_REDIS_PASSWORD` | Redis connection password | `` |
|
||||
@@ -347,19 +348,38 @@ The admin dashboard (`/admin`) provides:
|
||||
The cache engine is enabled in the background by default, using no additional resources.
|
||||
You can then start using the cache by setting the `ENABLE_GLOBAL_CACHE` or `ENABLE_REDIS_CACHE` environment variable to `true` - which will enable the cache for all queries without introspection. You can leave the global cache disabled and enable the cache for specific queries by adding the `@cached` directive to the query.
|
||||
|
||||
**Important**: The cache key is calculated from the **entire request body**, which includes both the GraphQL query and variables. This means:
|
||||
**Important**: The cache key is calculated from the **request body + user context (user ID and role)**. This means:
|
||||
- Identical queries with different variables are cached separately
|
||||
- Identical queries with different variable values get their own cache entries
|
||||
- This ensures correct caching behavior for parameterized queries
|
||||
- **Identical queries from different users are cached separately** (security isolation)
|
||||
- **Identical queries with different roles are cached separately** (prevents privilege escalation)
|
||||
- This ensures correct caching behavior and prevents data leakage between users
|
||||
|
||||
**🔒 Security Update (v0.27.0+)**: Cache keys now include user context by default to prevent security vulnerabilities where users could see each other's cached data. This is enabled by default and should NOT be disabled in multi-user applications.
|
||||
|
||||
Example:
|
||||
```graphql
|
||||
# These two requests will have DIFFERENT cache keys:
|
||||
# These requests will have DIFFERENT cache keys:
|
||||
|
||||
# Different variables
|
||||
query GetUser($id: ID!) { user(id: $id) { name } }
|
||||
variables: { "id": "123" }
|
||||
variables: { "id": "123" } // Cache key: MD5(body + user:alice + role:user)
|
||||
|
||||
query GetUser($id: ID!) { user(id: $id) { name } }
|
||||
variables: { "id": "456" }
|
||||
variables: { "id": "456" } // Cache key: MD5(body + user:alice + role:user)
|
||||
|
||||
# Different users (SECURITY: prevents data leakage)
|
||||
query GetMyProfile { me { email } }
|
||||
Authorization: Bearer token_for_alice // Cache key: MD5(body + user:alice + role:user)
|
||||
|
||||
query GetMyProfile { me { email } }
|
||||
Authorization: Bearer token_for_bob // Cache key: MD5(body + user:bob + role:user)
|
||||
|
||||
# Different roles (SECURITY: prevents privilege escalation)
|
||||
query GetData { data { value } }
|
||||
Authorization: Bearer token_admin // Cache key: MD5(body + user:alice + role:admin)
|
||||
|
||||
query GetData { data { value } }
|
||||
Authorization: Bearer token_user // Cache key: MD5(body + user:alice + role:user)
|
||||
```
|
||||
|
||||
In the case of the `@cached` you can add additional parameters to the directive which will set the cache for specific queries to the provided time.
|
||||
@@ -425,6 +445,8 @@ You can now specify the read-only GraphQL endpoint by setting the `HOST_GRAPHQL_
|
||||
|
||||
You can check out the [example of combined deployment with RW and read-only hasura](static/kubernetes-single-deployment-with-ro.yaml).
|
||||
|
||||
**Important:** When using a read-only Hasura instance connected to a PostgreSQL read replica, you **must** disable event trigger processing on that instance by setting `HASURA_GRAPHQL_EVENTS_FETCH_INTERVAL=0` in the read-only Hasura container environment variables. This prevents the read-only instance from attempting to process event triggers (which require write access to event log tables), avoiding "cannot set transaction read-write mode during recovery" errors.
|
||||
|
||||
### Resilience
|
||||
|
||||
#### Circuit Breaker Pattern
|
||||
@@ -723,6 +745,8 @@ Following tables are being cleaned:
|
||||
- `hdb_catalog.hdb_cron_event_invocation_logs`
|
||||
- `hdb_catalog.hdb_scheduled_event_invocation_logs`
|
||||
|
||||
**Important for RO/RW setups:** The `HASURA_EVENT_METADATA_DB` connection string must point to the **read-write primary database** where the `hdb_catalog` schema resides. The cleaner executes DELETE operations which require write permissions. Do not point this to a read-only replica.
|
||||
|
||||
|
||||
### Security
|
||||
|
||||
|
||||
@@ -563,6 +563,26 @@
|
||||
<span class="metric-label">Enabled</span>
|
||||
<span class="metric-value" id="cb-enabled">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Total Requests</span>
|
||||
<span class="metric-value" id="cb-total-requests">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Total Successes</span>
|
||||
<span class="metric-value" id="cb-total-successes">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Total Failures</span>
|
||||
<span class="metric-value" id="cb-total-failures">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Consecutive Successes</span>
|
||||
<span class="metric-value" id="cb-consecutive-successes">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Consecutive Failures</span>
|
||||
<span class="metric-value" id="cb-consecutive-failures">--</span>
|
||||
</div>
|
||||
<div class="metric-row">
|
||||
<span class="metric-label">Max Failures</span>
|
||||
<span class="metric-value" id="cb-max-failures">--</span>
|
||||
@@ -1055,6 +1075,63 @@
|
||||
|
||||
// Update all statistics
|
||||
function updateAllStats(data) {
|
||||
// Check if this is cluster mode data
|
||||
if (data.cluster_mode && data.stats) {
|
||||
// Cluster mode: data is structured differently
|
||||
// Stats contains aggregated data with nested objects
|
||||
const stats = data.stats;
|
||||
|
||||
// Update cluster status section
|
||||
document.getElementById('cluster-status-section').style.display = 'block';
|
||||
document.getElementById('cluster-total-instances').textContent = data.total_instances || 0;
|
||||
document.getElementById('cluster-healthy-instances').textContent = data.healthy_instances || 0;
|
||||
document.getElementById('overview-title').textContent = 'Cluster Overview';
|
||||
|
||||
// Update cluster info in toggle
|
||||
const totalInstances = data.total_instances || 0;
|
||||
document.getElementById('cluster-info').textContent =
|
||||
`(${totalInstances} instance${totalInstances !== 1 ? 's' : ''} available)`;
|
||||
|
||||
// Build stats object with uptime from cluster_uptime
|
||||
const statsWithUptime = {
|
||||
...stats,
|
||||
uptime_seconds: stats.cluster_uptime || 0,
|
||||
uptime_human: formatUptime(stats.cluster_uptime || 0)
|
||||
};
|
||||
updateStats(statsWithUptime);
|
||||
|
||||
// Extract nested objects from stats for cluster mode
|
||||
if (stats.circuit_breaker) updateCircuitBreaker(stats.circuit_breaker);
|
||||
if (stats.coalescing) updateCoalescing(stats.coalescing);
|
||||
if (stats.retry_budget) updateRetryBudget(stats.retry_budget);
|
||||
if (stats.websocket) updateWebSocket(stats.websocket);
|
||||
if (stats.connections) updateConnections(stats.connections);
|
||||
|
||||
// Handle memory for cluster mode (Redis doesn't track memory per instance)
|
||||
if (stats.memory) {
|
||||
const totalMemMB = stats.memory.total_usage_mb;
|
||||
if (totalMemMB < 0) {
|
||||
// All instances are using Redis cache
|
||||
document.getElementById('cache-memory').textContent = 'N/A';
|
||||
document.getElementById('cache-memory').title = 'Memory tracking not available for Redis cache';
|
||||
document.getElementById('cache-memory-pct').textContent = 'Redis cache';
|
||||
document.getElementById('memory-progress').style.width = '0%';
|
||||
} else {
|
||||
document.getElementById('cache-memory').textContent = totalMemMB.toFixed(2) + ' MB';
|
||||
document.getElementById('cache-memory-pct').textContent = 'Cluster total';
|
||||
}
|
||||
}
|
||||
|
||||
// Update instance list if available
|
||||
if (data.instances && data.instances.length > 0) {
|
||||
document.getElementById('instance-details-section').style.display = 'block';
|
||||
updateInstanceList(data.instances, null);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Non-cluster mode: original behavior
|
||||
if (data.stats) updateStats(data.stats);
|
||||
if (data.health) updateHealth(data.health);
|
||||
if (data.circuit_breaker) updateCircuitBreaker(data.circuit_breaker);
|
||||
@@ -1208,6 +1285,19 @@
|
||||
|
||||
document.getElementById('cb-enabled').textContent = data.enabled ? 'Yes' : 'No';
|
||||
|
||||
if (data.counts) {
|
||||
document.getElementById('cb-total-requests').textContent =
|
||||
(data.counts.requests || 0).toLocaleString();
|
||||
document.getElementById('cb-total-successes').textContent =
|
||||
(data.counts.total_successes || 0).toLocaleString();
|
||||
document.getElementById('cb-total-failures').textContent =
|
||||
(data.counts.total_failures || 0).toLocaleString();
|
||||
document.getElementById('cb-consecutive-successes').textContent =
|
||||
(data.counts.consecutive_successes || 0).toLocaleString();
|
||||
document.getElementById('cb-consecutive-failures').textContent =
|
||||
(data.counts.consecutive_failures || 0).toLocaleString();
|
||||
}
|
||||
|
||||
if (data.config) {
|
||||
document.getElementById('cb-max-failures').textContent = data.config.max_failures || '--';
|
||||
document.getElementById('cb-timeout').textContent = (data.config.timeout || '--') + 's';
|
||||
|
||||
+173
-13
@@ -79,9 +79,44 @@ func (ad *AdminDashboard) serveDashboard(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// getStats returns overall proxy statistics
|
||||
// In cluster mode (when metrics aggregator is available), returns aggregated stats from all instances
|
||||
func (ad *AdminDashboard) getStats(c *fiber.Ctx) error {
|
||||
// Check if cluster mode is enabled - if so, return aggregated stats
|
||||
if aggregator := GetMetricsAggregator(); aggregator != nil {
|
||||
metrics, err := aggregator.GetAggregatedMetrics()
|
||||
if err != nil {
|
||||
if ad.logger != nil {
|
||||
ad.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to get aggregated metrics, falling back to local stats",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
}
|
||||
// Fall through to local stats on error
|
||||
} else {
|
||||
// Return aggregated cluster stats
|
||||
response := map[string]interface{}{
|
||||
"cluster_mode": true,
|
||||
"total_instances": metrics.TotalInstances,
|
||||
"healthy_instances": metrics.HealthyInstances,
|
||||
"timestamp": metrics.LastUpdate.Format(time.RFC3339),
|
||||
"version": "0.27.0",
|
||||
}
|
||||
|
||||
// Add combined stats from aggregation
|
||||
if metrics.CombinedStats != nil {
|
||||
for k, v := range metrics.CombinedStats {
|
||||
response[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
|
||||
// Local instance stats (fallback or non-cluster mode)
|
||||
uptimeSeconds := time.Since(startTime).Seconds()
|
||||
stats := map[string]interface{}{
|
||||
"cluster_mode": false,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
"uptime_seconds": uptimeSeconds,
|
||||
"uptime_human": formatDuration(time.Since(startTime)),
|
||||
@@ -208,9 +243,17 @@ func (ad *AdminDashboard) getCircuitBreakerStatus(c *fiber.Ctx) error {
|
||||
if cb != nil {
|
||||
cbMutex.RLock()
|
||||
state := cb.State()
|
||||
counts := cb.Counts()
|
||||
cbMutex.RUnlock()
|
||||
|
||||
status["state"] = state.String()
|
||||
status["counts"] = map[string]interface{}{
|
||||
"requests": counts.Requests,
|
||||
"total_successes": counts.TotalSuccesses,
|
||||
"total_failures": counts.TotalFailures,
|
||||
"consecutive_successes": counts.ConsecutiveSuccesses,
|
||||
"consecutive_failures": counts.ConsecutiveFailures,
|
||||
}
|
||||
status["config"] = map[string]interface{}{
|
||||
"max_failures": cfg.CircuitBreaker.MaxFailures,
|
||||
"failure_ratio": cfg.CircuitBreaker.FailureRatio,
|
||||
@@ -225,9 +268,62 @@ func (ad *AdminDashboard) getCircuitBreakerStatus(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// getCacheStats returns cache statistics
|
||||
// In cluster mode, returns aggregated cache stats from all instances
|
||||
func (ad *AdminDashboard) getCacheStats(c *fiber.Ctx) error {
|
||||
// Check if cluster mode is enabled - if so, return aggregated cache stats
|
||||
if aggregator := GetMetricsAggregator(); aggregator != nil {
|
||||
metrics, err := aggregator.GetAggregatedMetrics()
|
||||
if err != nil {
|
||||
if ad.logger != nil {
|
||||
ad.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to get aggregated cache metrics, falling back to local stats",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
}
|
||||
// Fall through to local stats on error
|
||||
} else {
|
||||
// Build aggregated cache stats from combined stats
|
||||
response := map[string]interface{}{
|
||||
"cluster_mode": true,
|
||||
"total_instances": metrics.TotalInstances,
|
||||
}
|
||||
|
||||
// Add cache config from local config
|
||||
if cfg != nil {
|
||||
response["enabled"] = cfg.Cache.CacheEnable
|
||||
response["redis_enabled"] = cfg.Cache.CacheRedisEnable
|
||||
response["ttl_seconds"] = cfg.Cache.CacheTTL
|
||||
response["max_memory_mb"] = cfg.Cache.CacheMaxMemorySize
|
||||
response["max_entries"] = cfg.Cache.CacheMaxEntries
|
||||
}
|
||||
|
||||
// Extract aggregated cache stats from combined stats
|
||||
if metrics.CombinedStats != nil {
|
||||
if cacheHits, ok := metrics.CombinedStats["cache_hits"]; ok {
|
||||
response["cache_hits"] = cacheHits
|
||||
}
|
||||
if cacheMisses, ok := metrics.CombinedStats["cache_misses"]; ok {
|
||||
response["cache_misses"] = cacheMisses
|
||||
}
|
||||
if cachedQueries, ok := metrics.CombinedStats["cached_queries"]; ok {
|
||||
response["cached_queries"] = cachedQueries
|
||||
}
|
||||
if hitRate, ok := metrics.CombinedStats["cache_hit_rate_pct"]; ok {
|
||||
response["hit_rate_pct"] = hitRate
|
||||
}
|
||||
if memoryMB, ok := metrics.CombinedStats["memory_usage_mb"]; ok {
|
||||
response["memory_usage_mb"] = memoryMB
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
}
|
||||
}
|
||||
|
||||
// Local instance stats (fallback or non-cluster mode)
|
||||
stats := map[string]interface{}{
|
||||
"enabled": false,
|
||||
"cluster_mode": false,
|
||||
"enabled": false,
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
@@ -582,8 +678,8 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
||||
ticker := time.NewTicker(2 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Send initial stats immediately
|
||||
if stats := ad.gatherAllStats(); stats != nil {
|
||||
// Send initial stats immediately (cluster-aware for dashboard)
|
||||
if stats := ad.gatherAllStatsClusterAware(); stats != nil {
|
||||
if data, err := json.Marshal(stats); err == nil {
|
||||
c.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
@@ -593,8 +689,8 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Gather all stats
|
||||
stats := ad.gatherAllStats()
|
||||
// Gather all stats (cluster-aware for dashboard)
|
||||
stats := ad.gatherAllStatsClusterAware()
|
||||
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(stats)
|
||||
@@ -627,8 +723,56 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
||||
}
|
||||
|
||||
// gatherAllStats collects all statistics into a single structure
|
||||
// This always returns LOCAL stats for this instance (used by metrics aggregator)
|
||||
func (ad *AdminDashboard) gatherAllStats() map[string]interface{} {
|
||||
return ad.gatherAllStatsWithMode(false)
|
||||
}
|
||||
|
||||
// gatherAllStatsClusterAware collects statistics with cluster awareness
|
||||
// If cluster mode is available, returns aggregated stats from all instances
|
||||
func (ad *AdminDashboard) gatherAllStatsClusterAware() map[string]interface{} {
|
||||
return ad.gatherAllStatsWithMode(true)
|
||||
}
|
||||
|
||||
// gatherAllStatsWithMode collects statistics with optional cluster mode
|
||||
func (ad *AdminDashboard) gatherAllStatsWithMode(useClusterMode bool) map[string]interface{} {
|
||||
// Check if cluster mode is requested and available
|
||||
if useClusterMode {
|
||||
if aggregator := GetMetricsAggregator(); aggregator != nil {
|
||||
metrics, err := aggregator.GetAggregatedMetrics()
|
||||
if err == nil && metrics != nil {
|
||||
// Return aggregated cluster stats
|
||||
result := map[string]interface{}{
|
||||
"cluster_mode": true,
|
||||
"total_instances": metrics.TotalInstances,
|
||||
"healthy_instances": metrics.HealthyInstances,
|
||||
}
|
||||
|
||||
// Build stats section from combined stats
|
||||
stats := map[string]interface{}{
|
||||
"timestamp": metrics.LastUpdate.Format(time.RFC3339),
|
||||
"version": "0.27.0",
|
||||
}
|
||||
|
||||
// Copy all combined stats
|
||||
if metrics.CombinedStats != nil {
|
||||
for k, v := range metrics.CombinedStats {
|
||||
stats[k] = v
|
||||
}
|
||||
}
|
||||
result["stats"] = stats
|
||||
|
||||
// Add per-instance details
|
||||
result["instances"] = metrics.Instances
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to local stats
|
||||
result := make(map[string]interface{})
|
||||
result["cluster_mode"] = false
|
||||
|
||||
// Main stats
|
||||
uptimeSeconds := time.Since(startTime).Seconds()
|
||||
@@ -732,9 +876,17 @@ func (ad *AdminDashboard) gatherAllStats() map[string]interface{} {
|
||||
if cb != nil {
|
||||
cbMutex.RLock()
|
||||
state := cb.State()
|
||||
counts := cb.Counts()
|
||||
cbMutex.RUnlock()
|
||||
|
||||
cbStatus["state"] = state.String()
|
||||
cbStatus["counts"] = map[string]interface{}{
|
||||
"requests": counts.Requests,
|
||||
"total_successes": counts.TotalSuccesses,
|
||||
"total_failures": counts.TotalFailures,
|
||||
"consecutive_successes": counts.ConsecutiveSuccesses,
|
||||
"consecutive_failures": counts.ConsecutiveFailures,
|
||||
}
|
||||
cbStatus["config"] = map[string]interface{}{
|
||||
"max_failures": cfg.CircuitBreaker.MaxFailures,
|
||||
"failure_ratio": cfg.CircuitBreaker.FailureRatio,
|
||||
@@ -771,16 +923,24 @@ func (ad *AdminDashboard) gatherAllStats() map[string]interface{} {
|
||||
}
|
||||
cacheStats["hit_rate_pct"] = hitRate
|
||||
|
||||
memoryUsage := libpack_cache.GetCacheMemoryUsage()
|
||||
maxMemory := libpack_cache.GetCacheMaxMemorySize()
|
||||
cacheStats["memory_usage_bytes"] = memoryUsage
|
||||
cacheStats["memory_usage_mb"] = float64(memoryUsage) / (1024 * 1024)
|
||||
// Only get memory usage for in-memory cache (not Redis)
|
||||
if cfg.Cache.CacheEnable && !cfg.Cache.CacheRedisEnable {
|
||||
memoryUsage := libpack_cache.GetCacheMemoryUsage()
|
||||
maxMemory := libpack_cache.GetCacheMaxMemorySize()
|
||||
cacheStats["memory_usage_bytes"] = memoryUsage
|
||||
cacheStats["memory_usage_mb"] = float64(memoryUsage) / (1024 * 1024)
|
||||
|
||||
memoryUsagePct := 0.0
|
||||
if maxMemory > 0 {
|
||||
memoryUsagePct = float64(memoryUsage) / float64(maxMemory) * 100
|
||||
memoryUsagePct := 0.0
|
||||
if maxMemory > 0 {
|
||||
memoryUsagePct = float64(memoryUsage) / float64(maxMemory) * 100
|
||||
}
|
||||
cacheStats["memory_usage_pct"] = memoryUsagePct
|
||||
} else {
|
||||
// For Redis cache, memory tracking is not available per instance
|
||||
cacheStats["memory_usage_bytes"] = int64(-1)
|
||||
cacheStats["memory_usage_mb"] = float64(-1)
|
||||
cacheStats["memory_usage_pct"] = float64(-1)
|
||||
}
|
||||
cacheStats["memory_usage_pct"] = memoryUsagePct
|
||||
}
|
||||
}
|
||||
result["cache"] = cacheStats
|
||||
|
||||
@@ -215,11 +215,13 @@ func TestAdminDashboard_GetCacheStats(t *testing.T) {
|
||||
CacheMaxMemorySize int
|
||||
CacheMaxEntries int
|
||||
GraphQLQueryCacheSize int
|
||||
PerUserCacheDisabled bool
|
||||
}{
|
||||
CacheEnable: true,
|
||||
CacheTTL: 60,
|
||||
CacheMaxMemorySize: 100,
|
||||
CacheMaxEntries: 10000,
|
||||
CacheEnable: true,
|
||||
CacheTTL: 60,
|
||||
CacheMaxMemorySize: 100,
|
||||
CacheMaxEntries: 10000,
|
||||
PerUserCacheDisabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -170,9 +170,11 @@ func apiCircuitBreakerHealth(c *fiber.Ctx) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Get circuit breaker state
|
||||
// Get circuit breaker state with proper mutex protection
|
||||
cbMutex.RLock()
|
||||
state := cb.State()
|
||||
counts := cb.Counts()
|
||||
cbMutex.RUnlock()
|
||||
|
||||
// Determine health status
|
||||
var status string
|
||||
|
||||
Vendored
+42
-4
@@ -3,6 +3,7 @@ package libpack_cache
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -27,7 +28,9 @@ type CacheConfig struct {
|
||||
MaxMemorySize int64 `json:"max_memory_size"` // Maximum memory size in bytes
|
||||
MaxEntries int64 `json:"max_entries"` // Maximum number of entries
|
||||
}
|
||||
TTL int `json:"ttl"`
|
||||
TTL int `json:"ttl"`
|
||||
IncludeUserContext bool `json:"include_user_context"` // Include user ID and role in cache key
|
||||
PerUserCacheDisabled bool `json:"per_user_cache_disabled"` // Disable per-user caching (backward compatibility)
|
||||
}
|
||||
|
||||
type CacheStats struct {
|
||||
@@ -52,10 +55,14 @@ var (
|
||||
config *CacheConfig
|
||||
)
|
||||
|
||||
// CalculateHash generates an MD5 hash from the request body.
|
||||
// CalculateHash generates an MD5 hash from the request body and optionally user context.
|
||||
// For GraphQL requests, this includes both the query and variables,
|
||||
// ensuring that identical queries with different variables are cached separately.
|
||||
//
|
||||
// SECURITY FIX: This function now includes user ID and role in the cache key by default
|
||||
// to prevent data leakage between authenticated users. Set CACHE_PER_USER_DISABLED=true
|
||||
// to revert to the old behavior (NOT RECOMMENDED for multi-user applications).
|
||||
//
|
||||
// Example GraphQL request body:
|
||||
//
|
||||
// {
|
||||
@@ -63,8 +70,39 @@ var (
|
||||
// "variables": { "id": "123" }
|
||||
// }
|
||||
//
|
||||
// Different variable values will produce different cache keys.
|
||||
func CalculateHash(c *fiber.Ctx) string {
|
||||
// With user context enabled (default):
|
||||
// - Same query, same variables, same user → same cache key
|
||||
// - Same query, same variables, different user → different cache key
|
||||
//
|
||||
// Different variable values will always produce different cache keys.
|
||||
func CalculateHash(c *fiber.Ctx, userID string, userRole string) string {
|
||||
cacheKeyData := string(c.Body())
|
||||
|
||||
// Include user context in cache key (default behavior for security)
|
||||
// Only skip if explicitly disabled via configuration (backward compatibility)
|
||||
if config != nil && !config.PerUserCacheDisabled {
|
||||
// Normalize empty user values to prevent cache key collisions
|
||||
if userID == "" {
|
||||
userID = "-"
|
||||
}
|
||||
if userRole == "" {
|
||||
userRole = "-"
|
||||
}
|
||||
|
||||
// Append user context to ensure cache isolation between users
|
||||
cacheKeyData = fmt.Sprintf("%s|uid:%s|role:%s", cacheKeyData, userID, userRole)
|
||||
}
|
||||
|
||||
return strutil.Md5(cacheKeyData)
|
||||
}
|
||||
|
||||
// CalculateHashLegacy generates a cache hash using only the request body (DEPRECATED).
|
||||
// This function exists for backward compatibility only and should NOT be used
|
||||
// in production multi-user applications as it creates a security vulnerability
|
||||
// where users can see each other's cached data.
|
||||
//
|
||||
// Deprecated: Use CalculateHash with user context instead.
|
||||
func CalculateHashLegacy(c *fiber.Ctx) string {
|
||||
return strutil.Md5(c.Body())
|
||||
}
|
||||
|
||||
|
||||
Vendored
+78
-16
@@ -20,7 +20,7 @@ func (suite *Tests) Test_CalculateHash() {
|
||||
// Test with empty body
|
||||
suite.Run("empty body", func() {
|
||||
ctx.Request().SetBody([]byte(""))
|
||||
hash := CalculateHash(ctx)
|
||||
hash := CalculateHash(ctx, "user1", "admin")
|
||||
assert.NotEmpty(hash)
|
||||
assert.Equal(32, len(hash)) // MD5 hash is 32 characters
|
||||
})
|
||||
@@ -28,7 +28,7 @@ func (suite *Tests) Test_CalculateHash() {
|
||||
// Test with non-empty body
|
||||
suite.Run("non-empty body", func() {
|
||||
ctx.Request().SetBody([]byte("test body"))
|
||||
hash := CalculateHash(ctx)
|
||||
hash := CalculateHash(ctx, "user1", "admin")
|
||||
assert.NotEmpty(hash)
|
||||
assert.Equal(32, len(hash))
|
||||
})
|
||||
@@ -36,10 +36,10 @@ func (suite *Tests) Test_CalculateHash() {
|
||||
// Test with different bodies produce different hashes
|
||||
suite.Run("different bodies", func() {
|
||||
ctx.Request().SetBody([]byte("body1"))
|
||||
hash1 := CalculateHash(ctx)
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
ctx.Request().SetBody([]byte("body2"))
|
||||
hash2 := CalculateHash(ctx)
|
||||
hash2 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
assert.NotEqual(hash1, hash2)
|
||||
})
|
||||
@@ -51,10 +51,10 @@ func (suite *Tests) Test_CalculateHash() {
|
||||
query2 := []byte(`{"query":"query GetUser($id: ID!) { user(id: $id) { name } }","variables":{"id":"456"}}`)
|
||||
|
||||
ctx.Request().SetBody(query1)
|
||||
hash1 := CalculateHash(ctx)
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
ctx.Request().SetBody(query2)
|
||||
hash2 := CalculateHash(ctx)
|
||||
hash2 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
assert.NotEqual(hash1, hash2, "Different variables should produce different cache keys")
|
||||
})
|
||||
@@ -66,13 +66,83 @@ func (suite *Tests) Test_CalculateHash() {
|
||||
query2 := []byte(`{"query":"query GetUsers { users { name } }","variables":{}}`)
|
||||
|
||||
ctx.Request().SetBody(query1)
|
||||
hash1 := CalculateHash(ctx)
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
ctx.Request().SetBody(query2)
|
||||
hash2 := CalculateHash(ctx)
|
||||
hash2 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
assert.NotEqual(hash1, hash2, "Query with and without variables object should produce different cache keys")
|
||||
})
|
||||
|
||||
// SECURITY TEST: Different users should get different cache keys
|
||||
suite.Run("different users produce different cache keys", func() {
|
||||
// Same query, same variables, but different users - CRITICAL SECURITY TEST
|
||||
query := []byte(`{"query":"query GetMyProfile { me { id email } }"}`)
|
||||
ctx.Request().SetBody(query)
|
||||
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
hash2 := CalculateHash(ctx, "user2", "user")
|
||||
|
||||
assert.NotEqual(hash1, hash2, "Different users MUST produce different cache keys to prevent data leakage")
|
||||
})
|
||||
|
||||
// SECURITY TEST: Same user should get same cache key
|
||||
suite.Run("same user produces same cache key", func() {
|
||||
// Same query, same user
|
||||
query := []byte(`{"query":"query GetMyProfile { me { id email } }"}`)
|
||||
ctx.Request().SetBody(query)
|
||||
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
hash2 := CalculateHash(ctx, "user1", "admin")
|
||||
|
||||
assert.Equal(hash1, hash2, "Same user should get same cache key for cache effectiveness")
|
||||
})
|
||||
|
||||
// SECURITY TEST: Different roles should get different cache keys
|
||||
suite.Run("different roles produce different cache keys", func() {
|
||||
// Same query, same user ID, but different roles
|
||||
query := []byte(`{"query":"query GetData { data { value } }"}`)
|
||||
ctx.Request().SetBody(query)
|
||||
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
hash2 := CalculateHash(ctx, "user1", "user")
|
||||
|
||||
assert.NotEqual(hash1, hash2, "Different roles MUST produce different cache keys to prevent privilege escalation")
|
||||
})
|
||||
|
||||
// SECURITY TEST: Empty user context should be normalized
|
||||
suite.Run("empty user context is normalized", func() {
|
||||
query := []byte(`{"query":"query GetPublic { public { data } }"}`)
|
||||
ctx.Request().SetBody(query)
|
||||
|
||||
// Empty strings should be normalized to "-"
|
||||
hash1 := CalculateHash(ctx, "", "")
|
||||
hash2 := CalculateHash(ctx, "-", "-")
|
||||
|
||||
assert.Equal(hash1, hash2, "Empty user context should be normalized to prevent cache key collisions")
|
||||
})
|
||||
|
||||
// BACKWARD COMPATIBILITY TEST: Legacy mode without user context
|
||||
suite.Run("legacy mode without user context", func() {
|
||||
// Setup config with per-user cache disabled
|
||||
oldConfig := config
|
||||
config = &CacheConfig{
|
||||
Logger: libpack_logger.New(),
|
||||
Client: libpack_cache_memory.New(5 * time.Minute),
|
||||
TTL: 60,
|
||||
PerUserCacheDisabled: true, // Disable per-user caching
|
||||
}
|
||||
defer func() { config = oldConfig }()
|
||||
|
||||
query := []byte(`{"query":"query GetData { data { value } }"}`)
|
||||
ctx.Request().SetBody(query)
|
||||
|
||||
// In legacy mode, different users should get the SAME cache key (backward compatibility)
|
||||
hash1 := CalculateHash(ctx, "user1", "admin")
|
||||
hash2 := CalculateHash(ctx, "user2", "user")
|
||||
|
||||
assert.Equal(hash1, hash2, "With per-user cache disabled, all users get same cache key (backward compatibility)")
|
||||
})
|
||||
}
|
||||
|
||||
func (suite *Tests) Test_CacheDelete() {
|
||||
@@ -112,8 +182,6 @@ func (suite *Tests) Test_CacheDelete() {
|
||||
suite.Run("uninitialized cache", func() {
|
||||
// Save current config
|
||||
oldConfig := config
|
||||
|
||||
// Set config to nil
|
||||
config = nil
|
||||
|
||||
// This should not cause any errors
|
||||
@@ -156,8 +224,6 @@ func (suite *Tests) Test_CacheStoreWithTTL() {
|
||||
suite.Run("uninitialized cache", func() {
|
||||
// Save current config
|
||||
oldConfig := config
|
||||
|
||||
// Set config to nil
|
||||
config = nil
|
||||
|
||||
// This should not cause any errors
|
||||
@@ -194,8 +260,6 @@ func (suite *Tests) Test_CacheGetQueries() {
|
||||
suite.Run("uninitialized cache", func() {
|
||||
// Save current config
|
||||
oldConfig := config
|
||||
|
||||
// Set config to nil
|
||||
config = nil
|
||||
|
||||
// This should return 0
|
||||
@@ -280,8 +344,6 @@ func (suite *Tests) Test_GetCacheStats() {
|
||||
suite.Run("uninitialized cache", func() {
|
||||
// Save current config
|
||||
oldConfig := config
|
||||
|
||||
// Set config to nil
|
||||
config = nil
|
||||
|
||||
// This should return empty stats
|
||||
|
||||
@@ -33,8 +33,9 @@ func (suite *CircuitBreakerTestSuite) TestCircuitBreakerCacheFallback() {
|
||||
ctx := app.AcquireCtx(requestCtx)
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Calculate the cache key that would be used
|
||||
cacheKey := libpack_cache.CalculateHash(ctx)
|
||||
// Calculate the cache key that would be used (with default user context since no auth headers)
|
||||
// extractUserInfo() returns ("-", "-") when no auth is present
|
||||
cacheKey := libpack_cache.CalculateHash(ctx, "-", "-")
|
||||
|
||||
// Add a test response to the cache
|
||||
cachedResponse := []byte(`{"data":{"test":"cached-response"}}`)
|
||||
@@ -158,8 +159,9 @@ func (suite *CircuitBreakerTestSuite) TestCacheDisabledFallback() {
|
||||
ctx := app.AcquireCtx(requestCtx)
|
||||
defer app.ReleaseCtx(ctx)
|
||||
|
||||
// Calculate cache key and store a response
|
||||
cacheKey := libpack_cache.CalculateHash(ctx)
|
||||
// Calculate cache key and store a response (with default user context since no auth headers)
|
||||
// extractUserInfo() returns ("-", "-") when no auth is present
|
||||
cacheKey := libpack_cache.CalculateHash(ctx, "-", "-")
|
||||
cachedResponse := []byte(`{"data":{"test":"cached-response"}}`)
|
||||
libpack_cache.CacheStore(cacheKey, cachedResponse)
|
||||
|
||||
|
||||
@@ -23,18 +23,14 @@ func NewCircuitBreakerMetrics(monitoring *libpack_monitoring.MetricsSetup) *Circ
|
||||
// Initialize state value
|
||||
cbm.stateValue.Store(float64(0))
|
||||
|
||||
// Create gauge with callback that reads the atomic value
|
||||
cbm.stateGauge = monitoring.RegisterMetricsGauge(
|
||||
// Create gauge with callback that reads the atomic value on every scrape
|
||||
// This ensures the metric always reflects the current circuit breaker state
|
||||
cbm.stateGauge = monitoring.RegisterMetricsGaugeFunc(
|
||||
libpack_monitoring.MetricsCircuitState,
|
||||
nil,
|
||||
0, // Initial value doesn't matter as callback will be used
|
||||
)
|
||||
|
||||
// Override the gauge callback to read from atomic value
|
||||
cbm.stateGauge = monitoring.RegisterMetricsGauge(
|
||||
libpack_monitoring.MetricsCircuitState,
|
||||
nil,
|
||||
cbm.GetState(),
|
||||
func() float64 {
|
||||
return cbm.GetState()
|
||||
},
|
||||
)
|
||||
|
||||
return cbm
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
fiber "github.com/gofiber/fiber/v2"
|
||||
"github.com/graphql-go/graphql/language/ast"
|
||||
"github.com/graphql-go/graphql/language/parser"
|
||||
"github.com/graphql-go/graphql/language/source"
|
||||
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||
)
|
||||
|
||||
// debugParseGraphQLQuery provides detailed logging for mutation routing analysis
|
||||
// This is automatically called when LOG_LEVEL=DEBUG to help identify routing issues
|
||||
//
|
||||
// It logs:
|
||||
// - GraphQL query structure (operations, selections, directives)
|
||||
// - Final routing decision (which endpoint was chosen)
|
||||
// - Automatic detection of mutations routed to wrong endpoints
|
||||
//
|
||||
// To enable: Set LOG_LEVEL=DEBUG and restart the proxy
|
||||
func debugParseGraphQLQuery(c *fiber.Ctx, query string) {
|
||||
if cfg == nil || cfg.Logger == nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: "=== DEBUG: Parsing GraphQL Query ===",
|
||||
Pairs: map[string]interface{}{
|
||||
"query_length": len(query),
|
||||
"query_preview": truncateString(query, 100),
|
||||
},
|
||||
})
|
||||
|
||||
// Parse the query
|
||||
src := source.NewSource(&source.Source{
|
||||
Body: []byte(query),
|
||||
Name: "Debug GraphQL request",
|
||||
})
|
||||
|
||||
p, err := parser.Parse(parser.ParseParams{Source: src})
|
||||
if err != nil {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "DEBUG: Failed to parse query",
|
||||
Pairs: map[string]interface{}{"error": err.Error()},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: "DEBUG: Query parsed successfully",
|
||||
Pairs: map[string]interface{}{
|
||||
"definitions_count": len(p.Definitions),
|
||||
},
|
||||
})
|
||||
|
||||
// Analyze each definition
|
||||
for i, d := range p.Definitions {
|
||||
if oper, ok := d.(*ast.OperationDefinition); ok {
|
||||
operationType := strings.ToLower(oper.Operation)
|
||||
operationName := "unnamed"
|
||||
if oper.Name != nil {
|
||||
operationName = oper.Name.Value
|
||||
}
|
||||
|
||||
// Count selections
|
||||
selectionCount := 0
|
||||
if oper.SelectionSet != nil {
|
||||
selectionCount = len(oper.GetSelectionSet().Selections)
|
||||
}
|
||||
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: fmt.Sprintf("DEBUG: Definition #%d (OperationDefinition)", i),
|
||||
Pairs: map[string]interface{}{
|
||||
"operation_type": operationType,
|
||||
"operation_name": operationName,
|
||||
"selection_count": selectionCount,
|
||||
"is_mutation": operationType == "mutation",
|
||||
"directive_count": len(oper.Directives),
|
||||
},
|
||||
})
|
||||
|
||||
// Log selections for mutations
|
||||
if operationType == "mutation" && oper.SelectionSet != nil {
|
||||
for j, sel := range oper.GetSelectionSet().Selections {
|
||||
if field, ok := sel.(*ast.Field); ok {
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: fmt.Sprintf("DEBUG: Mutation field #%d", j),
|
||||
Pairs: map[string]interface{}{
|
||||
"field_name": field.Name.Value,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if frag, ok := d.(*ast.FragmentDefinition); ok {
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: fmt.Sprintf("DEBUG: Definition #%d (FragmentDefinition)", i),
|
||||
Pairs: map[string]interface{}{
|
||||
"fragment_name": frag.Name.Value,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Now run the actual parsing to see the result
|
||||
result := parseGraphQLQuery(c)
|
||||
|
||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||
Message: "DEBUG: Final routing decision",
|
||||
Pairs: map[string]interface{}{
|
||||
"operation_type": result.operationType,
|
||||
"operation_name": result.operationName,
|
||||
"active_endpoint": result.activeEndpoint,
|
||||
"should_block": result.shouldBlock,
|
||||
"should_ignore": result.shouldIgnore,
|
||||
"write_endpoint": cfg.Server.HostGraphQL,
|
||||
"read_endpoint": cfg.Server.HostGraphQLReadOnly,
|
||||
"is_using_write": result.activeEndpoint == cfg.Server.HostGraphQL,
|
||||
},
|
||||
})
|
||||
|
||||
// Check for potential issues
|
||||
if result.operationType == "mutation" && result.activeEndpoint != cfg.Server.HostGraphQL {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "DEBUG: ⚠️ BUG DETECTED: Mutation routed to wrong endpoint!",
|
||||
Pairs: map[string]interface{}{
|
||||
"expected_endpoint": cfg.Server.HostGraphQL,
|
||||
"actual_endpoint": result.activeEndpoint,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if result.operationType == "mutation" && strings.Contains(strings.ToLower(result.activeEndpoint), "read") {
|
||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "DEBUG: ⚠️ CRITICAL: Mutation endpoint contains 'read' in URL!",
|
||||
Pairs: map[string]interface{}{
|
||||
"endpoint": result.activeEndpoint,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,13 @@ const (
|
||||
)
|
||||
|
||||
// Use parameterized queries to prevent SQL injection
|
||||
// Cast $1 to interval type to allow proper parameterized interval values
|
||||
var delQueries = [...]string{
|
||||
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - INTERVAL $1",
|
||||
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - INTERVAL $1",
|
||||
"DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - INTERVAL $1",
|
||||
"DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - INTERVAL $1",
|
||||
"DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - INTERVAL $1",
|
||||
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - $1::INTERVAL",
|
||||
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - $1::INTERVAL",
|
||||
"DELETE FROM hdb_catalog.hdb_action_log WHERE created_at < NOW() - $1::INTERVAL",
|
||||
"DELETE FROM hdb_catalog.hdb_cron_event_invocation_logs WHERE created_at < NOW() - $1::INTERVAL",
|
||||
"DELETE FROM hdb_catalog.hdb_scheduled_event_invocation_logs WHERE created_at < NOW() - $1::INTERVAL",
|
||||
}
|
||||
|
||||
func enableHasuraEventCleaner(ctx context.Context) error {
|
||||
|
||||
@@ -340,8 +340,8 @@ func getDelQueries() []string {
|
||||
// This should return the actual delQueries from the main package
|
||||
// For testing purposes, we return expected parameterized queries
|
||||
return []string{
|
||||
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - INTERVAL '$1 days'",
|
||||
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - INTERVAL '$1 days'",
|
||||
"DELETE FROM hdb_catalog.event_log WHERE created_at < NOW() - $1::INTERVAL",
|
||||
"DELETE FROM hdb_catalog.event_invocation_logs WHERE created_at < NOW() - $1::INTERVAL",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,18 +9,18 @@ require (
|
||||
github.com/alicebob/miniredis/v2 v2.33.0
|
||||
github.com/avast/retry-go/v4 v4.7.0
|
||||
github.com/goccy/go-json v0.10.5
|
||||
github.com/gofiber/fiber/v2 v2.52.9
|
||||
github.com/gofiber/fiber/v2 v2.52.10
|
||||
github.com/gofiber/websocket/v2 v2.2.1
|
||||
github.com/gofrs/flock v0.13.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gookit/goutil v0.7.1
|
||||
github.com/gookit/goutil v0.7.2
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/graphql-go/graphql v0.8.1
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
github.com/lukaszraczylo/ask v0.0.0-20240916204100-6e9ef53a62d9
|
||||
github.com/lukaszraczylo/go-ratecounter v0.1.12
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.84
|
||||
github.com/redis/go-redis/v9 v9.16.0
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.89
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/sony/gobreaker v1.0.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/valyala/fasthttp v1.68.0
|
||||
@@ -28,7 +28,7 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
google.golang.org/grpc v1.76.0
|
||||
google.golang.org/grpc v1.77.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -61,14 +61,14 @@ require (
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
golang.org/x/crypto v0.43.0 // indirect
|
||||
golang.org/x/net v0.46.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/term v0.36.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||
golang.org/x/crypto v0.45.0 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/term v0.37.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251124214823-79d6a2a48846 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
@@ -36,8 +36,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-reflect v1.2.0 h1:O0T8rZCuNmGXewnATuKYnkL0xm6o8UNOJZd/gOkb9ms=
|
||||
github.com/goccy/go-reflect v1.2.0/go.mod h1:n0oYZn8VcV2CkWTxi8B9QjkCoq6GTtCEdfmR66YhFtE=
|
||||
github.com/gofiber/fiber/v2 v2.52.9 h1:YjKl5DOiyP3j0mO61u3NTmK7or8GzzWzCFzkboyP5cw=
|
||||
github.com/gofiber/fiber/v2 v2.52.9/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/gofiber/fiber/v2 v2.52.10 h1:jRHROi2BuNti6NYXmZ6gbNSfT3zj/8c0xy94GOU5elY=
|
||||
github.com/gofiber/fiber/v2 v2.52.10/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw=
|
||||
github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w=
|
||||
github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU=
|
||||
github.com/gofrs/flock v0.13.0 h1:95JolYOvGMqeH31+FC7D2+uULf6mG61mEZ/A8dRYMzw=
|
||||
@@ -48,8 +48,8 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gookit/goutil v0.7.1 h1:AaFJPN9mrdeYBv8HOybri26EHGCC34WJVT7jUStGJsI=
|
||||
github.com/gookit/goutil v0.7.1/go.mod h1:vJS9HXctYTCLtCsZot5L5xF+O1oR17cDYO9R0HxBmnU=
|
||||
github.com/gookit/goutil v0.7.2 h1:NSiqWWY+BT0MwIlKDeSVPfQmr9xTkkAqwDjhplobdgo=
|
||||
github.com/gookit/goutil v0.7.2/go.mod h1:vJS9HXctYTCLtCsZot5L5xF+O1oR17cDYO9R0HxBmnU=
|
||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/graphql-go/graphql v0.8.1 h1:p7/Ou/WpmulocJeEx7wjQy611rtXGQaAcXGqanuMMgc=
|
||||
@@ -74,8 +74,8 @@ github.com/lukaszraczylo/ask v0.0.0-20240916204100-6e9ef53a62d9 h1:pL8B9mjv6RPUf
|
||||
github.com/lukaszraczylo/ask v0.0.0-20240916204100-6e9ef53a62d9/go.mod h1:M+UVdyqZs++xtEPrascaVmZdOMhCnxjZ2SgH+xHpR0c=
|
||||
github.com/lukaszraczylo/go-ratecounter v0.1.12 h1:VO6hHYGw/Jy9JUizXf/bS0AI2QX1ueWWAWckMFVJ/w4=
|
||||
github.com/lukaszraczylo/go-ratecounter v0.1.12/go.mod h1:TqXEOCtFJStk1i0tkipprv1kiDHGon1MVUisjSTBSKM=
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.84 h1:yP00k8XSYKFYo6PmZFOsDblexLOG6WZzVWhzdstrxiw=
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.84/go.mod h1:PxQYblQDZISmYYj8sNfazAWxAOh1rhAtU208y+uPV8s=
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.89 h1:Xbu1Ny+a0lT2Sr2SaSC8mcHmGQDwGD4TJKk4DDd+PwA=
|
||||
github.com/lukaszraczylo/go-simple-graphql v1.2.89/go.mod h1:PxQYblQDZISmYYj8sNfazAWxAOh1rhAtU208y+uPV8s=
|
||||
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
|
||||
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
@@ -84,8 +84,8 @@ github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byF
|
||||
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4=
|
||||
github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/savsgio/gotils v0.0.0-20250924091648-bce9a52d7761 h1:McifyVxygw1d67y6vxUqls2D46J8W9nrki9c8c0eVvE=
|
||||
@@ -129,27 +129,27 @@ go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjce
|
||||
go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251103181224-f26f9409b101 h1:vk5TfqZHNn0obhPIYeS+cxIFKFQgser/M2jnI+9c6MM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251103181224-f26f9409b101/go.mod h1:E17fc4PDhkr22dE3RgnH2hEubUaky6ZwW4VhANxyspg=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
|
||||
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251124214823-79d6a2a48846 h1:ZdyUkS9po3H7G0tuh955QVyyotWvOD4W0aEapeGeUYk=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251124214823-79d6a2a48846/go.mod h1:Fk4kyraUvqD7i5H6S43sj2W98fbZa75lpZz/eUyhfO0=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846 h1:Wgl1rcDNThT+Zn47YyCXOXyX/COgMTIdhJ717F0l4xk=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251124214823-79d6a2a48846/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.77.0 h1:wVVY6/8cGA6vvffn+wWK5ToddbgdU3d8MNENr4evgXM=
|
||||
google.golang.org/grpc v1.77.0/go.mod h1:z0BY1iVj0q8E1uSQCjL9cppRj+gnZjzDnzV0dHhrNig=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
+52
-13
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
fiber "github.com/gofiber/fiber/v2"
|
||||
@@ -37,6 +38,40 @@ var (
|
||||
currentCacheSize int64 // Use atomic operations for this
|
||||
)
|
||||
|
||||
// sanitizeOperationName removes null bytes and other invalid characters from operation names
|
||||
// This prevents panics when creating metrics with invalid label values
|
||||
func sanitizeOperationName(name string) string {
|
||||
if name == "" || name == "undefined" {
|
||||
return name
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(name))
|
||||
|
||||
for _, r := range name {
|
||||
// Skip null bytes entirely
|
||||
if r == '\x00' {
|
||||
continue
|
||||
}
|
||||
// Replace control characters with underscores
|
||||
if r < 32 || r == 127 {
|
||||
buf.WriteByte('_')
|
||||
continue
|
||||
}
|
||||
// Only allow printable characters
|
||||
if unicode.IsPrint(r) {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
result := buf.String()
|
||||
// Return "undefined" if we ended up with an empty string after sanitization
|
||||
if result == "" {
|
||||
return "undefined"
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func prepareQueriesAndExemptions() {
|
||||
introspectionAllowedQueries = make(map[string]struct{})
|
||||
allowedUrls = make(map[string]struct{})
|
||||
@@ -298,8 +333,8 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
res.operationType = "mutation"
|
||||
if oper.Name != nil {
|
||||
mutationName = oper.Name.Value
|
||||
// Use mutation name immediately
|
||||
res.operationName = mutationName
|
||||
// Use mutation name immediately, sanitized to prevent metric panics
|
||||
res.operationName = sanitizeOperationName(mutationName)
|
||||
}
|
||||
break // Found a mutation, no need to continue first pass
|
||||
}
|
||||
@@ -316,7 +351,7 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
// We already set operation type to mutation in first pass
|
||||
// Only set name if we didn't find a mutation name earlier
|
||||
if res.operationName == "undefined" && oper.Name != nil {
|
||||
res.operationName = oper.Name.Value
|
||||
res.operationName = sanitizeOperationName(oper.Name.Value)
|
||||
}
|
||||
} else {
|
||||
// No mutation found, use the normal logic
|
||||
@@ -325,18 +360,10 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
}
|
||||
|
||||
if res.operationName == "undefined" && oper.Name != nil {
|
||||
res.operationName = oper.Name.Value
|
||||
res.operationName = sanitizeOperationName(oper.Name.Value)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle endpoint routing - always use write endpoint for mutations
|
||||
if res.operationType == "mutation" {
|
||||
res.activeEndpoint = cfg.Server.HostGraphQL
|
||||
} else if cfg.Server.HostGraphQLReadOnly != "" {
|
||||
// Use read-only endpoint for non-mutation operations
|
||||
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
|
||||
}
|
||||
|
||||
// Block mutations in read-only mode
|
||||
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
|
||||
if ifNotInTest() {
|
||||
@@ -359,13 +386,25 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle endpoint routing AFTER processing all definitions
|
||||
// This ensures mutations are always routed to the write endpoint
|
||||
if res.operationType == "mutation" {
|
||||
res.activeEndpoint = cfg.Server.HostGraphQL
|
||||
} else if cfg.Server.HostGraphQLReadOnly != "" {
|
||||
// Use read-only endpoint for non-mutation operations
|
||||
res.activeEndpoint = cfg.Server.HostGraphQLReadOnly
|
||||
}
|
||||
|
||||
// Track parsing time
|
||||
if ifNotInTest() && cfg.Monitoring != nil {
|
||||
parseTime := float64(time.Since(startTime).Milliseconds())
|
||||
cfg.Monitoring.IncrementFloat(libpack_monitoring.MetricsGraphQLParsingTime, nil, parseTime)
|
||||
}
|
||||
|
||||
return res
|
||||
// Create a copy to return, since the original will be returned to the pool
|
||||
// This prevents race conditions where concurrent requests could modify the same result
|
||||
result := *res
|
||||
return &result
|
||||
}
|
||||
|
||||
// processDirectives extracts caching directives from the operation
|
||||
|
||||
+290
-2
@@ -5,10 +5,11 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gookit/goutil/strutil"
|
||||
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
|
||||
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||
"github.com/sony/gobreaker"
|
||||
@@ -115,7 +116,8 @@ func (suite *Tests) TestCachingAndCircuitBreakerInteraction() {
|
||||
suite.Equal(responseBody, firstResponseBody, "Response body should match server response")
|
||||
|
||||
// Calculate hash the same way the system does, before releasing context
|
||||
cacheKey := strutil.Md5(ctx.Body())
|
||||
// Use default user context ("-", "-") since no auth headers are set in this test
|
||||
cacheKey := libpack_cache.CalculateHash(ctx, "-", "-")
|
||||
|
||||
// Store in cache directly for test
|
||||
libpack_cache.CacheStore(cacheKey, []byte(responseBody))
|
||||
@@ -495,3 +497,289 @@ func getMetricValue(metricName string) int {
|
||||
}
|
||||
return int(counter.Get())
|
||||
}
|
||||
|
||||
// TestRequestCoalescingIntegration tests that request coalescing works end-to-end
|
||||
// through the proxy layer, ensuring concurrent identical requests result in only
|
||||
// one backend call while all clients receive the correct response.
|
||||
func (suite *Tests) TestRequestCoalescingIntegration() {
|
||||
// Save original config
|
||||
originalCoalescing := cfg.RequestCoalescing
|
||||
originalClient := cfg.Client.FastProxyClient
|
||||
originalHostGraphQL := cfg.Server.HostGraphQL
|
||||
|
||||
// Restore after test
|
||||
defer func() {
|
||||
cfg.RequestCoalescing = originalCoalescing
|
||||
cfg.Client.FastProxyClient = originalClient
|
||||
cfg.Server.HostGraphQL = originalHostGraphQL
|
||||
}()
|
||||
|
||||
// Track backend calls
|
||||
var backendCallCount atomic.Int32
|
||||
var requestDelay = 100 * time.Millisecond
|
||||
|
||||
// Create test server that counts requests and introduces delay
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCallCount.Add(1)
|
||||
time.Sleep(requestDelay) // Delay to allow concurrent requests to coalesce
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"data":{"users":[{"id":"1","name":"Test User"}]}}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// Configure for test
|
||||
cfg.Server.HostGraphQL = server.URL
|
||||
cfg.Client.ClientTimeout = 5
|
||||
cfg.Client.FastProxyClient = createFasthttpClient(cfg)
|
||||
cfg.RequestCoalescing.Enable = true
|
||||
|
||||
// Initialize request coalescer for this test
|
||||
// Reset the global coalescer by creating a new one
|
||||
testCoalescer := NewRequestCoalescer(true, cfg.Logger, cfg.Monitoring)
|
||||
|
||||
// Temporarily replace global coalescer
|
||||
originalCoalescer := requestCoalescer
|
||||
requestCoalescer = testCoalescer
|
||||
defer func() {
|
||||
requestCoalescer = originalCoalescer
|
||||
}()
|
||||
|
||||
// Test Case 1: Concurrent identical requests should coalesce
|
||||
suite.Run("concurrent_identical_requests_coalesce", func() {
|
||||
backendCallCount.Store(0)
|
||||
testCoalescer.Reset()
|
||||
|
||||
concurrentRequests := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentRequests)
|
||||
|
||||
responses := make([]string, concurrentRequests)
|
||||
errors := make([]error, concurrentRequests)
|
||||
|
||||
// Launch concurrent requests with identical query
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
reqCtx := &fasthttp.RequestCtx{}
|
||||
reqCtx.Request.SetRequestURI("/graphql")
|
||||
reqCtx.Request.Header.SetMethod("POST")
|
||||
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
reqCtx.Request.SetBody([]byte(`{"query": "query { users { id name } }"}`))
|
||||
|
||||
ctx := suite.app.AcquireCtx(reqCtx)
|
||||
err := proxyTheRequest(ctx, cfg.Server.HostGraphQL)
|
||||
errors[index] = err
|
||||
responses[index] = string(ctx.Response().Body())
|
||||
suite.app.ReleaseCtx(ctx)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify only 1 backend call was made
|
||||
suite.Equal(int32(1), backendCallCount.Load(),
|
||||
"Should make only 1 backend call for %d concurrent identical requests", concurrentRequests)
|
||||
|
||||
// Verify all requests succeeded with same response
|
||||
expectedResponse := `{"data":{"users":[{"id":"1","name":"Test User"}]}}`
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
suite.Nil(errors[i], "Request %d should succeed", i)
|
||||
suite.Equal(expectedResponse, responses[i],
|
||||
"Request %d should have correct response", i)
|
||||
}
|
||||
|
||||
// Verify coalescing stats
|
||||
stats := testCoalescer.GetStats()
|
||||
suite.Equal(int64(concurrentRequests), stats["total_requests"],
|
||||
"Total requests should match")
|
||||
suite.Equal(int64(1), stats["primary_requests"],
|
||||
"Should have 1 primary request")
|
||||
suite.Equal(int64(concurrentRequests-1), stats["coalesced_requests"],
|
||||
"Should have %d coalesced requests", concurrentRequests-1)
|
||||
})
|
||||
|
||||
// Test Case 2: Different queries should NOT coalesce
|
||||
suite.Run("different_queries_not_coalesced", func() {
|
||||
backendCallCount.Store(0)
|
||||
testCoalescer.Reset()
|
||||
|
||||
// Create server that returns query-specific responses
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCallCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := make([]byte, r.ContentLength)
|
||||
_, _ = r.Body.Read(body)
|
||||
|
||||
var response string
|
||||
if strings.Contains(string(body), "query1") {
|
||||
response = `{"data":{"result":"query1"}}`
|
||||
} else if strings.Contains(string(body), "query2") {
|
||||
response = `{"data":{"result":"query2"}}`
|
||||
} else {
|
||||
response = `{"data":{"result":"unknown"}}`
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(response))
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
cfg.Server.HostGraphQL = server2.URL
|
||||
cfg.Client.FastProxyClient = createFasthttpClient(cfg)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var response1, response2 string
|
||||
var err1, err2 error
|
||||
|
||||
// Launch two requests with different queries concurrently
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reqCtx := &fasthttp.RequestCtx{}
|
||||
reqCtx.Request.SetRequestURI("/graphql")
|
||||
reqCtx.Request.Header.SetMethod("POST")
|
||||
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
reqCtx.Request.SetBody([]byte(`{"query": "query { query1 }"}`))
|
||||
|
||||
ctx := suite.app.AcquireCtx(reqCtx)
|
||||
err1 = proxyTheRequest(ctx, cfg.Server.HostGraphQL)
|
||||
response1 = string(ctx.Response().Body())
|
||||
suite.app.ReleaseCtx(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reqCtx := &fasthttp.RequestCtx{}
|
||||
reqCtx.Request.SetRequestURI("/graphql")
|
||||
reqCtx.Request.Header.SetMethod("POST")
|
||||
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
reqCtx.Request.SetBody([]byte(`{"query": "query { query2 }"}`))
|
||||
|
||||
ctx := suite.app.AcquireCtx(reqCtx)
|
||||
err2 = proxyTheRequest(ctx, cfg.Server.HostGraphQL)
|
||||
response2 = string(ctx.Response().Body())
|
||||
suite.app.ReleaseCtx(ctx)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Both requests should succeed
|
||||
suite.Nil(err1, "Query1 should succeed")
|
||||
suite.Nil(err2, "Query2 should succeed")
|
||||
|
||||
// Should have made 2 backend calls (no coalescing for different queries)
|
||||
suite.Equal(int32(2), backendCallCount.Load(),
|
||||
"Should make 2 backend calls for 2 different queries")
|
||||
|
||||
// Responses should be different
|
||||
suite.Contains(response1, "query1", "Response1 should be for query1")
|
||||
suite.Contains(response2, "query2", "Response2 should be for query2")
|
||||
})
|
||||
|
||||
// Test Case 3: Coalescing disabled should make separate calls
|
||||
suite.Run("coalescing_disabled", func() {
|
||||
// Create a fresh server for this test
|
||||
var disabledCallCount atomic.Int32
|
||||
serverDisabled := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
disabledCallCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"data":{"users":[{"id":"1"}]}}`))
|
||||
}))
|
||||
defer serverDisabled.Close()
|
||||
|
||||
cfg.Server.HostGraphQL = serverDisabled.URL
|
||||
cfg.Client.FastProxyClient = createFasthttpClient(cfg)
|
||||
|
||||
// Disable coalescing
|
||||
cfg.RequestCoalescing.Enable = false
|
||||
|
||||
concurrentRequests := 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentRequests)
|
||||
|
||||
// Launch concurrent identical requests
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
reqCtx := &fasthttp.RequestCtx{}
|
||||
reqCtx.Request.SetRequestURI("/graphql")
|
||||
reqCtx.Request.Header.SetMethod("POST")
|
||||
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
reqCtx.Request.SetBody([]byte(`{"query": "query { users { id } }"}`))
|
||||
|
||||
ctx := suite.app.AcquireCtx(reqCtx)
|
||||
_ = proxyTheRequest(ctx, cfg.Server.HostGraphQL)
|
||||
suite.app.ReleaseCtx(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should make separate backend calls when coalescing is disabled
|
||||
suite.Equal(int32(concurrentRequests), disabledCallCount.Load(),
|
||||
"Should make %d backend calls when coalescing is disabled", concurrentRequests)
|
||||
|
||||
// Re-enable for subsequent tests
|
||||
cfg.RequestCoalescing.Enable = true
|
||||
})
|
||||
|
||||
// Test Case 4: Error responses should be shared correctly
|
||||
suite.Run("error_responses_coalesced", func() {
|
||||
backendCallCount.Store(0)
|
||||
testCoalescer.Reset()
|
||||
|
||||
// Create server that returns errors
|
||||
serverError := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
backendCallCount.Add(1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
_, _ = w.Write([]byte(`{"errors":[{"message":"Internal server error"}]}`))
|
||||
}))
|
||||
defer serverError.Close()
|
||||
|
||||
cfg.Server.HostGraphQL = serverError.URL
|
||||
cfg.Client.FastProxyClient = createFasthttpClient(cfg)
|
||||
|
||||
concurrentRequests := 5
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(concurrentRequests)
|
||||
|
||||
errors := make([]error, concurrentRequests)
|
||||
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
go func(index int) {
|
||||
defer wg.Done()
|
||||
|
||||
reqCtx := &fasthttp.RequestCtx{}
|
||||
reqCtx.Request.SetRequestURI("/graphql")
|
||||
reqCtx.Request.Header.SetMethod("POST")
|
||||
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||
reqCtx.Request.SetBody([]byte(`{"query": "query { fail }"}`))
|
||||
|
||||
ctx := suite.app.AcquireCtx(reqCtx)
|
||||
errors[index] = proxyTheRequest(ctx, cfg.Server.HostGraphQL)
|
||||
suite.app.ReleaseCtx(ctx)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should still only make 1 backend call
|
||||
suite.Equal(int32(1), backendCallCount.Load(),
|
||||
"Should make only 1 backend call even for error responses")
|
||||
|
||||
// All requests should receive the same error
|
||||
for i := 0; i < concurrentRequests; i++ {
|
||||
suite.NotNil(errors[i], "Request %d should have error", i)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -133,6 +133,27 @@ func parseConfig() {
|
||||
c.Cache.CacheMaxEntries = getDetailsFromEnv("CACHE_MAX_ENTRIES", 10000) // Default 10000 entries
|
||||
// GraphQL query parsing cache - auto-calculate based on CPU cores if not set
|
||||
c.Cache.GraphQLQueryCacheSize = getDetailsFromEnv("GRAPHQL_QUERY_CACHE_SIZE", runtime.GOMAXPROCS(0)*250)
|
||||
|
||||
// SECURITY: Per-user cache isolation (enabled by default for security)
|
||||
// Set CACHE_PER_USER_DISABLED=true ONLY if you have a single-user application
|
||||
// or understand the security implications of shared cache across users
|
||||
c.Cache.PerUserCacheDisabled = getDetailsFromEnv("CACHE_PER_USER_DISABLED", false)
|
||||
|
||||
// Log warning if per-user caching is disabled
|
||||
if c.Cache.PerUserCacheDisabled {
|
||||
defer func() {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Warning(&libpack_logging.LogMessage{
|
||||
Message: "⚠️ Per-user cache isolation is DISABLED - Users may see each other's cached data!",
|
||||
Pairs: map[string]interface{}{
|
||||
"security_risk": "CRITICAL - Do not use in multi-user applications",
|
||||
"recommendation": "Remove CACHE_PER_USER_DISABLED or set it to false",
|
||||
},
|
||||
})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Redis cache
|
||||
c.Cache.CacheRedisEnable = getDetailsFromEnv("ENABLE_REDIS_CACHE", false)
|
||||
c.Cache.CacheRedisURL = getDetailsFromEnv("CACHE_REDIS_URL", "localhost:6379")
|
||||
@@ -246,7 +267,7 @@ func parseConfig() {
|
||||
c.Api.BannedUsersFile = validatedPath
|
||||
}
|
||||
c.Server.PurgeOnCrawl = getDetailsFromEnv("PURGE_METRICS_ON_CRAWL", false)
|
||||
c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 0)
|
||||
c.Server.PurgeEvery = getDetailsFromEnv("PURGE_METRICS_ON_TIMER", 1800) // Default: purge metrics every 30 minutes
|
||||
// Hasura event cleaner
|
||||
c.HasuraEventCleaner.Enable = getDetailsFromEnv("HASURA_EVENT_CLEANER", false)
|
||||
c.HasuraEventCleaner.ClearOlderThan = getDetailsFromEnv("HASURA_EVENT_CLEANER_OLDER_THAN", 1)
|
||||
@@ -355,8 +376,9 @@ func parseConfig() {
|
||||
// Initialize cache if enabled
|
||||
if cfg.Cache.CacheEnable || cfg.Cache.CacheRedisEnable {
|
||||
cacheConfig := &libpack_cache.CacheConfig{
|
||||
Logger: cfg.Logger,
|
||||
TTL: cfg.Cache.CacheTTL,
|
||||
Logger: cfg.Logger,
|
||||
TTL: cfg.Cache.CacheTTL,
|
||||
PerUserCacheDisabled: cfg.Cache.PerUserCacheDisabled,
|
||||
}
|
||||
// Redis cache configurations
|
||||
if cfg.Cache.CacheRedisEnable {
|
||||
@@ -387,15 +409,7 @@ func parseConfig() {
|
||||
initCircuitBreaker(cfg)
|
||||
}
|
||||
|
||||
// Initialize retry budget
|
||||
if cfg.RetryBudget.Enable {
|
||||
retryBudgetConfig := RetryBudgetConfig{
|
||||
TokensPerSecond: cfg.RetryBudget.TokensPerSecond,
|
||||
MaxTokens: cfg.RetryBudget.MaxTokens,
|
||||
Enabled: cfg.RetryBudget.Enable,
|
||||
}
|
||||
InitializeRetryBudget(retryBudgetConfig, cfg.Logger)
|
||||
}
|
||||
// Note: Retry budget is initialized in main() with context for graceful shutdown
|
||||
|
||||
// Initialize request coalescer
|
||||
if cfg.RequestCoalescing.Enable {
|
||||
@@ -420,11 +434,7 @@ func parseConfig() {
|
||||
healthMgr.StartHealthChecking()
|
||||
}
|
||||
|
||||
// Initialize RPS tracker for real-time requests per second monitoring
|
||||
InitializeRPSTracker()
|
||||
cfg.Logger.Info(&libpack_logging.LogMessage{
|
||||
Message: "RPS tracker initialized",
|
||||
})
|
||||
// Note: RPS tracker is initialized in main() with context for graceful shutdown
|
||||
|
||||
// Load rate limit configuration with improved error handling
|
||||
if err := loadRatelimitConfig(); err != nil {
|
||||
@@ -462,6 +472,22 @@ func main() {
|
||||
// Initialize shutdown manager
|
||||
shutdownManager = NewShutdownManager(ctx)
|
||||
|
||||
// Initialize RPS tracker with context for graceful shutdown
|
||||
InitializeRPSTracker(ctx)
|
||||
cfg.Logger.Info(&libpack_logging.LogMessage{
|
||||
Message: "RPS tracker initialized",
|
||||
})
|
||||
|
||||
// Initialize retry budget with context for graceful shutdown
|
||||
if cfg.RetryBudget.Enable {
|
||||
retryBudgetConfig := RetryBudgetConfig{
|
||||
TokensPerSecond: cfg.RetryBudget.TokensPerSecond,
|
||||
MaxTokens: cfg.RetryBudget.MaxTokens,
|
||||
Enabled: cfg.RetryBudget.Enable,
|
||||
}
|
||||
InitializeRetryBudgetWithContext(ctx, retryBudgetConfig, cfg.Logger)
|
||||
}
|
||||
|
||||
// Create a wait group to manage goroutines
|
||||
var wg sync.WaitGroup
|
||||
|
||||
|
||||
+11
-2
@@ -506,6 +506,7 @@ func (ma *MetricsAggregator) aggregateStats(instances []InstanceMetrics) map[str
|
||||
totalCacheMisses int64
|
||||
totalCachedQueries int64
|
||||
totalMemoryUsageMB float64
|
||||
hasValidMemoryStats bool // Track if any instance has valid memory stats
|
||||
totalCurrentRPS float64
|
||||
totalAvgRPS float64
|
||||
totalActiveConnections int64
|
||||
@@ -598,9 +599,11 @@ func (ma *MetricsAggregator) aggregateStats(instances []InstanceMetrics) map[str
|
||||
}
|
||||
|
||||
// Aggregate memory usage from full cache details
|
||||
// Skip -1 values which indicate Redis cache (memory tracking not available)
|
||||
if len(instance.Cache) > 0 {
|
||||
if memMB, ok := instance.Cache["memory_usage_mb"].(float64); ok {
|
||||
if memMB, ok := instance.Cache["memory_usage_mb"].(float64); ok && memMB >= 0 {
|
||||
totalMemoryUsageMB += memMB
|
||||
hasValidMemoryStats = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -718,7 +721,13 @@ func (ma *MetricsAggregator) aggregateStats(instances []InstanceMetrics) map[str
|
||||
"total_cached": totalCachedQueries,
|
||||
},
|
||||
"memory": map[string]interface{}{
|
||||
"total_usage_mb": totalMemoryUsageMB,
|
||||
"total_usage_mb": func() float64 {
|
||||
if hasValidMemoryStats {
|
||||
return totalMemoryUsageMB
|
||||
}
|
||||
return -1
|
||||
}(),
|
||||
"available": hasValidMemoryStats,
|
||||
},
|
||||
"connections": map[string]interface{}{
|
||||
"total_active": totalActiveConnections,
|
||||
|
||||
+107
-17
@@ -68,26 +68,74 @@ func ensureDefaultLabels(labels *map[string]string, podName string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sanitizeLabelValue removes or replaces characters that are invalid in metric labels
|
||||
// This includes null bytes, newlines, carriage returns, quotes, and backslashes
|
||||
func sanitizeLabelValue(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
|
||||
var buf strings.Builder
|
||||
buf.Grow(len(value))
|
||||
|
||||
for _, r := range value {
|
||||
switch r {
|
||||
case '\x00': // null byte
|
||||
continue // Skip null bytes entirely
|
||||
case '\n', '\r', '\t': // newlines, carriage returns, tabs
|
||||
buf.WriteByte(' ') // Replace with space
|
||||
case '"', '\\': // quotes and backslashes need escaping
|
||||
buf.WriteByte('\\')
|
||||
buf.WriteRune(r)
|
||||
default:
|
||||
// Only allow printable ASCII and common unicode characters
|
||||
if unicode.IsPrint(r) {
|
||||
buf.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
|
||||
if len(labels) == 0 {
|
||||
// Add defer/recover to prevent panics from crashing the application
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Log the panic but don't crash
|
||||
fmt.Fprintf(os.Stderr, "Recovered from panic in appendSortedLabels: %v\n", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(labels) == 0 || buf == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Create a snapshot to avoid concurrent access issues
|
||||
labelsCopy := make(map[string]string, len(labels))
|
||||
for k, v := range labels {
|
||||
labelsCopy[k] = v
|
||||
if k == "" {
|
||||
continue // Skip empty keys
|
||||
}
|
||||
// Sanitize the label value to remove null bytes and other invalid characters
|
||||
labelsCopy[k] = sanitizeLabelValue(v)
|
||||
}
|
||||
|
||||
if len(labelsCopy) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
keys := getSortedKeys(labelsCopy)
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
if v, ok := labelsCopy[k]; ok {
|
||||
if i > 0 {
|
||||
buf.WriteByte(',')
|
||||
}
|
||||
buf.WriteString(k)
|
||||
buf.WriteString(`="`)
|
||||
buf.WriteString(v)
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
buf.WriteString(k)
|
||||
buf.WriteString(`="`)
|
||||
buf.WriteString(labelsCopy[k])
|
||||
buf.WriteByte('"')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,7 +165,15 @@ func getSortedKeys(labels map[string]string) []string {
|
||||
}
|
||||
|
||||
func labelsToString(labels map[string]string) string {
|
||||
if labels == nil {
|
||||
// Add defer/recover to prevent panics from crashing the application
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Log the panic but don't crash
|
||||
fmt.Fprintf(os.Stderr, "Recovered from panic in labelsToString: %v\n", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if len(labels) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -126,17 +182,34 @@ func labelsToString(labels map[string]string) string {
|
||||
values := make(map[string]string, len(labels))
|
||||
|
||||
for k, v := range labels {
|
||||
if k == "" {
|
||||
continue // Skip empty keys
|
||||
}
|
||||
keys = append(keys, k)
|
||||
values[k] = v
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
sort.Strings(keys)
|
||||
|
||||
// Pre-allocate the builder with estimated capacity to avoid reallocation
|
||||
var sb strings.Builder
|
||||
estimatedSize := 0
|
||||
for _, k := range keys {
|
||||
sb.WriteString(k)
|
||||
sb.WriteByte('=')
|
||||
sb.WriteString(values[k])
|
||||
sb.WriteByte(';')
|
||||
estimatedSize += len(k) + len(values[k]) + 2 // key + value + '=' + ';'
|
||||
}
|
||||
sb.Grow(estimatedSize)
|
||||
|
||||
for _, k := range keys {
|
||||
if v, ok := values[k]; ok {
|
||||
sb.WriteString(k)
|
||||
sb.WriteByte('=')
|
||||
sb.WriteString(v)
|
||||
sb.WriteByte(';')
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
@@ -186,6 +259,14 @@ func is_special_rune(r rune) bool {
|
||||
}
|
||||
|
||||
func compile_metrics_with_labels(name string, labels map[string]string) string {
|
||||
// Add defer/recover to prevent panics from crashing the application
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Log the panic but don't crash
|
||||
fmt.Fprintf(os.Stderr, "Recovered from panic in compile_metrics_with_labels: %v\n", r)
|
||||
}
|
||||
}()
|
||||
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.WriteString(name)
|
||||
@@ -197,16 +278,25 @@ func compile_metrics_with_labels(name string, labels map[string]string) string {
|
||||
// Create a snapshot to avoid concurrent access issues
|
||||
labelsCopy := make(map[string]string, len(labels))
|
||||
for k, v := range labels {
|
||||
if k == "" {
|
||||
continue // Skip empty keys
|
||||
}
|
||||
labelsCopy[k] = v
|
||||
}
|
||||
|
||||
if len(labelsCopy) == 0 {
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
keys := getSortedKeys(labelsCopy)
|
||||
|
||||
for _, k := range keys {
|
||||
buf.WriteByte('_')
|
||||
buf.WriteString(k)
|
||||
buf.WriteByte('_')
|
||||
buf.WriteString(labelsCopy[k])
|
||||
if v, ok := labelsCopy[k]; ok {
|
||||
buf.WriteByte('_')
|
||||
buf.WriteString(k)
|
||||
buf.WriteByte('_')
|
||||
buf.WriteString(v)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package libpack_monitoring
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -17,6 +18,8 @@ type MetricsSetup struct {
|
||||
metrics_set_custom *metrics.Set
|
||||
ic *InitConfig
|
||||
metrics_prefix string
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
var log = libpack_logger.New().SetMinLogLevel(libpack_logger.LEVEL_INFO)
|
||||
@@ -27,10 +30,18 @@ type InitConfig struct {
|
||||
}
|
||||
|
||||
func NewMonitoring(ic *InitConfig) *MetricsSetup {
|
||||
return NewMonitoringWithContext(context.Background(), ic)
|
||||
}
|
||||
|
||||
// NewMonitoringWithContext creates a new monitoring instance with context for graceful shutdown
|
||||
func NewMonitoringWithContext(ctx context.Context, ic *InitConfig) *MetricsSetup {
|
||||
monCtx, cancel := context.WithCancel(ctx)
|
||||
ms := &MetricsSetup{
|
||||
ic: ic,
|
||||
metrics_set: metrics.NewSet(),
|
||||
metrics_set_custom: metrics.NewSet(),
|
||||
ctx: monCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
if flag.Lookup("test.v") == nil {
|
||||
@@ -39,8 +50,14 @@ func NewMonitoring(ic *InitConfig) *MetricsSetup {
|
||||
if ic.PurgeEvery > 0 {
|
||||
ticker := time.NewTicker(time.Duration(ic.PurgeEvery) * time.Second)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
ms.PurgeMetrics()
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ms.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
ms.PurgeMetrics()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -49,6 +66,13 @@ func NewMonitoring(ic *InitConfig) *MetricsSetup {
|
||||
return ms
|
||||
}
|
||||
|
||||
// Shutdown stops the monitoring goroutines
|
||||
func (ms *MetricsSetup) Shutdown() {
|
||||
if ms.cancel != nil {
|
||||
ms.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *MetricsSetup) startPrometheusEndpoint() {
|
||||
app := fiber.New(fiber.Config{
|
||||
DisableStartupMessage: true,
|
||||
@@ -95,6 +119,20 @@ func (ms *MetricsSetup) RegisterMetricsGauge(metric_name string, labels map[stri
|
||||
})
|
||||
}
|
||||
|
||||
// RegisterMetricsGaugeFunc registers a gauge with a callback function that is called on every scrape
|
||||
// This is useful for metrics that need to return a dynamic value
|
||||
func (ms *MetricsSetup) RegisterMetricsGaugeFunc(metric_name string, labels map[string]string, fn func() float64) *metrics.Gauge {
|
||||
if err := validate_metrics_name(metric_name); err != nil {
|
||||
log.Error(&libpack_logger.LogMessage{
|
||||
Message: "RegisterMetricsGaugeFunc() error - invalid metric name",
|
||||
Pairs: map[string]interface{}{"error": err.Error(), "metric_name": metric_name},
|
||||
})
|
||||
// Return a dummy gauge instead of nil to prevent panics
|
||||
return &metrics.Gauge{}
|
||||
}
|
||||
return ms.metrics_set_custom.GetOrCreateGauge(ms.get_metrics_name(metric_name, labels), fn)
|
||||
}
|
||||
|
||||
func (ms *MetricsSetup) RegisterMetricsCounter(metric_name string, labels map[string]string) *metrics.Counter {
|
||||
if err := validate_metrics_name(metric_name); err != nil {
|
||||
log.Error(&libpack_logger.LogMessage{
|
||||
|
||||
@@ -325,16 +325,72 @@ func setupTracing(c *fiber.Ctx) context.Context {
|
||||
return ctx
|
||||
}
|
||||
|
||||
// performProxyRequest executes the proxy request with retries and circuit breaker
|
||||
// performProxyRequest executes the proxy request with retries, circuit breaker, and request coalescing
|
||||
func performProxyRequest(c *fiber.Ctx, proxyURL string) error {
|
||||
// Extract user context for cache key (needed for coalescing and circuit breaker fallback)
|
||||
userID, userRole := extractUserInfo(c)
|
||||
|
||||
// Calculate cache key - includes user context for security
|
||||
// This key is used for both request coalescing and cache fallback
|
||||
cacheKey := libpack_cache.CalculateHash(c, userID, userRole)
|
||||
|
||||
// Check if request coalescing is enabled
|
||||
rc := GetRequestCoalescer()
|
||||
if rc != nil && cfg.RequestCoalescing.Enable {
|
||||
// Use request coalescing to deduplicate identical concurrent requests
|
||||
response, err := rc.Do(cacheKey, func() (*CoalescedResponse, error) {
|
||||
// Execute the actual proxy request
|
||||
proxyErr := performProxyRequestCore(c, proxyURL, cacheKey)
|
||||
|
||||
// Capture the response for coalescing
|
||||
if proxyErr != nil {
|
||||
return &CoalescedResponse{
|
||||
Err: proxyErr,
|
||||
StatusCode: c.Response().StatusCode(),
|
||||
}, proxyErr
|
||||
}
|
||||
|
||||
return &CoalescedResponse{
|
||||
Body: c.Response().Body(),
|
||||
StatusCode: c.Response().StatusCode(),
|
||||
Headers: make(map[string]string),
|
||||
}, nil
|
||||
})
|
||||
|
||||
// Check for error from rc.Do (though it typically returns nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check for error stored in the response (for coalesced requests)
|
||||
if response != nil && response.Err != nil {
|
||||
return response.Err
|
||||
}
|
||||
|
||||
// For coalesced requests (not the primary), we need to copy the response
|
||||
if response != nil && response.Body != nil && len(response.Body) > 0 {
|
||||
// Only set response if this is a coalesced request (body would be empty otherwise)
|
||||
if len(c.Response().Body()) == 0 {
|
||||
c.Response().SetStatusCode(response.StatusCode)
|
||||
c.Response().SetBody(response.Body)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// No coalescing - execute directly
|
||||
return performProxyRequestCore(c, proxyURL, cacheKey)
|
||||
}
|
||||
|
||||
// performProxyRequestCore executes the proxy request with retries and circuit breaker
|
||||
// This is the core implementation used by both direct calls and coalesced requests
|
||||
func performProxyRequestCore(c *fiber.Ctx, proxyURL string, cacheKey string) error {
|
||||
// If circuit breaker is not enabled, use the original method
|
||||
if !cfg.CircuitBreaker.Enable || cb == nil {
|
||||
return performProxyRequestWithRetries(c, proxyURL)
|
||||
}
|
||||
|
||||
// Calculate cache key for potential fallback
|
||||
cacheKey := libpack_cache.CalculateHash(c)
|
||||
|
||||
// Execute request through circuit breaker
|
||||
_, err := cb.Execute(func() (interface{}, error) {
|
||||
// Execute the request with retries
|
||||
|
||||
@@ -82,6 +82,36 @@ func (suite *Tests) Test_proxyTheRequest() {
|
||||
wantErr: false,
|
||||
wantEndpoint: "https://telegram-bot.app/",
|
||||
},
|
||||
{
|
||||
name: "Test mutation with multiple operations (bug fix regression test)",
|
||||
body: `{"query":"mutation getOrCreateUser { insert_tg_users_one(object: {id: 123}) { id } } query otherQuery { users { id } }"}`,
|
||||
host: "https://telegram-bot.app/",
|
||||
hostRO: "https://google.com/",
|
||||
path: "/v1/graphql",
|
||||
headers: supplied_headers,
|
||||
wantErr: false,
|
||||
wantEndpoint: "https://telegram-bot.app/",
|
||||
},
|
||||
{
|
||||
name: "Test mutation followed by fragment (bug fix regression test)",
|
||||
body: `{"query":"mutation insertUser { insert_users_one(object: {name: \"test\"}) { ...userFields } } fragment userFields on users { id name }"}`,
|
||||
host: "https://telegram-bot.app/",
|
||||
hostRO: "https://google.com/",
|
||||
path: "/v1/graphql",
|
||||
headers: supplied_headers,
|
||||
wantErr: false,
|
||||
wantEndpoint: "https://telegram-bot.app/",
|
||||
},
|
||||
{
|
||||
name: "Test complex mutation document (main-bot style)",
|
||||
body: `{"query":"mutation getOrCreateUser($user_id: bigint!, $group_id: bigint!) { insert_tg_users_one(object: {id: $user_id}, on_conflict: {constraint: tg_users_pkey, update_columns: last_seen}) { id } insert_tg_groups_one(object: {id: $group_id}, on_conflict: {constraint: tg_groups_pkey, update_columns: last_seen}) { id } }"}`,
|
||||
host: "https://telegram-bot.app/",
|
||||
hostRO: "https://google.com/",
|
||||
path: "/v1/graphql",
|
||||
headers: supplied_headers,
|
||||
wantErr: false,
|
||||
wantEndpoint: "https://telegram-bot.app/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
+33
-5
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -18,6 +19,8 @@ type RetryBudget struct {
|
||||
mu sync.RWMutex
|
||||
enabled bool
|
||||
logger *libpack_logger.Logger
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Statistics
|
||||
totalAttempts atomic.Int64
|
||||
@@ -32,13 +35,21 @@ type RetryBudgetConfig struct {
|
||||
Enabled bool // Whether retry budget is enabled
|
||||
}
|
||||
|
||||
// NewRetryBudget creates a new retry budget
|
||||
// NewRetryBudget creates a new retry budget (deprecated, use NewRetryBudgetWithContext)
|
||||
func NewRetryBudget(config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget {
|
||||
return NewRetryBudgetWithContext(context.Background(), config, logger)
|
||||
}
|
||||
|
||||
// NewRetryBudgetWithContext creates a new retry budget with context for graceful shutdown
|
||||
func NewRetryBudgetWithContext(ctx context.Context, config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget {
|
||||
budgetCtx, cancel := context.WithCancel(ctx)
|
||||
rb := &RetryBudget{
|
||||
tokensPerSecond: config.TokensPerSecond,
|
||||
maxTokens: int64(config.MaxTokens),
|
||||
enabled: config.Enabled,
|
||||
logger: logger,
|
||||
ctx: budgetCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Initialize with full bucket
|
||||
@@ -91,8 +102,20 @@ func (rb *RetryBudget) refillLoop() {
|
||||
ticker := time.NewTicker(100 * time.Millisecond) // Refill every 100ms
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rb.refill()
|
||||
for {
|
||||
select {
|
||||
case <-rb.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
rb.refill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the retry budget goroutine
|
||||
func (rb *RetryBudget) Shutdown() {
|
||||
if rb.cancel != nil {
|
||||
rb.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -187,10 +210,15 @@ var (
|
||||
retryBudgetOnce sync.Once
|
||||
)
|
||||
|
||||
// InitializeRetryBudget initializes the global retry budget
|
||||
// InitializeRetryBudget initializes the global retry budget (deprecated, use InitializeRetryBudgetWithContext)
|
||||
func InitializeRetryBudget(config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget {
|
||||
return InitializeRetryBudgetWithContext(context.Background(), config, logger)
|
||||
}
|
||||
|
||||
// InitializeRetryBudgetWithContext initializes the global retry budget with context for graceful shutdown
|
||||
func InitializeRetryBudgetWithContext(ctx context.Context, config RetryBudgetConfig, logger *libpack_logger.Logger) *RetryBudget {
|
||||
retryBudgetOnce.Do(func() {
|
||||
retryBudget = NewRetryBudget(config, logger)
|
||||
retryBudget = NewRetryBudgetWithContext(ctx, config, logger)
|
||||
if logger != nil && config.Enabled {
|
||||
logger.Info(&libpack_logger.LogMessage{
|
||||
Message: "Retry budget initialized",
|
||||
|
||||
+27
-8
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -12,11 +13,17 @@ type RPSTracker struct {
|
||||
lastSampleTime atomic.Int64 // Unix nano
|
||||
currentRPS uint64 // stored as uint64, accessed with atomic operations
|
||||
mu sync.RWMutex // for currentRPS updates
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRPSTracker creates a new RPS tracker
|
||||
func NewRPSTracker() *RPSTracker {
|
||||
tracker := &RPSTracker{}
|
||||
// NewRPSTracker creates a new RPS tracker with context for graceful shutdown
|
||||
func NewRPSTracker(ctx context.Context) *RPSTracker {
|
||||
trackerCtx, cancel := context.WithCancel(ctx)
|
||||
tracker := &RPSTracker{
|
||||
ctx: trackerCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
tracker.lastSampleTime.Store(time.Now().UnixNano())
|
||||
go tracker.updateLoop()
|
||||
return tracker
|
||||
@@ -33,8 +40,20 @@ func (r *RPSTracker) updateLoop() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
r.sample()
|
||||
for {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.sample()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the RPS tracker
|
||||
func (r *RPSTracker) Shutdown() {
|
||||
if r.cancel != nil {
|
||||
r.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,10 +94,10 @@ func (r *RPSTracker) GetCurrentRPS() float64 {
|
||||
|
||||
var globalRPSTracker *RPSTracker
|
||||
|
||||
// InitializeRPSTracker initializes the global RPS tracker
|
||||
func InitializeRPSTracker() *RPSTracker {
|
||||
// InitializeRPSTracker initializes the global RPS tracker with context for graceful shutdown
|
||||
func InitializeRPSTracker(ctx context.Context) *RPSTracker {
|
||||
if globalRPSTracker == nil {
|
||||
globalRPSTracker = NewRPSTracker()
|
||||
globalRPSTracker = NewRPSTracker(ctx)
|
||||
}
|
||||
return globalRPSTracker
|
||||
}
|
||||
|
||||
@@ -272,6 +272,17 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
||||
|
||||
// Parse the GraphQL query
|
||||
parsedResult := parseGraphQLQuery(c)
|
||||
|
||||
// Debug logging for mutation routing analysis (enabled when LOG_LEVEL=DEBUG)
|
||||
if cfg.LogLevel == "DEBUG" {
|
||||
var m map[string]interface{}
|
||||
if err := json.Unmarshal(c.Body(), &m); err == nil {
|
||||
if query, ok := m["query"].(string); ok {
|
||||
debugParseGraphQLQuery(c, query)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if parsedResult.shouldBlock {
|
||||
return c.Status(fiber.StatusForbidden).SendString("Request blocked")
|
||||
}
|
||||
@@ -316,8 +327,11 @@ func extractUserInfo(c *fiber.Ctx) (string, string) {
|
||||
|
||||
// handleCaching manages the caching logic for GraphQL requests
|
||||
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) {
|
||||
// Calculate query hash for cache key
|
||||
calculatedQueryHash := libpack_cache.CalculateHash(c)
|
||||
// Extract user role for cache key (in addition to userID already passed)
|
||||
_, userRole := extractUserInfo(c)
|
||||
|
||||
// Calculate query hash for cache key - now includes user context for security
|
||||
calculatedQueryHash := libpack_cache.CalculateHash(c, userID, userRole)
|
||||
|
||||
// Set cache time from header or default
|
||||
if parsedResult.cacheTime == 0 {
|
||||
|
||||
@@ -97,6 +97,9 @@ spec:
|
||||
value: "error"
|
||||
- name: HASURA_GRAPHQL_SERVER_PORT
|
||||
value: "8088"
|
||||
# Disable event trigger processing on read-only instance
|
||||
- name: HASURA_GRAPHQL_EVENTS_FETCH_INTERVAL
|
||||
value: "0"
|
||||
|
||||
- name: graphql-proxy
|
||||
image: ghcr.io/lukaszraczylo/graphql-monitoring-proxy:latest
|
||||
|
||||
+2
-1
@@ -44,7 +44,8 @@ type config struct {
|
||||
CacheRedisEnable bool
|
||||
CacheMaxMemorySize int
|
||||
CacheMaxEntries int
|
||||
GraphQLQueryCacheSize int // Max number of parsed GraphQL queries to cache
|
||||
GraphQLQueryCacheSize int // Max number of parsed GraphQL queries to cache
|
||||
PerUserCacheDisabled bool // Disable per-user cache isolation (SECURITY RISK - not recommended)
|
||||
}
|
||||
Client struct {
|
||||
GQLClient *graphql.BaseClient
|
||||
|
||||
+91
-2
@@ -8,6 +8,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
gorillaws "github.com/gorilla/websocket"
|
||||
@@ -141,8 +142,29 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
|
||||
// Set message size limit
|
||||
clientConn.SetReadLimit(wsp.maxMessageSize)
|
||||
|
||||
// Connect to backend WebSocket with forwarded headers
|
||||
backendConn, err := wsp.dialBackend(ctx, headers)
|
||||
// Read first message to extract authentication from connection_init payload
|
||||
// This bridges the gap between clients that send auth in payload vs Hasura expecting it in HTTP headers
|
||||
messageType, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to read first message from client",
|
||||
Pairs: map[string]interface{}{
|
||||
"connection_id": connectionID,
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to extract headers from connection_init payload (for GraphQL WebSocket protocols)
|
||||
enrichedHeaders := wsp.extractAuthFromPayload(message, headers)
|
||||
|
||||
// Connect to backend WebSocket with enriched headers
|
||||
backendConn, err := wsp.dialBackend(ctx, enrichedHeaders)
|
||||
if err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
@@ -159,6 +181,21 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// Forward the first message (connection_init) to backend
|
||||
if err := backendConn.WriteMessage(messageType, message); err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to forward connection_init to backend",
|
||||
Pairs: map[string]interface{}{
|
||||
"connection_id": connectionID,
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Backend WebSocket connection established",
|
||||
@@ -336,6 +373,58 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
|
||||
}
|
||||
}
|
||||
|
||||
// extractAuthFromPayload extracts authentication headers from GraphQL WebSocket connection_init payload
|
||||
// This bridges the gap between clients sending auth in payload and Hasura expecting it in HTTP headers
|
||||
func (wsp *WebSocketProxy) extractAuthFromPayload(message []byte, originalHeaders http.Header) http.Header {
|
||||
// Create a copy of original headers
|
||||
enrichedHeaders := make(http.Header)
|
||||
for k, v := range originalHeaders {
|
||||
enrichedHeaders[k] = v
|
||||
}
|
||||
|
||||
// Try to parse as JSON to extract headers from payload
|
||||
var msg map[string]interface{}
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
// Not JSON or parse error, return original headers
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Check if this is a connection_init message
|
||||
msgType, ok := msg["type"].(string)
|
||||
if !ok || (msgType != "connection_init" && msgType != "start") {
|
||||
// Not a connection_init, return original headers
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Extract payload
|
||||
payload, ok := msg["payload"].(map[string]interface{})
|
||||
if !ok {
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Try to extract headers from payload.headers (graphql-ws format)
|
||||
if payloadHeaders, ok := payload["headers"].(map[string]interface{}); ok {
|
||||
for key, value := range payloadHeaders {
|
||||
if strValue, ok := value.(string); ok {
|
||||
enrichedHeaders.Set(key, strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check top-level payload keys that look like headers (Apollo format)
|
||||
for key, value := range payload {
|
||||
if strValue, ok := value.(string); ok {
|
||||
// Common auth headers
|
||||
if key == "Authorization" || key == "authorization" ||
|
||||
key == "x-hasura-role" || key == "x-hasura-admin-secret" {
|
||||
enrichedHeaders.Set(key, strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// dialBackend establishes a WebSocket connection to the backend
|
||||
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
|
||||
// Convert http:// to ws:// or https:// to wss://
|
||||
|
||||
Reference in New Issue
Block a user