Files
kportal/internal/k8s/portforward.go
T
lukaszraczylo 7a33e01863 fix: 4 P0 concurrency races in forward + k8s
P0 #2 — currentConfig data race
  Manager.currentConfig was written without locking in Start/Reload but
  read from the health-checker callback goroutine. All accesses now go
  through workersMu (read or write as appropriate).

P0 #3 — Reload kills health checker permanently
  Reload's zero-forward branch called m.Stop() which tore down the
  health checker, watchdog, and event bus. After that, EnableForward
  silently registered callbacks against dead components. Now the branch
  stops only the running workers; the supervisory infrastructure stays
  alive across config changes.

P0 #4 — rest.Config write-write race
  executePortForward was mutating .Dial on the cached *rest.Config
  shared by all forwards in the same kube context. Cloning the config
  with rest.CopyConfig before mutation isolates per-forward dialers.

P0 #5 — ForwardWorker.Stop() double-close panic
  close(w.stopChan) is now wrapped in sync.Once, so concurrent Stop
  calls (Manager.Stop racing stopWorkerInternal) are safe.

New tests in internal/forward/concurrency_test.go exercise each fix
under -race: 16 concurrent worker Stops, repeated sequential Stops,
empty-Reload preserves infra pointers, and concurrent currentConfig
read/write.
2026-05-06 10:45:10 +01:00

293 lines
8.8 KiB
Go

