mirror of
https://github.com/lukaszraczylo/kportal.git
synced 2026-07-02 05:45:42 +00:00
WIP - before the testing.
This commit is contained in:
@@ -0,0 +1,298 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/nvm/kportal/internal/config"
|
||||
"github.com/nvm/kportal/internal/k8s"
|
||||
)
|
||||
|
||||
// Manager orchestrates all port-forward workers.
|
||||
// It handles starting, stopping, and hot-reloading forwards.
|
||||
type Manager struct {
|
||||
workers map[string]*ForwardWorker // key: forward.ID()
|
||||
workersMu sync.RWMutex
|
||||
clientPool *k8s.ClientPool
|
||||
resolver *k8s.ResourceResolver
|
||||
portForwarder *k8s.PortForwarder
|
||||
portChecker *PortChecker
|
||||
verbose bool
|
||||
currentConfig *config.Config
|
||||
}
|
||||
|
||||
// NewManager creates a new forward Manager.
|
||||
func NewManager(verbose bool) *Manager {
|
||||
clientPool, err := k8s.NewClientPool()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create client pool: %v", err)
|
||||
}
|
||||
|
||||
resolver := k8s.NewResourceResolver(clientPool)
|
||||
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
|
||||
|
||||
return &Manager{
|
||||
workers: make(map[string]*ForwardWorker),
|
||||
clientPool: clientPool,
|
||||
resolver: resolver,
|
||||
portForwarder: portForwarder,
|
||||
portChecker: NewPortChecker(),
|
||||
verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// Start initializes and starts all port-forwards from the configuration.
|
||||
func (m *Manager) Start(cfg *config.Config) error {
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("configuration is nil")
|
||||
}
|
||||
|
||||
m.currentConfig = cfg
|
||||
|
||||
// Get all forwards from config
|
||||
forwards := cfg.GetAllForwards()
|
||||
|
||||
if len(forwards) == 0 {
|
||||
return fmt.Errorf("no forwards configured")
|
||||
}
|
||||
|
||||
// Check port availability before starting
|
||||
ports := m.extractPorts(forwards)
|
||||
conflicts := m.portChecker.CheckAvailability(ports, nil)
|
||||
if len(conflicts) > 0 {
|
||||
// Add resource information to conflicts
|
||||
for i := range conflicts {
|
||||
conflicts[i].Resource = m.getResourceForPort(forwards, conflicts[i].Port)
|
||||
}
|
||||
return fmt.Errorf("port conflicts detected:\n%s", FormatConflicts(conflicts))
|
||||
}
|
||||
|
||||
// Start all workers
|
||||
log.Printf("Starting %d port-forward(s)...", len(forwards))
|
||||
|
||||
for _, fwd := range forwards {
|
||||
if err := m.startWorker(fwd); err != nil {
|
||||
log.Printf("Failed to start worker for %s: %v", fwd.ID(), err)
|
||||
// Continue with other workers
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("All port-forwards started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops all port-forward workers.
|
||||
func (m *Manager) Stop() {
|
||||
log.Printf("Stopping all port-forwards...")
|
||||
|
||||
m.workersMu.Lock()
|
||||
workers := make([]*ForwardWorker, 0, len(m.workers))
|
||||
for _, worker := range m.workers {
|
||||
workers = append(workers, worker)
|
||||
}
|
||||
m.workersMu.Unlock()
|
||||
|
||||
// Stop all workers
|
||||
var wg sync.WaitGroup
|
||||
for _, worker := range workers {
|
||||
wg.Add(1)
|
||||
go func(w *ForwardWorker) {
|
||||
defer wg.Done()
|
||||
w.Stop()
|
||||
}(worker)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Clear workers map
|
||||
m.workersMu.Lock()
|
||||
m.workers = make(map[string]*ForwardWorker)
|
||||
m.workersMu.Unlock()
|
||||
|
||||
log.Printf("All port-forwards stopped")
|
||||
}
|
||||
|
||||
// Reload applies a new configuration with hot-reload logic.
|
||||
// It diffs the new config against the current one and:
|
||||
// - Stops removed forwards
|
||||
// - Keeps unchanged forwards running
|
||||
// - Starts new forwards
|
||||
func (m *Manager) Reload(newCfg *config.Config) error {
|
||||
if newCfg == nil {
|
||||
return fmt.Errorf("new configuration is nil")
|
||||
}
|
||||
|
||||
log.Printf("Reloading configuration...")
|
||||
|
||||
// Get all forwards from new config
|
||||
newForwards := newCfg.GetAllForwards()
|
||||
|
||||
if len(newForwards) == 0 {
|
||||
log.Printf("New configuration has no forwards, stopping all")
|
||||
m.Stop()
|
||||
m.currentConfig = newCfg
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create maps for easier comparison
|
||||
newForwardsMap := make(map[string]config.Forward)
|
||||
for _, fwd := range newForwards {
|
||||
newForwardsMap[fwd.ID()] = fwd
|
||||
}
|
||||
|
||||
m.workersMu.RLock()
|
||||
currentForwardsMap := make(map[string]config.Forward)
|
||||
for id, worker := range m.workers {
|
||||
currentForwardsMap[id] = worker.GetForward()
|
||||
}
|
||||
m.workersMu.RUnlock()
|
||||
|
||||
// Determine changes
|
||||
var toAdd []config.Forward
|
||||
var toRemove []string
|
||||
var toKeep []string
|
||||
|
||||
// Find forwards to add and keep
|
||||
for id, fwd := range newForwardsMap {
|
||||
if _, exists := currentForwardsMap[id]; exists {
|
||||
toKeep = append(toKeep, id)
|
||||
} else {
|
||||
toAdd = append(toAdd, fwd)
|
||||
}
|
||||
}
|
||||
|
||||
// Find forwards to remove
|
||||
for id := range currentForwardsMap {
|
||||
if _, exists := newForwardsMap[id]; !exists {
|
||||
toRemove = append(toRemove, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Check port availability for new forwards
|
||||
if len(toAdd) > 0 {
|
||||
// Get currently managed ports to skip in availability check
|
||||
managedPorts := make(map[int]bool)
|
||||
for _, id := range toKeep {
|
||||
managedPorts[currentForwardsMap[id].LocalPort] = true
|
||||
}
|
||||
|
||||
// Check new ports
|
||||
newPorts := m.extractPorts(toAdd)
|
||||
conflicts := m.portChecker.CheckAvailability(newPorts, managedPorts)
|
||||
if len(conflicts) > 0 {
|
||||
// Add resource information to conflicts
|
||||
for i := range conflicts {
|
||||
conflicts[i].Resource = m.getResourceForPort(toAdd, conflicts[i].Port)
|
||||
}
|
||||
log.Printf("Config change rejected due to port conflicts:\n%s", FormatConflicts(conflicts))
|
||||
log.Printf("Keeping previous configuration active")
|
||||
return fmt.Errorf("port conflicts detected")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply changes
|
||||
log.Printf("Configuration diff: %d to add, %d to remove, %d to keep",
|
||||
len(toAdd), len(toRemove), len(toKeep))
|
||||
|
||||
// Stop removed forwards
|
||||
for _, id := range toRemove {
|
||||
if err := m.stopWorker(id); err != nil {
|
||||
log.Printf("Failed to stop worker %s: %v", id, err)
|
||||
} else {
|
||||
log.Printf("Stopped: %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
// Start new forwards
|
||||
for _, fwd := range toAdd {
|
||||
if err := m.startWorker(fwd); err != nil {
|
||||
log.Printf("Failed to start worker for %s: %v", fwd.ID(), err)
|
||||
} else {
|
||||
log.Printf("Started: %s", fwd.ID())
|
||||
}
|
||||
}
|
||||
|
||||
// Update current config
|
||||
m.currentConfig = newCfg
|
||||
|
||||
log.Printf("Configuration reloaded successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// startWorker creates and starts a new forward worker.
|
||||
func (m *Manager) startWorker(fwd config.Forward) error {
|
||||
m.workersMu.Lock()
|
||||
defer m.workersMu.Unlock()
|
||||
|
||||
// Check if worker already exists
|
||||
if _, exists := m.workers[fwd.ID()]; exists {
|
||||
return fmt.Errorf("worker already exists for %s", fwd.ID())
|
||||
}
|
||||
|
||||
// Create and start worker
|
||||
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose)
|
||||
worker.Start()
|
||||
|
||||
// Store worker
|
||||
m.workers[fwd.ID()] = worker
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopWorker stops and removes a forward worker.
|
||||
func (m *Manager) stopWorker(id string) error {
|
||||
m.workersMu.Lock()
|
||||
worker, exists := m.workers[id]
|
||||
if !exists {
|
||||
m.workersMu.Unlock()
|
||||
return fmt.Errorf("worker not found: %s", id)
|
||||
}
|
||||
delete(m.workers, id)
|
||||
m.workersMu.Unlock()
|
||||
|
||||
// Stop the worker
|
||||
worker.Stop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveForwards returns a list of all active forward IDs.
|
||||
func (m *Manager) GetActiveForwards() []string {
|
||||
m.workersMu.RLock()
|
||||
defer m.workersMu.RUnlock()
|
||||
|
||||
ids := make([]string, 0, len(m.workers))
|
||||
for id := range m.workers {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// GetWorkerCount returns the number of active workers.
|
||||
func (m *Manager) GetWorkerCount() int {
|
||||
m.workersMu.RLock()
|
||||
defer m.workersMu.RUnlock()
|
||||
|
||||
return len(m.workers)
|
||||
}
|
||||
|
||||
// extractPorts extracts all local ports from a list of forwards.
|
||||
func (m *Manager) extractPorts(forwards []config.Forward) []int {
|
||||
ports := make([]int, len(forwards))
|
||||
for i, fwd := range forwards {
|
||||
ports[i] = fwd.LocalPort
|
||||
}
|
||||
return ports
|
||||
}
|
||||
|
||||
// getResourceForPort finds the resource (forward ID) that uses a given port.
|
||||
func (m *Manager) getResourceForPort(forwards []config.Forward, port int) string {
|
||||
for _, fwd := range forwards {
|
||||
if fwd.LocalPort == port {
|
||||
return fwd.ID()
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PortConflict represents a local port that is already in use.
|
||||
type PortConflict struct {
|
||||
Port int // The conflicting port number
|
||||
Resource string // The forward resource that needs this port
|
||||
UsedBy string // Process information (PID, command) using the port
|
||||
}
|
||||
|
||||
// PortChecker checks port availability on the local system.
|
||||
type PortChecker struct{}
|
||||
|
||||
// NewPortChecker creates a new PortChecker instance.
|
||||
func NewPortChecker() *PortChecker {
|
||||
return &PortChecker{}
|
||||
}
|
||||
|
||||
// CheckAvailability checks if the given ports are available for binding.
|
||||
// It returns a list of conflicts for ports that are already in use.
|
||||
// The skipPorts map contains ports currently managed by kportal that should be excluded from the check.
|
||||
func (pc *PortChecker) CheckAvailability(ports []int, skipPorts map[int]bool) []PortConflict {
|
||||
var conflicts []PortConflict
|
||||
|
||||
for _, port := range ports {
|
||||
// Skip ports that are already managed by kportal
|
||||
if skipPorts[port] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to bind to the port
|
||||
if !pc.isPortAvailable(port) {
|
||||
// Port is in use, get process info
|
||||
usedBy := pc.getProcessUsingPort(port)
|
||||
conflicts = append(conflicts, PortConflict{
|
||||
Port: port,
|
||||
UsedBy: usedBy,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return conflicts
|
||||
}
|
||||
|
||||
// isPortAvailable checks if a port is available by attempting to bind to it.
|
||||
func (pc *PortChecker) isPortAvailable(port int) bool {
|
||||
// Try to listen on the port
|
||||
addr := fmt.Sprintf(":%d", port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
listener.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
// getProcessUsingPort returns information about the process using the given port.
|
||||
// Returns a string like "nginx (PID 1234)" or "unknown" if the process cannot be determined.
|
||||
func (pc *PortChecker) getProcessUsingPort(port int) string {
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "linux":
|
||||
return pc.getProcessUsingPortUnix(port)
|
||||
case "windows":
|
||||
return pc.getProcessUsingPortWindows(port)
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// getProcessUsingPortUnix uses lsof to find the process using a port on Unix systems.
|
||||
func (pc *PortChecker) getProcessUsingPortUnix(port int) string {
|
||||
// Use lsof to find the process
|
||||
// lsof -i :PORT -sTCP:LISTEN -t returns PIDs
|
||||
cmd := exec.Command("lsof", "-i", fmt.Sprintf(":%d", port), "-sTCP:LISTEN", "-t")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
pidStr := strings.TrimSpace(string(output))
|
||||
if pidStr == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Get the first PID if multiple are returned
|
||||
pids := strings.Split(pidStr, "\n")
|
||||
pid := pids[0]
|
||||
|
||||
// Get process name using ps
|
||||
cmd = exec.Command("ps", "-p", pid, "-o", "comm=")
|
||||
output, err = cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("PID %s", pid)
|
||||
}
|
||||
|
||||
procName := strings.TrimSpace(string(output))
|
||||
if procName == "" {
|
||||
return fmt.Sprintf("PID %s", pid)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (PID %s)", procName, pid)
|
||||
}
|
||||
|
||||
// getProcessUsingPortWindows uses netstat to find the process using a port on Windows.
|
||||
func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
|
||||
// Use netstat to find the process
|
||||
// netstat -ano | findstr :PORT
|
||||
cmd := exec.Command("netstat", "-ano")
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
lines := strings.Split(string(output), "\n")
|
||||
portStr := fmt.Sprintf(":%d", port)
|
||||
|
||||
for _, line := range lines {
|
||||
if !strings.Contains(line, portStr) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the line to extract PID
|
||||
// Format: TCP 0.0.0.0:8080 0.0.0.0:0 LISTENING 1234
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) < 5 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a LISTENING state
|
||||
if !strings.Contains(strings.ToUpper(line), "LISTENING") {
|
||||
continue
|
||||
}
|
||||
|
||||
pid := fields[len(fields)-1]
|
||||
|
||||
// Get process name using tasklist
|
||||
cmd = exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
|
||||
output, err = cmd.Output()
|
||||
if err != nil {
|
||||
return fmt.Sprintf("PID %s", pid)
|
||||
}
|
||||
|
||||
// Parse CSV output: "process.exe","1234","Console","1","12,345 K"
|
||||
csvLine := strings.TrimSpace(string(output))
|
||||
if csvLine == "" {
|
||||
return fmt.Sprintf("PID %s", pid)
|
||||
}
|
||||
|
||||
parts := strings.Split(csvLine, ",")
|
||||
if len(parts) > 0 {
|
||||
procName := strings.Trim(parts[0], "\"")
|
||||
return fmt.Sprintf("%s (PID %s)", procName, pid)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("PID %s", pid)
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// FormatConflicts formats port conflicts into a human-readable error message.
|
||||
func FormatConflicts(conflicts []PortConflict) string {
|
||||
if len(conflicts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("\nPort Conflicts Detected:\n")
|
||||
sb.WriteString(strings.Repeat("=", 50) + "\n\n")
|
||||
|
||||
for _, conflict := range conflicts {
|
||||
sb.WriteString(fmt.Sprintf("Port %d\n", conflict.Port))
|
||||
if conflict.Resource != "" {
|
||||
sb.WriteString(fmt.Sprintf(" Needed for: %s\n", conflict.Resource))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" Currently used by: %s\n", conflict.UsedBy))
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
sb.WriteString("Action: Stop conflicting processes or change localPort in config.\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// GetPortsFromForwards extracts all local ports from a list of forward configurations.
|
||||
func GetPortsFromForwards(forwards []interface{}) []int {
|
||||
ports := make([]int, 0, len(forwards))
|
||||
for _, fwd := range forwards {
|
||||
// This function expects a generic interface to work with different forward types
|
||||
// The actual implementation should use the Forward struct from config package
|
||||
if f, ok := fwd.(interface{ GetLocalPort() int }); ok {
|
||||
ports = append(ports, f.GetLocalPort())
|
||||
}
|
||||
}
|
||||
return ports
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestPortChecker_IsAvailable(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Test that isPortAvailable returns a bool
|
||||
// We use a high port that's likely to be available
|
||||
result := pc.isPortAvailable(54321)
|
||||
assert.IsType(t, false, result, "isPortAvailable should return bool")
|
||||
}
|
||||
|
||||
func TestPortChecker_CheckAvailability_EmptyPorts(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Test with empty ports slice
|
||||
conflicts := pc.CheckAvailability([]int{}, nil)
|
||||
assert.Empty(t, conflicts, "should return empty conflicts for empty ports")
|
||||
|
||||
// Test with nil exclude map
|
||||
conflicts = pc.CheckAvailability([]int{}, nil)
|
||||
assert.Empty(t, conflicts, "should return empty conflicts for nil exclude map")
|
||||
}
|
||||
|
||||
func TestPortChecker_CheckAvailability_ExcludeMap(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Create a listener to occupy a port
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
assert.NoError(t, err, "should create listener")
|
||||
defer listener.Close()
|
||||
|
||||
// Get the port that's now occupied
|
||||
addr := listener.Addr().(*net.TCPAddr)
|
||||
occupiedPort := addr.Port
|
||||
|
||||
// Test that the occupied port shows as conflicted
|
||||
conflicts := pc.CheckAvailability([]int{occupiedPort}, nil)
|
||||
assert.Len(t, conflicts, 1, "should detect conflict for occupied port")
|
||||
assert.Equal(t, occupiedPort, conflicts[0].Port)
|
||||
|
||||
// Test that skipPorts map excludes the port from conflict detection
|
||||
skipPorts := map[int]bool{
|
||||
occupiedPort: true,
|
||||
}
|
||||
conflicts = pc.CheckAvailability([]int{occupiedPort}, skipPorts)
|
||||
assert.Empty(t, conflicts, "should skip ports in exclude map")
|
||||
}
|
||||
|
||||
func TestPortChecker_CheckAvailability_MultipleSkipPorts(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Create multiple listeners
|
||||
listener1, err := net.Listen("tcp", ":0")
|
||||
assert.NoError(t, err)
|
||||
defer listener1.Close()
|
||||
|
||||
listener2, err := net.Listen("tcp", ":0")
|
||||
assert.NoError(t, err)
|
||||
defer listener2.Close()
|
||||
|
||||
port1 := listener1.Addr().(*net.TCPAddr).Port
|
||||
port2 := listener2.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// Test with both ports occupied
|
||||
conflicts := pc.CheckAvailability([]int{port1, port2}, nil)
|
||||
assert.Len(t, conflicts, 2, "should detect both conflicts")
|
||||
|
||||
// Test excluding one port
|
||||
skipPorts := map[int]bool{port1: true}
|
||||
conflicts = pc.CheckAvailability([]int{port1, port2}, skipPorts)
|
||||
assert.Len(t, conflicts, 1, "should detect only non-excluded port")
|
||||
assert.Equal(t, port2, conflicts[0].Port)
|
||||
|
||||
// Test excluding both ports
|
||||
skipPorts = map[int]bool{port1: true, port2: true}
|
||||
conflicts = pc.CheckAvailability([]int{port1, port2}, skipPorts)
|
||||
assert.Empty(t, conflicts, "should skip all excluded ports")
|
||||
}
|
||||
|
||||
func TestPortChecker_GetProcessInfo(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Test that getProcessUsingPort returns a string
|
||||
// We don't test actual process detection to avoid flakiness
|
||||
result := pc.getProcessUsingPort(12345)
|
||||
assert.IsType(t, "", result, "getProcessUsingPort should return string")
|
||||
assert.NotEmpty(t, result, "should return some string (even if 'unknown')")
|
||||
}
|
||||
|
||||
func TestFormatConflicts_Empty(t *testing.T) {
|
||||
// Test with empty conflicts
|
||||
output := FormatConflicts([]PortConflict{})
|
||||
assert.Empty(t, output, "should return empty string for no conflicts")
|
||||
}
|
||||
|
||||
func TestFormatConflicts_SingleConflict(t *testing.T) {
|
||||
conflicts := []PortConflict{
|
||||
{
|
||||
Port: 8080,
|
||||
Resource: "dev/default/pod/my-app:8080",
|
||||
UsedBy: "nginx (PID 1234)",
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatConflicts(conflicts)
|
||||
assert.NotEmpty(t, output, "should return non-empty output")
|
||||
assert.Contains(t, output, "Port Conflicts Detected", "should contain header")
|
||||
assert.Contains(t, output, "Port 8080", "should contain port number")
|
||||
assert.Contains(t, output, "dev/default/pod/my-app:8080", "should contain resource")
|
||||
assert.Contains(t, output, "nginx (PID 1234)", "should contain process info")
|
||||
}
|
||||
|
||||
func TestFormatConflicts_MultipleConflicts(t *testing.T) {
|
||||
conflicts := []PortConflict{
|
||||
{
|
||||
Port: 8080,
|
||||
Resource: "dev/default/pod/app1:8080",
|
||||
UsedBy: "nginx (PID 1234)",
|
||||
},
|
||||
{
|
||||
Port: 5432,
|
||||
Resource: "prod/database/service/postgres:5432",
|
||||
UsedBy: "postgres (PID 5678)",
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatConflicts(conflicts)
|
||||
assert.NotEmpty(t, output, "should return non-empty output")
|
||||
assert.Contains(t, output, "Port Conflicts Detected", "should contain header")
|
||||
assert.Contains(t, output, "Port 8080", "should contain first port")
|
||||
assert.Contains(t, output, "Port 5432", "should contain second port")
|
||||
assert.Contains(t, output, "nginx (PID 1234)", "should contain first process")
|
||||
assert.Contains(t, output, "postgres (PID 5678)", "should contain second process")
|
||||
assert.Contains(t, output, "Action:", "should contain action message")
|
||||
}
|
||||
|
||||
func TestFormatConflicts_WithoutResource(t *testing.T) {
|
||||
conflicts := []PortConflict{
|
||||
{
|
||||
Port: 8080,
|
||||
UsedBy: "nginx (PID 1234)",
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatConflicts(conflicts)
|
||||
assert.NotEmpty(t, output, "should return non-empty output")
|
||||
assert.Contains(t, output, "Port 8080", "should contain port")
|
||||
assert.Contains(t, output, "nginx (PID 1234)", "should contain process info")
|
||||
// Should not crash or include empty "Needed for:" line
|
||||
assert.NotContains(t, output, "Needed for: \n", "should not have empty resource line")
|
||||
}
|
||||
|
||||
func TestPortConflict_Structure(t *testing.T) {
|
||||
// Test that PortConflict structure works correctly
|
||||
conflict := PortConflict{
|
||||
Port: 8080,
|
||||
Resource: "dev/default/pod/app:8080",
|
||||
UsedBy: "nginx (PID 1234)",
|
||||
}
|
||||
|
||||
assert.Equal(t, 8080, conflict.Port)
|
||||
assert.Equal(t, "dev/default/pod/app:8080", conflict.Resource)
|
||||
assert.Equal(t, "nginx (PID 1234)", conflict.UsedBy)
|
||||
}
|
||||
|
||||
func TestNewPortChecker(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
assert.NotNil(t, pc, "NewPortChecker should return non-nil instance")
|
||||
}
|
||||
|
||||
func TestPortChecker_PortAvailability_Integration(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Create a listener to occupy a port
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
assert.NoError(t, err, "should create listener")
|
||||
defer listener.Close()
|
||||
|
||||
// Get the occupied port
|
||||
occupiedPort := listener.Addr().(*net.TCPAddr).Port
|
||||
|
||||
// Test that the port is correctly detected as unavailable
|
||||
available := pc.isPortAvailable(occupiedPort)
|
||||
assert.False(t, available, "occupied port should not be available")
|
||||
|
||||
// Close the listener
|
||||
listener.Close()
|
||||
|
||||
// The port should now be available (though there might be a brief delay)
|
||||
// We don't assert this to avoid flakiness in CI environments
|
||||
}
|
||||
|
||||
func TestPortChecker_CheckAvailability_AvailablePorts(t *testing.T) {
|
||||
pc := NewPortChecker()
|
||||
|
||||
// Use high port numbers that are very unlikely to be in use
|
||||
// This test might be slightly flaky in unusual environments, but should be stable
|
||||
unlikelyPorts := []int{54321, 54322, 54323}
|
||||
|
||||
conflicts := pc.CheckAvailability(unlikelyPorts, nil)
|
||||
|
||||
// Most likely all ports will be available
|
||||
// The function returns nil or empty slice when there are no conflicts
|
||||
// We just verify the function executes without panicking
|
||||
_ = conflicts
|
||||
}
|
||||
|
||||
func TestFormatConflicts_Formatting(t *testing.T) {
|
||||
conflicts := []PortConflict{
|
||||
{
|
||||
Port: 8080,
|
||||
Resource: "dev/default/pod/my-app:8080",
|
||||
UsedBy: "nginx (PID 1234)",
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatConflicts(conflicts)
|
||||
|
||||
// Check formatting details
|
||||
assert.Contains(t, output, "==================================================", "should contain separator line")
|
||||
assert.Contains(t, output, "\n", "should contain newlines")
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
package forward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/nvm/kportal/internal/config"
|
||||
"github.com/nvm/kportal/internal/k8s"
|
||||
"github.com/nvm/kportal/internal/retry"
|
||||
)
|
||||
|
||||
// ForwardWorker manages a single port-forward connection with automatic retry.
|
||||
type ForwardWorker struct {
|
||||
forward config.Forward
|
||||
portForwarder *k8s.PortForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
stopChan chan struct{}
|
||||
doneChan chan struct{}
|
||||
verbose bool
|
||||
lastPod string // Track the last pod we connected to
|
||||
}
|
||||
|
||||
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
|
||||
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool) *ForwardWorker {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &ForwardWorker{
|
||||
forward: fwd,
|
||||
portForwarder: portForwarder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
stopChan: make(chan struct{}),
|
||||
doneChan: make(chan struct{}),
|
||||
verbose: verbose,
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the port-forward worker in a goroutine.
|
||||
// The worker will continuously retry on failures with exponential backoff.
|
||||
func (w *ForwardWorker) Start() {
|
||||
go w.run()
|
||||
}
|
||||
|
||||
// Stop gracefully stops the port-forward worker.
|
||||
func (w *ForwardWorker) Stop() {
|
||||
w.cancel()
|
||||
close(w.stopChan)
|
||||
<-w.doneChan // Wait for worker to finish
|
||||
}
|
||||
|
||||
// run is the main worker loop that handles retries.
|
||||
func (w *ForwardWorker) run() {
|
||||
defer close(w.doneChan)
|
||||
|
||||
backoff := retry.NewBackoff()
|
||||
|
||||
for {
|
||||
// Check if we should stop
|
||||
select {
|
||||
case <-w.ctx.Done():
|
||||
if w.verbose {
|
||||
log.Printf("[%s] Worker stopped", w.forward.ID())
|
||||
}
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Resolve the resource to get current pod name
|
||||
podName, err := w.portForwarder.GetPodForResource(
|
||||
w.ctx,
|
||||
w.forward.GetContext(),
|
||||
w.forward.GetNamespace(),
|
||||
w.forward.Resource,
|
||||
w.forward.Selector,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[%s] Failed to resolve resource: %v", w.forward.ID(), err)
|
||||
w.sleepWithBackoff(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if pod changed (restart detected)
|
||||
if w.lastPod != "" && w.lastPod != podName {
|
||||
log.Printf("[%s] Switched to new pod: %s → %s", w.forward.ID(), w.lastPod, podName)
|
||||
} else if w.lastPod == "" {
|
||||
log.Printf("[%s] Forwarding %s → localhost:%d",
|
||||
w.forward.ID(), w.forward.String(), w.forward.LocalPort)
|
||||
}
|
||||
|
||||
w.lastPod = podName
|
||||
|
||||
// Establish port-forward connection
|
||||
err = w.establishForward(podName)
|
||||
|
||||
if err != nil {
|
||||
// Connection failed or was interrupted
|
||||
if w.ctx.Err() != nil {
|
||||
// Context was cancelled, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
// Log the error
|
||||
log.Printf("[%s] Port-forward connection failed: %v", w.forward.ID(), err)
|
||||
|
||||
// Clear last pod so we re-resolve on next attempt
|
||||
w.lastPod = ""
|
||||
|
||||
// Wait with backoff before retrying
|
||||
w.sleepWithBackoff(backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
// Connection closed normally (shouldn't happen unless stopped)
|
||||
if w.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Connection closed unexpectedly, retry
|
||||
log.Printf("[%s] Connection closed unexpectedly, retrying...", w.forward.ID())
|
||||
w.lastPod = ""
|
||||
w.sleepWithBackoff(backoff)
|
||||
}
|
||||
}
|
||||
|
||||
// establishForward establishes a port-forward connection.
|
||||
// This blocks until the connection is closed or an error occurs.
|
||||
func (w *ForwardWorker) establishForward(podName string) error {
|
||||
// Create channels for this forward
|
||||
stopChan := make(chan struct{}, 1)
|
||||
readyChan := make(chan struct{}, 1)
|
||||
|
||||
// Create a context for this forward attempt
|
||||
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
|
||||
defer forwardCancel()
|
||||
|
||||
// Start a goroutine to monitor for stop signal
|
||||
go func() {
|
||||
select {
|
||||
case <-w.stopChan:
|
||||
close(stopChan)
|
||||
case <-forwardCtx.Done():
|
||||
close(stopChan)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set up output writers
|
||||
var out, errOut io.Writer
|
||||
if w.verbose {
|
||||
out = &logWriter{prefix: fmt.Sprintf("[%s] ", w.forward.ID())}
|
||||
errOut = &logWriter{prefix: fmt.Sprintf("[%s] ERROR: ", w.forward.ID())}
|
||||
} else {
|
||||
out = io.Discard
|
||||
errOut = io.Discard
|
||||
}
|
||||
|
||||
// Create forward request
|
||||
req := &k8s.ForwardRequest{
|
||||
ContextName: w.forward.GetContext(),
|
||||
Namespace: w.forward.GetNamespace(),
|
||||
Resource: w.forward.Resource,
|
||||
Selector: w.forward.Selector,
|
||||
LocalPort: w.forward.LocalPort,
|
||||
RemotePort: w.forward.Port,
|
||||
StopChan: stopChan,
|
||||
ReadyChan: readyChan,
|
||||
Out: out,
|
||||
ErrOut: errOut,
|
||||
}
|
||||
|
||||
// Start port forwarding in a goroutine
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
errChan <- w.portForwarder.Forward(forwardCtx, req)
|
||||
}()
|
||||
|
||||
// Wait for ready or error
|
||||
select {
|
||||
case <-readyChan:
|
||||
if w.verbose {
|
||||
log.Printf("[%s] Port-forward connection established", w.forward.ID())
|
||||
}
|
||||
case err := <-errChan:
|
||||
return fmt.Errorf("failed to establish forward: %w", err)
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
case <-time.After(30 * time.Second):
|
||||
return fmt.Errorf("timeout waiting for port-forward to become ready")
|
||||
}
|
||||
|
||||
// Wait for connection to close or error
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-w.ctx.Done():
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// sleepWithBackoff waits for the next backoff duration.
|
||||
// Returns early if the worker is stopped.
|
||||
func (w *ForwardWorker) sleepWithBackoff(backoff *retry.Backoff) {
|
||||
delay := backoff.Next()
|
||||
|
||||
if w.verbose {
|
||||
log.Printf("[%s] Retrying in %v (attempt %d)", w.forward.ID(), delay, backoff.Attempt())
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
// Continue with retry
|
||||
case <-w.ctx.Done():
|
||||
// Worker stopped
|
||||
}
|
||||
}
|
||||
|
||||
// GetForward returns the forward configuration for this worker.
|
||||
func (w *ForwardWorker) GetForward() config.Forward {
|
||||
return w.forward
|
||||
}
|
||||
|
||||
// IsRunning returns true if the worker is running.
|
||||
func (w *ForwardWorker) IsRunning() bool {
|
||||
select {
|
||||
case <-w.doneChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// logWriter implements io.Writer to write log messages with a prefix.
|
||||
type logWriter struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (lw *logWriter) Write(p []byte) (n int, err error) {
|
||||
log.Printf("%s%s", lw.prefix, string(p))
|
||||
return len(p), nil
|
||||
}
|
||||
Reference in New Issue
Block a user