Compare commits

...

2 Commits

2 changed files with 166 additions and 18 deletions
+75 -16
View File
@@ -69,25 +69,42 @@ func ensureDefaultLabels(labels *map[string]string, podName string) {
}
func appendSortedLabels(buf *bytes.Buffer, labels map[string]string) {
if len(labels) == 0 {
// Add defer/recover to prevent panics from crashing the application
defer func() {
if r := recover(); r != nil {
// Log the panic but don't crash
fmt.Fprintf(os.Stderr, "Recovered from panic in appendSortedLabels: %v\n", r)
}
}()
if len(labels) == 0 || buf == nil {
return
}
// Create a snapshot to avoid concurrent access issues
labelsCopy := make(map[string]string, len(labels))
for k, v := range labels {
if k == "" {
continue // Skip empty keys
}
labelsCopy[k] = v
}
if len(labelsCopy) == 0 {
return
}
keys := getSortedKeys(labelsCopy)
for i, k := range keys {
if i > 0 {
buf.WriteByte(',')
if v, ok := labelsCopy[k]; ok {
if i > 0 {
buf.WriteByte(',')
}
buf.WriteString(k)
buf.WriteString(`="`)
buf.WriteString(v)
buf.WriteByte('"')
}
buf.WriteString(k)
buf.WriteString(`="`)
buf.WriteString(labelsCopy[k])
buf.WriteByte('"')
}
}
@@ -117,7 +134,15 @@ func getSortedKeys(labels map[string]string) []string {
}
func labelsToString(labels map[string]string) string {
if labels == nil {
// Add defer/recover to prevent panics from crashing the application
defer func() {
if r := recover(); r != nil {
// Log the panic but don't crash
fmt.Fprintf(os.Stderr, "Recovered from panic in labelsToString: %v\n", r)
}
}()
if len(labels) == 0 {
return ""
}
@@ -126,17 +151,34 @@ func labelsToString(labels map[string]string) string {
values := make(map[string]string, len(labels))
for k, v := range labels {
if k == "" {
continue // Skip empty keys
}
keys = append(keys, k)
values[k] = v
}
if len(keys) == 0 {
return ""
}
sort.Strings(keys)
// Pre-allocate the builder with estimated capacity to avoid reallocation
var sb strings.Builder
estimatedSize := 0
for _, k := range keys {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(values[k])
sb.WriteByte(';')
estimatedSize += len(k) + len(values[k]) + 2 // key + value + '=' + ';'
}
sb.Grow(estimatedSize)
for _, k := range keys {
if v, ok := values[k]; ok {
sb.WriteString(k)
sb.WriteByte('=')
sb.WriteString(v)
sb.WriteByte(';')
}
}
return sb.String()
}
@@ -186,6 +228,14 @@ func is_special_rune(r rune) bool {
}
func compile_metrics_with_labels(name string, labels map[string]string) string {
// Add defer/recover to prevent panics from crashing the application
defer func() {
if r := recover(); r != nil {
// Log the panic but don't crash
fmt.Fprintf(os.Stderr, "Recovered from panic in compile_metrics_with_labels: %v\n", r)
}
}()
var buf bytes.Buffer
buf.WriteString(name)
@@ -197,16 +247,25 @@ func compile_metrics_with_labels(name string, labels map[string]string) string {
// Create a snapshot to avoid concurrent access issues
labelsCopy := make(map[string]string, len(labels))
for k, v := range labels {
if k == "" {
continue // Skip empty keys
}
labelsCopy[k] = v
}
if len(labelsCopy) == 0 {
return buf.String()
}
keys := getSortedKeys(labelsCopy)
for _, k := range keys {
buf.WriteByte('_')
buf.WriteString(k)
buf.WriteByte('_')
buf.WriteString(labelsCopy[k])
if v, ok := labelsCopy[k]; ok {
buf.WriteByte('_')
buf.WriteString(k)
buf.WriteByte('_')
buf.WriteString(v)
}
}
return buf.String()
+91 -2
View File
@@ -8,6 +8,7 @@ import (
"sync/atomic"
"time"
"github.com/goccy/go-json"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
gorillaws "github.com/gorilla/websocket"
@@ -141,8 +142,29 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
// Set message size limit
clientConn.SetReadLimit(wsp.maxMessageSize)
// Connect to backend WebSocket with forwarded headers
backendConn, err := wsp.dialBackend(ctx, headers)
// Read first message to extract authentication from connection_init payload
// This bridges the gap between clients that send auth in payload vs Hasura expecting it in HTTP headers
messageType, message, err := clientConn.ReadMessage()
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
wsp.logger.Error(&libpack_logger.LogMessage{
Message: "Failed to read first message from client",
Pairs: map[string]interface{}{
"connection_id": connectionID,
"error": err.Error(),
},
})
}
clientConn.Close()
return
}
// Try to extract headers from connection_init payload (for GraphQL WebSocket protocols)
enrichedHeaders := wsp.extractAuthFromPayload(message, headers)
// Connect to backend WebSocket with enriched headers
backendConn, err := wsp.dialBackend(ctx, enrichedHeaders)
if err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
@@ -159,6 +181,21 @@ func (wsp *WebSocketProxy) handleConnection(ctx context.Context, clientConn *web
}
defer backendConn.Close()
// Forward the first message (connection_init) to backend
if err := backendConn.WriteMessage(messageType, message); err != nil {
wsp.errors.Add(1)
if wsp.logger != nil {
wsp.logger.Error(&libpack_logger.LogMessage{
Message: "Failed to forward connection_init to backend",
Pairs: map[string]interface{}{
"connection_id": connectionID,
"error": err.Error(),
},
})
}
return
}
if wsp.logger != nil {
wsp.logger.Debug(&libpack_logger.LogMessage{
Message: "Backend WebSocket connection established",
@@ -336,6 +373,58 @@ func (wsp *WebSocketProxy) proxyBackendToClient(ctx context.Context, backend *go
}
}
// extractAuthFromPayload extracts authentication headers from GraphQL WebSocket connection_init payload
// This bridges the gap between clients sending auth in payload and Hasura expecting it in HTTP headers
func (wsp *WebSocketProxy) extractAuthFromPayload(message []byte, originalHeaders http.Header) http.Header {
// Create a copy of original headers
enrichedHeaders := make(http.Header)
for k, v := range originalHeaders {
enrichedHeaders[k] = v
}
// Try to parse as JSON to extract headers from payload
var msg map[string]interface{}
if err := json.Unmarshal(message, &msg); err != nil {
// Not JSON or parse error, return original headers
return enrichedHeaders
}
// Check if this is a connection_init message
msgType, ok := msg["type"].(string)
if !ok || (msgType != "connection_init" && msgType != "start") {
// Not a connection_init, return original headers
return enrichedHeaders
}
// Extract payload
payload, ok := msg["payload"].(map[string]interface{})
if !ok {
return enrichedHeaders
}
// Try to extract headers from payload.headers (graphql-ws format)
if payloadHeaders, ok := payload["headers"].(map[string]interface{}); ok {
for key, value := range payloadHeaders {
if strValue, ok := value.(string); ok {
enrichedHeaders.Set(key, strValue)
}
}
}
// Also check top-level payload keys that look like headers (Apollo format)
for key, value := range payload {
if strValue, ok := value.(string); ok {
// Common auth headers
if key == "Authorization" || key == "authorization" ||
key == "x-hasura-role" || key == "x-hasura-admin-secret" {
enrichedHeaders.Set(key, strValue)
}
}
}
return enrichedHeaders
}
// dialBackend establishes a WebSocket connection to the backend
func (wsp *WebSocketProxy) dialBackend(ctx context.Context, headers http.Header) (*gorillaws.Conn, error) {
// Convert http:// to ws:// or https:// to wss://