mirror of
https://github.com/lukaszraczylo/graphql-monitoring-proxy.git
synced 2026-06-14 02:32:10 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5ae4ea1e25 | |||
| fd30dc0890 | |||
| 2966661054 |
+125
-4
@@ -8,6 +8,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
gorillaws "github.com/gorilla/websocket"
|
||||
@@ -90,9 +91,16 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
|
||||
|
||||
// Capture headers from the upgrade request to forward to backend
|
||||
headers := make(http.Header)
|
||||
var subprotocols []string
|
||||
|
||||
for key, value := range c.Request().Header.All() {
|
||||
keyStr := string(key)
|
||||
// Forward important headers (skip connection-specific ones)
|
||||
// Capture subprotocol separately
|
||||
if keyStr == "Sec-Websocket-Protocol" || keyStr == "Sec-WebSocket-Protocol" {
|
||||
subprotocols = append(subprotocols, string(value))
|
||||
}
|
||||
// Forward important headers including WebSocket subprotocol
|
||||
// Skip only connection-establishment headers that will be regenerated
|
||||
if keyStr != "Connection" && keyStr != "Upgrade" &&
|
||||
keyStr != "Sec-Websocket-Key" && keyStr != "Sec-Websocket-Version" &&
|
||||
keyStr != "Sec-Websocket-Extensions" {
|
||||
@@ -100,11 +108,16 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure WebSocket with subprotocol support
|
||||
config := websocket.Config{
|
||||
Subprotocols: subprotocols,
|
||||
}
|
||||
|
||||
return websocket.New(func(clientConn *websocket.Conn) {
|
||||
// Use background context for long-lived WebSocket connections
|
||||
// The original request context expires after the upgrade
|
||||
wsp.handleConnection(context.Background(), clientConn, headers)
|
||||
})(c)
|
||||
}, config)(c)
|
||||
}
|
||||
|
||||
// handleConnection manages a single WebSocket connection
|
||||
@@ -129,8 +142,29 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
|
||||
// Set message size limit
|
||||
clientConn.SetReadLimit(wsp.maxMessageSize)
|
||||
|
||||
// Connect to backend WebSocket with forwarded headers
|
||||
backendConn, err := wsp.dialBackend(ctx, headers)
|
||||
// Read first message to extract authentication from connection_init payload
|
||||
// This bridges the gap between clients that send auth in payload vs Hasura expecting it in HTTP headers
|
||||
messageType, message, err := clientConn.ReadMessage()
|
||||
if err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to read first message from client",
|
||||
Pairs: map[string]interface{}{
|
||||
"connection_id": connectionID,
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
clientConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Try to extract headers from connection_init payload (for GraphQL WebSocket protocols)
|
||||
enrichedHeaders := wsp.extractAuthFromPayload(message, headers)
|
||||
|
||||
// Connect to backend WebSocket with enriched headers
|
||||
backendConn, err := wsp.dialBackend(ctx, enrichedHeaders)
|
||||
if err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
@@ -147,6 +181,32 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
// Forward the first message (connection_init) to backend
|
||||
if err := backendConn.WriteMessage(messageType, message); err != nil {
|
||||
wsp.errors.Add(1)
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Error(&libpack_logger.LogMessage{
|
||||
Message: "Failed to forward connection_init to backend",
|
||||
Pairs: map[string]interface{}{
|
||||
"connection_id": connectionID,
|
||||
"error": err.Error(),
|
||||
},
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if wsp.logger != nil {
|
||||
wsp.logger.Debug(&libpack_logger.LogMessage{
|
||||
Message: "Backend WebSocket connection established",
|
||||
Pairs: map[string]interface{}{
|
||||
"connection_id": connectionID,
|
||||
"subprotocol": backendConn.Subprotocol(),
|
||||
"has_authorization": headers.Get("Authorization") != "",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Set up bidirectional proxying
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
@@ -313,6 +373,58 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
|
||||
}
|
||||
}
|
||||
|
||||
// extractAuthFromPayload extracts authentication headers from GraphQL WebSocket connection_init payload
|
||||
// This bridges the gap between clients sending auth in payload and Hasura expecting it in HTTP headers
|
||||
func (wsp *WebSocketProxy) extractAuthFromPayload(message []byte, originalHeaders http.Header) http.Header {
|
||||
// Create a copy of original headers
|
||||
enrichedHeaders := make(http.Header)
|
||||
for k, v := range originalHeaders {
|
||||
enrichedHeaders[k] = v
|
||||
}
|
||||
|
||||
// Try to parse as JSON to extract headers from payload
|
||||
var msg map[string]interface{}
|
||||
if err := json.Unmarshal(message, &msg); err != nil {
|
||||
// Not JSON or parse error, return original headers
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Check if this is a connection_init message
|
||||
msgType, ok := msg["type"].(string)
|
||||
if !ok || (msgType != "connection_init" && msgType != "start") {
|
||||
// Not a connection_init, return original headers
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Extract payload
|
||||
payload, ok := msg["payload"].(map[string]interface{})
|
||||
if !ok {
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// Try to extract headers from payload.headers (graphql-ws format)
|
||||
if payloadHeaders, ok := payload["headers"].(map[string]interface{}); ok {
|
||||
for key, value := range payloadHeaders {
|
||||
if strValue, ok := value.(string); ok {
|
||||
enrichedHeaders.Set(key, strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check top-level payload keys that look like headers (Apollo format)
|
||||
for key, value := range payload {
|
||||
if strValue, ok := value.(string); ok {
|
||||
// Common auth headers
|
||||
if key == "Authorization" || key == "authorization" ||
|
||||
key == "x-hasura-role" || key == "x-hasura-admin-secret" {
|
||||
enrichedHeaders.Set(key, strValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return enrichedHeaders
|
||||
}
|
||||
|
||||
// dialBackend establishes a WebSocket connection to the backend
|
||||
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
|
||||
// Convert http:// to ws:// or https:// to wss://
|
||||
@@ -326,9 +438,18 @@ func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header)
|
||||
// Append GraphQL WebSocket path
|
||||
wsURL = wsURL + "/v1/graphql"
|
||||
|
||||
// Extract subprotocols from headers (e.g., graphql-ws, graphql-transport-ws)
|
||||
var subprotocols []string
|
||||
if proto := headers.Get("Sec-WebSocket-Protocol"); proto != "" {
|
||||
subprotocols = []string{proto}
|
||||
// Remove from headers since it will be set via Subprotocols field
|
||||
headers.Del("Sec-WebSocket-Protocol")
|
||||
}
|
||||
|
||||
// Use gorilla websocket dialer
|
||||
dialer := gorillaws.Dialer{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
Subprotocols: subprotocols,
|
||||
}
|
||||
|
||||
// Dial the backend with forwarded headers
|
||||
|
||||
Reference in New Issue
Block a user