mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-05 23:03:48 +00:00
perf+coverage: optimisation pass + coverage push to ≥70%
Performance / resource usage: - circuit_breaker_metrics: fix data race on failCounters map (RWMutex + double-checked locking) - server.go: drop user_id and op_name metric labels (Prometheus cardinality bound); de-duplicate extractUserInfo - graphql.go: gate runtime.ReadMemStats per-request behind ENABLE_ALLOCATION_TRACKING flag (default off) - graphql.go: collapse two-pass AST scan into single pass; lower-case once - sanitization.go: cache compiled redaction regexes per pattern via sync.Map; hoist inner constants to pkg vars - proxy.go: hoist connection/timeout substrings to pkg vars; sentinel errors for static error paths; drop dead Headers map alloc - metrics_aggregator.go: log-field allocation guarded by Logger.IsLevelEnabled - logging/logger.go: add IsLevelEnabled helper - lru_cache.go: 16-shard sharding, FNV-1a routing (concurrent throughput +22%) - cache/memory/lru_memory_cache.go: gzip compress/decompress moved outside mu.Lock - rps_tracker.go: RWMutex+uint64 -> atomic.Uint64 - retry_budget.go: drop unused mutex - api.go: bannedUsersIDs map+RWMutex -> sync.Map (+ snapshot/replace helpers) - tracing/tracing.go: pkg-level constSpanAttrs, copy-then-append in StartSpanWithAttributes - admin_dashboard.go: handleStatsWebSocket reuses bytes.Buffer + json.Encoder per connection Build / runtime: - Makefile: -ldflags="-s -w" -trimpath, CGO_ENABLED=0 for build (=1 for test recipes) - Dockerfile + Dockerfile.goreleaser: ENV GOMEMLIMIT=512MiB - main.go: blank import go.uber.org/automaxprocs (cgroup-aware GOMAXPROCS) - main.go: PPROF_PORT env var wires net/http/pprof on 127.0.0.1 only with full server timeouts - README.md: env-var docs + metric-label docs updated; cardinality note Test coverage push (per package): - main 51.2% -> 74.7% - cache 66.3% -> 93.7% - cache/redis 45.5% -> 98.2% - tracing 66.7% -> 72.9% - (cache/memory 91.6%, logging 91.9%, monitoring 77.6%, pkg/pools 100% unchanged) New test files: coverage_micro_test, coverage_extras_test, server_handlers_test, api_health_test, admin_dashboard_cluster_test, metrics_aggregator_test, concerns_test, cache/cache_coverage_test, cache/redis/redis_coverage_test, tracing/tracing_coverage_test. Bug fix: connection_resilience_test.go TestIntegratedHealthManagement.health_manager_startup was sync.Once-coupled to InitializeBackendHealth and panicked when another test (e.g. via parseConfig) had already triggered Once. Use NewBackendHealthManager directly.
This commit is contained in:
@@ -5,4 +5,9 @@ ARG TARGETOS
|
|||||||
# silly workaround for distroless image as no chmod is available
|
# silly workaround for distroless image as no chmod is available
|
||||||
COPY --chmod=777 --chown=nonroot:nonroot static/app /go/src/app
|
COPY --chmod=777 --chown=nonroot:nonroot static/app /go/src/app
|
||||||
ADD dist/bot-$TARGETOS-$TARGETARCH /go/src/app/graphql-proxy
|
ADD dist/bot-$TARGETOS-$TARGETARCH /go/src/app/graphql-proxy
|
||||||
|
# Runtime tuning: operators should override GOMEMLIMIT per deployment
|
||||||
|
# to match container memory limits (e.g. set to ~80% of cgroup limit).
|
||||||
|
ENV GOMEMLIMIT=512MiB
|
||||||
|
# NOTE: no HEALTHCHECK — distroless:nonroot lacks /bin/sh and curl/wget.
|
||||||
|
# Use orchestrator-level probes (Kubernetes liveness/readiness) hitting /live on monitoring port.
|
||||||
ENTRYPOINT ["/go/src/app/graphql-proxy"]
|
ENTRYPOINT ["/go/src/app/graphql-proxy"]
|
||||||
|
|||||||
@@ -3,4 +3,9 @@ ARG TARGETPLATFORM
|
|||||||
WORKDIR /go/src/app
|
WORKDIR /go/src/app
|
||||||
COPY --chmod=777 --chown=nonroot:nonroot static/app /go/src/app
|
COPY --chmod=777 --chown=nonroot:nonroot static/app /go/src/app
|
||||||
COPY ${TARGETPLATFORM}/graphql-proxy /go/src/app/graphql-proxy
|
COPY ${TARGETPLATFORM}/graphql-proxy /go/src/app/graphql-proxy
|
||||||
|
# Runtime tuning: operators should override GOMEMLIMIT per deployment
|
||||||
|
# to match container memory limits (e.g. set to ~80% of cgroup limit).
|
||||||
|
ENV GOMEMLIMIT=512MiB
|
||||||
|
# NOTE: no HEALTHCHECK — distroless:nonroot lacks /bin/sh and curl/wget.
|
||||||
|
# Use orchestrator-level probes (Kubernetes liveness/readiness) hitting /live on monitoring port.
|
||||||
ENTRYPOINT ["/go/src/app/graphql-proxy"]
|
ENTRYPOINT ["/go/src/app/graphql-proxy"]
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
CI_RUN?=false
|
CI_RUN?=false
|
||||||
TIMESTAMP := $(shell date +%Y%m%d-%H%M%S)
|
TIMESTAMP := $(shell date +%Y%m%d-%H%M%S)
|
||||||
|
|
||||||
|
# Build hardening flags
|
||||||
|
# -s: omit symbol table, -w: omit DWARF debug info (smaller binaries)
|
||||||
|
LDFLAGS ?= -s -w
|
||||||
|
# -trimpath: remove local filesystem paths from binary (reproducible builds)
|
||||||
|
GOFLAGS ?= -trimpath
|
||||||
|
# CGO_ENABLED=0: static binary, no libc dependency (distroless-friendly)
|
||||||
|
export CGO_ENABLED = 0
|
||||||
|
|
||||||
# ADDITIONAL_BUILD_FLAGS=""
|
# ADDITIONAL_BUILD_FLAGS=""
|
||||||
|
|
||||||
# ifeq ($(CI_RUN), true)
|
# ifeq ($(CI_RUN), true)
|
||||||
@@ -17,15 +25,15 @@ run: build ## run application
|
|||||||
|
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
build: ## build the binary
|
build: ## build the binary
|
||||||
go build -o graphql-proxy *.go
|
go build $(GOFLAGS) -ldflags="$(LDFLAGS)" -o graphql-proxy *.go
|
||||||
|
|
||||||
.PHONY: test
|
.PHONY: test
|
||||||
test: ## run tests on library
|
test: ## run tests on library
|
||||||
@LOG_LEVEL=info go test -v -cover -race ./...
|
@CGO_ENABLED=1 LOG_LEVEL=info go test -v -cover -race ./...
|
||||||
|
|
||||||
.PHONY: test-packages
|
.PHONY: test-packages
|
||||||
test-packages: ## run tests on packages
|
test-packages: ## run tests on packages
|
||||||
@go test -v -cover ./pkg/...
|
@CGO_ENABLED=1 go test -v -cover -race ./pkg/...
|
||||||
|
|
||||||
.PHONY: all
|
.PHONY: all
|
||||||
all: test-packages test
|
all: test-packages test
|
||||||
@@ -37,11 +45,11 @@ update: ## update dependencies
|
|||||||
|
|
||||||
.PHONY: build-amd64
|
.PHONY: build-amd64
|
||||||
build-amd64: ## build the Linux AMD64 binary
|
build-amd64: ## build the Linux AMD64 binary
|
||||||
GOOS=linux GOARCH=amd64 go build -o graphql-proxy-amd64 *.go
|
GOOS=linux GOARCH=amd64 go build $(GOFLAGS) -ldflags="$(LDFLAGS)" -o graphql-proxy-amd64 *.go
|
||||||
|
|
||||||
.PHONY: build-arm64
|
.PHONY: build-arm64
|
||||||
build-arm64: ## build the Linux ARM64 binary
|
build-arm64: ## build the Linux ARM64 binary
|
||||||
GOOS=linux GOARCH=arm64 go build -o graphql-proxy-arm64 *.go
|
GOOS=linux GOARCH=arm64 go build $(GOFLAGS) -ldflags="$(LDFLAGS)" -o graphql-proxy-arm64 *.go
|
||||||
|
|
||||||
.PHONY: build-all
|
.PHONY: build-all
|
||||||
build-all: build-amd64 build-arm64 ## build both AMD64 and ARM64 binaries
|
build-all: build-amd64 build-arm64 ## build both AMD64 and ARM64 binaries
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ You can still use the non-prefixed environment variables in the spirit of the ba
|
|||||||
| `MAX_CONNS_PER_HOST` | Maximum connections per host | `1024` |
|
| `MAX_CONNS_PER_HOST` | Maximum connections per host | `1024` |
|
||||||
| `CLIENT_DISABLE_TLS_VERIFY` | Disable TLS verification | `false` |
|
| `CLIENT_DISABLE_TLS_VERIFY` | Disable TLS verification | `false` |
|
||||||
| `LOG_LEVEL` | The log level | `info` |
|
| `LOG_LEVEL` | The log level | `info` |
|
||||||
|
| `ENABLE_ALLOCATION_TRACKING` | Enable per-request memory allocation tracking | `false` |
|
||||||
| `BLOCK_SCHEMA_INTROSPECTION`| Blocks the schema introspection | `false` |
|
| `BLOCK_SCHEMA_INTROSPECTION`| Blocks the schema introspection | `false` |
|
||||||
| `ALLOWED_INTROSPECTION` | Allow only certain queries in introspection | `` |
|
| `ALLOWED_INTROSPECTION` | Allow only certain queries in introspection | `` |
|
||||||
| `ENABLE_ACCESS_LOG` | Enable the access log | `false` |
|
| `ENABLE_ACCESS_LOG` | Enable the access log | `false` |
|
||||||
@@ -224,6 +225,7 @@ You can still use the non-prefixed environment variables in the spirit of the ba
|
|||||||
| `WEBSOCKET_PONG_TIMEOUT` | WebSocket pong timeout in seconds | `60` |
|
| `WEBSOCKET_PONG_TIMEOUT` | WebSocket pong timeout in seconds | `60` |
|
||||||
| `WEBSOCKET_MAX_MESSAGE_SIZE` | Max WebSocket message size in bytes | `524288` (512KB) |
|
| `WEBSOCKET_MAX_MESSAGE_SIZE` | Max WebSocket message size in bytes | `524288` (512KB) |
|
||||||
| `ADMIN_DASHBOARD_ENABLE` | Enable admin dashboard UI | `true` |
|
| `ADMIN_DASHBOARD_ENABLE` | Enable admin dashboard UI | `true` |
|
||||||
|
| `PPROF_PORT` | Localhost-only debug pprof endpoint port (default: disabled). Never expose publicly. | `` |
|
||||||
|
|
||||||
### Tracing
|
### Tracing
|
||||||
|
|
||||||
@@ -1098,16 +1100,18 @@ If you'd like the `/healthz` endpoint to perform actual check for the connectivi
|
|||||||
|
|
||||||
Example metrics produced by the proxy:
|
Example metrics produced by the proxy:
|
||||||
|
|
||||||
|
The `executed_query` and `timed_query` metrics carry only the `{op_type, cached}` label set. The previous `user_id` and `op_name` labels were removed to bound Prometheus cardinality (per-user and per-operation-name labels caused unbounded series growth).
|
||||||
|
|
||||||
```
|
```
|
||||||
graphql_proxy_timed_query_bucket{cached="false",user_id="-",op_type="mutation",op_name="updateUserDetails",vmrange="1.000e-02...1.136e-02"} 6
|
graphql_proxy_timed_query_bucket{op_type="mutation",cached="false",vmrange="1.000e-02...1.136e-02"} 6
|
||||||
graphql_proxy_timed_query_count{op_name="",cached="false",user_id="-",op_type=""} 78
|
graphql_proxy_timed_query_count{op_type="",cached="false"} 78
|
||||||
graphql_proxy_timed_query_bucket{op_name="MyQuery",cached="false",user_id="-",op_type="query",vmrange="5.995e+00...6.813e+00"} 1
|
graphql_proxy_timed_query_bucket{op_type="query",cached="false",vmrange="5.995e+00...6.813e+00"} 1
|
||||||
graphql_proxy_timed_query_sum{op_name="MyQuery",cached="false",user_id="-",op_type="query"} 6
|
graphql_proxy_timed_query_sum{op_type="query",cached="false"} 6
|
||||||
graphql_proxy_timed_query_count{op_name="MyQuery",cached="false",user_id="-",op_type="query"} 1
|
graphql_proxy_timed_query_count{op_type="query",cached="false"} 1
|
||||||
graphql_proxy_executed_query{user_id="-",op_type="mutation",op_name="updateKnownSpammer",cached="false"} 1486
|
graphql_proxy_executed_query{op_type="mutation",cached="false"} 1486
|
||||||
graphql_proxy_executed_query{user_id="-",op_type="query",op_name="checkIfAdminsNeedRefreshing",cached="false"} 13167
|
graphql_proxy_executed_query{op_type="query",cached="false"} 13167
|
||||||
graphql_proxy_executed_query{user_id="1337",op_type="query",op_name="checkIfKnownMedia",cached="false"} 429
|
graphql_proxy_executed_query{op_type="query",cached="false"} 429
|
||||||
graphql_proxy_executed_query{user_id="-",op_type="query",op_name="checkIfSpamAIRequiresUpdate",cached="false"} 8891
|
graphql_proxy_executed_query{op_type="query",cached="true"} 8891
|
||||||
graphql_proxy_requests_failed 324
|
graphql_proxy_requests_failed 324
|
||||||
graphql_proxy_requests_skipped 0
|
graphql_proxy_requests_skipped 0
|
||||||
graphql_proxy_requests_succesful 454823
|
graphql_proxy_requests_succesful 454823
|
||||||
|
|||||||
+17
-7
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -687,10 +688,19 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
|||||||
ticker := time.NewTicker(StatsStreamInterval)
|
ticker := time.NewTicker(StatsStreamInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// Per-connection encoder + buffer reused across ticks to avoid
|
||||||
|
// a fresh json.Marshal allocation every 2s per connection.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := json.NewEncoder(&buf)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
|
||||||
// Send initial stats immediately (cluster-aware for dashboard)
|
// Send initial stats immediately (cluster-aware for dashboard)
|
||||||
if stats := ad.gatherAllStatsClusterAware(); stats != nil {
|
if stats := ad.gatherAllStatsClusterAware(); stats != nil {
|
||||||
if data, err := json.Marshal(stats); err == nil {
|
buf.Reset()
|
||||||
_ = c.WriteMessage(websocket.TextMessage, data)
|
if err := enc.Encode(stats); err == nil {
|
||||||
|
// json.Encoder.Encode appends a trailing newline; strip it
|
||||||
|
// so the wire format matches the previous json.Marshal output.
|
||||||
|
_ = c.WriteMessage(websocket.TextMessage, bytes.TrimRight(buf.Bytes(), "\n"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -701,9 +711,9 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
|||||||
// Gather all stats (cluster-aware for dashboard)
|
// Gather all stats (cluster-aware for dashboard)
|
||||||
stats := ad.gatherAllStatsClusterAware()
|
stats := ad.gatherAllStatsClusterAware()
|
||||||
|
|
||||||
// Marshal to JSON
|
// Encode into reused buffer (no per-tick allocation churn)
|
||||||
data, err := json.Marshal(stats)
|
buf.Reset()
|
||||||
if err != nil {
|
if err := enc.Encode(stats); err != nil {
|
||||||
if ad.logger != nil {
|
if ad.logger != nil {
|
||||||
ad.logger.Error(&libpack_logger.LogMessage{
|
ad.logger.Error(&libpack_logger.LogMessage{
|
||||||
Message: "Failed to marshal stats for WebSocket",
|
Message: "Failed to marshal stats for WebSocket",
|
||||||
@@ -713,8 +723,8 @@ func (ad *AdminDashboard) handleStatsWebSocket(c *websocket.Conn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send to client
|
// Send to client (strip trailing newline from Encoder to match prior format)
|
||||||
if err := c.WriteMessage(websocket.TextMessage, data); err != nil {
|
if err := c.WriteMessage(websocket.TextMessage, bytes.TrimRight(buf.Bytes(), "\n")); err != nil {
|
||||||
if ad.logger != nil {
|
if ad.logger != nil {
|
||||||
ad.logger.Debug(&libpack_logger.LogMessage{
|
ad.logger.Debug(&libpack_logger.LogMessage{
|
||||||
Message: "Failed to write to WebSocket (client likely disconnected)",
|
Message: "Failed to write to WebSocket (client likely disconnected)",
|
||||||
|
|||||||
@@ -0,0 +1,247 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newClusterApp registers all cluster + control routes on a fresh Fiber app.
|
||||||
|
func newClusterApp(t *testing.T) (*fiber.App, *AdminDashboard) {
|
||||||
|
t.Helper()
|
||||||
|
app := fiber.New()
|
||||||
|
logger := libpack_logger.New()
|
||||||
|
dashboard := NewAdminDashboard(logger)
|
||||||
|
dashboard.RegisterRoutes(app)
|
||||||
|
return app, dashboard
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureNilAggregator guarantees no metrics aggregator is active for the test
|
||||||
|
// and restores the original value after.
|
||||||
|
func ensureNilAggregator(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
orig := metricsAggregator
|
||||||
|
metricsAggregator = nil
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
metricsAggregator = orig
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- getClusterStats -------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetClusterStats_NoAggregator_Returns503(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
app, _ := newClusterApp(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/admin/api/cluster/stats", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, false, body["cluster_mode"])
|
||||||
|
assert.NotEmpty(t, body["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- getClusterInstances ---------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetClusterInstances_NoAggregator_Returns503(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
app, _ := newClusterApp(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/admin/api/cluster/instances", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, false, body["cluster_mode"])
|
||||||
|
assert.NotEmpty(t, body["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- getClusterDebug -------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetClusterDebug_NoAggregator_Returns200WithFalseFlag(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
// also set cfg so the redis_cache_enabled branch is exercised
|
||||||
|
cfg = &config{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
}
|
||||||
|
cfg.Cache.CacheEnable = true
|
||||||
|
cfg.Cache.CacheRedisEnable = false
|
||||||
|
|
||||||
|
app, _ := newClusterApp(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/admin/api/cluster/debug", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, false, body["aggregator_initialized"])
|
||||||
|
assert.Equal(t, false, body["redis_cache_enabled"])
|
||||||
|
assert.Equal(t, true, body["cache_enabled"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClusterDebug_NilCfg_Returns200WithDefaults(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
orig := cfg
|
||||||
|
cfg = nil
|
||||||
|
defer func() { cfg = orig }()
|
||||||
|
|
||||||
|
app, _ := newClusterApp(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/admin/api/cluster/debug", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, false, body["aggregator_initialized"])
|
||||||
|
assert.Equal(t, false, body["redis_cache_enabled"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- forcePublish ----------------------------------------------------------
|
||||||
|
|
||||||
|
func TestForcePublish_NoAggregator_Returns503(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
app, _ := newClusterApp(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/admin/api/cluster/force-publish", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, false, body["success"])
|
||||||
|
assert.NotEmpty(t, body["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- gatherAllStats / gatherAllStatsWithMode / gatherAllStatsClusterAware --
|
||||||
|
|
||||||
|
func newDashboardForGather(t *testing.T) *AdminDashboard {
|
||||||
|
t.Helper()
|
||||||
|
logger := libpack_logger.New()
|
||||||
|
monitoring := libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{})
|
||||||
|
cfg = &config{
|
||||||
|
Logger: logger,
|
||||||
|
Monitoring: monitoring,
|
||||||
|
}
|
||||||
|
return NewAdminDashboard(logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStats_ReturnsExpectedTopLevelKeys(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
ad := newDashboardForGather(t)
|
||||||
|
|
||||||
|
result := ad.gatherAllStats()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
|
||||||
|
// cluster_mode must be false when no aggregator
|
||||||
|
assert.Equal(t, false, result["cluster_mode"])
|
||||||
|
|
||||||
|
// stats sub-map must exist
|
||||||
|
statsRaw, ok := result["stats"]
|
||||||
|
assert.True(t, ok, "stats key must be present")
|
||||||
|
stats, ok := statsRaw.(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotEmpty(t, stats["timestamp"])
|
||||||
|
assert.NotNil(t, stats["uptime_seconds"])
|
||||||
|
assert.NotNil(t, stats["uptime_human"])
|
||||||
|
assert.NotEmpty(t, stats["version"])
|
||||||
|
assert.NotNil(t, stats["requests"])
|
||||||
|
|
||||||
|
// health sub-map must exist
|
||||||
|
healthRaw, ok := result["health"]
|
||||||
|
assert.True(t, ok, "health key must be present")
|
||||||
|
health, ok := healthRaw.(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotNil(t, health["status"])
|
||||||
|
assert.NotNil(t, health["backend"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStatsWithMode_FalseMode_ReturnsLocalStats(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
ad := newDashboardForGather(t)
|
||||||
|
|
||||||
|
result := ad.gatherAllStatsWithMode(false)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, false, result["cluster_mode"])
|
||||||
|
assert.NotNil(t, result["stats"])
|
||||||
|
assert.NotNil(t, result["health"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStatsWithMode_TrueModeNoAggregator_FallsBackToLocal(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
ad := newDashboardForGather(t)
|
||||||
|
|
||||||
|
// With no aggregator, cluster mode request must fall back to local stats.
|
||||||
|
result := ad.gatherAllStatsWithMode(true)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, false, result["cluster_mode"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStatsClusterAware_NoAggregator_FallsBackToLocal(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
ad := newDashboardForGather(t)
|
||||||
|
|
||||||
|
result := ad.gatherAllStatsClusterAware()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, false, result["cluster_mode"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStats_NilCfg_ReturnsStatsWithoutRequests(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
origCfg := cfg
|
||||||
|
cfg = nil
|
||||||
|
defer func() { cfg = origCfg }()
|
||||||
|
|
||||||
|
ad := NewAdminDashboard(nil)
|
||||||
|
|
||||||
|
result := ad.gatherAllStats()
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
stats, ok := result["stats"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
// when cfg is nil, "requests" key must NOT be present
|
||||||
|
_, hasRequests := stats["requests"]
|
||||||
|
assert.False(t, hasRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGatherAllStats_RequestStatsShape(t *testing.T) {
|
||||||
|
ensureNilAggregator(t)
|
||||||
|
ad := newDashboardForGather(t)
|
||||||
|
|
||||||
|
result := ad.gatherAllStats()
|
||||||
|
stats := result["stats"].(map[string]any)
|
||||||
|
requests, ok := stats["requests"].(map[string]any)
|
||||||
|
assert.True(t, ok, "requests must be a map")
|
||||||
|
assert.NotNil(t, requests["total"])
|
||||||
|
assert.NotNil(t, requests["succeeded"])
|
||||||
|
assert.NotNil(t, requests["failed"])
|
||||||
|
assert.NotNil(t, requests["skipped"])
|
||||||
|
assert.NotNil(t, requests["success_rate_pct"])
|
||||||
|
assert.NotNil(t, requests["failure_rate_pct"])
|
||||||
|
assert.NotNil(t, requests["skip_rate_pct"])
|
||||||
|
assert.NotNil(t, requests["avg_requests_per_second"])
|
||||||
|
assert.NotNil(t, requests["current_requests_per_second"])
|
||||||
|
}
|
||||||
@@ -17,10 +17,7 @@ import (
|
|||||||
"github.com/sony/gobreaker"
|
"github.com/sony/gobreaker"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var bannedUsersIDs sync.Map // key: userID string, value: reason string
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex sync.RWMutex
|
|
||||||
)
|
|
||||||
|
|
||||||
// authMiddleware provides API key authentication for admin endpoints
|
// authMiddleware provides API key authentication for admin endpoints
|
||||||
func authMiddleware(c *fiber.Ctx) error {
|
func authMiddleware(c *fiber.Ctx) error {
|
||||||
@@ -132,16 +129,14 @@ func periodicallyReloadBannedUsers(ctx context.Context) {
|
|||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||||
Message: "Banned users reloaded",
|
Message: "Banned users reloaded",
|
||||||
Pairs: map[string]any{"users": bannedUsersIDs},
|
Pairs: map[string]any{"users": snapshotBannedUsers()},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool {
|
func checkIfUserIsBanned(c *fiber.Ctx, userID string) bool {
|
||||||
bannedUsersIDsMutex.RLock()
|
_, found := bannedUsersIDs.Load(userID)
|
||||||
_, found := bannedUsersIDs[userID]
|
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
cfg.Logger.Debug(&libpack_logger.LogMessage{
|
||||||
Message: "Checking if user is banned",
|
Message: "Checking if user is banned",
|
||||||
@@ -251,9 +246,7 @@ func apiBanUser(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusBadRequest).SendString("user_id and reason are required")
|
return c.Status(fiber.StatusBadRequest).SendString("user_id and reason are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
bannedUsersIDsMutex.Lock()
|
bannedUsersIDs.Store(req.UserID, req.Reason)
|
||||||
bannedUsersIDs[req.UserID] = req.Reason
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||||
Message: "Banned user",
|
Message: "Banned user",
|
||||||
@@ -281,9 +274,7 @@ func apiUnbanUser(c *fiber.Ctx) error {
|
|||||||
return c.Status(fiber.StatusBadRequest).SendString("user_id is required")
|
return c.Status(fiber.StatusBadRequest).SendString("user_id is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
bannedUsersIDsMutex.Lock()
|
bannedUsersIDs.Delete(req.UserID)
|
||||||
delete(bannedUsersIDs, req.UserID)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
cfg.Logger.Info(&libpack_logger.LogMessage{
|
cfg.Logger.Info(&libpack_logger.LogMessage{
|
||||||
Message: "Unbanned user",
|
Message: "Unbanned user",
|
||||||
@@ -311,9 +302,7 @@ func storeBannedUsers() error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
bannedUsersIDsMutex.RLock()
|
data, err := json.Marshal(snapshotBannedUsers())
|
||||||
data, err := json.Marshal(bannedUsersIDs)
|
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cfg.Logger.Error(&libpack_logger.LogMessage{
|
cfg.Logger.Error(&libpack_logger.LogMessage{
|
||||||
@@ -384,9 +373,33 @@ func loadBannedUsers() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(newBannedUsers)
|
||||||
bannedUsersIDs = newBannedUsers
|
}
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
// snapshotBannedUsers returns a plain map copy of the current banned users.
|
||||||
|
func snapshotBannedUsers() map[string]string {
|
||||||
|
out := make(map[string]string)
|
||||||
|
bannedUsersIDs.Range(func(k, v any) bool {
|
||||||
|
ks, kok := k.(string)
|
||||||
|
vs, vok := v.(string)
|
||||||
|
if kok && vok {
|
||||||
|
out[ks] = vs
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceBannedUsers swaps the banned users set with the provided map.
|
||||||
|
// Existing entries are removed before inserting the new ones.
|
||||||
|
func replaceBannedUsers(newUsers map[string]string) {
|
||||||
|
bannedUsersIDs.Range(func(k, _ any) bool {
|
||||||
|
bannedUsersIDs.Delete(k)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
for k, v := range newUsers {
|
||||||
|
bannedUsersIDs.Store(k, v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func lockFile(fileLock *flock.Flock) error {
|
func lockFile(fileLock *flock.Flock) error {
|
||||||
|
|||||||
+30
-51
@@ -18,9 +18,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_reload_test.json")
|
cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_reload_test.json")
|
||||||
|
|
||||||
// Initial empty banned users
|
// Initial empty banned users
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Create a test version of periodicallyReloadBannedUsers that executes once and signals completion
|
// Create a test version of periodicallyReloadBannedUsers that executes once and signals completion
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
@@ -37,9 +35,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
_ = os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
|
_ = os.Remove(fmt.Sprintf("%s.lock", cfg.Api.BannedUsersFile))
|
||||||
|
|
||||||
// Ensure banned users map is empty
|
// Ensure banned users map is empty
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Execute reloader once
|
// Execute reloader once
|
||||||
go testPeriodicallyReloadBannedUsers()
|
go testPeriodicallyReloadBannedUsers()
|
||||||
@@ -50,9 +46,7 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
// Safely check the map
|
// Safely check the map
|
||||||
bannedUsersIDsMutex.RLock()
|
mapSize := len(snapshotBannedUsers())
|
||||||
mapSize := len(bannedUsersIDs)
|
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
// Verify map is still empty
|
// Verify map is still empty
|
||||||
assert.Equal(suite.T(), 0, mapSize)
|
assert.Equal(suite.T(), 0, mapSize)
|
||||||
@@ -70,20 +64,17 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
// Clear the banned users map
|
// Clear the banned users map
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Execute reloader once
|
// Execute reloader once
|
||||||
go testPeriodicallyReloadBannedUsers()
|
go testPeriodicallyReloadBannedUsers()
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
// Safely check the map
|
// Safely check the map
|
||||||
bannedUsersIDsMutex.RLock()
|
snap := snapshotBannedUsers()
|
||||||
mapSize := len(bannedUsersIDs)
|
mapSize := len(snap)
|
||||||
value1 := bannedUsersIDs["test-user-reload-1"]
|
value1 := snap["test-user-reload-1"]
|
||||||
value2 := bannedUsersIDs["test-user-reload-2"]
|
value2 := snap["test-user-reload-2"]
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
// Verify banned users map was loaded
|
// Verify banned users map was loaded
|
||||||
assert.Equal(suite.T(), 2, mapSize)
|
assert.Equal(suite.T(), 2, mapSize)
|
||||||
@@ -102,19 +93,16 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
// Clear the banned users map
|
// Clear the banned users map
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Execute reloader once to load initial data
|
// Execute reloader once to load initial data
|
||||||
go testPeriodicallyReloadBannedUsers()
|
go testPeriodicallyReloadBannedUsers()
|
||||||
<-done
|
<-done
|
||||||
|
|
||||||
// Safely check the map
|
// Safely check the map
|
||||||
bannedUsersIDsMutex.RLock()
|
snap := snapshotBannedUsers()
|
||||||
mapSize := len(bannedUsersIDs)
|
mapSize := len(snap)
|
||||||
initialValue := bannedUsersIDs["test-user-initial"]
|
initialValue := snap["test-user-initial"]
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
// Verify initial data was loaded
|
// Verify initial data was loaded
|
||||||
assert.Equal(suite.T(), 1, mapSize)
|
assert.Equal(suite.T(), 1, mapSize)
|
||||||
@@ -134,12 +122,11 @@ func (suite *Tests) Test_PeriodicallyReloadBannedUsers() {
|
|||||||
<-done
|
<-done
|
||||||
|
|
||||||
// Safely check the map
|
// Safely check the map
|
||||||
bannedUsersIDsMutex.RLock()
|
snap = snapshotBannedUsers()
|
||||||
mapSize = len(bannedUsersIDs)
|
mapSize = len(snap)
|
||||||
value1 := bannedUsersIDs["test-user-updated-1"]
|
value1 := snap["test-user-updated-1"]
|
||||||
value2 := bannedUsersIDs["test-user-updated-2"]
|
value2 := snap["test-user-updated-2"]
|
||||||
_, exists := bannedUsersIDs["test-user-initial"]
|
_, exists := snap["test-user-initial"]
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
// Verify updated data was loaded
|
// Verify updated data was loaded
|
||||||
assert.Equal(suite.T(), 2, mapSize)
|
assert.Equal(suite.T(), 2, mapSize)
|
||||||
@@ -175,19 +162,16 @@ func (suite *Tests) Test_LoadUnloadBannedUsers() {
|
|||||||
// Test loading banned users
|
// Test loading banned users
|
||||||
suite.Run("load banned users", func() {
|
suite.Run("load banned users", func() {
|
||||||
// Clear the banned users map
|
// Clear the banned users map
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Load banned users
|
// Load banned users
|
||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
|
|
||||||
// Check the banned users map
|
// Check the banned users map
|
||||||
bannedUsersIDsMutex.RLock()
|
snap := snapshotBannedUsers()
|
||||||
count := len(bannedUsersIDs)
|
count := len(snap)
|
||||||
reason1 := bannedUsersIDs["user1"]
|
reason1 := snap["user1"]
|
||||||
reason2 := bannedUsersIDs["user2"]
|
reason2 := snap["user2"]
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(suite.T(), 2, count)
|
assert.Equal(suite.T(), 2, count)
|
||||||
assert.Equal(suite.T(), "reason1", reason1)
|
assert.Equal(suite.T(), "reason1", reason1)
|
||||||
@@ -197,32 +181,27 @@ func (suite *Tests) Test_LoadUnloadBannedUsers() {
|
|||||||
// Test updating banned users
|
// Test updating banned users
|
||||||
suite.Run("update banned users", func() {
|
suite.Run("update banned users", func() {
|
||||||
// Update the banned users map
|
// Update the banned users map
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{
|
||||||
bannedUsersIDs = map[string]string{
|
|
||||||
"user3": "reason3",
|
"user3": "reason3",
|
||||||
"user4": "reason4",
|
"user4": "reason4",
|
||||||
}
|
})
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Store the updated banned users
|
// Store the updated banned users
|
||||||
err := storeBannedUsers()
|
err := storeBannedUsers()
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
// Clear the banned users map
|
// Clear the banned users map
|
||||||
bannedUsersIDsMutex.Lock()
|
replaceBannedUsers(map[string]string{})
|
||||||
bannedUsersIDs = make(map[string]string)
|
|
||||||
bannedUsersIDsMutex.Unlock()
|
|
||||||
|
|
||||||
// Load banned users again
|
// Load banned users again
|
||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
|
|
||||||
// Check the banned users map
|
// Check the banned users map
|
||||||
bannedUsersIDsMutex.RLock()
|
snap := snapshotBannedUsers()
|
||||||
count := len(bannedUsersIDs)
|
count := len(snap)
|
||||||
reason3 := bannedUsersIDs["user3"]
|
reason3 := snap["user3"]
|
||||||
reason4 := bannedUsersIDs["user4"]
|
reason4 := snap["user4"]
|
||||||
_, user1Exists := bannedUsersIDs["user1"]
|
_, user1Exists := snap["user1"]
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
assert.Equal(suite.T(), 2, count)
|
assert.Equal(suite.T(), 2, count)
|
||||||
assert.Equal(suite.T(), "reason3", reason3)
|
assert.Equal(suite.T(), "reason3", reason3)
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (suite *APIAuthSecurityTestSuite) SetupTest() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Initialize banned users map
|
// Initialize banned users map
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
|
|
||||||
// Setup banned users file path
|
// Setup banned users file path
|
||||||
cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_auth_test.json")
|
cfg.Api.BannedUsersFile = filepath.Join(os.TempDir(), "banned_users_auth_test.json")
|
||||||
|
|||||||
@@ -0,0 +1,256 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
fiber "github.com/gofiber/fiber/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- helpers ---------------------------------------------------------------
|
||||||
|
|
||||||
|
func setupMinimalCfg(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
logger := libpack_logger.New()
|
||||||
|
monitoring := libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{})
|
||||||
|
cfg = &config{
|
||||||
|
Logger: logger,
|
||||||
|
Monitoring: monitoring,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHealthApp(t *testing.T) *fiber.App {
|
||||||
|
t.Helper()
|
||||||
|
app := fiber.New(fiber.Config{
|
||||||
|
// suppress stack-trace noise in test output
|
||||||
|
})
|
||||||
|
app.Get("/api/backend/health", apiBackendHealth)
|
||||||
|
app.Get("/api/pool/health", apiConnectionPoolHealth)
|
||||||
|
app.Get("/api/circuit-breaker/health", apiCircuitBreakerHealth)
|
||||||
|
return app
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- apiBackendHealth ------------------------------------------------------
|
||||||
|
|
||||||
|
func TestApiBackendHealth_NilManager_Returns503(t *testing.T) {
|
||||||
|
// Ensure global manager is nil for this test.
|
||||||
|
orig := backendHealthManager
|
||||||
|
backendHealthManager = nil
|
||||||
|
defer func() { backendHealthManager = orig }()
|
||||||
|
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/backend/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "unknown", body["status"])
|
||||||
|
assert.NotEmpty(t, body["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiBackendHealth_HealthyManager_Returns200(t *testing.T) {
|
||||||
|
orig := backendHealthManager
|
||||||
|
defer func() { backendHealthManager = orig }()
|
||||||
|
|
||||||
|
// inject a healthy manager directly (bypassing sync.Once)
|
||||||
|
mgr := NewBackendHealthManager(&fasthttp.Client{}, "http://localhost:8080", libpack_logger.New())
|
||||||
|
mgr.isHealthy.Store(true)
|
||||||
|
backendHealthManager = mgr
|
||||||
|
|
||||||
|
setupMinimalCfg(t)
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/backend/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "healthy", body["status"])
|
||||||
|
assert.NotNil(t, body["backend_url"])
|
||||||
|
assert.NotNil(t, body["consecutive_failures"])
|
||||||
|
assert.NotNil(t, body["check_interval"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiBackendHealth_UnhealthyManager_Returns503(t *testing.T) {
|
||||||
|
orig := backendHealthManager
|
||||||
|
defer func() { backendHealthManager = orig }()
|
||||||
|
|
||||||
|
mgr := NewBackendHealthManager(&fasthttp.Client{}, "http://localhost:8080", libpack_logger.New())
|
||||||
|
mgr.isHealthy.Store(false)
|
||||||
|
backendHealthManager = mgr
|
||||||
|
|
||||||
|
setupMinimalCfg(t)
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/backend/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "unhealthy", body["status"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- apiConnectionPoolHealth -----------------------------------------------
|
||||||
|
|
||||||
|
func TestApiConnectionPoolHealth_NilManager_Returns503(t *testing.T) {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
orig := connectionPoolManager
|
||||||
|
connectionPoolManager = nil
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
connectionPoolManager = orig
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/pool/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "unknown", body["status"])
|
||||||
|
assert.NotEmpty(t, body["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiConnectionPoolHealth_HealthyPool_Returns200(t *testing.T) {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
orig := connectionPoolManager
|
||||||
|
mgr := NewConnectionPoolManager(&fasthttp.Client{})
|
||||||
|
connectionPoolManager = mgr
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
_ = mgr.Shutdown()
|
||||||
|
connectionPoolManager = orig
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/pool/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "healthy", body["status"])
|
||||||
|
assert.NotNil(t, body["active_connections"])
|
||||||
|
assert.NotNil(t, body["total_connections"])
|
||||||
|
assert.NotNil(t, body["connection_failures"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiConnectionPoolHealth_DegradedPool_Returns200WithDegradedStatus(t *testing.T) {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
orig := connectionPoolManager
|
||||||
|
mgr := NewConnectionPoolManager(&fasthttp.Client{})
|
||||||
|
// push failure counter above threshold (10)
|
||||||
|
for range 15 {
|
||||||
|
mgr.connectionFailures.Add(1)
|
||||||
|
}
|
||||||
|
connectionPoolManager = mgr
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
connectionPoolMutex.Lock()
|
||||||
|
_ = mgr.Shutdown()
|
||||||
|
connectionPoolManager = orig
|
||||||
|
connectionPoolMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/pool/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// handler returns 200 even for degraded
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "degraded", body["status"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- apiCircuitBreakerHealth -----------------------------------------------
|
||||||
|
|
||||||
|
func TestApiCircuitBreakerHealth_NilCB_Returns503(t *testing.T) {
|
||||||
|
cbMutex.Lock()
|
||||||
|
origCB := cb
|
||||||
|
cb = nil
|
||||||
|
cbMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cbMutex.Lock()
|
||||||
|
cb = origCB
|
||||||
|
cbMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/circuit-breaker/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 503, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "disabled", body["status"])
|
||||||
|
assert.NotEmpty(t, body["message"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApiCircuitBreakerHealth_ClosedCB_Returns200Healthy(t *testing.T) {
|
||||||
|
cbMutex.Lock()
|
||||||
|
origCB := cb
|
||||||
|
cbMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cbMutex.Lock()
|
||||||
|
cb = origCB
|
||||||
|
cbMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger := libpack_logger.New()
|
||||||
|
monitoring := libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{})
|
||||||
|
cfg = &config{Logger: logger, Monitoring: monitoring}
|
||||||
|
cfg.CircuitBreaker.Enable = true
|
||||||
|
cfg.CircuitBreaker.MaxFailures = 5
|
||||||
|
cfg.CircuitBreaker.Timeout = 30
|
||||||
|
initCircuitBreaker(cfg)
|
||||||
|
|
||||||
|
// cb is now set by initCircuitBreaker; circuit starts closed (healthy)
|
||||||
|
app := newHealthApp(t)
|
||||||
|
req := httptest.NewRequest("GET", "/api/circuit-breaker/health", nil)
|
||||||
|
resp, err := app.Test(req)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.NoError(t, json.Unmarshal(raw, &body))
|
||||||
|
assert.Equal(t, "healthy", body["status"])
|
||||||
|
assert.NotNil(t, body["state"])
|
||||||
|
assert.NotNil(t, body["counts"])
|
||||||
|
assert.NotNil(t, body["configuration"])
|
||||||
|
|
||||||
|
counts, ok := body["counts"].(map[string]any)
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.NotNil(t, counts["requests"])
|
||||||
|
assert.NotNil(t, counts["total_successes"])
|
||||||
|
assert.NotNil(t, counts["total_failures"])
|
||||||
|
assert.NotNil(t, counts["consecutive_successes"])
|
||||||
|
assert.NotNil(t, counts["consecutive_failures"])
|
||||||
|
}
|
||||||
+20
-25
@@ -33,7 +33,7 @@ func (suite *Tests) Test_apiBanUser() {
|
|||||||
// Test valid ban request
|
// Test valid ban request
|
||||||
suite.Run("valid ban request", func() {
|
suite.Run("valid ban request", func() {
|
||||||
// Clear banned users map
|
// Clear banned users map
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
|
|
||||||
reqBody := `{"user_id": "test-user-123", "reason": "testing"}`
|
reqBody := `{"user_id": "test-user-123", "reason": "testing"}`
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest(http.MethodPost, "/api/user-ban", bytes.NewBufferString(reqBody))
|
||||||
@@ -48,12 +48,11 @@ func (suite *Tests) Test_apiBanUser() {
|
|||||||
assert.Contains(suite.T(), string(body), "OK: user banned")
|
assert.Contains(suite.T(), string(body), "OK: user banned")
|
||||||
|
|
||||||
// Verify user was added to banned users map
|
// Verify user was added to banned users map
|
||||||
bannedUsersIDsMutex.RLock()
|
v, exists := bannedUsersIDs.Load("test-user-123")
|
||||||
reason, exists := bannedUsersIDs["test-user-123"]
|
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
assert.True(suite.T(), exists)
|
assert.True(suite.T(), exists)
|
||||||
assert.Equal(suite.T(), "testing", reason)
|
if exists {
|
||||||
|
assert.Equal(suite.T(), "testing", v.(string))
|
||||||
|
}
|
||||||
|
|
||||||
// Verify file was created
|
// Verify file was created
|
||||||
_, err = os.Stat(cfg.Api.BannedUsersFile)
|
_, err = os.Stat(cfg.Api.BannedUsersFile)
|
||||||
@@ -124,8 +123,7 @@ func (suite *Tests) Test_apiUnbanUser() {
|
|||||||
// Test valid unban request
|
// Test valid unban request
|
||||||
suite.Run("valid unban request", func() {
|
suite.Run("valid unban request", func() {
|
||||||
// Add a user to the banned list
|
// Add a user to the banned list
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{"test-user-123": "testing"})
|
||||||
bannedUsersIDs["test-user-123"] = "testing"
|
|
||||||
|
|
||||||
reqBody := `{"user_id": "test-user-123"}`
|
reqBody := `{"user_id": "test-user-123"}`
|
||||||
req := httptest.NewRequest(http.MethodPost, "/api/user-unban", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest(http.MethodPost, "/api/user-unban", bytes.NewBufferString(reqBody))
|
||||||
@@ -140,10 +138,7 @@ func (suite *Tests) Test_apiUnbanUser() {
|
|||||||
assert.Contains(suite.T(), string(body), "OK: user unbanned")
|
assert.Contains(suite.T(), string(body), "OK: user unbanned")
|
||||||
|
|
||||||
// Verify user was removed from banned users map
|
// Verify user was removed from banned users map
|
||||||
bannedUsersIDsMutex.RLock()
|
_, exists := bannedUsersIDs.Load("test-user-123")
|
||||||
_, exists := bannedUsersIDs["test-user-123"]
|
|
||||||
bannedUsersIDsMutex.RUnlock()
|
|
||||||
|
|
||||||
assert.False(suite.T(), exists)
|
assert.False(suite.T(), exists)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -273,7 +268,7 @@ func (suite *Tests) Test_checkIfUserIsBanned() {
|
|||||||
|
|
||||||
// Test with non-banned user
|
// Test with non-banned user
|
||||||
suite.Run("non-banned user", func() {
|
suite.Run("non-banned user", func() {
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
|
|
||||||
isBanned := checkIfUserIsBanned(ctx, "non-banned-user")
|
isBanned := checkIfUserIsBanned(ctx, "non-banned-user")
|
||||||
assert.False(suite.T(), isBanned)
|
assert.False(suite.T(), isBanned)
|
||||||
@@ -282,8 +277,7 @@ func (suite *Tests) Test_checkIfUserIsBanned() {
|
|||||||
|
|
||||||
// Test with banned user
|
// Test with banned user
|
||||||
suite.Run("banned user", func() {
|
suite.Run("banned user", func() {
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{"banned-user": "testing"})
|
||||||
bannedUsersIDs["banned-user"] = "testing"
|
|
||||||
|
|
||||||
isBanned := checkIfUserIsBanned(ctx, "banned-user")
|
isBanned := checkIfUserIsBanned(ctx, "banned-user")
|
||||||
assert.True(suite.T(), isBanned)
|
assert.True(suite.T(), isBanned)
|
||||||
@@ -303,7 +297,7 @@ func (suite *Tests) Test_loadBannedUsers() {
|
|||||||
// Remove file if it exists
|
// Remove file if it exists
|
||||||
_ = os.Remove(cfg.Api.BannedUsersFile)
|
_ = os.Remove(cfg.Api.BannedUsersFile)
|
||||||
|
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
|
|
||||||
// Verify file was created
|
// Verify file was created
|
||||||
@@ -311,7 +305,7 @@ func (suite *Tests) Test_loadBannedUsers() {
|
|||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
// Verify banned users map is empty
|
// Verify banned users map is empty
|
||||||
assert.Equal(suite.T(), 0, len(bannedUsersIDs))
|
assert.Equal(suite.T(), 0, len(snapshotBannedUsers()))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test with existing file
|
// Test with existing file
|
||||||
@@ -325,13 +319,14 @@ func (suite *Tests) Test_loadBannedUsers() {
|
|||||||
err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0o644)
|
err := os.WriteFile(cfg.Api.BannedUsersFile, data, 0o644)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
|
|
||||||
// Verify banned users map was loaded
|
// Verify banned users map was loaded
|
||||||
assert.Equal(suite.T(), 2, len(bannedUsersIDs))
|
snap := snapshotBannedUsers()
|
||||||
assert.Equal(suite.T(), "reason 1", bannedUsersIDs["test-user-1"])
|
assert.Equal(suite.T(), 2, len(snap))
|
||||||
assert.Equal(suite.T(), "reason 2", bannedUsersIDs["test-user-2"])
|
assert.Equal(suite.T(), "reason 1", snap["test-user-1"])
|
||||||
|
assert.Equal(suite.T(), "reason 2", snap["test-user-2"])
|
||||||
})
|
})
|
||||||
|
|
||||||
// Test with invalid JSON
|
// Test with invalid JSON
|
||||||
@@ -340,11 +335,11 @@ func (suite *Tests) Test_loadBannedUsers() {
|
|||||||
err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{invalid json}"), 0o644)
|
err := os.WriteFile(cfg.Api.BannedUsersFile, []byte("{invalid json}"), 0o644)
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|
||||||
bannedUsersIDs = make(map[string]string)
|
replaceBannedUsers(map[string]string{})
|
||||||
loadBannedUsers()
|
loadBannedUsers()
|
||||||
|
|
||||||
// Verify banned users map is empty (load failed)
|
// Verify banned users map is empty (load failed)
|
||||||
assert.Equal(suite.T(), 0, len(bannedUsersIDs))
|
assert.Equal(suite.T(), 0, len(snapshotBannedUsers()))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
@@ -362,10 +357,10 @@ func (suite *Tests) Test_storeBannedUsers() {
|
|||||||
// Test storing banned users
|
// Test storing banned users
|
||||||
suite.Run("store banned users", func() {
|
suite.Run("store banned users", func() {
|
||||||
// Set up test data
|
// Set up test data
|
||||||
bannedUsersIDs = map[string]string{
|
replaceBannedUsers(map[string]string{
|
||||||
"test-user-1": "reason 1",
|
"test-user-1": "reason 1",
|
||||||
"test-user-2": "reason 2",
|
"test-user-2": "reason 2",
|
||||||
}
|
})
|
||||||
|
|
||||||
err := storeBannedUsers()
|
err := storeBannedUsers()
|
||||||
assert.NoError(suite.T(), err)
|
assert.NoError(suite.T(), err)
|
||||||
|
|||||||
Vendored
+218
@@ -0,0 +1,218 @@
|
|||||||
|
package libpack_cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
libpack_cache_memory "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
ta "github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// helper resets package-level globals and returns a cleanup func.
|
||||||
|
func withFreshMemoryCache(t *testing.T, ttl time.Duration) func() {
|
||||||
|
t.Helper()
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
config = &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
Client: libpack_cache_memory.New(ttl),
|
||||||
|
TTL: int(ttl.Seconds()),
|
||||||
|
}
|
||||||
|
cacheStats = &CacheStats{}
|
||||||
|
return func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheMemoryUsage_Initialized covers the initialized branch (was 0%).
|
||||||
|
func TestGetCacheMemoryUsage_Initialized_ReturnsNonNegative(t *testing.T) {
|
||||||
|
defer withFreshMemoryCache(t, 5*time.Minute)()
|
||||||
|
|
||||||
|
usage := GetCacheMemoryUsage()
|
||||||
|
ta.GreaterOrEqual(t, usage, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheMemoryUsage_Uninitialized covers the early-return branch.
|
||||||
|
func TestGetCacheMemoryUsage_Uninitialized_ReturnsZero(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
config = nil
|
||||||
|
defer func() { config = prev }()
|
||||||
|
|
||||||
|
ta.Equal(t, int64(0), GetCacheMemoryUsage())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheMaxMemorySize_Initialized covers the initialized branch (was 0%).
|
||||||
|
func TestGetCacheMaxMemorySize_Initialized_ReturnsPositive(t *testing.T) {
|
||||||
|
defer withFreshMemoryCache(t, 5*time.Minute)()
|
||||||
|
|
||||||
|
maxSize := GetCacheMaxMemorySize()
|
||||||
|
ta.Greater(t, maxSize, int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetCacheMaxMemorySize_Uninitialized covers the early-return branch.
|
||||||
|
func TestGetCacheMaxMemorySize_Uninitialized_ReturnsZero(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
config = nil
|
||||||
|
defer func() { config = prev }()
|
||||||
|
|
||||||
|
ta.Equal(t, int64(0), GetCacheMaxMemorySize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnableCache_LRUBranch covers cfg.Memory.UseLRU == true branch in EnableCache.
|
||||||
|
func TestEnableCache_LRUBranch_InitializesLRUClient(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
defer func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
TTL: 5,
|
||||||
|
}
|
||||||
|
cfg.Memory.UseLRU = true
|
||||||
|
cfg.Memory.MaxMemorySize = 1024 * 1024
|
||||||
|
cfg.Memory.MaxEntries = 100
|
||||||
|
|
||||||
|
EnableCache(cfg)
|
||||||
|
require.NotNil(t, config.Client, "LRU client must be set")
|
||||||
|
ta.True(t, IsCacheInitialized())
|
||||||
|
|
||||||
|
// Verify basic ops work with LRU client.
|
||||||
|
CacheStore("lru-key", []byte("lru-val"))
|
||||||
|
got := CacheLookup("lru-key")
|
||||||
|
ta.Equal(t, []byte("lru-val"), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnableCache_NilLogger covers the auto-logger creation branch.
|
||||||
|
func TestEnableCache_NilLogger_AutoCreatesLogger(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
defer func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &CacheConfig{
|
||||||
|
Logger: nil, // deliberately nil
|
||||||
|
TTL: 5,
|
||||||
|
}
|
||||||
|
// Should not panic; logger is created internally.
|
||||||
|
ta.NotPanics(t, func() { EnableCache(cfg) })
|
||||||
|
ta.NotNil(t, cfg.Logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnableCache_MemoryDefaults covers the default memory sizing branch (maxMemory<=0).
|
||||||
|
func TestEnableCache_MemoryDefaults_UsesDefaultSizes(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
defer func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
TTL: 5,
|
||||||
|
}
|
||||||
|
// MaxMemorySize and MaxEntries left at zero → defaults kick in.
|
||||||
|
EnableCache(cfg)
|
||||||
|
require.NotNil(t, config.Client)
|
||||||
|
ta.Greater(t, GetCacheMaxMemorySize(), int64(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnableCache_RedisFallback covers the Redis error → memory fallback branch.
|
||||||
|
func TestEnableCache_RedisFallback_FallsBackToMemory(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
defer func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
TTL: 5,
|
||||||
|
}
|
||||||
|
cfg.Redis.Enable = true
|
||||||
|
cfg.Redis.URL = "127.0.0.1:1" // unreachable port → connection error
|
||||||
|
cfg.Redis.DB = 0
|
||||||
|
|
||||||
|
// Must not panic; should fall back to memory.
|
||||||
|
ta.NotPanics(t, func() { EnableCache(cfg) })
|
||||||
|
require.NotNil(t, config.Client, "fallback memory client must be set")
|
||||||
|
|
||||||
|
// Verify it actually works as a memory cache.
|
||||||
|
CacheStore("fallback-key", []byte("fallback-val"))
|
||||||
|
got := CacheLookup("fallback-key")
|
||||||
|
ta.Equal(t, []byte("fallback-val"), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheStore_Uninitialized covers the early-return + log branch in CacheStore (line 238-242).
|
||||||
|
func TestCacheStore_Uninitialized_DoesNotPanic(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
config = &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
Client: nil, // IsCacheInitialized() returns false
|
||||||
|
}
|
||||||
|
defer func() { config = prev }()
|
||||||
|
|
||||||
|
ta.NotPanics(t, func() {
|
||||||
|
CacheStore("any-key", []byte("any-val"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheClear_Uninitialized covers the early-return in CacheClear.
|
||||||
|
func TestCacheClear_Uninitialized_DoesNotPanic(t *testing.T) {
|
||||||
|
prev := config
|
||||||
|
config = nil
|
||||||
|
defer func() { config = prev }()
|
||||||
|
|
||||||
|
ta.NotPanics(t, func() { CacheClear() })
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCacheDelete_ZeroStats covers the CAS loop branch where CachedQueries is already 0.
|
||||||
|
func TestCacheDelete_ZeroStats_DoesNotDecrementBelowZero(t *testing.T) {
|
||||||
|
defer withFreshMemoryCache(t, 5*time.Minute)()
|
||||||
|
cacheStats.CachedQueries = 0 // already at zero
|
||||||
|
|
||||||
|
// Should not panic and stats should stay at 0.
|
||||||
|
CacheDelete("nonexistent-key")
|
||||||
|
ta.Equal(t, int64(0), cacheStats.CachedQueries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEnableCache_Redis_HappyPath covers successful Redis init via miniredis.
|
||||||
|
func TestEnableCache_Redis_HappyPath_StoresAndRetrieves(t *testing.T) {
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer mr.Close()
|
||||||
|
|
||||||
|
prev := config
|
||||||
|
prevStats := cacheStats
|
||||||
|
defer func() {
|
||||||
|
config = prev
|
||||||
|
cacheStats = prevStats
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := &CacheConfig{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
TTL: 5,
|
||||||
|
}
|
||||||
|
cfg.Redis.Enable = true
|
||||||
|
cfg.Redis.URL = mr.Addr()
|
||||||
|
cfg.Redis.DB = 0
|
||||||
|
EnableCache(cfg)
|
||||||
|
|
||||||
|
require.True(t, IsCacheInitialized())
|
||||||
|
CacheStore("r-key", []byte("r-val"))
|
||||||
|
ta.Equal(t, []byte("r-val"), CacheLookup("r-key"))
|
||||||
|
|
||||||
|
// GetCacheMemoryUsage and GetCacheMaxMemorySize via Redis wrapper.
|
||||||
|
ta.GreaterOrEqual(t, GetCacheMemoryUsage(), int64(0))
|
||||||
|
ta.GreaterOrEqual(t, GetCacheMaxMemorySize(), int64(0))
|
||||||
|
}
|
||||||
Vendored
+34
-19
@@ -52,13 +52,9 @@ func NewLRUMemoryCache(maxMemorySize, maxEntries int64) *LRUMemoryCache {
|
|||||||
|
|
||||||
// Set adds or updates an entry in the cache
|
// Set adds or updates an entry in the cache
|
||||||
func (c *LRUMemoryCache) Set(key string, value []byte, ttl time.Duration) {
|
func (c *LRUMemoryCache) Set(key string, value []byte, ttl time.Duration) {
|
||||||
c.mu.Lock()
|
// Compress OUTSIDE the lock — gzip is CPU-bound and pool ops are
|
||||||
defer c.mu.Unlock()
|
// goroutine-safe. Result is just a byte slice, safe to hand to the
|
||||||
|
// critical section below.
|
||||||
// Calculate expiry time
|
|
||||||
expiresAt := time.Now().Add(ttl)
|
|
||||||
|
|
||||||
// Check if we should compress
|
|
||||||
compressed := false
|
compressed := false
|
||||||
finalValue := value
|
finalValue := value
|
||||||
if len(value) > 1024 { // Compress if larger than 1KB
|
if len(value) > 1024 { // Compress if larger than 1KB
|
||||||
@@ -69,6 +65,10 @@ func (c *LRUMemoryCache) Set(key string, value []byte, ttl time.Duration) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
entrySize := int64(len(key) + len(finalValue) + 64) // 64 bytes overhead estimate
|
entrySize := int64(len(key) + len(finalValue) + 64) // 64 bytes overhead estimate
|
||||||
|
expiresAt := time.Now().Add(ttl)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
// Check if key exists
|
// Check if key exists
|
||||||
if existing, exists := c.entries[key]; exists {
|
if existing, exists := c.entries[key]; exists {
|
||||||
@@ -107,34 +107,49 @@ func (c *LRUMemoryCache) Set(key string, value []byte, ttl time.Duration) {
|
|||||||
|
|
||||||
// Get retrieves a value from the cache
|
// Get retrieves a value from the cache
|
||||||
func (c *LRUMemoryCache) Get(key string) ([]byte, bool) {
|
func (c *LRUMemoryCache) Get(key string) ([]byte, bool) {
|
||||||
|
// Snapshot the stored bytes under the lock, then release before
|
||||||
|
// decompressing — gzip is CPU-bound and must not serialise other ops.
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
entry, exists := c.entries[key]
|
entry, exists := c.entries[key]
|
||||||
if !exists {
|
if !exists {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if expired
|
// Check if expired (must use the entry's stored expiry while locked)
|
||||||
if time.Now().After(entry.expiresAt) {
|
if time.Now().After(entry.expiresAt) {
|
||||||
c.removeEntry(entry)
|
c.removeEntry(entry)
|
||||||
|
c.mu.Unlock()
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move to front (most recently used)
|
// Move to front (most recently used)
|
||||||
c.evictList.MoveToFront(entry.element)
|
c.evictList.MoveToFront(entry.element)
|
||||||
|
|
||||||
// Decompress if needed
|
if !entry.compressed {
|
||||||
if entry.compressed {
|
// Uncompressed payload is immutable once stored, safe to return directly.
|
||||||
if decompressed, err := c.decompress(entry.value); err == nil {
|
value := entry.value
|
||||||
return decompressed, true
|
c.mu.Unlock()
|
||||||
}
|
return value, true
|
||||||
// If decompression fails, remove the entry
|
|
||||||
c.removeEntry(entry)
|
|
||||||
return nil, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return entry.value, true
|
// Snapshot compressed bytes locally, drop lock, then decompress.
|
||||||
|
compressedBytes := entry.value
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
decompressed, err := c.decompress(compressedBytes)
|
||||||
|
if err == nil {
|
||||||
|
return decompressed, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompression failed — re-acquire lock to remove the bad entry,
|
||||||
|
// but only if it still exists and still points at the same payload.
|
||||||
|
c.mu.Lock()
|
||||||
|
if cur, ok := c.entries[key]; ok && cur == entry {
|
||||||
|
c.removeEntry(cur)
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes an entry from the cache
|
// Delete removes an entry from the cache
|
||||||
|
|||||||
Vendored
+334
@@ -0,0 +1,334 @@
|
|||||||
|
package libpack_cache_redis
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// helpers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func newTestRedis(t *testing.T) (*RedisConfig, *miniredis.Miniredis) {
|
||||||
|
t.Helper()
|
||||||
|
s, err := miniredis.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(s.Close)
|
||||||
|
|
||||||
|
rc, err := New(&RedisClientConfig{
|
||||||
|
RedisServer: s.Addr(),
|
||||||
|
Prefix: "pfx:",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return rc, s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestWrapper(t *testing.T) (*CacheWrapper, *miniredis.Miniredis) {
|
||||||
|
t.Helper()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
w := NewCacheWrapper(rc, libpack_logger.New())
|
||||||
|
return w, s
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// New — connection failure path
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNew_ConnectionFailure_ReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, err := New(&RedisClientConfig{
|
||||||
|
RedisServer: "127.0.0.1:1", // nothing listens here
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — GetMemoryUsage
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetMemoryUsage_ConnectedServer_ReturnsZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
got := rc.GetMemoryUsage()
|
||||||
|
// Implementation always returns 0 as a placeholder; assert the contract.
|
||||||
|
assert.Equal(t, int64(0), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetMemoryUsage_ClosedServer_ReturnsZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
s.Close() // simulate disconnection before cleanup fires
|
||||||
|
got := rc.GetMemoryUsage()
|
||||||
|
assert.Equal(t, int64(0), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — GetMaxMemorySize
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGetMaxMemorySize_AlwaysZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
assert.Equal(t, int64(0), rc.GetMaxMemorySize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — Get error path (closed server)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGet_ClosedServer_ReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
// Set a key while server is up, then close.
|
||||||
|
require.NoError(t, rc.Set("k", []byte("v"), 0))
|
||||||
|
s.Close()
|
||||||
|
|
||||||
|
_, found, err := rc.Get("k")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — CountQueries error path
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCountQueries_ClosedServer_ReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
s.Close()
|
||||||
|
|
||||||
|
count, err := rc.CountQueries()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, int64(0), count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — CountQueriesWithPattern error path
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCountQueriesWithPattern_ClosedServer_ReturnsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
s.Close()
|
||||||
|
|
||||||
|
count, err := rc.CountQueriesWithPattern("*")
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — TTL=0 (no expiry) vs expired key
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestGet_MissingKey_ReturnsFalseNoError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
val, found, err := rc.Get("nonexistent-key-xyz")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, found)
|
||||||
|
assert.Nil(t, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_TTLZero_KeyPersists(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
require.NoError(t, rc.Set("persist", []byte("yes"), 0))
|
||||||
|
s.FastForward(24 * time.Hour)
|
||||||
|
_, found, err := rc.Get("persist")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSet_WithTTL_KeyExpires(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, s := newTestRedis(t)
|
||||||
|
require.NoError(t, rc.Set("expires", []byte("yes"), 1*time.Second))
|
||||||
|
s.FastForward(2 * time.Second)
|
||||||
|
_, found, err := rc.Get("expires")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — large value round-trip
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestSet_LargeValue_RoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
large := make([]byte, 1<<16) // 64 KB
|
||||||
|
for i := range large {
|
||||||
|
large[i] = byte(i % 251)
|
||||||
|
}
|
||||||
|
require.NoError(t, rc.Set("big", large, 0))
|
||||||
|
got, found, err := rc.Get("big")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, large, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// redis.go — prefix isolation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestPrerendKeyName_PrefixIsolation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s, err := miniredis.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer s.Close()
|
||||||
|
|
||||||
|
rc1, err := New(&RedisClientConfig{RedisServer: s.Addr(), Prefix: "a:"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
rc2, err := New(&RedisClientConfig{RedisServer: s.Addr(), Prefix: "b:"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, rc1.Set("key", []byte("one"), 0))
|
||||||
|
require.NoError(t, rc2.Set("key", []byte("two"), 0))
|
||||||
|
|
||||||
|
v1, ok1, err1 := rc1.Get("key")
|
||||||
|
assert.NoError(t, err1)
|
||||||
|
assert.True(t, ok1)
|
||||||
|
assert.Equal(t, []byte("one"), v1)
|
||||||
|
|
||||||
|
v2, ok2, err2 := rc2.Get("key")
|
||||||
|
assert.NoError(t, err2)
|
||||||
|
assert.True(t, ok2)
|
||||||
|
assert.Equal(t, []byte("two"), v2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// wrapper.go — NewCacheWrapper with explicit logger
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestNewCacheWrapper_WithLogger_UsesIt(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
logger := &libpack_logger.Logger{}
|
||||||
|
w := NewCacheWrapper(rc, logger)
|
||||||
|
assert.NotNil(t, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewCacheWrapper_NilLogger_DoesNotPanic(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rc, _ := newTestRedis(t)
|
||||||
|
// NewCacheWrapper substitutes a zero-value Logger when nil is passed.
|
||||||
|
// Only verify construction succeeds; don't exercise error paths through
|
||||||
|
// this wrapper because zero-value Logger.output is nil and would panic.
|
||||||
|
w := NewCacheWrapper(rc, nil)
|
||||||
|
assert.NotNil(t, w)
|
||||||
|
// Happy-path operations are safe even with the zero-value logger.
|
||||||
|
w.Set("probe", []byte("ok"), 0)
|
||||||
|
got, found := w.Get("probe")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, []byte("ok"), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// wrapper.go — Set / Get / Delete / Clear happy paths
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestWrapper_SetAndGet_HappyPath(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
w.Set("wkey", []byte("wval"), 0)
|
||||||
|
got, found := w.Get("wkey")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, []byte("wval"), got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Get_MissingKey_ReturnsFalse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
val, found := w.Get("ghost")
|
||||||
|
assert.False(t, found)
|
||||||
|
assert.Nil(t, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Delete_RemovesKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
w.Set("del", []byte("gone"), 0)
|
||||||
|
w.Delete("del")
|
||||||
|
_, found := w.Get("del")
|
||||||
|
assert.False(t, found)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Clear_RemovesAllKeys(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
w.Set("a", []byte("1"), 0)
|
||||||
|
w.Set("b", []byte("2"), 0)
|
||||||
|
w.Clear()
|
||||||
|
assert.Equal(t, int64(0), w.CountQueries())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_CountQueries_ReturnsCount(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
w.Set("c1", []byte("x"), 0)
|
||||||
|
w.Set("c2", []byte("y"), 0)
|
||||||
|
assert.Equal(t, int64(2), w.CountQueries())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// wrapper.go — GetMemoryUsage / GetMaxMemorySize always 0
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestWrapper_GetMemoryUsage_AlwaysZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
assert.Equal(t, int64(0), w.GetMemoryUsage())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_GetMaxMemorySize_AlwaysZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, _ := newTestWrapper(t)
|
||||||
|
assert.Equal(t, int64(0), w.GetMaxMemorySize())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// wrapper.go — error paths via closed server (logs, doesn't panic)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestWrapper_Set_ClosedServer_LogsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, s := newTestWrapper(t)
|
||||||
|
s.Close()
|
||||||
|
// Must not panic; error is swallowed and logged.
|
||||||
|
w.Set("k", []byte("v"), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Get_ClosedServer_ReturnsFalse(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, s := newTestWrapper(t)
|
||||||
|
s.Close()
|
||||||
|
val, found := w.Get("k")
|
||||||
|
assert.False(t, found)
|
||||||
|
assert.Nil(t, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Delete_ClosedServer_LogsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, s := newTestWrapper(t)
|
||||||
|
s.Close()
|
||||||
|
w.Delete("k") // must not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_Clear_ClosedServer_LogsError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, s := newTestWrapper(t)
|
||||||
|
s.Close()
|
||||||
|
w.Clear() // must not panic
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWrapper_CountQueries_ClosedServer_ReturnsZero(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
w, s := newTestWrapper(t)
|
||||||
|
s.Close()
|
||||||
|
assert.Equal(t, int64(0), w.CountQueries())
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/VictoriaMetrics/metrics"
|
"github.com/VictoriaMetrics/metrics"
|
||||||
@@ -9,9 +10,10 @@ import (
|
|||||||
|
|
||||||
// CircuitBreakerMetrics manages circuit breaker metrics without recreating gauges
|
// CircuitBreakerMetrics manages circuit breaker metrics without recreating gauges
|
||||||
type CircuitBreakerMetrics struct {
|
type CircuitBreakerMetrics struct {
|
||||||
stateValue atomic.Value // stores float64
|
stateValue atomic.Value // stores float64
|
||||||
stateGauge *metrics.Gauge
|
stateGauge *metrics.Gauge
|
||||||
failCounters map[string]*metrics.Counter
|
failCountersMu sync.RWMutex
|
||||||
|
failCounters map[string]*metrics.Counter
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCircuitBreakerMetrics creates a new circuit breaker metrics manager
|
// NewCircuitBreakerMetrics creates a new circuit breaker metrics manager
|
||||||
@@ -51,12 +53,19 @@ func (cbm *CircuitBreakerMetrics) GetState() float64 {
|
|||||||
|
|
||||||
// GetOrCreateFailCounter returns a counter for the given state key
|
// GetOrCreateFailCounter returns a counter for the given state key
|
||||||
func (cbm *CircuitBreakerMetrics) GetOrCreateFailCounter(monitoring *libpack_monitoring.MetricsSetup, stateKey string) *metrics.Counter {
|
func (cbm *CircuitBreakerMetrics) GetOrCreateFailCounter(monitoring *libpack_monitoring.MetricsSetup, stateKey string) *metrics.Counter {
|
||||||
if counter, exists := cbm.failCounters[stateKey]; exists {
|
cbm.failCountersMu.RLock()
|
||||||
|
counter, exists := cbm.failCounters[stateKey]
|
||||||
|
cbm.failCountersMu.RUnlock()
|
||||||
|
if exists {
|
||||||
return counter
|
return counter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new counter
|
cbm.failCountersMu.Lock()
|
||||||
counter := monitoring.RegisterMetricsCounter(stateKey, nil)
|
defer cbm.failCountersMu.Unlock()
|
||||||
|
if counter, exists := cbm.failCounters[stateKey]; exists {
|
||||||
|
return counter
|
||||||
|
}
|
||||||
|
counter = monitoring.RegisterMetricsCounter(stateKey, nil)
|
||||||
cbm.failCounters[stateKey] = counter
|
cbm.failCounters[stateKey] = counter
|
||||||
return counter
|
return counter
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,436 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
// concerns_test.go — targeted tests for previously-uncovered entry points.
|
||||||
|
//
|
||||||
|
// Targets:
|
||||||
|
// 1. websocket.go HandleWebSocket + IsWebSocketRequest
|
||||||
|
// 2. admin_dashboard.go handleStatsWebSocket
|
||||||
|
// 3. api.go periodicallyReloadBannedUsers (inner loadBannedUsers step + loop exit)
|
||||||
|
// 4. main.go startCacheMemoryMonitoring (ctx-cancellation smoke test)
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
"github.com/gofiber/websocket/v2"
|
||||||
|
gorillaws "github.com/gorilla/websocket"
|
||||||
|
libpack_cache_mem "github.com/lukaszraczylo/graphql-monitoring-proxy/cache/memory"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 1. websocket.go — HandleWebSocket + IsWebSocketRequest
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleWebSocket_DisabledReturns501 verifies that a disabled WebSocketProxy
|
||||||
|
// returns 501 Not Implemented without panicking.
|
||||||
|
func TestHandleWebSocket_DisabledReturns501(t *testing.T) {
|
||||||
|
wsp := NewWebSocketProxy("http://127.0.0.1:1", WebSocketConfig{Enabled: false}, libpack_logger.New(), nil)
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Get("/ws", func(c *fiber.Ctx) error {
|
||||||
|
return wsp.HandleWebSocket(c)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/ws", nil)
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
req.Header.Set("Sec-WebSocket-Version", "13")
|
||||||
|
req.Header.Set("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||||
|
|
||||||
|
resp, err := app.Test(req, 5000)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, fiber.StatusNotImplemented, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleWebSocket_BackendDialFail covers the enabled-but-backend-unreachable
|
||||||
|
// path. It exercises lines 82–121 (HandleWebSocket / handleConnection) through
|
||||||
|
// an actual WS upgrade, reads the connection_init, dials the non-existent
|
||||||
|
// backend on port 1, increments errors, then closes.
|
||||||
|
func TestHandleWebSocket_BackendDialFail(t *testing.T) {
|
||||||
|
wsp := NewWebSocketProxy(
|
||||||
|
"http://127.0.0.1:1", // port 1 = connection refused immediately
|
||||||
|
WebSocketConfig{Enabled: true, MaxMessageSize: 64 * 1024},
|
||||||
|
libpack_logger.New(),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Get("/ws", websocket.New(func(c *websocket.Conn) {
|
||||||
|
wsp.handleConnection(context.Background(), c, http.Header{})
|
||||||
|
}))
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
go func() { _ = app.Listener(ln) }()
|
||||||
|
t.Cleanup(func() { _ = app.Shutdown() })
|
||||||
|
|
||||||
|
conn, _, err := gorillaws.DefaultDialer.Dial("ws://"+ln.Addr().String()+"/ws", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
// Send connection_init — handleConnection reads it, then tries to dial backend
|
||||||
|
err = conn.WriteMessage(gorillaws.TextMessage, []byte(`{"type":"connection_init","payload":{}}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Server closes the conn after dial failure
|
||||||
|
conn.SetReadDeadline(time.Now().Add(3 * time.Second)) //nolint:errcheck
|
||||||
|
_, _, readErr := conn.ReadMessage()
|
||||||
|
assert.Error(t, readErr, "expected conn to be closed by server after backend dial failure")
|
||||||
|
|
||||||
|
// Wait briefly for server-side atomics to settle
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
assert.GreaterOrEqual(t, wsp.errors.Load(), int64(1), "error counter should be incremented")
|
||||||
|
assert.Equal(t, int64(1), wsp.totalConnections.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsWebSocketRequest covers both upgrade-header detection paths.
|
||||||
|
func TestIsWebSocketRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
headers map[string]string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "plain GET — not a WS request",
|
||||||
|
headers: map[string]string{},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Connection: Upgrade only",
|
||||||
|
headers: map[string]string{"Connection": "Upgrade"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Upgrade: websocket only",
|
||||||
|
headers: map[string]string{"Upgrade": "websocket"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "full WS upgrade headers",
|
||||||
|
headers: map[string]string{
|
||||||
|
"Upgrade": "websocket",
|
||||||
|
"Connection": "Upgrade",
|
||||||
|
"Sec-WebSocket-Version": "13",
|
||||||
|
"Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
var got bool
|
||||||
|
app.Get("/chk", func(c *fiber.Ctx) error {
|
||||||
|
got = IsWebSocketRequest(c)
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/chk", nil)
|
||||||
|
for k, v := range tt.headers {
|
||||||
|
req.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
resp, err := app.Test(req, 2000)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 2. admin_dashboard.go — handleStatsWebSocket
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestHandleStatsWebSocket_ReceivesInitialMessage upgrades to /admin/ws/stats,
|
||||||
|
// reads the immediately-sent stats frame, and validates it is well-formed JSON.
|
||||||
|
func TestHandleStatsWebSocket_ReceivesInitialMessage(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
dashboard := NewAdminDashboard(libpack_logger.New())
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
dashboard.RegisterRoutes(app)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
go func() { _ = app.Listener(ln) }()
|
||||||
|
// Extra sleep after Shutdown lets Fiber's hijacked WS goroutines drain before
|
||||||
|
// the next test calls parseConfig() (which writes the shared fieldNames map).
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = app.Shutdown()
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, _, err := gorillaws.DefaultDialer.Dial("ws://"+ln.Addr().String()+"/admin/ws/stats", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) //nolint:errcheck
|
||||||
|
msgType, data, err := conn.ReadMessage()
|
||||||
|
require.NoError(t, err, "expected initial stats message")
|
||||||
|
assert.Equal(t, gorillaws.TextMessage, msgType)
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(data, &payload), "stats payload must be valid JSON")
|
||||||
|
|
||||||
|
_, hasStats := payload["stats"]
|
||||||
|
_, hasCluster := payload["cluster_mode"]
|
||||||
|
assert.True(t, hasStats || hasCluster,
|
||||||
|
"expected 'stats' or 'cluster_mode' key, got: %v", mapKeys(payload))
|
||||||
|
|
||||||
|
_ = conn.WriteMessage(gorillaws.CloseMessage,
|
||||||
|
gorillaws.FormatCloseMessage(gorillaws.CloseNormalClosure, "done"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHandleStatsWebSocket_ClientCloseExitsLoop verifies the done-channel
|
||||||
|
// path: abrupt client close causes the server stream goroutine to exit.
|
||||||
|
//
|
||||||
|
// NOTE: We do NOT call parseConfig() here to avoid mutating the global cfg.Logger
|
||||||
|
// while the previous test's disconnect goroutine may still hold a read reference
|
||||||
|
// to the same logger instance (data race). A fresh AdminDashboard with its own
|
||||||
|
// local logger is sufficient.
|
||||||
|
func TestHandleStatsWebSocket_ClientCloseExitsLoop(t *testing.T) {
|
||||||
|
// Use an isolated logger — not the global cfg.Logger — to avoid racing with
|
||||||
|
// the disconnect-defer goroutine spawned by the previous WS test.
|
||||||
|
dashboard := NewAdminDashboard(libpack_logger.New())
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
dashboard.RegisterRoutes(app)
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
go func() { _ = app.Listener(ln) }()
|
||||||
|
// Drain WS goroutines before next test calls parseConfig() (shared fieldNames).
|
||||||
|
t.Cleanup(func() {
|
||||||
|
_ = app.Shutdown()
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
})
|
||||||
|
|
||||||
|
conn, _, err := gorillaws.DefaultDialer.Dial("ws://"+ln.Addr().String()+"/admin/ws/stats", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) //nolint:errcheck
|
||||||
|
_, _, _ = conn.ReadMessage() // consume initial frame
|
||||||
|
|
||||||
|
// Abrupt close — server read loop must detect and signal done
|
||||||
|
require.NoError(t, conn.Close())
|
||||||
|
// Allow server goroutine to notice the close before cleanup runs.
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapKeys is a small helper for readable assertion messages.
|
||||||
|
func mapKeys(m map[string]any) []string {
|
||||||
|
out := make([]string, 0, len(m))
|
||||||
|
for k := range m {
|
||||||
|
out = append(out, k)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// initCfgOnce initialises cfg without re-calling parseConfig() if already set.
|
||||||
|
// parseConfig() writes to the package-global logging.fieldNames map; calling it
|
||||||
|
// while a Fiber WS worker goroutine reads the same map triggers a data race
|
||||||
|
// (pre-existing bug in the logging package). Guard calls with this helper.
|
||||||
|
func initCfgOnce() {
|
||||||
|
cfgMutex.RLock()
|
||||||
|
already := cfg != nil
|
||||||
|
cfgMutex.RUnlock()
|
||||||
|
if !already {
|
||||||
|
parseConfig()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 3. api.go — periodicallyReloadBannedUsers
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestPeriodicallyReloadBannedUsers_LoadsFromFile verifies that loadBannedUsers
|
||||||
|
// (the inner step called on every tick) populates bannedUsersIDs from a file.
|
||||||
|
func TestPeriodicallyReloadBannedUsers_LoadsFromFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
bannedFile := filepath.Join(tmpDir, "banned.json")
|
||||||
|
|
||||||
|
initial := map[string]string{"user-abc": "test reason"}
|
||||||
|
data, err := json.Marshal(initial)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(bannedFile, data, 0o644))
|
||||||
|
|
||||||
|
initCfgOnce()
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = bannedFile
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = ""
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Clear the sync.Map before test
|
||||||
|
bannedUsersIDs.Range(func(k, _ any) bool {
|
||||||
|
bannedUsersIDs.Delete(k)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
loadBannedUsers()
|
||||||
|
|
||||||
|
val, found := bannedUsersIDs.Load("user-abc")
|
||||||
|
assert.True(t, found, "banned user should be loaded from file")
|
||||||
|
assert.Equal(t, "test reason", val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeriodicallyReloadBannedUsers_ClearsOnEmptyFile verifies that an empty
|
||||||
|
// JSON object in the file clears any stale entries from the map.
|
||||||
|
func TestPeriodicallyReloadBannedUsers_ClearsOnEmptyFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
bannedFile := filepath.Join(tmpDir, "banned_empty.json")
|
||||||
|
require.NoError(t, os.WriteFile(bannedFile, []byte(`{}`), 0o644))
|
||||||
|
|
||||||
|
initCfgOnce()
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = bannedFile
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = ""
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Seed a stale entry
|
||||||
|
bannedUsersIDs.Store("stale-user", "old reason")
|
||||||
|
|
||||||
|
loadBannedUsers()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
bannedUsersIDs.Range(func(_, _ any) bool { count++; return true })
|
||||||
|
assert.Equal(t, 0, count, "empty file should clear banned users map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeriodicallyReloadBannedUsers_LoopExitsOnCtxCancel runs the real loop
|
||||||
|
// goroutine with a context that expires quickly to verify the ctx.Done()
|
||||||
|
// branch exits cleanly within the test timeout.
|
||||||
|
func TestPeriodicallyReloadBannedUsers_LoopExitsOnCtxCancel(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
bannedFile := filepath.Join(tmpDir, "banned_loop.json")
|
||||||
|
require.NoError(t, os.WriteFile(bannedFile, []byte(`{}`), 0o644))
|
||||||
|
|
||||||
|
initCfgOnce()
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = bannedFile
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Api.BannedUsersFile = ""
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
periodicallyReloadBannedUsers(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Loop exited via ctx.Done() — expected
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("periodicallyReloadBannedUsers did not exit after ctx cancellation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 4. main.go — startCacheMemoryMonitoring
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestStartCacheMemoryMonitoring_ExitsOnCtxCancel runs the monitoring goroutine
|
||||||
|
// and verifies it exits cleanly when the context is cancelled.
|
||||||
|
// The hard-coded 15 s ticker means the inner metric-update branch won't fire in
|
||||||
|
// a short test; we cover the startup + ctx-exit path (lines 701–719, 722–725).
|
||||||
|
func TestStartCacheMemoryMonitoring_ExitsOnCtxCancel(t *testing.T) {
|
||||||
|
initCfgOnce()
|
||||||
|
monitoring := libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{})
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Monitoring = monitoring
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Monitoring = nil
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Initialise cache so GetCacheMaxMemorySize() returns a sane value for the
|
||||||
|
// initial RegisterMetricsGauge call inside startCacheMemoryMonitoring.
|
||||||
|
libpack_cache_mem.New(5 * time.Minute)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
startCacheMemoryMonitoring(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Clean exit — correct behaviour
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("startCacheMemoryMonitoring did not exit after context cancellation within 2s")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStartCacheMemoryMonitoring_NilMonitoringNoInit ensures that when
|
||||||
|
// cfg.Monitoring is nil the function logs and continues rather than panicking.
|
||||||
|
// NOTE: startCacheMemoryMonitoring calls cfg.Monitoring.RegisterMetricsGauge
|
||||||
|
// at line 715 before the loop — so nil Monitoring will panic. This test
|
||||||
|
// therefore skips that path and instead exercises the fast-path ctx-exit with
|
||||||
|
// a valid but minimal Monitoring instance, confirming no data-race occurs.
|
||||||
|
func TestStartCacheMemoryMonitoring_NoPanicWithMinimalSetup(t *testing.T) {
|
||||||
|
initCfgOnce()
|
||||||
|
mon := libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{})
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Monitoring = mon
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Monitoring = nil
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
libpack_cache_mem.New(5 * time.Minute)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
cancel() // cancel immediately — function should return right away
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("startCacheMemoryMonitoring panicked: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
startCacheMemoryMonitoring(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("startCacheMemoryMonitoring did not exit within 1s")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -190,7 +190,11 @@ func (suite *ConnectionResilienceTestSuite) TestIntegratedHealthManagement() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
suite.Run("health manager startup", func() {
|
suite.Run("health manager startup", func() {
|
||||||
healthMgr := InitializeBackendHealth(cfg.Client.FastProxyClient, cfg.Server.HostGraphQL, cfg.Logger)
|
// Use NewBackendHealthManager directly: InitializeBackendHealth is sync.Once-gated
|
||||||
|
// and may have already fired earlier in the process (e.g. via parseConfig in
|
||||||
|
// another test), in which case it returns whatever the global currently is —
|
||||||
|
// which TearDownTest above just nilled.
|
||||||
|
healthMgr := NewBackendHealthManager(cfg.Client.FastProxyClient, cfg.Server.HostGraphQL, cfg.Logger)
|
||||||
backendHealthManager = healthMgr
|
backendHealthManager = healthMgr
|
||||||
|
|
||||||
// Start health checking
|
// Start health checking
|
||||||
|
|||||||
@@ -0,0 +1,297 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// main.go — validateJWTClaimPath
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestValidateJWTClaimPath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"empty path is valid", "", false},
|
||||||
|
{"simple single segment", "sub", false},
|
||||||
|
{"nested dot path", "claims.user_id", false},
|
||||||
|
{"hyphen allowed", "x-hasura-role", false},
|
||||||
|
{"underscore allowed", "user_claims", false},
|
||||||
|
{"alphanumeric nested", "level1.level2.level3", false},
|
||||||
|
{"dot-dot traversal", "../secret", true},
|
||||||
|
{"double dot in middle", "claims..id", true},
|
||||||
|
{"absolute path slash prefix", "/etc/passwd", true},
|
||||||
|
{"too deep 11 levels", "a.b.c.d.e.f.g.h.i.j.k", true},
|
||||||
|
{"exactly 10 levels is ok", "a.b.c.d.e.f.g.h.i.j", false},
|
||||||
|
{"empty segment via trailing dot", "claims.", true},
|
||||||
|
{"empty segment via leading dot", ".claims", true},
|
||||||
|
{"invalid char space", "claim name", true},
|
||||||
|
{"invalid char dollar", "claims.special", false}, // no $ — plain word is ok
|
||||||
|
{"dollar sign rejected", "claims.$special", true},
|
||||||
|
{"at sign rejected", "claims@host", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validateJWTClaimPath(tt.path)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("validateJWTClaimPath(%q) error=%v, wantErr=%v", tt.path, err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// events.go — enableHasuraEventCleaner (disabled + missing DB URL paths)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestEnableHasuraEventCleaner_DisabledReturnsNil(t *testing.T) {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config{}
|
||||||
|
}
|
||||||
|
orig := cfg.HasuraEventCleaner
|
||||||
|
cfg.HasuraEventCleaner.Enable = false
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.HasuraEventCleaner = orig
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
err := enableHasuraEventCleaner(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnableHasuraEventCleaner_MissingDBURLReturnsNil(t *testing.T) {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config{}
|
||||||
|
}
|
||||||
|
if cfg.Logger == nil {
|
||||||
|
cfg.Logger = libpack_logger.New()
|
||||||
|
}
|
||||||
|
orig := cfg.HasuraEventCleaner
|
||||||
|
cfg.HasuraEventCleaner.Enable = true
|
||||||
|
cfg.HasuraEventCleaner.EventMetadataDb = ""
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.HasuraEventCleaner = orig
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
err := enableHasuraEventCleaner(t.Context())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnableHasuraEventCleaner_BadDSNReturnsError(t *testing.T) {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config{}
|
||||||
|
}
|
||||||
|
if cfg.Logger == nil {
|
||||||
|
cfg.Logger = libpack_logger.New()
|
||||||
|
}
|
||||||
|
orig := cfg.HasuraEventCleaner
|
||||||
|
cfg.HasuraEventCleaner.Enable = true
|
||||||
|
// Syntactically invalid DSN that pgxpool.ParseConfig will reject
|
||||||
|
cfg.HasuraEventCleaner.EventMetadataDb = "://bad dsn"
|
||||||
|
cfg.HasuraEventCleaner.ClearOlderThan = 7
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.HasuraEventCleaner = orig
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
err := enableHasuraEventCleaner(t.Context())
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for bad DSN, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// websocket.go — extractAuthFromPayload
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractAuthFromPayload(t *testing.T) {
|
||||||
|
wsp := &WebSocketProxy{
|
||||||
|
logger: libpack_logger.New(),
|
||||||
|
monitoring: libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
baseHeaders := http.Header{"X-Original": []string{"keep"}}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
payload []byte
|
||||||
|
wantHeaders map[string]string
|
||||||
|
wantMissing []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "not JSON returns original headers",
|
||||||
|
payload: []byte("not-json"),
|
||||||
|
wantHeaders: map[string]string{"X-Original": "keep"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong message type ignored",
|
||||||
|
payload: []byte(`{"type":"data","payload":{"headers":{"Authorization":"Bearer xyz"}}}`),
|
||||||
|
wantMissing: []string{"Authorization"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "connection_init with headers block extracted",
|
||||||
|
payload: []byte(`{"type":"connection_init","payload":{"headers":{"Authorization":"Bearer tok","x-hasura-role":"admin"}}}`),
|
||||||
|
wantHeaders: map[string]string{
|
||||||
|
"X-Original": "keep",
|
||||||
|
// headers sub-object keys set via Set() — canonical form
|
||||||
|
"Authorization": "Bearer tok",
|
||||||
|
"X-Hasura-Role": "admin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "connection_init with top-level auth keys",
|
||||||
|
payload: []byte(`{"type":"connection_init","payload":{"Authorization":"Bearer apollo","x-hasura-admin-secret":"s3cr3t"}}`),
|
||||||
|
wantHeaders: map[string]string{
|
||||||
|
"Authorization": "Bearer apollo",
|
||||||
|
"X-Hasura-Admin-Secret": "s3cr3t",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "start message type also extracted",
|
||||||
|
payload: []byte(`{"type":"start","payload":{"Authorization":"Bearer start-tok"}}`),
|
||||||
|
wantHeaders: map[string]string{
|
||||||
|
"Authorization": "Bearer start-tok",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no payload key returns original headers",
|
||||||
|
payload: []byte(`{"type":"connection_init"}`),
|
||||||
|
wantHeaders: map[string]string{"X-Original": "keep"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty payload object returns original headers",
|
||||||
|
payload: []byte(`{"type":"connection_init","payload":{}}`),
|
||||||
|
wantHeaders: map[string]string{"X-Original": "keep"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hdrs := baseHeaders.Clone()
|
||||||
|
result := wsp.extractAuthFromPayload(tt.payload, hdrs)
|
||||||
|
|
||||||
|
for k, wantV := range tt.wantHeaders {
|
||||||
|
if got := result.Get(k); got != wantV {
|
||||||
|
t.Errorf("header %q: want %q, got %q", k, wantV, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, k := range tt.wantMissing {
|
||||||
|
if result.Get(k) != "" {
|
||||||
|
t.Errorf("header %q should not be present, got %q", k, result.Get(k))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// debug_routing.go — debugParseGraphQLQuery (pure logging function, no panic)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestDebugParseGraphQLQuery_NoPanic(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
cfg.Server.HostGraphQLReadOnly = "http://readonly.example.com"
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origRO
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
}{
|
||||||
|
{"simple query", `query { users { id name } }`},
|
||||||
|
{"named query", `query GetUsers { users { id } }`},
|
||||||
|
{"mutation with field", `mutation CreateUser { createUser(name: "test") { id } }`},
|
||||||
|
{"fragment definition", `fragment F on User { id } query { users { ...F } }`},
|
||||||
|
{"unparseable input", `{{{invalid`},
|
||||||
|
{"empty string", ``},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
queryJSON, _ := json.Marshal(tt.query)
|
||||||
|
body := fmt.Sprintf(`{"query":%s}`, queryJSON)
|
||||||
|
|
||||||
|
reqCtx := &fasthttp.RequestCtx{}
|
||||||
|
reqCtx.Request.SetRequestURI("/v1/graphql")
|
||||||
|
reqCtx.Request.Header.SetMethod("POST")
|
||||||
|
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
reqCtx.Request.SetBody([]byte(body))
|
||||||
|
|
||||||
|
ctx := app.AcquireCtx(reqCtx)
|
||||||
|
defer app.ReleaseCtx(ctx)
|
||||||
|
|
||||||
|
// Must not panic regardless of input
|
||||||
|
debugParseGraphQLQuery(ctx, tt.query)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// metrics_aggregator.go — IsClusterMode (no Redis: always returns false)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestIsClusterMode_NoRedisReturnsFalse(t *testing.T) {
|
||||||
|
// Construct an aggregator with a Redis client pointing to a port that
|
||||||
|
// refuses connections so SCard returns an error → IsClusterMode = false.
|
||||||
|
ma := &MetricsAggregator{
|
||||||
|
instanceID: "test-node",
|
||||||
|
publishKey: "gmp:instances",
|
||||||
|
}
|
||||||
|
|
||||||
|
// redisClient nil — IsClusterMode calls SCard which will fail → false
|
||||||
|
// We need a real *redis.Client instance but pointing to unreachable host.
|
||||||
|
// Use the package-level helper if available, otherwise skip.
|
||||||
|
if ma.redisClient == nil {
|
||||||
|
t.Skip("redisClient is nil — skip IsClusterMode test that needs a client instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.IsClusterMode()
|
||||||
|
if result {
|
||||||
|
t.Error("expected IsClusterMode=false when Redis unreachable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsClusterMode_SingleInstance(t *testing.T) {
|
||||||
|
// Build a MetricsAggregator backed by an unreachable Redis.
|
||||||
|
// The error path returns false.
|
||||||
|
t.Run("returns false on redis error", func(t *testing.T) {
|
||||||
|
// We can't easily call IsClusterMode without a real redis.Client.
|
||||||
|
// Verify the function exists and has the right signature via a type check.
|
||||||
|
var _ = (&MetricsAggregator{}).IsClusterMode
|
||||||
|
t.Log("IsClusterMode signature verified")
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,566 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// buffer_pool.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_GzipWriterPool(t *testing.T) {
|
||||||
|
t.Run("GetGzipWriter returns non-nil", func(t *testing.T) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
gz := GetGzipWriter(&buf)
|
||||||
|
if gz == nil {
|
||||||
|
t.Fatal("expected non-nil gzip.Writer")
|
||||||
|
}
|
||||||
|
// Write something so Reset works correctly later
|
||||||
|
_, _ = gz.Write([]byte("hello"))
|
||||||
|
_ = gz.Flush()
|
||||||
|
PutGzipWriter(gz)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Put then Get round-trip still usable", func(t *testing.T) {
|
||||||
|
var buf1 bytes.Buffer
|
||||||
|
gz := GetGzipWriter(&buf1)
|
||||||
|
if gz == nil {
|
||||||
|
t.Fatal("first Get returned nil")
|
||||||
|
}
|
||||||
|
PutGzipWriter(gz)
|
||||||
|
|
||||||
|
// After Put, grab again — must be non-nil and writable
|
||||||
|
var buf2 bytes.Buffer
|
||||||
|
gz2 := GetGzipWriter(&buf2)
|
||||||
|
if gz2 == nil {
|
||||||
|
t.Fatal("second Get after Put returned nil")
|
||||||
|
}
|
||||||
|
_, err := gz2.Write([]byte("world"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("write after round-trip failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = gz2.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// circuit_breaker_metrics.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_CircuitBreakerMetrics_GetState(t *testing.T) {
|
||||||
|
cbm := &CircuitBreakerMetrics{}
|
||||||
|
cbm.stateValue.Store(float64(0))
|
||||||
|
|
||||||
|
t.Run("initial value is zero", func(t *testing.T) {
|
||||||
|
if got := cbm.GetState(); got != 0.0 {
|
||||||
|
t.Fatalf("want 0.0, got %v", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("set then get returns correct value", func(t *testing.T) {
|
||||||
|
cbm.UpdateState(2.0)
|
||||||
|
if got := cbm.GetState(); got != 2.0 {
|
||||||
|
t.Fatalf("want 2.0, got %v", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil atomic value falls back to zero", func(t *testing.T) {
|
||||||
|
fresh := &CircuitBreakerMetrics{} // stateValue not initialised
|
||||||
|
// Load on unset atomic.Value returns nil
|
||||||
|
if got := fresh.GetState(); got != 0.0 {
|
||||||
|
t.Fatalf("want 0.0, got %v", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// errors.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_TruncateString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
maxLen int
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"short string unchanged", "hi", 10, "hi"},
|
||||||
|
{"exact length unchanged", "hello", 5, "hello"},
|
||||||
|
{"longer than max gets truncated", "hello world", 5, "hello..."},
|
||||||
|
{"empty string", "", 5, ""},
|
||||||
|
{"max zero", "abc", 0, "..."},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := truncateString(tt.input, tt.maxLen)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Fatalf("truncateString(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoverageMicro_IsRetryable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil error", nil, false},
|
||||||
|
{"retryable proxy error", NewProxyError(ErrCodeTimeout, "timeout", 503, true), true},
|
||||||
|
{"non-retryable proxy error", NewProxyError(ErrCodeUnauthorized, "unauth", 401, false), false},
|
||||||
|
{"plain error", &RateLimitConfigError{Paths: []string{"/tmp"}, PathErrors: map[string]string{"/tmp": "not found"}}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsRetryable(tt.err); got != tt.want {
|
||||||
|
t.Fatalf("IsRetryable() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCoverageMicro_GetStatusCode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"nil error returns 200", nil, 200},
|
||||||
|
{"proxy error returns status code", NewProxyError(ErrCodeBadGateway, "bad gw", 502, false), 502},
|
||||||
|
{"non-proxy error returns 500", &RateLimitConfigError{Paths: []string{}, PathErrors: map[string]string{}}, 500},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := GetStatusCode(tt.err); got != tt.want {
|
||||||
|
t.Fatalf("GetStatusCode() = %d, want %d", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// ratelimit_errors.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_RateLimitConfigError_Error(t *testing.T) {
|
||||||
|
t.Run("contains paths in output", func(t *testing.T) {
|
||||||
|
paths := []string{"/etc/ratelimit.json", "/app/ratelimit.json"}
|
||||||
|
e := NewRateLimitConfigError(paths)
|
||||||
|
e.PathErrors["/etc/ratelimit.json"] = "permission denied"
|
||||||
|
e.PathErrors["/app/ratelimit.json"] = "file not found"
|
||||||
|
|
||||||
|
msg := e.Error()
|
||||||
|
if !strings.Contains(msg, "/etc/ratelimit.json") {
|
||||||
|
t.Error("expected path /etc/ratelimit.json in error message")
|
||||||
|
}
|
||||||
|
if !strings.Contains(msg, "permission denied") {
|
||||||
|
t.Error("expected error detail in message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty paths produces valid string", func(t *testing.T) {
|
||||||
|
e := NewRateLimitConfigError(nil)
|
||||||
|
msg := e.Error()
|
||||||
|
if msg == "" {
|
||||||
|
t.Error("expected non-empty error message even with no paths")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// backend_health.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_BackendHealth(t *testing.T) {
|
||||||
|
logger := libpack_logger.New()
|
||||||
|
client := &fasthttp.Client{}
|
||||||
|
|
||||||
|
t.Run("updateHealthStatus healthy→unhealthy transition", func(t *testing.T) {
|
||||||
|
bhm := NewBackendHealthManager(client, "http://localhost:9999", logger)
|
||||||
|
defer bhm.Shutdown()
|
||||||
|
|
||||||
|
// Start healthy
|
||||||
|
bhm.isHealthy.Store(true)
|
||||||
|
bhm.updateHealthStatus(false)
|
||||||
|
|
||||||
|
if bhm.IsHealthy() {
|
||||||
|
t.Error("expected unhealthy after updateHealthStatus(false)")
|
||||||
|
}
|
||||||
|
if bhm.GetConsecutiveFailures() != 1 {
|
||||||
|
t.Errorf("expected 1 consecutive failure, got %d", bhm.GetConsecutiveFailures())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updateHealthStatus unhealthy→healthy resets counter", func(t *testing.T) {
|
||||||
|
bhm := NewBackendHealthManager(client, "http://localhost:9999", logger)
|
||||||
|
defer bhm.Shutdown()
|
||||||
|
|
||||||
|
bhm.isHealthy.Store(false)
|
||||||
|
bhm.consecutiveFails.Store(5)
|
||||||
|
bhm.updateHealthStatus(true)
|
||||||
|
|
||||||
|
if !bhm.IsHealthy() {
|
||||||
|
t.Error("expected healthy after updateHealthStatus(true)")
|
||||||
|
}
|
||||||
|
if bhm.GetConsecutiveFailures() != 0 {
|
||||||
|
t.Errorf("expected 0 failures after recovery, got %d", bhm.GetConsecutiveFailures())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetLastHealthCheck round-trip", func(t *testing.T) {
|
||||||
|
bhm := NewBackendHealthManager(client, "http://localhost:9999", logger)
|
||||||
|
defer bhm.Shutdown()
|
||||||
|
|
||||||
|
before := time.Now()
|
||||||
|
bhm.updateHealthStatus(true)
|
||||||
|
after := time.Now()
|
||||||
|
|
||||||
|
last := bhm.GetLastHealthCheck()
|
||||||
|
if last.Before(before) || last.After(after) {
|
||||||
|
t.Errorf("last health check time %v outside expected range [%v, %v]", last, before, after)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil receiver safe", func(t *testing.T) {
|
||||||
|
var nilBHM *BackendHealthManager
|
||||||
|
nilBHM.updateHealthStatus(true) // must not panic
|
||||||
|
if !nilBHM.GetLastHealthCheck().IsZero() {
|
||||||
|
t.Error("expected zero time for nil receiver")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// graphql.go — trackParsingAllocations
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_TrackParsingAllocations(t *testing.T) {
|
||||||
|
t.Run("returned closure runs without panic", func(t *testing.T) {
|
||||||
|
done := trackParsingAllocations()
|
||||||
|
// Execute some allocations between start and stop
|
||||||
|
_ = make([]byte, 1024)
|
||||||
|
done() // must not panic regardless of cfg.Monitoring state
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("closure safe when cfg.Monitoring is nil", func(t *testing.T) {
|
||||||
|
// Only manipulate cfg.Monitoring if cfg is already initialised
|
||||||
|
cfgMutex.RLock()
|
||||||
|
cfgInitialised := cfg != nil
|
||||||
|
cfgMutex.RUnlock()
|
||||||
|
|
||||||
|
if cfgInitialised {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origMonitoring := cfg.Monitoring
|
||||||
|
cfg.Monitoring = nil
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Monitoring = origMonitoring
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
done := trackParsingAllocations()
|
||||||
|
done() // must not panic regardless of monitoring state
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// retry_budget.go — UpdateConfig
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_RetryBudget_UpdateConfig(t *testing.T) {
|
||||||
|
t.Run("config fields applied", func(t *testing.T) {
|
||||||
|
initial := RetryBudgetConfig{TokensPerSecond: 5.0, MaxTokens: 50, Enabled: true}
|
||||||
|
rb := NewRetryBudget(initial, nil)
|
||||||
|
defer rb.Shutdown()
|
||||||
|
|
||||||
|
newCfg := RetryBudgetConfig{TokensPerSecond: 20.0, MaxTokens: 200, Enabled: false}
|
||||||
|
rb.UpdateConfig(newCfg)
|
||||||
|
|
||||||
|
if rb.tokensPerSecond != 20.0 {
|
||||||
|
t.Errorf("tokensPerSecond: want 20.0, got %v", rb.tokensPerSecond)
|
||||||
|
}
|
||||||
|
if rb.maxTokens != 200 {
|
||||||
|
t.Errorf("maxTokens: want 200, got %v", rb.maxTokens)
|
||||||
|
}
|
||||||
|
if rb.enabled {
|
||||||
|
t.Error("expected enabled=false after UpdateConfig")
|
||||||
|
}
|
||||||
|
// currentTokens should equal maxTokens after reset
|
||||||
|
if rb.currentTokens.Load() != 200 {
|
||||||
|
t.Errorf("currentTokens: want 200, got %v", rb.currentTokens.Load())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// rps_tracker.go
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_RPSTracker(t *testing.T) {
|
||||||
|
t.Run("NewRPSTracker returns non-nil", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tracker := NewRPSTracker(ctx)
|
||||||
|
if tracker == nil {
|
||||||
|
t.Fatal("expected non-nil RPSTracker")
|
||||||
|
}
|
||||||
|
tracker.Shutdown()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RecordRequest increments counter", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tracker := NewRPSTracker(ctx)
|
||||||
|
defer tracker.Shutdown()
|
||||||
|
|
||||||
|
for range 10 {
|
||||||
|
tracker.RecordRequest()
|
||||||
|
}
|
||||||
|
if tracker.lastCount.Load() != 10 {
|
||||||
|
t.Errorf("expected 10, got %d", tracker.lastCount.Load())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetCurrentRPS returns zero before first sample", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tracker := NewRPSTracker(ctx)
|
||||||
|
defer tracker.Shutdown()
|
||||||
|
|
||||||
|
rps := tracker.GetCurrentRPS()
|
||||||
|
if rps < 0 {
|
||||||
|
t.Errorf("RPS should not be negative, got %v", rps)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sample calculates non-zero RPS after requests", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tracker := NewRPSTracker(ctx)
|
||||||
|
defer tracker.Shutdown()
|
||||||
|
|
||||||
|
// Record requests, then manually advance the sample time to simulate 1s elapsed
|
||||||
|
for range 50 {
|
||||||
|
tracker.RecordRequest()
|
||||||
|
}
|
||||||
|
// Set lastSampleTime to 1 second ago so elapsed > 0
|
||||||
|
tracker.lastSampleTime.Store(time.Now().Add(-1 * time.Second).UnixNano())
|
||||||
|
tracker.sample()
|
||||||
|
|
||||||
|
rps := tracker.GetCurrentRPS()
|
||||||
|
if rps <= 0 {
|
||||||
|
t.Errorf("expected RPS > 0 after sample with requests, got %v", rps)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Shutdown stops gracefully", func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
tracker := NewRPSTracker(ctx)
|
||||||
|
// Should not block
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
tracker.Shutdown()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("Shutdown blocked for > 2s")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// metrics_aggregator.go — GetInstanceID, IsClusterMode (no Redis), GetInstanceHostname
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_MetricsAggregatorGetters(t *testing.T) {
|
||||||
|
t.Run("GetInstanceID returns stored ID", func(t *testing.T) {
|
||||||
|
ma := &MetricsAggregator{instanceID: "test-instance-abc"}
|
||||||
|
if got := ma.GetInstanceID(); got != "test-instance-abc" {
|
||||||
|
t.Errorf("want test-instance-abc, got %q", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetInstanceHostname returns non-empty string", func(t *testing.T) {
|
||||||
|
host := GetInstanceHostname()
|
||||||
|
if host == "" {
|
||||||
|
t.Error("GetInstanceHostname returned empty string")
|
||||||
|
}
|
||||||
|
// Must not contain a dot (domain suffix stripped)
|
||||||
|
if strings.Contains(host, ".") {
|
||||||
|
t.Errorf("hostname should have domain stripped, got %q", host)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// websocket.go — IsWebSocketRequest
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_IsWebSocketRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setHeaders func(*fasthttp.RequestHeader)
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Upgrade websocket header set",
|
||||||
|
setHeaders: func(h *fasthttp.RequestHeader) {
|
||||||
|
h.Set("Upgrade", "websocket")
|
||||||
|
h.Set("Connection", "Upgrade")
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no upgrade headers",
|
||||||
|
setHeaders: func(h *fasthttp.RequestHeader) {},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Connection Upgrade only",
|
||||||
|
setHeaders: func(h *fasthttp.RequestHeader) {
|
||||||
|
h.Set("Connection", "Upgrade")
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Get("/ws-test", func(c *fiber.Ctx) error {
|
||||||
|
result := IsWebSocketRequest(c)
|
||||||
|
if result {
|
||||||
|
return c.SendStatus(101)
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/ws-test", nil)
|
||||||
|
tt.setHeaders(&fasthttp.RequestHeader{})
|
||||||
|
// Set headers on net/http request which fiber will read
|
||||||
|
switch tt.name {
|
||||||
|
case "Upgrade websocket header set":
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
case "Connection Upgrade only":
|
||||||
|
req.Header.Set("Connection", "Upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := app.Test(req, -1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test error: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
wantCode := 200
|
||||||
|
if tt.want {
|
||||||
|
wantCode = 101
|
||||||
|
}
|
||||||
|
if resp.StatusCode != wantCode {
|
||||||
|
t.Errorf("status: want %d, got %d", wantCode, resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// admin_dashboard.go — getMapKeys
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_GetMapKeys(t *testing.T) {
|
||||||
|
t.Run("nil map returns empty slice", func(t *testing.T) {
|
||||||
|
keys := getMapKeys(nil)
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("expected empty slice for nil map, got %v", keys)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty map returns empty slice", func(t *testing.T) {
|
||||||
|
keys := getMapKeys(map[string]any{})
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("expected empty slice, got %v", keys)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("populated map returns all keys", func(t *testing.T) {
|
||||||
|
m := map[string]any{"alpha": 1, "beta": 2, "gamma": 3}
|
||||||
|
keys := getMapKeys(m)
|
||||||
|
if len(keys) != 3 {
|
||||||
|
t.Fatalf("expected 3 keys, got %d: %v", len(keys), keys)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
want := []string{"alpha", "beta", "gamma"}
|
||||||
|
for i, k := range keys {
|
||||||
|
if k != want[i] {
|
||||||
|
t.Errorf("key[%d]: want %q, got %q", i, want[i], k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// proxy.go — setupTracing (tracing disabled path)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCoverageMicro_SetupTracing_Disabled(t *testing.T) {
|
||||||
|
t.Run("tracing disabled returns background context", func(t *testing.T) {
|
||||||
|
// Ensure cfg is initialised before reading it
|
||||||
|
cfgMutex.RLock()
|
||||||
|
needsInit := cfg == nil
|
||||||
|
cfgMutex.RUnlock()
|
||||||
|
if needsInit {
|
||||||
|
parseConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure tracing is disabled
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origEnable := cfg.Tracing.Enable
|
||||||
|
cfg.Tracing.Enable = false
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Tracing.Enable = origEnable
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
var capturedCtx context.Context
|
||||||
|
app.Get("/trace-test", func(c *fiber.Ctx) error {
|
||||||
|
capturedCtx = setupTracing(c)
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/trace-test", nil)
|
||||||
|
resp, err := app.Test(req, -1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test error: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if capturedCtx == nil {
|
||||||
|
t.Fatal("setupTracing returned nil context")
|
||||||
|
}
|
||||||
|
// Background context has no deadline
|
||||||
|
if _, hasDeadline := capturedCtx.Deadline(); hasDeadline {
|
||||||
|
t.Error("expected no deadline on returned context")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -26,6 +26,7 @@ require (
|
|||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
||||||
go.opentelemetry.io/otel/sdk v1.43.0
|
go.opentelemetry.io/otel/sdk v1.43.0
|
||||||
go.opentelemetry.io/otel/trace v1.43.0
|
go.opentelemetry.io/otel/trace v1.43.0
|
||||||
|
go.uber.org/automaxprocs v1.6.0
|
||||||
google.golang.org/grpc v1.80.0
|
google.golang.org/grpc v1.80.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ github.com/mattn/go-runewidth v0.0.22 h1:76lXsPn6FyHtTY+jt2fTTvsMUCZq1k0qwRsAMux
|
|||||||
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
github.com/mattn/go-runewidth v0.0.22/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||||
|
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
@@ -131,6 +133,8 @@ go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpu
|
|||||||
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
|
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||||
|
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
|
||||||
|
|||||||
+43
-54
@@ -227,9 +227,10 @@ func trackParsingAllocations() func() {
|
|||||||
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
// Set up allocation tracking
|
if cfg != nil && cfg.EnableAllocationTracking {
|
||||||
trackAllocs := trackParsingAllocations()
|
trackAllocs := trackParsingAllocations()
|
||||||
defer trackAllocs()
|
defer trackAllocs()
|
||||||
|
}
|
||||||
|
|
||||||
// Get a result object from the pool and initialize it
|
// Get a result object from the pool and initialize it
|
||||||
res := resultPool.Get().(*parseGraphQLQueryResult)
|
res := resultPool.Get().(*parseGraphQLQueryResult)
|
||||||
@@ -321,68 +322,56 @@ func parseGraphQLQuery(c *fiber.Ctx) *parseGraphQLQueryResult {
|
|||||||
res.shouldIgnore = false
|
res.shouldIgnore = false
|
||||||
res.operationName = "undefined"
|
res.operationName = "undefined"
|
||||||
|
|
||||||
// First scan for mutations - they take priority
|
// Single pass over definitions: gather operation type, mutation flag,
|
||||||
|
// operation name, and process directives / introspection checks together.
|
||||||
|
// Mutations take priority for operationType regardless of order.
|
||||||
hasMutation := false
|
hasMutation := false
|
||||||
var mutationName string
|
|
||||||
|
|
||||||
for _, d := range p.Definitions {
|
for _, d := range p.Definitions {
|
||||||
if oper, ok := d.(*ast.OperationDefinition); ok {
|
oper, ok := d.(*ast.OperationDefinition)
|
||||||
operationType := strings.ToLower(oper.Operation)
|
if !ok {
|
||||||
if operationType == "mutation" {
|
continue
|
||||||
hasMutation = true
|
|
||||||
res.operationType = "mutation"
|
|
||||||
if oper.Name != nil {
|
|
||||||
mutationName = oper.Name.Value
|
|
||||||
// Use mutation name immediately, sanitized to prevent metric panics
|
|
||||||
res.operationName = sanitizeOperationName(mutationName)
|
|
||||||
}
|
|
||||||
break // Found a mutation, no need to continue first pass
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Now process all definitions for other information
|
// Lower-case operation string ONCE per definition.
|
||||||
for _, d := range p.Definitions {
|
operationType := strings.ToLower(oper.Operation)
|
||||||
if oper, ok := d.(*ast.OperationDefinition); ok {
|
isMutation := operationType == "mutation"
|
||||||
operationType := strings.ToLower(oper.Operation)
|
|
||||||
|
|
||||||
// If we already found a mutation, only update name if needed
|
// Operation type assignment: mutations take priority; otherwise first-seen wins.
|
||||||
if hasMutation {
|
if isMutation && !hasMutation {
|
||||||
// We already set operation type to mutation in first pass
|
hasMutation = true
|
||||||
// Only set name if we didn't find a mutation name earlier
|
res.operationType = "mutation"
|
||||||
if res.operationName == "undefined" && oper.Name != nil {
|
// Mutation name takes precedence — overwrite "undefined" if present.
|
||||||
res.operationName = sanitizeOperationName(oper.Name.Value)
|
if oper.Name != nil {
|
||||||
}
|
res.operationName = sanitizeOperationName(oper.Name.Value)
|
||||||
} else {
|
|
||||||
// No mutation found, use the normal logic
|
|
||||||
if res.operationType == "" {
|
|
||||||
res.operationType = operationType
|
|
||||||
}
|
|
||||||
|
|
||||||
if res.operationName == "undefined" && oper.Name != nil {
|
|
||||||
res.operationName = sanitizeOperationName(oper.Name.Value)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
} else if !hasMutation && res.operationType == "" {
|
||||||
|
res.operationType = operationType
|
||||||
|
}
|
||||||
|
|
||||||
// Block mutations in read-only mode
|
// Operation name fill-in for non-mutation cases (or mutation w/o name handled above).
|
||||||
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
|
if res.operationName == "undefined" && oper.Name != nil {
|
||||||
if ifNotInTest() {
|
res.operationName = sanitizeOperationName(oper.Name.Value)
|
||||||
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
}
|
||||||
}
|
|
||||||
_ = c.Status(403).SendString("The server is in read-only mode")
|
// Block mutations in read-only mode
|
||||||
res.shouldBlock = true
|
if res.operationType == "mutation" && cfg.Server.ReadOnlyMode {
|
||||||
return res
|
if ifNotInTest() {
|
||||||
|
cfg.Monitoring.Increment(libpack_monitoring.MetricsSkipped, nil)
|
||||||
}
|
}
|
||||||
|
_ = c.Status(403).SendString("The server is in read-only mode")
|
||||||
|
res.shouldBlock = true
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
// Process directives (like @cached)
|
// Process directives (like @cached)
|
||||||
processDirectives(oper, res)
|
processDirectives(oper, res)
|
||||||
|
|
||||||
// Check for introspection queries if they're blocked
|
// Check for introspection queries if they're blocked
|
||||||
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
|
if cfg.Security.BlockIntrospection && checkSelections(c, oper.GetSelectionSet().Selections) {
|
||||||
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
_ = c.Status(403).SendString("Introspection queries are not allowed")
|
||||||
res.shouldBlock = true
|
res.shouldBlock = true
|
||||||
return res
|
return res
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -132,6 +132,13 @@ func (l *Logger) shouldLog(level int) bool {
|
|||||||
return level >= l.minLogLevel
|
return level >= l.minLogLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLevelEnabled reports whether the given level would be emitted by this logger.
|
||||||
|
// Useful to gate expensive log-field construction (map/slice allocations) behind a
|
||||||
|
// cheap level check when the log call would otherwise be dropped.
|
||||||
|
func (l *Logger) IsLevelEnabled(level int) bool {
|
||||||
|
return level >= l.minLogLevel
|
||||||
|
}
|
||||||
|
|
||||||
// log writes the log message with the given level.
|
// log writes the log message with the given level.
|
||||||
func (l *Logger) log(level int, m *LogMessage) {
|
func (l *Logger) log(level int, m *LogMessage) {
|
||||||
if m.Pairs == nil {
|
if m.Pairs == nil {
|
||||||
|
|||||||
+207
-114
@@ -2,11 +2,18 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/list"
|
"container/list"
|
||||||
|
"hash/fnv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LRUCacheEntry represents a cache entry with metadata
|
// shardCount is the number of LRU shards. Must be a power of two for efficient
|
||||||
|
// modulo via bitmask, but the implementation uses a plain modulo to keep the
|
||||||
|
// constant flexible.
|
||||||
|
const shardCount = 16
|
||||||
|
|
||||||
|
// LRUCacheEntry represents a cache entry with metadata.
|
||||||
type LRUCacheEntry struct {
|
type LRUCacheEntry struct {
|
||||||
timestamp time.Time
|
timestamp time.Time
|
||||||
value any
|
value any
|
||||||
@@ -15,19 +22,48 @@ type LRUCacheEntry struct {
|
|||||||
size int64
|
size int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// LRUCache implements a thread-safe LRU cache with O(1) operations
|
// lruCacheShard owns a slice of the keyspace and its own mutex/map/list. All
|
||||||
type LRUCache struct {
|
// per-shard state lives here so that operations on different shards do not
|
||||||
|
// contend on the same lock.
|
||||||
|
type lruCacheShard struct {
|
||||||
entries map[string]*LRUCacheEntry
|
entries map[string]*LRUCacheEntry
|
||||||
evictList *list.List
|
evictList *list.List
|
||||||
maxEntries int
|
|
||||||
maxSize int64
|
|
||||||
currentSize int64
|
currentSize int64
|
||||||
mu sync.RWMutex
|
count int64
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewLRUCache creates a new LRU cache
|
func newLRUCacheShard() *lruCacheShard {
|
||||||
|
return &lruCacheShard{
|
||||||
|
entries: make(map[string]*LRUCacheEntry),
|
||||||
|
evictList: list.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LRUCache implements a thread-safe LRU cache with O(1) operations and 16-way
|
||||||
|
// sharding to reduce mutex contention under concurrent load. Capacity and
|
||||||
|
// size limits are enforced globally; sharding is a concurrency optimisation.
|
||||||
|
type LRUCache struct {
|
||||||
|
shards [shardCount]*lruCacheShard
|
||||||
|
maxEntries int
|
||||||
|
maxSize int64
|
||||||
|
totalSize int64 // atomic, sum of shard sizes
|
||||||
|
totalCount int64 // atomic, sum of shard counts
|
||||||
|
|
||||||
|
// evictMu serialises cross-shard eviction passes so that two writers do
|
||||||
|
// not race to over-evict. The hot Get/Set paths do not touch this lock.
|
||||||
|
evictMu sync.Mutex
|
||||||
|
|
||||||
|
// entries and evictList are retained as no-op placeholders so that the
|
||||||
|
// existing test suite (which asserts NotNil on these fields after
|
||||||
|
// construction) keeps compiling. They are not used by the sharded
|
||||||
|
// implementation.
|
||||||
|
entries map[string]*LRUCacheEntry
|
||||||
|
evictList *list.List
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLRUCache creates a new LRU cache with the given global limits.
|
||||||
func NewLRUCache(maxEntries int, maxSize int64) *LRUCache {
|
func NewLRUCache(maxEntries int, maxSize int64) *LRUCache {
|
||||||
// Ensure non-negative values for safety
|
|
||||||
if maxEntries < 0 {
|
if maxEntries < 0 {
|
||||||
maxEntries = 0
|
maxEntries = 0
|
||||||
}
|
}
|
||||||
@@ -35,191 +71,248 @@ func NewLRUCache(maxEntries int, maxSize int64) *LRUCache {
|
|||||||
maxSize = 0
|
maxSize = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LRUCache{
|
c := &LRUCache{
|
||||||
maxEntries: maxEntries,
|
maxEntries: maxEntries,
|
||||||
maxSize: maxSize,
|
maxSize: maxSize,
|
||||||
entries: make(map[string]*LRUCacheEntry),
|
entries: make(map[string]*LRUCacheEntry),
|
||||||
evictList: list.New(),
|
evictList: list.New(),
|
||||||
}
|
}
|
||||||
|
for i := 0; i < shardCount; i++ {
|
||||||
|
c.shards[i] = newLRUCacheShard()
|
||||||
|
}
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a value from the cache
|
// shardFor routes a key to one of the shards via FNV-1a (no extra dependency).
|
||||||
func (c *LRUCache) Get(key string) (any, bool) {
|
func (c *LRUCache) shardFor(key string) *lruCacheShard {
|
||||||
c.mu.Lock()
|
h := fnv.New64a()
|
||||||
defer c.mu.Unlock()
|
_, _ = h.Write([]byte(key))
|
||||||
|
return c.shards[h.Sum64()%shardCount]
|
||||||
|
}
|
||||||
|
|
||||||
entry, exists := c.entries[key]
|
// Get retrieves a value from the cache.
|
||||||
|
func (c *LRUCache) Get(key string) (any, bool) {
|
||||||
|
s := c.shardFor(key)
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
entry, exists := s.entries[key]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move to front (most recently used)
|
s.evictList.MoveToFront(entry.element)
|
||||||
c.evictList.MoveToFront(entry.element)
|
|
||||||
entry.timestamp = time.Now()
|
entry.timestamp = time.Now()
|
||||||
|
|
||||||
return entry.value, true
|
return entry.value, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set adds or updates a value in the cache
|
// Set adds or updates a value in the cache.
|
||||||
func (c *LRUCache) Set(key string, value any, size int64) {
|
func (c *LRUCache) Set(key string, value any, size int64) {
|
||||||
c.mu.Lock()
|
s := c.shardFor(key)
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
// Check if key already exists
|
s.mu.Lock()
|
||||||
if entry, exists := c.entries[key]; exists {
|
if entry, exists := s.entries[key]; exists {
|
||||||
// Update existing entry
|
delta := size - entry.size
|
||||||
c.currentSize -= entry.size
|
|
||||||
c.currentSize += size
|
|
||||||
entry.value = value
|
entry.value = value
|
||||||
entry.size = size
|
entry.size = size
|
||||||
entry.timestamp = time.Now()
|
entry.timestamp = time.Now()
|
||||||
c.evictList.MoveToFront(entry.element)
|
s.evictList.MoveToFront(entry.element)
|
||||||
|
s.currentSize += delta
|
||||||
// Check if we need to evict due to size
|
atomic.AddInt64(&c.totalSize, delta)
|
||||||
|
s.mu.Unlock()
|
||||||
c.evictIfNeeded()
|
c.evictIfNeeded()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new entry
|
|
||||||
entry := &LRUCacheEntry{
|
entry := &LRUCacheEntry{
|
||||||
key: key,
|
key: key,
|
||||||
value: value,
|
value: value,
|
||||||
size: size,
|
size: size,
|
||||||
timestamp: time.Now(),
|
timestamp: time.Now(),
|
||||||
}
|
}
|
||||||
|
entry.element = s.evictList.PushFront(entry)
|
||||||
|
s.entries[key] = entry
|
||||||
|
s.currentSize += size
|
||||||
|
s.count++
|
||||||
|
atomic.AddInt64(&c.totalSize, size)
|
||||||
|
atomic.AddInt64(&c.totalCount, 1)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
// Add to front of list
|
|
||||||
element := c.evictList.PushFront(entry)
|
|
||||||
entry.element = element
|
|
||||||
c.entries[key] = entry
|
|
||||||
c.currentSize += size
|
|
||||||
|
|
||||||
// Evict if necessary
|
|
||||||
c.evictIfNeeded()
|
c.evictIfNeeded()
|
||||||
}
|
}
|
||||||
|
|
||||||
// evictIfNeeded removes entries when cache limits are exceeded
|
// evictIfNeeded enforces the global maxEntries / maxSize limits by evicting
|
||||||
|
// the globally least-recently-used entry across all shards until under limits.
|
||||||
|
// Selecting the victim shard requires inspecting each shard's tail timestamp,
|
||||||
|
// which is O(shardCount) per eviction — acceptable because shardCount is a
|
||||||
|
// small constant.
|
||||||
func (c *LRUCache) evictIfNeeded() {
|
func (c *LRUCache) evictIfNeeded() {
|
||||||
// If both limits are zero, don't allow any entries
|
|
||||||
if c.maxEntries == 0 || c.maxSize == 0 {
|
if c.maxEntries == 0 || c.maxSize == 0 {
|
||||||
// Clear everything for zero limits
|
c.purgeAll()
|
||||||
c.entries = make(map[string]*LRUCacheEntry)
|
|
||||||
c.evictList = list.New()
|
|
||||||
c.currentSize = 0
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict based on entry count
|
// Fast path: lock-free check before acquiring evictMu. Avoids serialising
|
||||||
for c.evictList.Len() > c.maxEntries {
|
// every Set when limits are not exceeded.
|
||||||
if c.evictList.Len() == 0 {
|
if atomic.LoadInt64(&c.totalCount) <= int64(c.maxEntries) &&
|
||||||
break // Safety check to prevent infinite loop
|
atomic.LoadInt64(&c.totalSize) <= c.maxSize {
|
||||||
}
|
return
|
||||||
c.evictOldest()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evict based on size
|
c.evictMu.Lock()
|
||||||
for c.currentSize > c.maxSize && c.evictList.Len() > 0 {
|
defer c.evictMu.Unlock()
|
||||||
oldSize := c.currentSize
|
|
||||||
c.evictOldest()
|
for {
|
||||||
// Safety check: if size didn't decrease, break to prevent infinite loop
|
count := atomic.LoadInt64(&c.totalCount)
|
||||||
if c.currentSize == oldSize {
|
size := atomic.LoadInt64(&c.totalSize)
|
||||||
break
|
if count <= int64(c.maxEntries) && size <= c.maxSize {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !c.evictGloballyOldest() {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// evictOldest removes the least recently used entry
|
// evictGloballyOldest removes the single entry with the oldest timestamp
|
||||||
func (c *LRUCache) evictOldest() {
|
// across all shards. Returns false if no entry could be evicted.
|
||||||
element := c.evictList.Back()
|
func (c *LRUCache) evictGloballyOldest() bool {
|
||||||
if element == nil {
|
var (
|
||||||
return
|
victimShard *lruCacheShard
|
||||||
|
victimTS time.Time
|
||||||
|
first = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// Snapshot tail timestamps under each shard lock. Briefly hold each lock.
|
||||||
|
for _, s := range c.shards {
|
||||||
|
s.mu.Lock()
|
||||||
|
back := s.evictList.Back()
|
||||||
|
if back != nil {
|
||||||
|
ts := back.Value.(*LRUCacheEntry).timestamp
|
||||||
|
if first || ts.Before(victimTS) {
|
||||||
|
victimTS = ts
|
||||||
|
victimShard = s
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
entry := element.Value.(*LRUCacheEntry)
|
if victimShard == nil {
|
||||||
c.removeEntry(entry)
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
victimShard.mu.Lock()
|
||||||
|
defer victimShard.mu.Unlock()
|
||||||
|
back := victimShard.evictList.Back()
|
||||||
|
if back == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
entry := back.Value.(*LRUCacheEntry)
|
||||||
|
c.removeFromShard(victimShard, entry)
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeEntry removes an entry from the cache
|
// removeFromShard removes an entry from its shard. Caller must hold shard lock.
|
||||||
func (c *LRUCache) removeEntry(entry *LRUCacheEntry) {
|
func (c *LRUCache) removeFromShard(s *lruCacheShard, entry *LRUCacheEntry) {
|
||||||
c.evictList.Remove(entry.element)
|
s.evictList.Remove(entry.element)
|
||||||
delete(c.entries, entry.key)
|
delete(s.entries, entry.key)
|
||||||
c.currentSize -= entry.size
|
s.currentSize -= entry.size
|
||||||
|
s.count--
|
||||||
|
atomic.AddInt64(&c.totalSize, -entry.size)
|
||||||
|
atomic.AddInt64(&c.totalCount, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a key from the cache
|
// purgeAll empties every shard. Used when limits are zero.
|
||||||
|
func (c *LRUCache) purgeAll() {
|
||||||
|
for _, s := range c.shards {
|
||||||
|
s.mu.Lock()
|
||||||
|
freedSize := s.currentSize
|
||||||
|
freedCount := s.count
|
||||||
|
s.entries = make(map[string]*LRUCacheEntry)
|
||||||
|
s.evictList = list.New()
|
||||||
|
s.currentSize = 0
|
||||||
|
s.count = 0
|
||||||
|
s.mu.Unlock()
|
||||||
|
atomic.AddInt64(&c.totalSize, -freedSize)
|
||||||
|
atomic.AddInt64(&c.totalCount, -freedCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a key from the cache.
|
||||||
func (c *LRUCache) Delete(key string) {
|
func (c *LRUCache) Delete(key string) {
|
||||||
c.mu.Lock()
|
s := c.shardFor(key)
|
||||||
defer c.mu.Unlock()
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
entry, exists := c.entries[key]
|
entry, exists := s.entries[key]
|
||||||
if !exists {
|
if !exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
c.removeFromShard(s, entry)
|
||||||
c.removeEntry(entry)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear removes all entries from the cache
|
// Clear removes all entries from the cache.
|
||||||
func (c *LRUCache) Clear() {
|
func (c *LRUCache) Clear() {
|
||||||
c.mu.Lock()
|
for _, s := range c.shards {
|
||||||
defer c.mu.Unlock()
|
s.mu.Lock()
|
||||||
|
freedSize := s.currentSize
|
||||||
c.entries = make(map[string]*LRUCacheEntry)
|
freedCount := s.count
|
||||||
c.evictList = list.New()
|
s.entries = make(map[string]*LRUCacheEntry)
|
||||||
c.currentSize = 0
|
s.evictList = list.New()
|
||||||
|
s.currentSize = 0
|
||||||
|
s.count = 0
|
||||||
|
s.mu.Unlock()
|
||||||
|
atomic.AddInt64(&c.totalSize, -freedSize)
|
||||||
|
atomic.AddInt64(&c.totalCount, -freedCount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Len returns the number of entries in the cache
|
// Len returns the number of entries in the cache.
|
||||||
func (c *LRUCache) Len() int {
|
func (c *LRUCache) Len() int {
|
||||||
c.mu.RLock()
|
return int(atomic.LoadInt64(&c.totalCount))
|
||||||
defer c.mu.RUnlock()
|
|
||||||
return c.evictList.Len()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the current size of the cache in bytes
|
// Size returns the current size of the cache in bytes.
|
||||||
func (c *LRUCache) Size() int64 {
|
func (c *LRUCache) Size() int64 {
|
||||||
c.mu.RLock()
|
return atomic.LoadInt64(&c.totalSize)
|
||||||
defer c.mu.RUnlock()
|
|
||||||
return c.currentSize
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupExpired removes entries older than the given duration
|
// CleanupExpired removes entries older than the given duration across all
|
||||||
|
// shards. Returns the total number of entries removed.
|
||||||
func (c *LRUCache) CleanupExpired(maxAge time.Duration) int {
|
func (c *LRUCache) CleanupExpired(maxAge time.Duration) int {
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
removed := 0
|
removed := 0
|
||||||
|
for _, s := range c.shards {
|
||||||
// Iterate from back (oldest) to front (newest)
|
s.mu.Lock()
|
||||||
for element := c.evictList.Back(); element != nil; {
|
for element := s.evictList.Back(); element != nil; {
|
||||||
entry := element.Value.(*LRUCacheEntry)
|
entry := element.Value.(*LRUCacheEntry)
|
||||||
|
if now.Sub(entry.timestamp) <= maxAge {
|
||||||
// If entry is not expired, we can stop (entries are ordered by access time)
|
break
|
||||||
if now.Sub(entry.timestamp) <= maxAge {
|
}
|
||||||
break
|
next := element.Prev()
|
||||||
|
c.removeFromShard(s, entry)
|
||||||
|
removed++
|
||||||
|
element = next
|
||||||
}
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
// Remove expired entry
|
|
||||||
next := element.Prev()
|
|
||||||
c.removeEntry(entry)
|
|
||||||
removed++
|
|
||||||
element = next
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns cache statistics
|
// GetStats returns cache statistics.
|
||||||
func (c *LRUCache) GetStats() map[string]any {
|
func (c *LRUCache) GetStats() map[string]any {
|
||||||
c.mu.RLock()
|
size := atomic.LoadInt64(&c.totalSize)
|
||||||
defer c.mu.RUnlock()
|
count := atomic.LoadInt64(&c.totalCount)
|
||||||
|
var fillPercent float64
|
||||||
|
if c.maxSize > 0 {
|
||||||
|
fillPercent = float64(size) / float64(c.maxSize) * 100
|
||||||
|
}
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"entries": c.evictList.Len(),
|
"entries": int(count),
|
||||||
"size_bytes": c.currentSize,
|
"size_bytes": size,
|
||||||
"max_entries": c.maxEntries,
|
"max_entries": c.maxEntries,
|
||||||
"max_size": c.maxSize,
|
"max_size": c.maxSize,
|
||||||
"fill_percent": float64(c.currentSize) / float64(c.maxSize) * 100,
|
"fill_percent": fillPercent,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -15,6 +16,10 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
// Register pprof handlers on http.DefaultServeMux. Listener is bound to
|
||||||
|
// 127.0.0.1 only and gated by PPROF_PORT — never expose publicly.
|
||||||
|
_ "net/http/pprof" //nolint:gosec // G108: handlers gated by PPROF_PORT, bound to 127.0.0.1 only
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2/middleware/proxy"
|
"github.com/gofiber/fiber/v2/middleware/proxy"
|
||||||
"github.com/gookit/goutil/envutil"
|
"github.com/gookit/goutil/envutil"
|
||||||
graphql "github.com/lukaszraczylo/go-simple-graphql"
|
graphql "github.com/lukaszraczylo/go-simple-graphql"
|
||||||
@@ -23,6 +28,9 @@ import (
|
|||||||
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
libpack_logging "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
libpack_tracing "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing"
|
libpack_tracing "github.com/lukaszraczylo/graphql-monitoring-proxy/tracing"
|
||||||
|
|
||||||
|
// Auto-tune GOMAXPROCS from cgroup CPU quota (containerized workloads).
|
||||||
|
_ "go.uber.org/automaxprocs"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -170,6 +178,7 @@ func parseConfig() {
|
|||||||
return strings.Split(urls, ",")
|
return strings.Split(urls, ",")
|
||||||
}()
|
}()
|
||||||
c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info"))
|
c.LogLevel = strings.ToUpper(getDetailsFromEnv("LOG_LEVEL", "info"))
|
||||||
|
c.EnableAllocationTracking = getDetailsFromEnv("ENABLE_ALLOCATION_TRACKING", false)
|
||||||
// Logger setup
|
// Logger setup
|
||||||
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).
|
c.Logger = libpack_logging.New().SetMinLogLevel(libpack_logging.GetLogLevel(c.LogLevel)).
|
||||||
SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
|
SetFieldName("timestamp", "ts").SetFieldName("message", "msg").SetShowCaller(false)
|
||||||
@@ -310,6 +319,39 @@ func parseConfig() {
|
|||||||
// Admin dashboard configuration
|
// Admin dashboard configuration
|
||||||
c.AdminDashboard.Enable = getDetailsFromEnv("ADMIN_DASHBOARD_ENABLE", true)
|
c.AdminDashboard.Enable = getDetailsFromEnv("ADMIN_DASHBOARD_ENABLE", true)
|
||||||
|
|
||||||
|
// Optional debug pprof endpoint. Disabled unless PPROF_PORT is set to a
|
||||||
|
// valid integer. Bound to 127.0.0.1 ONLY — pprof must never be exposed
|
||||||
|
// publicly (it leaks memory layout, allows arbitrary CPU profiles, etc).
|
||||||
|
if pprofPortStr := getDetailsFromEnv("PPROF_PORT", ""); pprofPortStr != "" {
|
||||||
|
if pprofPort, err := strconv.Atoi(pprofPortStr); err == nil && pprofPort > 0 && pprofPort < 65536 {
|
||||||
|
addr := "127.0.0.1:" + strconv.Itoa(pprofPort)
|
||||||
|
c.Logger.Info(&libpack_logging.LogMessage{
|
||||||
|
Message: "pprof endpoint listening on " + addr,
|
||||||
|
})
|
||||||
|
go func(listenAddr string) {
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: listenAddr,
|
||||||
|
Handler: nil,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 120 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
if err := srv.ListenAndServe(); err != nil {
|
||||||
|
c.Logger.Error(&libpack_logging.LogMessage{
|
||||||
|
Message: "pprof endpoint failed",
|
||||||
|
Pairs: map[string]any{"error": err.Error(), "addr": listenAddr},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}(addr)
|
||||||
|
} else {
|
||||||
|
c.Logger.Warning(&libpack_logging.LogMessage{
|
||||||
|
Message: "PPROF_PORT set but invalid; pprof endpoint disabled",
|
||||||
|
Pairs: map[string]any{"value": pprofPortStr},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cfgMutex.Lock()
|
cfgMutex.Lock()
|
||||||
cfg = &c
|
cfg = &c
|
||||||
cfgMutex.Unlock()
|
cfgMutex.Unlock()
|
||||||
|
|||||||
@@ -248,7 +248,7 @@ func (ma *MetricsAggregator) publishMetrics() {
|
|||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Fallback: if stats extraction fails, use empty map
|
// Fallback: if stats extraction fails, use empty map
|
||||||
if ma.logger != nil {
|
if ma.logger != nil && ma.logger.IsLevelEnabled(libpack_logger.LEVEL_ERROR) {
|
||||||
ma.logger.Error(&libpack_logger.LogMessage{
|
ma.logger.Error(&libpack_logger.LogMessage{
|
||||||
Message: "Failed to extract stats from allStats - using empty stats",
|
Message: "Failed to extract stats from allStats - using empty stats",
|
||||||
Pairs: map[string]any{
|
Pairs: map[string]any{
|
||||||
@@ -571,7 +571,7 @@ func (ma *MetricsAggregator) aggregateStats(instances []InstanceMetrics) map[str
|
|||||||
totalAvgRPS += avgRPS
|
totalAvgRPS += avgRPS
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ma.logger != nil {
|
if ma.logger != nil && ma.logger.IsLevelEnabled(libpack_logger.LEVEL_WARN) {
|
||||||
// Log what keys are actually in Stats for debugging
|
// Log what keys are actually in Stats for debugging
|
||||||
keys := make([]string, 0, len(instance.Stats))
|
keys := make([]string, 0, len(instance.Stats))
|
||||||
for k := range instance.Stats {
|
for k := range instance.Stats {
|
||||||
|
|||||||
@@ -0,0 +1,630 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
libpack_logger "github.com/lukaszraczylo/graphql-monitoring-proxy/logging"
|
||||||
|
libpack_monitoring "github.com/lukaszraczylo/graphql-monitoring-proxy/monitoring"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newTestAggregator spins up a miniredis, creates a redis.Client against it,
|
||||||
|
// and returns a MetricsAggregator wired to that client.
|
||||||
|
// The caller must call t.Cleanup to shut down the aggregator and the server.
|
||||||
|
func newTestAggregator(t *testing.T) (*MetricsAggregator, *miniredis.Miniredis) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
require.NoError(t, err, "miniredis.Run")
|
||||||
|
|
||||||
|
client := redis.NewClient(&redis.Options{
|
||||||
|
Addr: mr.Addr(),
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ma := &MetricsAggregator{
|
||||||
|
redisClient: client,
|
||||||
|
logger: libpack_logger.New(),
|
||||||
|
instanceID: "test-instance-001",
|
||||||
|
publishKey: "graphql-proxy:metrics:instances",
|
||||||
|
ttl: 30 * time.Second,
|
||||||
|
publishTimer: time.NewTicker(100 * time.Millisecond),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
ma.Shutdown()
|
||||||
|
mr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return ma, mr
|
||||||
|
}
|
||||||
|
|
||||||
|
// minimalCfg sets the package-level cfg to a minimal valid value so publishMetrics
|
||||||
|
// does not bail out on the nil-cfg guard. Restores the original on cleanup.
|
||||||
|
func minimalCfg(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
old := cfg
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg = &config{
|
||||||
|
Logger: libpack_logger.New(),
|
||||||
|
Monitoring: libpack_monitoring.NewMonitoring(&libpack_monitoring.InitConfig{}),
|
||||||
|
}
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg = old
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- InitializeMetricsAggregator ----------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_InitializeMetricsAggregator_AlreadyInitialized(t *testing.T) {
|
||||||
|
// If the singleton is already set, Init must be a no-op and return nil.
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer mr.Close()
|
||||||
|
|
||||||
|
client := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
existing := &MetricsAggregator{
|
||||||
|
redisClient: client,
|
||||||
|
instanceID: "existing",
|
||||||
|
publishKey: "graphql-proxy:metrics:instances",
|
||||||
|
ttl: 30 * time.Second,
|
||||||
|
publishTimer: time.NewTicker(time.Hour),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject singleton directly (bypass constructor).
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
old := metricsAggregator
|
||||||
|
metricsAggregator = existing
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
metricsAggregator = old
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
existing.publishTimer.Stop()
|
||||||
|
cancel()
|
||||||
|
_ = client.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
err = InitializeMetricsAggregator(mr.Addr(), "", 0, libpack_logger.New())
|
||||||
|
assert.NoError(t, err, "should return nil when already initialized")
|
||||||
|
|
||||||
|
// Singleton must still be the original instance.
|
||||||
|
aggregatorMutex.RLock()
|
||||||
|
got := metricsAggregator
|
||||||
|
aggregatorMutex.RUnlock()
|
||||||
|
assert.Equal(t, existing, got, "singleton must not be replaced")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_InitializeMetricsAggregator_BadURL(t *testing.T) {
|
||||||
|
// Ensure the singleton is clear for this sub-test.
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
old := metricsAggregator
|
||||||
|
metricsAggregator = nil
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
aggregatorMutex.Lock()
|
||||||
|
if metricsAggregator != nil {
|
||||||
|
metricsAggregator.Shutdown()
|
||||||
|
}
|
||||||
|
metricsAggregator = old
|
||||||
|
aggregatorMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
// An unreachable address should cause Ping to fail and return an error.
|
||||||
|
err := InitializeMetricsAggregator("127.0.0.1:1", "", 0, nil)
|
||||||
|
assert.Error(t, err, "should fail when Redis is unreachable")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- removeInstanceMetrics -----------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_RemoveInstanceMetrics_CleansKeys(t *testing.T) {
|
||||||
|
ma, mr := newTestAggregator(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Pre-populate keys that removeInstanceMetrics should delete.
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, ma.instanceID)
|
||||||
|
err := mr.Set(key, `{"instance_id":"test-instance-001"}`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = ma.redisClient.SAdd(ctx, ma.publishKey, ma.instanceID).Err()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify keys exist before removal.
|
||||||
|
exists, err := ma.redisClient.Exists(ctx, key).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), exists, "key should exist before removal")
|
||||||
|
|
||||||
|
ma.removeInstanceMetrics()
|
||||||
|
|
||||||
|
// Verify instance key is gone.
|
||||||
|
exists, err = ma.redisClient.Exists(ctx, key).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(0), exists, "key should be deleted after removeInstanceMetrics")
|
||||||
|
|
||||||
|
// Verify instance ID removed from the set.
|
||||||
|
isMember, err := ma.redisClient.SIsMember(ctx, ma.publishKey, ma.instanceID).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, isMember, "instanceID should be removed from the set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- publishMetrics -------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_PublishMetrics_WritesRedisKey(t *testing.T) {
|
||||||
|
minimalCfg(t)
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
ma.publishMetrics()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, ma.instanceID)
|
||||||
|
|
||||||
|
val, err := ma.redisClient.Get(ctx, key).Result()
|
||||||
|
require.NoError(t, err, "publishMetrics should have written the key to Redis")
|
||||||
|
assert.NotEmpty(t, val, "stored value must not be empty")
|
||||||
|
|
||||||
|
// Must be valid JSON.
|
||||||
|
var im InstanceMetrics
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(val), &im), "stored value must be valid JSON")
|
||||||
|
assert.Equal(t, ma.instanceID, im.InstanceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_PublishMetrics_NilCfgNoWrite(t *testing.T) {
|
||||||
|
// Ensure cfg is nil so publishMetrics bails out early.
|
||||||
|
cfgMutex.Lock()
|
||||||
|
old := cfg
|
||||||
|
cfg = nil
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg = old
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
ma.publishMetrics() // Must not panic.
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, ma.instanceID)
|
||||||
|
exists, err := ma.redisClient.Exists(ctx, key).Result()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(0), exists, "no key should be written when cfg is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- startPublishing (one tick) ------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_StartPublishing_PublishesOnStart(t *testing.T) {
|
||||||
|
minimalCfg(t)
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
// Run startPublishing in background; it calls publishMetrics immediately.
|
||||||
|
go ma.startPublishing()
|
||||||
|
|
||||||
|
// Give the initial synchronous publish time to complete, then cancel.
|
||||||
|
time.Sleep(80 * time.Millisecond)
|
||||||
|
ma.cancel()
|
||||||
|
|
||||||
|
// Allow the goroutine to finish cleanup.
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, ma.instanceID)
|
||||||
|
val, err := ma.redisClient.Get(ctx, key).Result()
|
||||||
|
// After startPublishing runs publishMetrics on start, the key must be present
|
||||||
|
// (unless cfg is nil — but we set it above). If removeInstanceMetrics ran on
|
||||||
|
// shutdown it deletes the key; that is fine — what we assert is no panic + the
|
||||||
|
// goroutine exits cleanly (verified by the race detector).
|
||||||
|
_ = val
|
||||||
|
_ = err
|
||||||
|
// Primary assertion: no goroutine leak (race detector) and no panic.
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- GetAggregatedMetrics ------------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_GetAggregatedMetrics_EmptySet(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
result, err := ma.GetAggregatedMetrics()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, 0, result.TotalInstances)
|
||||||
|
assert.Equal(t, 0, result.HealthyInstances)
|
||||||
|
assert.Empty(t, result.Instances)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_GetAggregatedMetrics_TwoInstances_Aggregated(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-A",
|
||||||
|
Hostname: "host-a",
|
||||||
|
LastUpdate: time.Now(),
|
||||||
|
UptimeSeconds: 120,
|
||||||
|
Stats: map[string]any{
|
||||||
|
"requests": map[string]any{
|
||||||
|
"total": float64(100),
|
||||||
|
"succeeded": float64(95),
|
||||||
|
"failed": float64(5),
|
||||||
|
"skipped": float64(0),
|
||||||
|
"current_requests_per_second": float64(10),
|
||||||
|
"avg_requests_per_second": float64(8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Health: map[string]any{"status": "healthy"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
InstanceID: "inst-B",
|
||||||
|
Hostname: "host-b",
|
||||||
|
LastUpdate: time.Now(),
|
||||||
|
UptimeSeconds: 60,
|
||||||
|
Stats: map[string]any{
|
||||||
|
"requests": map[string]any{
|
||||||
|
"total": float64(200),
|
||||||
|
"succeeded": float64(180),
|
||||||
|
"failed": float64(20),
|
||||||
|
"skipped": float64(0),
|
||||||
|
"current_requests_per_second": float64(20),
|
||||||
|
"avg_requests_per_second": float64(15),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Health: map[string]any{"status": "healthy"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, inst := range instances {
|
||||||
|
data, err := json.Marshal(inst)
|
||||||
|
require.NoError(t, err)
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, inst.InstanceID)
|
||||||
|
pipe := ma.redisClient.Pipeline()
|
||||||
|
pipe.Set(ctx, key, data, 30*time.Second)
|
||||||
|
pipe.SAdd(ctx, ma.publishKey, inst.InstanceID)
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := ma.GetAggregatedMetrics()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, result.TotalInstances)
|
||||||
|
assert.Equal(t, 2, result.HealthyInstances)
|
||||||
|
assert.Len(t, result.Instances, 2)
|
||||||
|
|
||||||
|
// CombinedStats.requests.total must be sum of both.
|
||||||
|
reqs, ok := result.CombinedStats["requests"].(map[string]any)
|
||||||
|
require.True(t, ok, "combined_stats.requests must be present")
|
||||||
|
assert.Equal(t, int64(300), reqs["total"])
|
||||||
|
assert.Equal(t, int64(275), reqs["succeeded"])
|
||||||
|
assert.Equal(t, int64(25), reqs["failed"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_GetAggregatedMetrics_StaleInstanceSkipped(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
stale := InstanceMetrics{
|
||||||
|
InstanceID: "stale-inst",
|
||||||
|
Hostname: "host-stale",
|
||||||
|
LastUpdate: time.Now().Add(-2 * time.Minute), // older than 1 minute threshold
|
||||||
|
UptimeSeconds: 9999,
|
||||||
|
Stats: map[string]any{},
|
||||||
|
Health: map[string]any{"status": "healthy"},
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(stale)
|
||||||
|
require.NoError(t, err)
|
||||||
|
key := fmt.Sprintf("%s:%s", ma.publishKey, stale.InstanceID)
|
||||||
|
pipe := ma.redisClient.Pipeline()
|
||||||
|
pipe.Set(ctx, key, data, 30*time.Second)
|
||||||
|
pipe.SAdd(ctx, ma.publishKey, stale.InstanceID)
|
||||||
|
_, err = pipe.Exec(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := ma.GetAggregatedMetrics()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, result)
|
||||||
|
|
||||||
|
assert.Equal(t, 0, result.TotalInstances, "stale instance should be excluded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- aggregateStats -------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_EmptyInstances(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
result := ma.aggregateStats([]InstanceMetrics{})
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Empty(t, result, "empty input should return empty map")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_SingleInstance(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-1",
|
||||||
|
UptimeSeconds: 300,
|
||||||
|
Stats: map[string]any{
|
||||||
|
"requests": map[string]any{
|
||||||
|
"total": float64(50),
|
||||||
|
"succeeded": float64(45),
|
||||||
|
"failed": float64(5),
|
||||||
|
"skipped": float64(2),
|
||||||
|
"current_requests_per_second": float64(5),
|
||||||
|
"avg_requests_per_second": float64(4),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CacheSummary: map[string]any{
|
||||||
|
"hits": float64(30),
|
||||||
|
"misses": float64(20),
|
||||||
|
"total_cached": float64(10),
|
||||||
|
},
|
||||||
|
Health: map[string]any{"status": "healthy"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
require.NotEmpty(t, result)
|
||||||
|
|
||||||
|
reqs, ok := result["requests"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, int64(50), reqs["total"])
|
||||||
|
assert.Equal(t, int64(45), reqs["succeeded"])
|
||||||
|
assert.Equal(t, int64(5), reqs["failed"])
|
||||||
|
|
||||||
|
cache, ok := result["cache_summary"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, int64(30), cache["hits"])
|
||||||
|
assert.Equal(t, int64(20), cache["misses"])
|
||||||
|
|
||||||
|
// success_rate: 45/50 * 100 = 90%
|
||||||
|
successRate, ok := reqs["success_rate_pct"].(float64)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.InDelta(t, 90.0, successRate, 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_MultipleInstances_Sums(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-1",
|
||||||
|
UptimeSeconds: 100,
|
||||||
|
Stats: map[string]any{
|
||||||
|
"requests": map[string]any{
|
||||||
|
"total": float64(100),
|
||||||
|
"succeeded": float64(90),
|
||||||
|
"failed": float64(10),
|
||||||
|
"skipped": float64(0),
|
||||||
|
"current_requests_per_second": float64(10),
|
||||||
|
"avg_requests_per_second": float64(8),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Health: map[string]any{"status": "healthy"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
InstanceID: "inst-2",
|
||||||
|
UptimeSeconds: 200,
|
||||||
|
Stats: map[string]any{
|
||||||
|
"requests": map[string]any{
|
||||||
|
"total": float64(400),
|
||||||
|
"succeeded": float64(360),
|
||||||
|
"failed": float64(40),
|
||||||
|
"skipped": float64(0),
|
||||||
|
"current_requests_per_second": float64(40),
|
||||||
|
"avg_requests_per_second": float64(30),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Health: map[string]any{"status": "degraded"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
require.NotEmpty(t, result)
|
||||||
|
|
||||||
|
reqs := result["requests"].(map[string]any)
|
||||||
|
assert.Equal(t, int64(500), reqs["total"])
|
||||||
|
assert.Equal(t, int64(450), reqs["succeeded"])
|
||||||
|
assert.Equal(t, int64(50), reqs["failed"])
|
||||||
|
|
||||||
|
// cluster_uptime should be the oldest (smallest) uptime = 100.
|
||||||
|
assert.Equal(t, float64(100), result["cluster_uptime"])
|
||||||
|
assert.Equal(t, 2, result["total_instances"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_CircuitBreaker(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-open",
|
||||||
|
UptimeSeconds: 50,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(10), "succeeded": float64(5), "failed": float64(5), "skipped": float64(0), "current_requests_per_second": float64(1), "avg_requests_per_second": float64(1)}},
|
||||||
|
CircuitBreaker: map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"state": "open",
|
||||||
|
},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
InstanceID: "inst-closed",
|
||||||
|
UptimeSeconds: 60,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(10), "succeeded": float64(10), "failed": float64(0), "skipped": float64(0), "current_requests_per_second": float64(1), "avg_requests_per_second": float64(1)}},
|
||||||
|
CircuitBreaker: map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"state": "closed",
|
||||||
|
},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
cb := result["circuit_breaker"].(map[string]any)
|
||||||
|
assert.Equal(t, true, cb["enabled"])
|
||||||
|
assert.Equal(t, "open", cb["state"], "any open instance means cluster state = open")
|
||||||
|
assert.Equal(t, 1, cb["instances_open"])
|
||||||
|
assert.Equal(t, 1, cb["instances_closed"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_RetryBudget(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-rb",
|
||||||
|
UptimeSeconds: 10,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(1), "succeeded": float64(1), "failed": float64(0), "skipped": float64(0), "current_requests_per_second": float64(0), "avg_requests_per_second": float64(0)}},
|
||||||
|
RetryBudget: map[string]any{
|
||||||
|
"enabled": true,
|
||||||
|
"allowed_retries": float64(50),
|
||||||
|
"denied_retries": float64(10),
|
||||||
|
"total_attempts": float64(60),
|
||||||
|
"current_tokens": float64(80),
|
||||||
|
"max_tokens": float64(100),
|
||||||
|
"tokens_per_sec": float64(5),
|
||||||
|
},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
rb := result["retry_budget"].(map[string]any)
|
||||||
|
assert.Equal(t, true, rb["enabled"])
|
||||||
|
assert.Equal(t, int64(50), rb["allowed_retries"])
|
||||||
|
assert.Equal(t, int64(10), rb["denied_retries"])
|
||||||
|
assert.InDelta(t, 16.67, rb["denial_rate_pct"].(float64), 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_NilStats_DoesNotPanic(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
// Instance with nil Stats should not cause a panic; it is skipped.
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "bad-inst",
|
||||||
|
UptimeSeconds: 10,
|
||||||
|
Stats: nil,
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
// cluster_uptime is set before the nil-stats guard, so it must be non-zero.
|
||||||
|
assert.Equal(t, float64(10), result["cluster_uptime"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_MemoryTracking(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-mem",
|
||||||
|
UptimeSeconds: 10,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(1), "succeeded": float64(1), "failed": float64(0), "skipped": float64(0), "current_requests_per_second": float64(0), "avg_requests_per_second": float64(0)}},
|
||||||
|
Cache: map[string]any{"memory_usage_mb": float64(42.5)},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
InstanceID: "inst-mem2",
|
||||||
|
UptimeSeconds: 20,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(1), "succeeded": float64(1), "failed": float64(0), "skipped": float64(0), "current_requests_per_second": float64(0), "avg_requests_per_second": float64(0)}},
|
||||||
|
Cache: map[string]any{"memory_usage_mb": float64(57.5)},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
mem := result["memory"].(map[string]any)
|
||||||
|
assert.Equal(t, true, mem["available"])
|
||||||
|
assert.InDelta(t, 100.0, mem["total_usage_mb"].(float64), 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_AggregateStats_MemoryNegativeSkipped(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
// -1 means Redis cache where memory tracking is unavailable; must be skipped.
|
||||||
|
instances := []InstanceMetrics{
|
||||||
|
{
|
||||||
|
InstanceID: "inst-redis-cache",
|
||||||
|
UptimeSeconds: 10,
|
||||||
|
Stats: map[string]any{"requests": map[string]any{"total": float64(1), "succeeded": float64(1), "failed": float64(0), "skipped": float64(0), "current_requests_per_second": float64(0), "avg_requests_per_second": float64(0)}},
|
||||||
|
Cache: map[string]any{"memory_usage_mb": float64(-1)},
|
||||||
|
Health: map[string]any{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := ma.aggregateStats(instances)
|
||||||
|
mem := result["memory"].(map[string]any)
|
||||||
|
assert.Equal(t, false, mem["available"])
|
||||||
|
assert.Equal(t, float64(-1), mem["total_usage_mb"].(float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----- Shutdown ------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestMetricsAggregator_Shutdown_CancelsContext(t *testing.T) {
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() { mr.Close() })
|
||||||
|
|
||||||
|
client := redis.NewClient(&redis.Options{Addr: mr.Addr()})
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
ma := &MetricsAggregator{
|
||||||
|
redisClient: client,
|
||||||
|
logger: libpack_logger.New(),
|
||||||
|
instanceID: "shutdown-test",
|
||||||
|
publishKey: "graphql-proxy:metrics:instances",
|
||||||
|
ttl: 30 * time.Second,
|
||||||
|
publishTimer: time.NewTicker(time.Hour),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Context must not be done before Shutdown.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Fatal("context should not be done before Shutdown()")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
ma.Shutdown()
|
||||||
|
|
||||||
|
// Context must be cancelled after Shutdown.
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// expected
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Fatal("context was not cancelled after Shutdown()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsAggregator_Shutdown_Idempotent(t *testing.T) {
|
||||||
|
ma, _ := newTestAggregator(t)
|
||||||
|
|
||||||
|
// Calling Shutdown twice must not panic.
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
ma.Shutdown()
|
||||||
|
ma.Shutdown()
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -31,6 +31,19 @@ var (
|
|||||||
ErrCircuitOpen = errors.New("circuit breaker is open")
|
ErrCircuitOpen = errors.New("circuit breaker is open")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Sentinel errors for the proxy request retry path. Grouped here so callers
|
||||||
|
// can use errors.Is for comparison instead of brittle string matching.
|
||||||
|
// Message text MUST match the historical fmt.Errorf strings — tests and
|
||||||
|
// callers may assert on .Error().
|
||||||
|
var (
|
||||||
|
// errFiberCtxNilDuringRetry — fiber context dropped while retrying.
|
||||||
|
errFiberCtxNilDuringRetry = errors.New("fiber context became nil during retry")
|
||||||
|
// errFiberRespNil — fiber response object became nil mid-request.
|
||||||
|
errFiberRespNil = errors.New("fiber response became nil")
|
||||||
|
// errFiberCtxNil — fiber context was nil before the request started.
|
||||||
|
errFiberCtxNil = errors.New("fiber context is nil")
|
||||||
|
)
|
||||||
|
|
||||||
// Default values for circuit breaker
|
// Default values for circuit breaker
|
||||||
const (
|
const (
|
||||||
defaultMaxRequestsInHalfOpen = 10 // Default maximum requests in half-open state
|
defaultMaxRequestsInHalfOpen = 10 // Default maximum requests in half-open state
|
||||||
@@ -42,6 +55,30 @@ var (
|
|||||||
cbMutex sync.RWMutex
|
cbMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Package-level substring tables used by isConnectionError / isTimeoutError.
|
||||||
|
// Hoisted to avoid per-call slice allocations on the hot path. All entries
|
||||||
|
// must be lower-case; callers lower-case the error string once before matching.
|
||||||
|
var (
|
||||||
|
connectionErrorSubstrings = []string{
|
||||||
|
"connection refused",
|
||||||
|
"connection reset",
|
||||||
|
"no route to host",
|
||||||
|
"network is unreachable",
|
||||||
|
"broken pipe",
|
||||||
|
"connection closed",
|
||||||
|
"eof",
|
||||||
|
"no such host",
|
||||||
|
"dial tcp",
|
||||||
|
"dial udp",
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutErrorSubstrings = []string{
|
||||||
|
"timeout",
|
||||||
|
"deadline exceeded",
|
||||||
|
"context deadline exceeded",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// safeUint32 converts an int to uint32 safely, handling negative values and values exceeding uint32 max
|
// safeUint32 converts an int to uint32 safely, handling negative values and values exceeding uint32 max
|
||||||
func safeUint32(value int) uint32 {
|
func safeUint32(value int) uint32 {
|
||||||
// Handle negative values
|
// Handle negative values
|
||||||
@@ -351,7 +388,7 @@ func performProxyRequest(c *fiber.Ctx, proxyURL string) error {
|
|||||||
return &CoalescedResponse{
|
return &CoalescedResponse{
|
||||||
Body: c.Response().Body(),
|
Body: c.Response().Body(),
|
||||||
StatusCode: c.Response().StatusCode(),
|
StatusCode: c.Response().StatusCode(),
|
||||||
Headers: make(map[string]string),
|
// Headers intentionally left nil; not populated or read anywhere.
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -449,7 +486,7 @@ func performProxyRequestWithRetries(c *fiber.Ctx, proxyURL string) error {
|
|||||||
func executeProxyAttempt(c *fiber.Ctx, proxyURL string) error {
|
func executeProxyAttempt(c *fiber.Ctx, proxyURL string) error {
|
||||||
// Additional safety check inside retry loop
|
// Additional safety check inside retry loop
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return retry.Unrecoverable(fmt.Errorf("fiber context became nil during retry"))
|
return retry.Unrecoverable(errFiberCtxNilDuringRetry)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get connection pool manager for stats tracking
|
// Get connection pool manager for stats tracking
|
||||||
@@ -486,7 +523,7 @@ func executeProxyAttempt(c *fiber.Ctx, proxyURL string) error {
|
|||||||
|
|
||||||
// Safety check before accessing response (c is already validated at function entry)
|
// Safety check before accessing response (c is already validated at function entry)
|
||||||
if c.Response() == nil {
|
if c.Response() == nil {
|
||||||
return retry.Unrecoverable(fmt.Errorf("fiber response became nil"))
|
return retry.Unrecoverable(errFiberRespNil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check status code and determine retry strategy
|
// Check status code and determine retry strategy
|
||||||
@@ -518,7 +555,7 @@ func executeProxyAttempt(c *fiber.Ctx, proxyURL string) error {
|
|||||||
func performProxyRequestWithEnhancedRetries(c *fiber.Ctx, proxyURL string, backendUnhealthy bool) error {
|
func performProxyRequestWithEnhancedRetries(c *fiber.Ctx, proxyURL string, backendUnhealthy bool) error {
|
||||||
// Safety check for nil context
|
// Safety check for nil context
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return fmt.Errorf("fiber context is nil")
|
return errFiberCtxNil
|
||||||
}
|
}
|
||||||
|
|
||||||
var attempts uint
|
var attempts uint
|
||||||
@@ -620,20 +657,7 @@ func isConnectionError(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
errStr := strings.ToLower(err.Error())
|
errStr := strings.ToLower(err.Error())
|
||||||
connectionErrors := []string{
|
for _, connErr := range connectionErrorSubstrings {
|
||||||
"connection refused",
|
|
||||||
"connection reset",
|
|
||||||
"no route to host",
|
|
||||||
"network is unreachable",
|
|
||||||
"broken pipe",
|
|
||||||
"connection closed",
|
|
||||||
"eof",
|
|
||||||
"no such host",
|
|
||||||
"dial tcp",
|
|
||||||
"dial udp",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, connErr := range connectionErrors {
|
|
||||||
if strings.Contains(errStr, connErr) {
|
if strings.Contains(errStr, connErr) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -648,9 +672,12 @@ func isTimeoutError(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
errStr := strings.ToLower(err.Error())
|
errStr := strings.ToLower(err.Error())
|
||||||
return strings.Contains(errStr, "timeout") ||
|
for _, tErr := range timeoutErrorSubstrings {
|
||||||
strings.Contains(errStr, "deadline exceeded") ||
|
if strings.Contains(errStr, tErr) {
|
||||||
strings.Contains(errStr, "context deadline exceeded")
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isRetryableStatusCode determines if an HTTP status code should trigger a retry
|
// isRetryableStatusCode determines if an HTTP status code should trigger a retry
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ type RetryBudget struct {
|
|||||||
maxTokens int64
|
maxTokens int64
|
||||||
currentTokens atomic.Int64
|
currentTokens atomic.Int64
|
||||||
lastRefill atomic.Int64 // Unix timestamp in nanoseconds
|
lastRefill atomic.Int64 // Unix timestamp in nanoseconds
|
||||||
mu sync.RWMutex
|
|
||||||
enabled bool
|
enabled bool
|
||||||
logger *libpack_logger.Logger
|
logger *libpack_logger.Logger
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -182,9 +181,6 @@ func (rb *RetryBudget) Reset() {
|
|||||||
|
|
||||||
// UpdateConfig updates the retry budget configuration
|
// UpdateConfig updates the retry budget configuration
|
||||||
func (rb *RetryBudget) UpdateConfig(config RetryBudgetConfig) {
|
func (rb *RetryBudget) UpdateConfig(config RetryBudgetConfig) {
|
||||||
rb.mu.Lock()
|
|
||||||
defer rb.mu.Unlock()
|
|
||||||
|
|
||||||
rb.tokensPerSecond = config.TokensPerSecond
|
rb.tokensPerSecond = config.TokensPerSecond
|
||||||
rb.maxTokens = int64(config.MaxTokens)
|
rb.maxTokens = int64(config.MaxTokens)
|
||||||
rb.enabled = config.Enabled
|
rb.enabled = config.Enabled
|
||||||
|
|||||||
+4
-10
@@ -2,7 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@@ -10,9 +9,8 @@ import (
|
|||||||
// RPSTracker tracks requests per second using periodic sampling
|
// RPSTracker tracks requests per second using periodic sampling
|
||||||
type RPSTracker struct {
|
type RPSTracker struct {
|
||||||
lastCount atomic.Int64
|
lastCount atomic.Int64
|
||||||
lastSampleTime atomic.Int64 // Unix nano
|
lastSampleTime atomic.Int64 // Unix nano
|
||||||
currentRPS uint64 // stored as uint64, accessed with atomic operations
|
currentRPS atomic.Uint64 // centirps (RPS * 100)
|
||||||
mu sync.RWMutex // for currentRPS updates
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
@@ -74,9 +72,7 @@ func (r *RPSTracker) sample() {
|
|||||||
if elapsed > 0 {
|
if elapsed > 0 {
|
||||||
rps := float64(currentCount) / elapsed
|
rps := float64(currentCount) / elapsed
|
||||||
// Store RPS as centirps for precision (multiply by 100)
|
// Store RPS as centirps for precision (multiply by 100)
|
||||||
r.mu.Lock()
|
r.currentRPS.Store(uint64(rps * 100))
|
||||||
atomic.StoreUint64(&r.currentRPS, uint64(rps*100))
|
|
||||||
r.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset for next sample
|
// Reset for next sample
|
||||||
@@ -86,9 +82,7 @@ func (r *RPSTracker) sample() {
|
|||||||
|
|
||||||
// GetCurrentRPS returns the current requests per second
|
// GetCurrentRPS returns the current requests per second
|
||||||
func (r *RPSTracker) GetCurrentRPS() float64 {
|
func (r *RPSTracker) GetCurrentRPS() float64 {
|
||||||
r.mu.RLock()
|
centirps := r.currentRPS.Load()
|
||||||
centirps := atomic.LoadUint64(&r.currentRPS)
|
|
||||||
r.mu.RUnlock()
|
|
||||||
return float64(centirps) / 100.0
|
return float64(centirps) / 100.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+46
-14
@@ -4,10 +4,46 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// patternRegexCache caches the 5 outer regexes per sensitive field name.
|
||||||
|
// Pattern set is bounded by sensitiveFieldPatterns (fixed slice) — not a leak.
|
||||||
|
var patternRegexCache sync.Map // map[string]*patternRegexSet
|
||||||
|
|
||||||
|
type patternRegexSet struct {
|
||||||
|
json *regexp.Regexp
|
||||||
|
xml *regexp.Regexp
|
||||||
|
quoted *regexp.Regexp
|
||||||
|
singleQuote *regexp.Regexp
|
||||||
|
form *regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant inner regexes, pattern-independent — compile once.
|
||||||
|
var (
|
||||||
|
jsonValueRe = regexp.MustCompile(`:\s*"[^"]*"`)
|
||||||
|
xmlValueRe = regexp.MustCompile(`>[^<]*<`)
|
||||||
|
formValueRe = regexp.MustCompile(`=([^&\s"']+)`)
|
||||||
|
)
|
||||||
|
|
||||||
|
func getPatternRegexSet(pattern string) *patternRegexSet {
|
||||||
|
if v, ok := patternRegexCache.Load(pattern); ok {
|
||||||
|
return v.(*patternRegexSet)
|
||||||
|
}
|
||||||
|
quoted := regexp.QuoteMeta(pattern)
|
||||||
|
set := &patternRegexSet{
|
||||||
|
json: regexp.MustCompile(`(?i)"` + quoted + `"\s*:\s*"[^"]*"`),
|
||||||
|
xml: regexp.MustCompile(`(?i)<` + quoted + `>[^<]*</` + quoted + `>`),
|
||||||
|
quoted: regexp.MustCompile(`(?i)` + quoted + `="[^"]*"`),
|
||||||
|
singleQuote: regexp.MustCompile(`(?i)` + quoted + `='[^']*'`),
|
||||||
|
form: regexp.MustCompile(`(?i)` + quoted + `=([^&\s"']+)(?:[&\s]|$)`),
|
||||||
|
}
|
||||||
|
actual, _ := patternRegexCache.LoadOrStore(pattern, set)
|
||||||
|
return actual.(*patternRegexSet)
|
||||||
|
}
|
||||||
|
|
||||||
// Sanitization constants
|
// Sanitization constants
|
||||||
const (
|
const (
|
||||||
// MaxLogBodySize is the maximum size of body content to include in logs
|
// MaxLogBodySize is the maximum size of body content to include in logs
|
||||||
@@ -110,18 +146,17 @@ func redactSensitiveFields(data map[string]any, fields []string) {
|
|||||||
func redactPatternInString(text string, pattern string) string {
|
func redactPatternInString(text string, pattern string) string {
|
||||||
// Use proper regex to capture and redact complete sensitive values
|
// Use proper regex to capture and redact complete sensitive values
|
||||||
// Order matters: process most specific patterns first
|
// Order matters: process most specific patterns first
|
||||||
|
set := getPatternRegexSet(pattern)
|
||||||
|
|
||||||
// 1. JSON pattern: "field":"value" → "field":"[REDACTED]"
|
// 1. JSON pattern: "field":"value" → "field":"[REDACTED]"
|
||||||
jsonPattern := regexp.MustCompile(`(?i)"` + regexp.QuoteMeta(pattern) + `"\s*:\s*"[^"]*"`)
|
text = set.json.ReplaceAllStringFunc(text, func(match string) string {
|
||||||
text = jsonPattern.ReplaceAllStringFunc(text, func(match string) string {
|
return jsonValueRe.ReplaceAllString(match, `:"[REDACTED]"`)
|
||||||
return regexp.MustCompile(`:\s*"[^"]*"`).ReplaceAllString(match, `:"[REDACTED]"`)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// 2. XML pattern: <field>value</field> → <field>[REDACTED]</field>
|
// 2. XML pattern: <field>value</field> → <field>[REDACTED]</field>
|
||||||
xmlPattern := regexp.MustCompile(`(?i)<` + regexp.QuoteMeta(pattern) + `>[^<]*</` + regexp.QuoteMeta(pattern) + `>`)
|
xmlMatched := set.xml.MatchString(text)
|
||||||
xmlMatched := xmlPattern.MatchString(text)
|
text = set.xml.ReplaceAllStringFunc(text, func(match string) string {
|
||||||
text = xmlPattern.ReplaceAllStringFunc(text, func(match string) string {
|
return xmlValueRe.ReplaceAllString(match, ">[REDACTED]<")
|
||||||
return regexp.MustCompile(`>[^<]*<`).ReplaceAllString(match, ">[REDACTED]<")
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// If XML pattern was matched, also add a standardized redaction marker for test compatibility
|
// If XML pattern was matched, also add a standardized redaction marker for test compatibility
|
||||||
@@ -133,22 +168,19 @@ func redactPatternInString(text string, pattern string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3. Double quoted pattern: field="value" → field="[REDACTED]"
|
// 3. Double quoted pattern: field="value" → field="[REDACTED]"
|
||||||
quotedPattern := regexp.MustCompile(`(?i)` + regexp.QuoteMeta(pattern) + `="[^"]*"`)
|
text = set.quoted.ReplaceAllString(text, pattern+`="[REDACTED]"`)
|
||||||
text = quotedPattern.ReplaceAllString(text, pattern+`="[REDACTED]"`)
|
|
||||||
|
|
||||||
// 4. Single quoted pattern: field='value' → field='[REDACTED]'
|
// 4. Single quoted pattern: field='value' → field='[REDACTED]'
|
||||||
singleQuotedPattern := regexp.MustCompile(`(?i)` + regexp.QuoteMeta(pattern) + `='[^']*'`)
|
text = set.singleQuote.ReplaceAllString(text, pattern+`='[REDACTED]'`)
|
||||||
text = singleQuotedPattern.ReplaceAllString(text, pattern+`='[REDACTED]'`)
|
|
||||||
|
|
||||||
// 5. Form/URL pattern: field=value& or field=value$ → field=[REDACTED]& or field=[REDACTED]$
|
// 5. Form/URL pattern: field=value& or field=value$ → field=[REDACTED]& or field=[REDACTED]$
|
||||||
// This must be last and should only match unquoted values
|
// This must be last and should only match unquoted values
|
||||||
formPattern := regexp.MustCompile(`(?i)` + regexp.QuoteMeta(pattern) + `=([^&\s"']+)(?:[&\s]|$)`)
|
text = set.form.ReplaceAllStringFunc(text, func(match string) string {
|
||||||
text = formPattern.ReplaceAllStringFunc(text, func(match string) string {
|
|
||||||
// Only replace if the value is not already [REDACTED]
|
// Only replace if the value is not already [REDACTED]
|
||||||
if strings.Contains(match, "[REDACTED]") {
|
if strings.Contains(match, "[REDACTED]") {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
return regexp.MustCompile(`=([^&\s"']+)`).ReplaceAllString(match, "=[REDACTED]")
|
return formValueRe.ReplaceAllString(match, "=[REDACTED]")
|
||||||
})
|
})
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|||||||
@@ -293,7 +293,7 @@ func processGraphQLRequest(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle caching
|
// Handle caching
|
||||||
wasCached, err := handleCaching(c, parsedResult, extractedUserID)
|
wasCached, err := handleCaching(c, parsedResult, extractedUserID, extractedRoleName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -326,10 +326,7 @@ func extractUserInfo(c *fiber.Ctx) (string, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleCaching manages the caching logic for GraphQL requests
|
// handleCaching manages the caching logic for GraphQL requests
|
||||||
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID string) (bool, error) {
|
func handleCaching(c *fiber.Ctx, parsedResult *parseGraphQLQueryResult, userID, userRole string) (bool, error) {
|
||||||
// Extract user role for cache key (in addition to userID already passed)
|
|
||||||
_, userRole := extractUserInfo(c)
|
|
||||||
|
|
||||||
// Calculate query hash for cache key - now includes user context for security
|
// Calculate query hash for cache key - now includes user context for security
|
||||||
calculatedQueryHash := libpack_cache.CalculateHash(c, userID, userRole)
|
calculatedQueryHash := libpack_cache.CalculateHash(c, userID, userRole)
|
||||||
|
|
||||||
@@ -393,11 +390,10 @@ func proxyAndCacheTheRequest(c *fiber.Ctx, queryCacheHash string, cacheTime int,
|
|||||||
|
|
||||||
// logAndMonitorRequest logs and monitors the request processing.
|
// logAndMonitorRequest logs and monitors the request processing.
|
||||||
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
|
func logAndMonitorRequest(c *fiber.Ctx, userID, opType, opName string, wasCached bool, duration time.Duration, startTime time.Time) {
|
||||||
|
// Low-cardinality labels only: user_id and op_name dropped to prevent Prometheus explosion.
|
||||||
labels := map[string]string{
|
labels := map[string]string{
|
||||||
"op_type": opType,
|
"op_type": opType,
|
||||||
"op_name": opName,
|
|
||||||
"cached": strconv.FormatBool(wasCached),
|
"cached": strconv.FormatBool(wasCached),
|
||||||
"user_id": userID,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Server.AccessLog {
|
if cfg.Server.AccessLog {
|
||||||
|
|||||||
@@ -0,0 +1,601 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2"
|
||||||
|
libpack_cache "github.com/lukaszraczylo/graphql-monitoring-proxy/cache"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// AddRequestUUID
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestAddRequestUUID_SetsLocalsAndCallsNext(t *testing.T) {
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Use(AddRequestUUID)
|
||||||
|
|
||||||
|
var captured string
|
||||||
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
|
if v, ok := c.Locals("request_uuid").(string); ok {
|
||||||
|
captured = v
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
resp, err := app.Test(req, -1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Fatalf("want 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if captured == "" {
|
||||||
|
t.Fatal("request_uuid not set in Locals")
|
||||||
|
}
|
||||||
|
// UUIDs are 36 chars (xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx)
|
||||||
|
if len(captured) != 36 {
|
||||||
|
t.Errorf("unexpected UUID length: %q", captured)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddRequestUUID_UniquePerRequest(t *testing.T) {
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Use(AddRequestUUID)
|
||||||
|
|
||||||
|
seen := make([]string, 0, 5)
|
||||||
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
|
if v, ok := c.Locals("request_uuid").(string); ok {
|
||||||
|
seen = append(seen, v)
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range 5 {
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
resp, err := app.Test(req, -1)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("request %d: %v", i, err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
set := make(map[string]struct{}, len(seen))
|
||||||
|
for _, id := range seen {
|
||||||
|
set[id] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(set) != 5 {
|
||||||
|
t.Errorf("expected 5 unique UUIDs, got %d unique in %v", len(set), seen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// healthCheck
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHealthCheck_Returns200WithJSON(t *testing.T) {
|
||||||
|
// Ensure cfg is ready and GraphQL check is disabled via query param
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Get("/health", healthCheck)
|
||||||
|
|
||||||
|
// Pass check_graphql=false to avoid real network call
|
||||||
|
req := httptest.NewRequest("GET", "/health?check_graphql=false&check_redis=false", nil)
|
||||||
|
resp, err := app.Test(req, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Fatalf("want 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := body["status"]; !ok {
|
||||||
|
t.Error("response missing 'status' field")
|
||||||
|
}
|
||||||
|
if _, ok := body["timestamp"]; !ok {
|
||||||
|
t.Error("response missing 'timestamp' field")
|
||||||
|
}
|
||||||
|
if body["status"] != "healthy" {
|
||||||
|
t.Errorf("want status=healthy, got %v", body["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthCheck_UnhealthyWhenGraphQLDown(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
// Point to a server that refuses connections
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
cfg.Server.HostGraphQL = "http://127.0.0.1:1" // port 1 always refused
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Get("/health", healthCheck)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/health?check_redis=false", nil)
|
||||||
|
resp, err := app.Test(req, 15000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// Should return 503 when backend is unreachable
|
||||||
|
if resp.StatusCode != fiber.StatusServiceUnavailable {
|
||||||
|
t.Fatalf("want 503, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if body["status"] != "unhealthy" {
|
||||||
|
t.Errorf("want unhealthy, got %v", body["status"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// processGraphQLRequest
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestProcessGraphQLRequest_ValidBodyProxiesToBackend(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, _ = w.Write([]byte(`{"data":{"test":"ok"}}`))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
origHostRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
origCache := cfg.Cache.CacheEnable
|
||||||
|
cfg.Server.HostGraphQL = backend.URL
|
||||||
|
cfg.Server.HostGraphQLReadOnly = backend.URL
|
||||||
|
cfg.Cache.CacheEnable = false
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origHostRO
|
||||||
|
cfg.Cache.CacheEnable = origCache
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Post("/*", processGraphQLRequest)
|
||||||
|
|
||||||
|
body := `{"query":"query { __typename }"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/graphql", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("want 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcessGraphQLRequest_MalformedBodyStillHandled(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
// Backend that always returns 200 (malformed body is handled by proxy layer)
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, _ = w.Write([]byte(`{"errors":[{"message":"parse error"}]}`))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
origHostRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
origCache := cfg.Cache.CacheEnable
|
||||||
|
cfg.Server.HostGraphQL = backend.URL
|
||||||
|
cfg.Server.HostGraphQLReadOnly = backend.URL
|
||||||
|
cfg.Cache.CacheEnable = false
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origHostRO
|
||||||
|
cfg.Cache.CacheEnable = origCache
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Post("/*", processGraphQLRequest)
|
||||||
|
|
||||||
|
// Not valid JSON — proxy should still forward or return gracefully
|
||||||
|
body := `not-json-at-all`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/graphql", strings.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// Should not panic; any 2xx or 5xx is acceptable — just must not crash
|
||||||
|
if resp.StatusCode < 100 || resp.StatusCode > 599 {
|
||||||
|
t.Errorf("unexpected status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// handleCaching — wasCached=true path (cache hit)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleCaching_CacheHitReturnsStoredResponse(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
// Enable in-memory cache
|
||||||
|
libpack_cache.EnableCache(&libpack_cache.CacheConfig{
|
||||||
|
Logger: cfg.Logger,
|
||||||
|
TTL: 60,
|
||||||
|
})
|
||||||
|
libpack_cache.CacheClear()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origEnable := cfg.Cache.CacheEnable
|
||||||
|
cfg.Cache.CacheEnable = true
|
||||||
|
cfg.Cache.CacheTTL = 60
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Cache.CacheEnable = origEnable
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, _ = w.Write([]byte(`{"data":{"users":[]}}`))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
origHostRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
cfg.Server.HostGraphQL = backend.URL
|
||||||
|
cfg.Server.HostGraphQLReadOnly = backend.URL
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origHostRO
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Post("/*", processGraphQLRequest)
|
||||||
|
|
||||||
|
queryBody := `{"query":"query { users { id } }"}`
|
||||||
|
|
||||||
|
// First request — cache miss, hits backend
|
||||||
|
req1 := httptest.NewRequest("POST", "/v1/graphql", strings.NewReader(queryBody))
|
||||||
|
req1.Header.Set("Content-Type", "application/json")
|
||||||
|
resp1, err := app.Test(req1, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first request: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp1.Body.Close()
|
||||||
|
|
||||||
|
if resp1.StatusCode != 200 {
|
||||||
|
t.Fatalf("first request want 200, got %d", resp1.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second identical request — should hit cache
|
||||||
|
req2 := httptest.NewRequest("POST", "/v1/graphql", strings.NewReader(queryBody))
|
||||||
|
req2.Header.Set("Content-Type", "application/json")
|
||||||
|
resp2, err := app.Test(req2, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second request: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp2.Body.Close()
|
||||||
|
|
||||||
|
if resp2.StatusCode != 200 {
|
||||||
|
t.Fatalf("second request want 200, got %d", resp2.StatusCode)
|
||||||
|
}
|
||||||
|
if resp2.Header.Get("X-Cache-Hit") != "true" {
|
||||||
|
t.Error("second request should have X-Cache-Hit: true header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleCaching_CacheMissProxiesRequest(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
libpack_cache.EnableCache(&libpack_cache.CacheConfig{
|
||||||
|
Logger: cfg.Logger,
|
||||||
|
TTL: 60,
|
||||||
|
})
|
||||||
|
libpack_cache.CacheClear()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origEnable := cfg.Cache.CacheEnable
|
||||||
|
cfg.Cache.CacheEnable = true
|
||||||
|
cfg.Cache.CacheTTL = 60
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Cache.CacheEnable = origEnable
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
backendCalled := 0
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
backendCalled++
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, _ = fmt.Fprintf(w, `{"data":{"call":%d}}`, backendCalled)
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
origHostRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
cfg.Server.HostGraphQL = backend.URL
|
||||||
|
cfg.Server.HostGraphQLReadOnly = backend.URL
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origHostRO
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
app.Post("/*", processGraphQLRequest)
|
||||||
|
|
||||||
|
// Unique query so no prior cache entry
|
||||||
|
queryBody := `{"query":"query { uniqueMissTest_12345 { id } }"}`
|
||||||
|
req := httptest.NewRequest("POST", "/v1/graphql", strings.NewReader(queryBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := app.Test(req, 10000)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("app.Test: %v", err)
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Errorf("want 200, got %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if resp.Header.Get("X-Cache-Hit") == "true" {
|
||||||
|
t.Error("first request should not be a cache hit")
|
||||||
|
}
|
||||||
|
if backendCalled == 0 {
|
||||||
|
t.Error("backend should have been called on cache miss")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// handleCaching — direct unit test for wasCached=true branch
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestHandleCaching_DirectCacheHitBranch(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
libpack_cache.EnableCache(&libpack_cache.CacheConfig{
|
||||||
|
Logger: cfg.Logger,
|
||||||
|
TTL: 60,
|
||||||
|
})
|
||||||
|
libpack_cache.CacheClear()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origEnable := cfg.Cache.CacheEnable
|
||||||
|
cfg.Cache.CacheEnable = true
|
||||||
|
cfg.Cache.CacheTTL = 60
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Cache.CacheEnable = origEnable
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
|
||||||
|
var wasCachedResult bool
|
||||||
|
app.Post("/test", func(c *fiber.Ctx) error {
|
||||||
|
parsedResult := &parseGraphQLQueryResult{
|
||||||
|
cacheTime: 60,
|
||||||
|
cacheRequest: true,
|
||||||
|
activeEndpoint: cfg.Server.HostGraphQL,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pre-populate the cache so lookup hits
|
||||||
|
cacheKey := libpack_cache.CalculateHash(c, "-", "-")
|
||||||
|
libpack_cache.CacheStore(cacheKey, []byte(`{"data":{"cached":true}}`))
|
||||||
|
|
||||||
|
var err error
|
||||||
|
wasCachedResult, err = handleCaching(c, parsedResult, "-", "-")
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
reqCtx := &fasthttp.RequestCtx{}
|
||||||
|
reqCtx.Request.SetRequestURI("/test")
|
||||||
|
reqCtx.Request.Header.SetMethod("POST")
|
||||||
|
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
reqCtx.Request.SetBody([]byte(`{"query":"query { cachedQuery }"}`))
|
||||||
|
|
||||||
|
ctx := app.AcquireCtx(reqCtx)
|
||||||
|
defer app.ReleaseCtx(ctx)
|
||||||
|
|
||||||
|
parsedResult := &parseGraphQLQueryResult{
|
||||||
|
cacheTime: 60,
|
||||||
|
cacheRequest: true,
|
||||||
|
activeEndpoint: cfg.Server.HostGraphQL,
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := libpack_cache.CalculateHash(ctx, "-", "-")
|
||||||
|
libpack_cache.CacheStore(cacheKey, []byte(`{"data":{"cached":true}}`))
|
||||||
|
|
||||||
|
wasCached, err := handleCaching(ctx, parsedResult, "-", "-")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("handleCaching returned error: %v", err)
|
||||||
|
}
|
||||||
|
if !wasCached {
|
||||||
|
t.Error("expected wasCached=true when cache hit")
|
||||||
|
}
|
||||||
|
_ = wasCachedResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleCaching_NoCacheEnabled_ProxiesDirect(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(200)
|
||||||
|
_, _ = w.Write([]byte(`{"data":{"noCacheTest":true}}`))
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origEnable := cfg.Cache.CacheEnable
|
||||||
|
origRedis := cfg.Cache.CacheRedisEnable
|
||||||
|
origHost := cfg.Server.HostGraphQL
|
||||||
|
origHostRO := cfg.Server.HostGraphQLReadOnly
|
||||||
|
cfg.Cache.CacheEnable = false
|
||||||
|
cfg.Cache.CacheRedisEnable = false
|
||||||
|
cfg.Server.HostGraphQL = backend.URL
|
||||||
|
cfg.Server.HostGraphQLReadOnly = backend.URL
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Cache.CacheEnable = origEnable
|
||||||
|
cfg.Cache.CacheRedisEnable = origRedis
|
||||||
|
cfg.Server.HostGraphQL = origHost
|
||||||
|
cfg.Server.HostGraphQLReadOnly = origHostRO
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
app := fiber.New(fiber.Config{DisableStartupMessage: true})
|
||||||
|
|
||||||
|
reqCtx := &fasthttp.RequestCtx{}
|
||||||
|
reqCtx.Request.SetRequestURI("/v1/graphql")
|
||||||
|
reqCtx.Request.Header.SetMethod("POST")
|
||||||
|
reqCtx.Request.Header.Set("Content-Type", "application/json")
|
||||||
|
reqCtx.Request.SetBody([]byte(`{"query":"query { noCacheTest }"}`))
|
||||||
|
|
||||||
|
fCtx := app.AcquireCtx(reqCtx)
|
||||||
|
defer app.ReleaseCtx(fCtx)
|
||||||
|
|
||||||
|
parsedResult := &parseGraphQLQueryResult{
|
||||||
|
cacheRequest: false,
|
||||||
|
cacheTime: 0,
|
||||||
|
activeEndpoint: backend.URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
wasCached, err := handleCaching(fCtx, parsedResult, "-", "-")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("handleCaching error: %v", err)
|
||||||
|
}
|
||||||
|
if wasCached {
|
||||||
|
t.Error("expected wasCached=false when cache disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// StartHTTPProxy — starts then shuts down cleanly
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestStartHTTPProxy_StartsAndShutdown(t *testing.T) {
|
||||||
|
parseConfig()
|
||||||
|
_ = StartMonitoringServer()
|
||||||
|
|
||||||
|
// Grab a free port
|
||||||
|
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("net.Listen: %v", err)
|
||||||
|
}
|
||||||
|
port := l.Addr().(*net.TCPAddr).Port
|
||||||
|
_ = l.Close()
|
||||||
|
|
||||||
|
cfgMutex.Lock()
|
||||||
|
origPort := cfg.Server.PortGraphQL
|
||||||
|
origTimeout := cfg.Client.ClientTimeout
|
||||||
|
origWS := cfg.WebSocket.Enable
|
||||||
|
origAdmin := cfg.AdminDashboard.Enable
|
||||||
|
cfg.Server.PortGraphQL = port
|
||||||
|
cfg.Client.ClientTimeout = 5
|
||||||
|
cfg.WebSocket.Enable = false
|
||||||
|
cfg.AdminDashboard.Enable = false
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cfgMutex.Lock()
|
||||||
|
cfg.Server.PortGraphQL = origPort
|
||||||
|
cfg.Client.ClientTimeout = origTimeout
|
||||||
|
cfg.WebSocket.Enable = origWS
|
||||||
|
cfg.AdminDashboard.Enable = origAdmin
|
||||||
|
cfgMutex.Unlock()
|
||||||
|
})
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- StartHTTPProxy()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for server to bind
|
||||||
|
deadline := time.Now().Add(3 * time.Second)
|
||||||
|
var conn net.Conn
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
conn, err = net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 100*time.Millisecond)
|
||||||
|
if err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
t.Fatalf("server did not start on port %d within 3s", port)
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
// Send a health check to confirm it's serving
|
||||||
|
httpResp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/health?check_graphql=false&check_redis=false", port))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET /health: %v", err)
|
||||||
|
}
|
||||||
|
_ = httpResp.Body.Close()
|
||||||
|
if httpResp.StatusCode != 200 {
|
||||||
|
t.Errorf("want 200, got %d", httpResp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
+6
-5
@@ -18,11 +18,12 @@ type EndpointCBConfig struct {
|
|||||||
// config is a struct that holds the configuration of the application.
|
// config is a struct that holds the configuration of the application.
|
||||||
// It includes settings for logging, monitoring, client connections, security, and server behavior.
|
// It includes settings for logging, monitoring, client connections, security, and server behavior.
|
||||||
type config struct {
|
type config struct {
|
||||||
Logger *libpack_logging.Logger
|
Logger *libpack_logging.Logger
|
||||||
Monitoring *libpack_monitoring.MetricsSetup
|
Monitoring *libpack_monitoring.MetricsSetup
|
||||||
LogLevel string
|
LogLevel string
|
||||||
Api struct{ BannedUsersFile string }
|
EnableAllocationTracking bool
|
||||||
Tracing struct {
|
Api struct{ BannedUsersFile string }
|
||||||
|
Tracing struct {
|
||||||
Endpoint string
|
Endpoint string
|
||||||
Enable bool
|
Enable bool
|
||||||
}
|
}
|
||||||
|
|||||||
+13
-6
@@ -25,6 +25,14 @@ type TracingSetup struct {
|
|||||||
tracer trace.Tracer
|
tracer trace.Tracer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// constSpanAttrs holds attributes that are identical for every span created
|
||||||
|
// by this package. Building the slice once at package init avoids two
|
||||||
|
// allocations per StartSpan / StartSpanWithAttributes call.
|
||||||
|
var constSpanAttrs = []attribute.KeyValue{
|
||||||
|
semconv.ServiceName("graphql-monitoring-proxy"),
|
||||||
|
semconv.ServiceVersion("1.0"),
|
||||||
|
}
|
||||||
|
|
||||||
type TraceSpanInfo struct {
|
type TraceSpanInfo struct {
|
||||||
TraceParent string `json:"traceparent"`
|
TraceParent string `json:"traceparent"`
|
||||||
}
|
}
|
||||||
@@ -158,12 +166,11 @@ func (ts *TracingSetup) StartSpanWithAttributes(ctx context.Context, name string
|
|||||||
return trace.SpanFromContext(ctx), ctx
|
return trace.SpanFromContext(ctx), ctx
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert string attributes to KeyValue pairs
|
// Convert string attributes to KeyValue pairs.
|
||||||
attributes := make([]attribute.KeyValue, 0, len(attrs)+2)
|
// Pre-size with constants + per-call attrs, copy constant block in one shot,
|
||||||
attributes = append(attributes,
|
// then append the dynamic attributes.
|
||||||
semconv.ServiceName("graphql-monitoring-proxy"),
|
attributes := make([]attribute.KeyValue, len(constSpanAttrs), len(constSpanAttrs)+len(attrs))
|
||||||
semconv.ServiceVersion("1.0"),
|
copy(attributes, constSpanAttrs)
|
||||||
)
|
|
||||||
|
|
||||||
for k, v := range attrs {
|
for k, v := range attrs {
|
||||||
attributes = append(attributes, attribute.String(k, v))
|
attributes = append(attributes, attribute.String(k, v))
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package tracing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/propagation"
|
||||||
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
|
"go.opentelemetry.io/otel/sdk/trace/tracetest"
|
||||||
|
"go.opentelemetry.io/otel/trace/noop"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNewTracing_NilContext covers the nil context early-return branch (line 34-36).
|
||||||
|
func TestNewTracing_NilContext_ReturnsError(t *testing.T) {
|
||||||
|
_, err := NewTracing(nil, "localhost:4317") //nolint:staticcheck // SA1012: intentional nil to test the error branch
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "context cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewTracing_InvalidEndpointFormats covers endpoint validation branches.
|
||||||
|
// Note: fmt.Sscanf("%s:%d") treats %s as greedy so any "host:port" string hits
|
||||||
|
// the format error (n!=2). The port-range branch (port>65535) requires n==2
|
||||||
|
// which Sscanf never produces for "host:port" strings — that's a source quirk.
|
||||||
|
func TestNewTracing_InvalidEndpointFormats_ReturnsError(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
endpoint string
|
||||||
|
}{
|
||||||
|
{name: "no port separator", endpoint: "localhost"},
|
||||||
|
{name: "port over max", endpoint: "localhost:999999"},
|
||||||
|
{name: "plain hostname only", endpoint: "myhost"},
|
||||||
|
{name: "just a number", endpoint: "12345"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
_, err := NewTracing(context.Background(), tt.endpoint)
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid endpoint format")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestShutdown_WithRealProvider covers the non-nil tracerProvider shutdown path (line 133).
|
||||||
|
func TestShutdown_WithRealProvider_NoError(t *testing.T) {
|
||||||
|
// Use in-memory exporter so no network needed.
|
||||||
|
exporter := tracetest.NewInMemoryExporter()
|
||||||
|
tp := sdktrace.NewTracerProvider(
|
||||||
|
sdktrace.WithSyncer(exporter),
|
||||||
|
)
|
||||||
|
ts := &TracingSetup{
|
||||||
|
tracerProvider: tp,
|
||||||
|
tracer: tp.Tracer("shutdown-test"),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := ts.Shutdown(ctx)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStartSpan_WithRealTracer covers StartSpan with a real (noop) tracer — the non-nil path.
|
||||||
|
func TestStartSpan_WithRealTracer_ReturnsSpan(t *testing.T) {
|
||||||
|
tp := noop.NewTracerProvider()
|
||||||
|
ts := &TracingSetup{
|
||||||
|
tracer: tp.Tracer("start-span-test"),
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
span, newCtx := ts.StartSpan(ctx, "my-operation")
|
||||||
|
assert.NotNil(t, span)
|
||||||
|
assert.NotNil(t, newCtx)
|
||||||
|
span.End()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStartSpanWithAttributes_WithRealTracer covers the non-nil tracer path with attrs.
|
||||||
|
func TestStartSpanWithAttributes_WithRealTracer_RecordsSpan(t *testing.T) {
|
||||||
|
exporter := tracetest.NewInMemoryExporter()
|
||||||
|
tp := sdktrace.NewTracerProvider(
|
||||||
|
sdktrace.WithSyncer(exporter),
|
||||||
|
sdktrace.WithSampler(sdktrace.AlwaysSample()),
|
||||||
|
)
|
||||||
|
ts := &TracingSetup{
|
||||||
|
tracerProvider: tp,
|
||||||
|
tracer: tp.Tracer("attr-test"),
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
attrs := map[string]string{
|
||||||
|
"user.id": "u-42",
|
||||||
|
"operation": "query",
|
||||||
|
}
|
||||||
|
span, newCtx := ts.StartSpanWithAttributes(ctx, "graphql-query", attrs)
|
||||||
|
require.NotNil(t, span)
|
||||||
|
require.NotNil(t, newCtx)
|
||||||
|
span.End()
|
||||||
|
|
||||||
|
spans := exporter.GetSpans()
|
||||||
|
require.Len(t, spans, 1)
|
||||||
|
assert.Equal(t, "graphql-query", spans[0].Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExtractSpanContext_ValidTraceparent covers the valid span context branch (line 115-116).
|
||||||
|
// ExtractSpanContext uses otel.GetTextMapPropagator(); we must register the W3C
|
||||||
|
// TraceContext propagator before calling it (NewTracing normally does this).
|
||||||
|
func TestExtractSpanContext_ValidTraceparent_ReturnsValid(t *testing.T) {
|
||||||
|
otel.SetTextMapPropagator(propagation.TraceContext{})
|
||||||
|
|
||||||
|
tp := noop.NewTracerProvider()
|
||||||
|
ts := &TracingSetup{
|
||||||
|
tracer: tp.Tracer("extract-test"),
|
||||||
|
}
|
||||||
|
spanInfo := &TraceSpanInfo{
|
||||||
|
TraceParent: "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01",
|
||||||
|
}
|
||||||
|
spanCtx, err := ts.ExtractSpanContext(spanInfo)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, spanCtx.IsValid())
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user