Compare commits

...

3 Commits

+125 -4
View File
@@ -8,6 +8,7 @@ import (
"sync/atomic"
"time"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
gorillaws "github.com/gorilla/websocket"
@@ -90,9 +91,16 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
// Capture headers from the upgrade request to forward to backend
headers := make(http.Header)
var subprotocols []string
for key, value := range c.Request().Header.All() {
keyStr := string(key)
// Forward important headers (skip connection-specific ones)
// Capture subprotocol separately
if keyStr == "Sec-Websocket-Protocol" || keyStr == "Sec-WebSocket-Protocol" {
subprotocols = append(subprotocols, string(value))
}
// Forward important headers including WebSocket subprotocol
// Skip only connection-establishment headers that will be regenerated
if keyStr != "Connection" && keyStr != "Upgrade" &&
keyStr != "Sec-Websocket-Key" && keyStr != "Sec-Websocket-Version" &&
keyStr != "Sec-Websocket-Extensions" {
@@ -100,11 +108,16 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
}
}
// Configure WebSocket with subprotocol support
config := websocket.Config{
Subprotocols: subprotocols,
}
return websocket.New(func(clientConn *websocket.Conn) {
// Use background context for long-lived WebSocket connections
// The original request context expires after the upgrade
wsp.handleConnection(context.Background(), clientConn, headers)
})(c)
}, config)(c)
}
// handleConnection manages a single WebSocket connection
@@ -129,8 +142,29 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
// Set message size limit
clientConn.SetReadLimit(wsp.maxMessageSize)
// Connect to backend WebSocket with forwarded headers
backendConn, err := wsp.dialBackend(ctx, headers)
// Read first message to extract authentication from connection_init payload
// This bridges the gap between clients that send auth in payload vs Hasura expecting it in HTTP headers
messageType, message, err := clientConn.ReadMessage()
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
wsp.logger.Error(&libpack_logger.LogMessage{
Message: "Failed to read first message from client",
Pairs: map[string]interface{}{
"connection_id": connectionID,
"error": err.Error(),
},
})
}
clientConn.Close()
return
}
// Try to extract headers from connection_init payload (for GraphQL WebSocket protocols)
enrichedHeaders := wsp.extractAuthFromPayload(message, headers)
// Connect to backend WebSocket with enriched headers
backendConn, err := wsp.dialBackend(ctx, enrichedHeaders)
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
@@ -147,6 +181,32 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
}
defer backendConn.Close()
// Forward the first message (connection_init) to backend
if err := backendConn.WriteMessage(messageType, message); err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
wsp.logger.Error(&libpack_logger.LogMessage{
Message: "Failed to forward connection_init to backend",
Pairs: map[string]interface{}{
"connection_id": connectionID,
"error": err.Error(),
},
})
}
return
}
if wsp.logger != nil {
wsp.logger.Debug(&libpack_logger.LogMessage{
Message: "Backend WebSocket connection established",
Pairs: map[string]interface{}{
"connection_id": connectionID,
"subprotocol": backendConn.Subprotocol(),
"has_authorization": headers.Get("Authorization") != "",
},
})
}
// Set up bidirectional proxying
var wg sync.WaitGroup
wg.Add(2)
@@ -313,6 +373,58 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
}
}
// extractAuthFromPayload extracts authentication headers from GraphQL WebSocket connection_init payload
// This bridges the gap between clients sending auth in payload and Hasura expecting it in HTTP headers
func (wsp *WebSocketProxy) extractAuthFromPayload(message []byte, originalHeaders http.Header) http.Header {
// Create a copy of original headers
enrichedHeaders := make(http.Header)
for k, v := range originalHeaders {
enrichedHeaders[k] = v
}
// Try to parse as JSON to extract headers from payload
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
// Not JSON or parse error, return original headers
return enrichedHeaders
}
// Check if this is a connection_init message
msgType, ok := msg["type"].(string)
if !ok || (msgType != "connection_init" && msgType != "start") {
// Not a connection_init, return original headers
return enrichedHeaders
}
// Extract payload
payload, ok := msg["payload"].(map[string]interface{})
if !ok {
return enrichedHeaders
}
// Try to extract headers from payload.headers (graphql-ws format)
if payloadHeaders, ok := payload["headers"].(map[string]interface{}); ok {
for key, value := range payloadHeaders {
if strValue, ok := value.(string); ok {
enrichedHeaders.Set(key, strValue)
}
}
}
// Also check top-level payload keys that look like headers (Apollo format)
for key, value := range payload {
if strValue, ok := value.(string); ok {
// Common auth headers
if key == "Authorization" || key == "authorization" ||
key == "x-hasura-role" || key == "x-hasura-admin-secret" {
enrichedHeaders.Set(key, strValue)
}
}
}
return enrichedHeaders
}
// dialBackend establishes a WebSocket connection to the backend
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
// Convert http:// to ws:// or https:// to wss://
@@ -326,9 +438,18 @@ func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header)
// Append GraphQL WebSocket path
wsURL = wsURL + "/v1/graphql"
// Extract subprotocols from headers (e.g., graphql-ws, graphql-transport-ws)
var subprotocols []string
if proto := headers.Get("Sec-WebSocket-Protocol"); proto != "" {
subprotocols = []string{proto}
// Remove from headers since it will be set via Subprotocols field
headers.Del("Sec-WebSocket-Protocol")
}
// Use gorilla websocket dialer
dialer := gorillaws.Dialer{
HandshakeTimeout: 10 * time.Second,
Subprotocols: subprotocols,
}
// Dial the backend with forwarded headers