Compare commits

...

2 Commits

2 changed files with 24 additions and 8 deletions
+21 -7
View File
@@ -88,15 +88,27 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
}
// Capture headers from the upgrade request to forward to backend
headers := make(http.Header)
for key, value := range c.Request().Header.All() {
keyStr := string(key)
// Forward important headers (skip connection-specific ones)
if keyStr != "Connection" && keyStr != "Upgrade" &&
keyStr != "Sec-Websocket-Key" && keyStr != "Sec-Websocket-Version" &&
keyStr != "Sec-Websocket-Extensions" {
headers.Add(keyStr, string(value))
}
}
return websocket.New(func(clientConn *websocket.Conn) {
// Use background context for long-lived WebSocket connections
// The original request context expires after the upgrade
wsp.handleConnection(context.Background(), clientConn)
wsp.handleConnection(context.Background(), clientConn, headers)
})(c)
}
// handleConnection manages a single WebSocket connection
func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *websocket.Conn) {
func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *websocket.Conn, headers http.Header) {
connectionID := fmt.Sprintf("%p", clientConn)
startTime := time.Now()
@@ -117,8 +129,8 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
// Set message size limit
clientConn.SetReadLimit(wsp.maxMessageSize)
// Connect to backend WebSocket
backendConn, err := wsp.dialBackend(ctx)
// Connect to backend WebSocket with forwarded headers
backendConn, err := wsp.dialBackend(ctx, headers)
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
@@ -302,7 +314,7 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
}
// dialBackend establishes a WebSocket connection to the backend
func (wsp *WebSocketProxy) dialBackend(ctx context.Context) (*gorillaws.Conn, error) {
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
// Convert http:// to ws:// or https:// to wss://
wsURL := wsp.backendURL
if len(wsURL) > 7 && wsURL[:7] == "http://" {
@@ -311,13 +323,15 @@ func (wsp *WebSocketProxy) dialBackend(ctx context.Context) (*gorillaws.Conn, er
wsURL = "wss://" + wsURL[8:]
}
// Append GraphQL WebSocket path
wsURL = wsURL + "/v1/graphql"
// Use gorilla websocket dialer
dialer := gorillaws.Dialer{
HandshakeTimeout: 10 * time.Second,
}
// Dial the backend with proper headers
headers := http.Header{}
// Dial the backend with forwarded headers
conn, _, err := dialer.DialContext(ctx, wsURL, headers)
if err != nil {
return nil, fmt.Errorf("failed to dial backend WebSocket: %w", err)
+3 -1
View File
@@ -2,6 +2,7 @@ package main
import (
"context"
"net/http"
"testing"
"time"
@@ -167,7 +168,8 @@ func TestWebSocketProxy_DialBackend_URLConversion(t *testing.T) {
// We can't fully test dialBackend without a real WebSocket server,
// but we can verify the URL conversion logic
ctx := context.Background()
_, err := wsp.dialBackend(ctx)
headers := http.Header{}
_, err := wsp.dialBackend(ctx, headers)
// We expect an error since there's no server, but we verify the conversion happened
assert.Error(t, err) // Should fail to connect to non-existent server