package k8s
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/lukaszraczylo/kportal/internal/config"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
)
// PortForwarder handles Kubernetes port-forwarding operations.
type PortForwarder struct {
clientPool *ClientPool
resolver *ResourceResolver
tcpKeepalive time.Duration // TCP keepalive interval
dialTimeout time.Duration // Connection dial timeout
}
// NewPortForwarder creates a new PortForwarder instance with default settings.
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
return &PortForwarder{
clientPool: clientPool,
resolver: resolver,
tcpKeepalive: config.DefaultTCPKeepalive,
dialTimeout: config.DefaultDialTimeout,
}
}
// SetTCPKeepalive configures the TCP keepalive interval for new connections.
func (pf *PortForwarder) SetTCPKeepalive(keepalive time.Duration) {
pf.tcpKeepalive = keepalive
}
// SetDialTimeout configures the connection dial timeout.
func (pf *PortForwarder) SetDialTimeout(timeout time.Duration) {
pf.dialTimeout = timeout
}
// ForwardRequest contains the parameters for a port-forward request.
type ForwardRequest struct {
Out io.Writer
ErrOut io.Writer
StopChan chan struct{}
ReadyChan chan struct{}
ContextName string
Namespace string
Resource string
Selector string
LocalPort int
RemotePort int
}
// Forward establishes a port-forward connection to a Kubernetes resource.
// It supports both pod and service forwarding.
// The connection runs until StopChan is closed or an error occurs.
func (pf *PortForwarder) Forward(ctx context.Context, req *ForwardRequest) error {
// Resolve the resource to an actual pod name
resolvedResource, err := pf.resolver.Resolve(ctx, req.ContextName, req.Namespace, req.Resource, req.Selector)
if err != nil {
return fmt.Errorf("failed to resolve resource: %w", err)
}
// Parse the resolved resource
parts := strings.SplitN(resolvedResource, "/", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid resolved resource format: %s", resolvedResource)
}
resourceType := parts[0]
resourceName := parts[1]
// Handle different resource types
switch resourceType {
case "pod":
return pf.forwardToPod(ctx, req, resourceName)
case "service":
return pf.forwardToService(ctx, req, resourceName)
default:
return fmt.Errorf("unsupported resource type: %s", resourceType)
}
}
// forwardToPod establishes a port-forward to a specific pod.
func (pf *PortForwarder) forwardToPod(ctx context.Context, req *ForwardRequest, podName string) error {
// Get Kubernetes client and config
client, err := pf.clientPool.GetClient(req.ContextName)
if err != nil {
return fmt.Errorf("failed to get client: %w", err)
}
config, err := pf.clientPool.GetRestConfig(req.ContextName)
if err != nil {
return fmt.Errorf("failed to get rest config: %w", err)
}
// Verify pod exists and is running
pod, err := client.CoreV1().Pods(req.Namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get pod: %w", err)
}
if pod.Status.Phase != corev1.PodRunning {
return fmt.Errorf("pod is not running (current phase: %s)", pod.Status.Phase)
}
// Build the port-forward URL
reqURL := client.CoreV1().RESTClient().Post().
Resource("pods").
Namespace(req.Namespace).
Name(podName).
SubResource("portforward").
URL()
// Create the port-forward
return pf.executePortForward(config, reqURL, req)
}
// forwardToService establishes a port-forward to a service.
// This resolves the service to its backing pods and forwards to one of them.
func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardRequest, serviceName string) error {
// Get Kubernetes client
client, err := pf.clientPool.GetClient(req.ContextName)
if err != nil {
return fmt.Errorf("failed to get client: %w", err)
}
// Get the service
service, err := client.CoreV1().Services(req.Namespace).Get(ctx, serviceName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get service: %w", err)
}
// Get pods backing the service using label selector
if len(service.Spec.Selector) == 0 {
return fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", serviceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(req.Namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
})
if err != nil {
return fmt.Errorf("failed to list pods for service: %w", err)
}
// Find first running pod
var targetPod *corev1.Pod
for i := range pods.Items {
pod := &pods.Items[i]
if pod.Status.Phase == corev1.PodRunning {
targetPod = pod
break
}
}
if targetPod == nil {
return fmt.Errorf("no running pods found for service %s", serviceName)
}
// Forward to the pod
config, err := pf.clientPool.GetRestConfig(req.ContextName)
if err != nil {
return fmt.Errorf("failed to get rest config: %w", err)
}
reqURL := client.CoreV1().RESTClient().Post().
Resource("pods").
Namespace(req.Namespace).
Name(targetPod.Name).
SubResource("portforward").
URL()
return pf.executePortForward(config, reqURL, req)
}
// executePortForward performs the actual port-forward operation.
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
// Clone the rest.Config before mutating. ClientPool.GetRestConfig returns a
// cached pointer shared across all forwards on the same context; mutating
// config.Dial directly causes a write-write race when multiple forwards
// run concurrently against the same context.
cfg := rest.CopyConfig(config)
// Configure TCP settings on the underlying connection
// This is set in the rest.Config which will be used by the SPDY transport
if cfg.Dial == nil {
// Create a custom dialer with configurable timeout and keepalive
// - Timeout: How long to wait for connection to establish
// - KeepAlive: TCP keepalive helps OS detect dead connections at network layer
dialer := &net.Dialer{
Timeout: pf.dialTimeout, // Configurable dial timeout
KeepAlive: pf.tcpKeepalive, // Configurable keepalive interval
}
cfg.Dial = dialer.DialContext
}
// Create SPDY roundtripper
transport, upgrader, err := spdy.RoundTripperFor(cfg)
if err != nil {
return fmt.Errorf("failed to create round tripper: %w", err)
}
// Create dialer
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, url)
// Set up port forwarding
ports := []string{fmt.Sprintf("%d:%d", req.LocalPort, req.RemotePort)}
// Create output writers
out := req.Out
errOut := req.ErrOut
if out == nil {
out = io.Discard
}
if errOut == nil {
errOut = io.Discard
}
// Create port forwarder
fw, err := portforward.New(dialer, ports, req.StopChan, req.ReadyChan, out, errOut)
if err != nil {
return fmt.Errorf("failed to create port forwarder: %w", err)
}
// Start forwarding (blocks until stopped or error)
if err := fw.ForwardPorts(); err != nil {
return fmt.Errorf("port forward failed: %w", err)
}
return nil
}
// GetPodForResource returns the pod name that would be used for forwarding.
// This is useful for logging and debugging.
func (pf *PortForwarder) GetPodForResource(ctx context.Context, contextName, namespace, resource, selector string) (string, error) {
resolvedResource, err := pf.resolver.Resolve(ctx, contextName, namespace, resource, selector)
if err != nil {
return "", err
}
parts := strings.SplitN(resolvedResource, "/", 2)
if len(parts) != 2 {
return "", fmt.Errorf("invalid resolved resource format: %s", resolvedResource)
}
resourceType := parts[0]
resourceName := parts[1]
if resourceType == "service" {
// For services, need to resolve to backing pod
client, err := pf.clientPool.GetClient(contextName)
if err != nil {
return "", fmt.Errorf("failed to get client: %w", err)
}
service, err := client.CoreV1().Services(namespace).Get(ctx, resourceName, metav1.GetOptions{})
if err != nil {
return "", fmt.Errorf("failed to get service: %w", err)
}
if len(service.Spec.Selector) == 0 {
return "", fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", resourceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
})
if err != nil {
return "", fmt.Errorf("failed to list pods: %w", err)
}
for i := range pods.Items {
if pods.Items[i].Status.Phase == corev1.PodRunning {
return pods.Items[i].Name, nil
}
}
return "", fmt.Errorf("no running pods found for service")
}
return resourceName, nil
}