Compare commits

...

4 Commits

3 changed files with 65 additions and 15 deletions
+39 -7
View File
@@ -69,19 +69,33 @@ func ensureDefaultLabels(labels *map[string]string, podName string) {
}
func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
keys := getSortedKeys(labels)
if len(labels) == 0 {
return
}
// Create a snapshot to avoid concurrent access issues
labelsCopy := make(map[string]string, len(labels))
for k, v := range labels {
labelsCopy[k] = v
}
keys := getSortedKeys(labelsCopy)
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(k)
buf.WriteString(`="`)
buf.WriteString(labels[k])
buf.WriteString(labelsCopy[k])
buf.WriteByte('"')
}
}
func getSortedKeys(labels map[string]string) []string {
if labels == nil {
return []string{}
}
labelsKey := labelsToString(labels)
// Check if the sorted keys are already cached
@@ -89,7 +103,7 @@ func getSortedKeys(labels map[string]string) []string {
return keys.([]string)
}
// Compute the sorted keys
// Compute the sorted keys - create a snapshot to avoid concurrent access issues
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
@@ -103,9 +117,17 @@ func getSortedKeys(labels map[string]string) []string {
}
func labelsToString(labels map[string]string) string {
if labels == nil {
return ""
}
// Create a snapshot of the map to avoid concurrent access issues
keys := make([]string, 0, len(labels))
for k := range labels {
values := make(map[string]string, len(labels))
for k, v := range labels {
keys = append(keys, k)
values[k] = v
}
sort.Strings(keys)
@@ -113,7 +135,7 @@ func labelsToString(labels map[string]string) string {
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(labels[k])
sb.WriteString(values[k])
sb.WriteByte(';')
}
return sb.String()
@@ -168,13 +190,23 @@ func compile_metrics_with_labels(name string, labels map[string]string) string {
buf.WriteString(name)
keys := getSortedKeys(labels)
if len(labels) == 0 {
return buf.String()
}
// Create a snapshot to avoid concurrent access issues
labelsCopy := make(map[string]string, len(labels))
for k, v := range labels {
labelsCopy[k] = v
}
keys := getSortedKeys(labelsCopy)
for _, k := range keys {
buf.WriteByte('_')
buf.WriteString(k)
buf.WriteByte('_')
buf.WriteString(labels[k])
buf.WriteString(labelsCopy[k])
}
return buf.String()
+23 -7
View File
@@ -88,13 +88,27 @@ func (wsp *WebSocketProxy) HandleWebSocket(c *fiber.Ctx) error {
return fiber.NewError(fiber.StatusUpgradeRequired, "WebSocket upgrade required")
}
// Capture headers from the upgrade request to forward to backend
headers := make(http.Header)
for key, value := range c.Request().Header.All() {
keyStr := string(key)
// Forward important headers (skip connection-specific ones)
if keyStr != "Connection" && keyStr != "Upgrade" &&
keyStr != "Sec-Websocket-Key" && keyStr != "Sec-Websocket-Version" &&
keyStr != "Sec-Websocket-Extensions" {
headers.Add(keyStr, string(value))
}
}
return websocket.New(func(clientConn *websocket.Conn) {
wsp.handleConnection(c.Context(), clientConn)
// Use background context for long-lived WebSocket connections
// The original request context expires after the upgrade
wsp.handleConnection(context.Background(), clientConn, headers)
})(c)
}
// handleConnection manages a single WebSocket connection
func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *websocket.Conn) {
func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *websocket.Conn, headers http.Header) {
connectionID := fmt.Sprintf("%p", clientConn)
startTime := time.Now()
@@ -115,8 +129,8 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
// Set message size limit
clientConn.SetReadLimit(wsp.maxMessageSize)
// Connect to backend WebSocket
backendConn, err := wsp.dialBackend(ctx)
// Connect to backend WebSocket with forwarded headers
backendConn, err := wsp.dialBackend(ctx, headers)
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
@@ -300,7 +314,7 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
}
// dialBackend establishes a WebSocket connection to the backend
func (wsp *WebSocketProxy) dialBackend(ctx context.Context) (*gorillaws.Conn, error) {
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
// Convert http:// to ws:// or https:// to wss://
wsURL := wsp.backendURL
if len(wsURL) > 7 && wsURL[:7] == "http://" {
@@ -309,13 +323,15 @@ func (wsp *WebSocketProxy) dialBackend(ctx context.Context) (*gorillaws.Conn, er
wsURL = "wss://" + wsURL[8:]
}
// Append GraphQL WebSocket path
wsURL = wsURL + "/v1/graphql"
// Use gorilla websocket dialer
dialer := gorillaws.Dialer{
HandshakeTimeout: 10 * time.Second,
}
// Dial the backend with proper headers
headers := http.Header{}
// Dial the backend with forwarded headers
conn, _, err := dialer.DialContext(ctx, wsURL, headers)
if err != nil {
return nil, fmt.Errorf("failed to dial backend WebSocket: %w", err)
+3 -1
View File
@@ -2,6 +2,7 @@ package main
import (
"context"
"net/http"
"testing"
"time"
@@ -167,7 +168,8 @@ func TestWebSocketProxy_DialBackend_URLConversion(t *testing.T) {
// We can't fully test dialBackend without a real WebSocket server,
// but we can verify the URL conversion logic
ctx := context.Background()
_, err := wsp.dialBackend(ctx)
headers := http.Header{}
_, err := wsp.dialBackend(ctx, headers)
// We expect an error since there's no server, but we verify the conversion happened
assert.Error(t, err) // Should fail to connect to non-existent server