From 3a7cc6f5029fa0997096f14514b95b36e37ae7d3 Mon Sep 17 00:00:00 2001 From: Lukasz Raczylo Date: Tue, 25 Nov 2025 01:28:23 +0000 Subject: [PATCH] bugfixes nov2025 pt3 (#6) * Minor improvements. * DRY the codebase. * Add version checker / updater. --- .goreleaser.yaml | 2 +- Makefile | 2 +- cmd/kportal/main.go | 64 ++++++++++- internal/config/config.go | 93 +++++++++------- internal/config/validator.go | 19 ++-- internal/forward/portcheck.go | 185 ++++++++++++++++++++++++------- internal/forward/watchdog.go | 30 +++-- internal/healthcheck/checker.go | 50 ++++----- internal/k8s/portforward.go | 12 +- internal/ui/bubbletea_ui.go | 40 ++++++- internal/ui/wizard_handlers.go | 37 ++----- internal/ui/wizard_state.go | 10 ++ internal/version/checker.go | 158 ++++++++++++++++++++++++++ internal/version/checker_test.go | 90 +++++++++++++++ 14 files changed, 634 insertions(+), 158 deletions(-) create mode 100644 internal/version/checker.go create mode 100644 internal/version/checker_test.go diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 3dc9890..d2fc480 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -19,7 +19,7 @@ builds: - arm64 ldflags: - -s -w - - -X main.version={{.Version}} + - -X main.appVersion={{.Version}} archives: - id: kportal diff --git a/Makefile b/Makefile index f9b0fb1..31024a1 100644 --- a/Makefile +++ b/Makefile @@ -37,7 +37,7 @@ GOFMT=$(GOCMD) fmt # Build flags BUILD_FLAGS=-buildvcs=false -LDFLAGS=-ldflags="-s -w -X main.version=$(VERSION)" +LDFLAGS=-ldflags="-s -w -X main.appVersion=$(VERSION)" all: fmt vet staticcheck test build diff --git a/cmd/kportal/main.go b/cmd/kportal/main.go index 3361a3f..74ee843 100644 --- a/cmd/kportal/main.go +++ b/cmd/kportal/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "io" @@ -19,6 +20,7 @@ import ( "github.com/nvm/kportal/internal/k8s" "github.com/nvm/kportal/internal/logger" "github.com/nvm/kportal/internal/ui" + "github.com/nvm/kportal/internal/version" "k8s.io/klog/v2" ) @@ -26,6 +28,10 @@ const ( defaultConfigFile = ".kportal.yaml" initialForwardSettleTime = 100 * time.Millisecond tableUpdateInterval = 2 * time.Second + + // GitHub repository info for update checks + githubOwner = "lukaszraczylo" + githubRepo = "kportal" ) var ( @@ -34,16 +40,22 @@ var ( logFormat = flag.String("log-format", "text", "Log format: text or json") check = flag.Bool("check", false, "Validate configuration and exit") showVersion = flag.Bool("version", false, "Show version and exit") + checkUpdate = flag.Bool("update", false, "Check for updates and exit") convertInput = flag.String("convert", "", "Convert kftray JSON config to kportal YAML (provide input file path)") convertOutput = flag.String("convert-output", ".kportal.yaml", "Output file for converted configuration") - version = "0.1.0" // Set via ldflags during build + appVersion = "0.1.0" // Set via ldflags during build ) func main() { flag.Parse() if *showVersion { - fmt.Printf("kportal version %s\n", version) + fmt.Printf("kportal version %s\n", appVersion) + os.Exit(0) + } + + if *checkUpdate { + checkForUpdates() os.Exit(0) } @@ -177,7 +189,7 @@ func main() { // Only log startup messages in verbose mode if *verbose { - log.Printf("kportal v%s", version) + log.Printf("kportal v%s", appVersion) log.Printf("Loading configuration from: %s", *configFile) } @@ -209,17 +221,40 @@ func main() { } else { manager.DisableForward(id) } - }, version) + }, appVersion) // Set wizard dependencies // Note: mutator is always available (for delete/edit), discovery requires valid kubeconfig (for add) bubbleTeaUI.SetWizardDependencies(discovery, mutator, *configFile) + // Check for updates in background (non-blocking) + go func() { + checker := version.NewChecker(githubOwner, githubRepo, appVersion) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if update := checker.CheckForUpdate(ctx); update != nil { + bubbleTeaUI.SetUpdateAvailable(update.LatestVersion, update.ReleaseURL) + } + }() + manager.SetStatusUI(bubbleTeaUI) } else { // Verbose mode with simple table tableUI = ui.NewTableUI(*verbose) manager.SetStatusUI(tableUI) + + // Check for updates and print to log + go func() { + checker := version.NewChecker(githubOwner, githubRepo, appVersion) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if update := checker.CheckForUpdate(ctx); update != nil { + log.Printf("Update available: v%s (current: v%s) - %s", + update.LatestVersion, update.CurrentVersion, update.ReleaseURL) + } + }() } // Start forwards @@ -322,3 +357,24 @@ func main() { manager.Stop() } } + +// checkForUpdates checks for available updates and prints the result +func checkForUpdates() { + fmt.Printf("kportal version %s\n", appVersion) + fmt.Println("Checking for updates...") + + checker := version.NewChecker(githubOwner, githubRepo, appVersion) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + update := checker.CheckForUpdate(ctx) + if update == nil { + fmt.Println("You are running the latest version.") + return + } + + fmt.Printf("\nUpdate available: v%s\n", update.LatestVersion) + fmt.Printf("Download: %s\n", update.ReleaseURL) + fmt.Println("\nTo update, download the latest release from the URL above") + fmt.Println("or use your package manager (e.g., 'brew upgrade kportal').") +} diff --git a/internal/config/config.go b/internal/config/config.go index 7db62fb..fbc6573 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "bytes" "fmt" "os" "time" @@ -9,7 +10,20 @@ import ( ) const ( - maxConfigSize = 10 * 1024 * 1024 // 10MB + // maxConfigSize is the maximum allowed configuration file size (10MB) + maxConfigSize = 10 * 1024 * 1024 + + // Default health check settings + DefaultHealthCheckInterval = 3 * time.Second // How often to check connection health + DefaultHealthCheckTimeout = 2 * time.Second // Timeout for health check probes + DefaultHealthCheckMethod = "data-transfer" // More reliable than tcp-dial + DefaultMaxConnectionAge = 25 * time.Minute // Reconnect before k8s 30min timeout + DefaultMaxIdleTime = 10 * time.Minute // Reconnect if no activity + + // Default reliability settings + DefaultTCPKeepalive = 30 * time.Second // OS-level TCP keepalive interval + DefaultDialTimeout = 30 * time.Second // Connection establishment timeout + DefaultWatchdogPeriod = 30 * time.Second // Goroutine health check interval ) // Config represents the root configuration structure from .kportal.yaml @@ -36,24 +50,31 @@ type ReliabilitySpec struct { WatchdogPeriod string `yaml:"watchdogPeriod,omitempty"` // e.g., "30s" - goroutine watchdog interval } +// parseDurationOrDefault parses a duration string and returns the default if empty or invalid. +func parseDurationOrDefault(value string, defaultDur time.Duration) time.Duration { + if value == "" { + return defaultDur + } + if d, err := time.ParseDuration(value); err == nil { + return d + } + return defaultDur +} + // GetHealthCheckIntervalOrDefault returns the health check interval or default value func (c *Config) GetHealthCheckIntervalOrDefault() time.Duration { - if c.HealthCheck != nil && c.HealthCheck.Interval != "" { - if d, err := time.ParseDuration(c.HealthCheck.Interval); err == nil { - return d - } + if c.HealthCheck == nil { + return DefaultHealthCheckInterval } - return 3 * time.Second // Default: check every 3 seconds + return parseDurationOrDefault(c.HealthCheck.Interval, DefaultHealthCheckInterval) } // GetHealthCheckTimeoutOrDefault returns the health check timeout or default value func (c *Config) GetHealthCheckTimeoutOrDefault() time.Duration { - if c.HealthCheck != nil && c.HealthCheck.Timeout != "" { - if d, err := time.ParseDuration(c.HealthCheck.Timeout); err == nil { - return d - } + if c.HealthCheck == nil { + return DefaultHealthCheckTimeout } - return 2 * time.Second // Default: 2 second timeout + return parseDurationOrDefault(c.HealthCheck.Timeout, DefaultHealthCheckTimeout) } // GetHealthCheckMethod returns the health check method or default @@ -61,37 +82,31 @@ func (c *Config) GetHealthCheckMethod() string { if c.HealthCheck != nil && c.HealthCheck.Method != "" { return c.HealthCheck.Method } - return "data-transfer" // Default: more reliable data transfer test + return DefaultHealthCheckMethod } // GetMaxConnectionAge returns the max connection age or default func (c *Config) GetMaxConnectionAge() time.Duration { - if c.HealthCheck != nil && c.HealthCheck.MaxConnectionAge != "" { - if d, err := time.ParseDuration(c.HealthCheck.MaxConnectionAge); err == nil { - return d - } + if c.HealthCheck == nil { + return DefaultMaxConnectionAge } - return 25 * time.Minute // Default: 25 minutes (before typical 30min k8s timeout) + return parseDurationOrDefault(c.HealthCheck.MaxConnectionAge, DefaultMaxConnectionAge) } // GetMaxIdleTime returns the max idle time or default func (c *Config) GetMaxIdleTime() time.Duration { - if c.HealthCheck != nil && c.HealthCheck.MaxIdleTime != "" { - if d, err := time.ParseDuration(c.HealthCheck.MaxIdleTime); err == nil { - return d - } + if c.HealthCheck == nil { + return DefaultMaxIdleTime } - return 10 * time.Minute // Default: 10 minutes idle before reconnect + return parseDurationOrDefault(c.HealthCheck.MaxIdleTime, DefaultMaxIdleTime) } // GetTCPKeepalive returns the TCP keepalive duration or default func (c *Config) GetTCPKeepalive() time.Duration { - if c.Reliability != nil && c.Reliability.TCPKeepalive != "" { - if d, err := time.ParseDuration(c.Reliability.TCPKeepalive); err == nil { - return d - } + if c.Reliability == nil { + return DefaultTCPKeepalive } - return 30 * time.Second // Default: 30 second keepalive + return parseDurationOrDefault(c.Reliability.TCPKeepalive, DefaultTCPKeepalive) } // GetRetryOnStale returns whether to retry on stale connections @@ -104,22 +119,18 @@ func (c *Config) GetRetryOnStale() bool { // GetWatchdogPeriod returns the goroutine watchdog check period or default func (c *Config) GetWatchdogPeriod() time.Duration { - if c.Reliability != nil && c.Reliability.WatchdogPeriod != "" { - if d, err := time.ParseDuration(c.Reliability.WatchdogPeriod); err == nil { - return d - } + if c.Reliability == nil { + return DefaultWatchdogPeriod } - return 30 * time.Second // Default: check every 30 seconds + return parseDurationOrDefault(c.Reliability.WatchdogPeriod, DefaultWatchdogPeriod) } // GetDialTimeout returns the connection dial timeout or default func (c *Config) GetDialTimeout() time.Duration { - if c.Reliability != nil && c.Reliability.DialTimeout != "" { - if d, err := time.ParseDuration(c.Reliability.DialTimeout); err == nil { - return d - } + if c.Reliability == nil { + return DefaultDialTimeout } - return 30 * time.Second // Default: 30 second dial timeout + return parseDurationOrDefault(c.Reliability.DialTimeout, DefaultDialTimeout) } // Context represents a Kubernetes context with its namespaces @@ -209,9 +220,15 @@ func LoadConfig(path string) (*Config, error) { } // ParseConfig parses YAML configuration data into a Config struct. +// It uses strict parsing that rejects unknown keys to catch typos. func ParseConfig(data []byte) (*Config, error) { var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { + + // Use decoder with KnownFields to reject unknown keys (catches typos) + decoder := yaml.NewDecoder(bytes.NewReader(data)) + decoder.KnownFields(true) + + if err := decoder.Decode(&cfg); err != nil { return nil, fmt.Errorf("failed to parse YAML: %w", err) } diff --git a/internal/config/validator.go b/internal/config/validator.go index a510300..bab8a4a 100644 --- a/internal/config/validator.go +++ b/internal/config/validator.go @@ -6,10 +6,15 @@ import ( ) const ( - minPort = 1 - maxPort = 65535 + MinPort = 1 + MaxPort = 65535 ) +// IsValidPort returns true if the port number is within the valid range (1-65535). +func IsValidPort(port int) bool { + return port >= MinPort && port <= MaxPort +} + // ValidationError represents a configuration validation error with context. type ValidationError struct { Field string // The field that failed validation @@ -84,7 +89,7 @@ func (v *Validator) validateStructure(cfg *Config) []ValidationError { Field: fmt.Sprintf("contexts[%d].namespaces", i), Message: fmt.Sprintf("Context '%s' must have at least one namespace", ctx.Name), }) - continue + // Don't continue - still validate other aspects of the context if any } for j, ns := range ctx.Namespaces { @@ -130,17 +135,17 @@ func (v *Validator) validateForward(fwd *Forward) []ValidationError { } // Validate ports - if fwd.Port < minPort || fwd.Port > maxPort { + if fwd.Port < MinPort || fwd.Port > MaxPort { errs = append(errs, ValidationError{ Field: "port", - Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), minPort, maxPort), + Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), MinPort, MaxPort), }) } - if fwd.LocalPort < minPort || fwd.LocalPort > maxPort { + if fwd.LocalPort < MinPort || fwd.LocalPort > MaxPort { errs = append(errs, ValidationError{ Field: "localPort", - Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), minPort, maxPort), + Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), MinPort, MaxPort), }) } diff --git a/internal/forward/portcheck.go b/internal/forward/portcheck.go index 6bf9d2d..7c690b6 100644 --- a/internal/forward/portcheck.go +++ b/internal/forward/portcheck.go @@ -6,11 +6,20 @@ import ( "os/exec" "runtime" "strings" + + "github.com/nvm/kportal/internal/logger" +) + +const ( + // maxPIDLength is the maximum length of a valid PID string (9 digits covers PIDs up to 999,999,999) + maxPIDLength = 9 + // minNetstatFields is the minimum number of fields expected in netstat output + minNetstatFields = 5 ) // isValidPID validates that a PID string contains only digits func isValidPID(pid string) bool { - if len(pid) == 0 || len(pid) > 9 { + if len(pid) == 0 || len(pid) > maxPIDLength { return false } for _, c := range pid { @@ -21,6 +30,72 @@ func isValidPID(pid string) bool { return true } +// processInfo holds information about a process using a port +type processInfo struct { + pid string + name string + isValid bool +} + +// formatProcessInfo formats process information for display +func formatProcessInfo(info processInfo) string { + if !info.isValid { + return "unknown" + } + if info.name != "" { + return fmt.Sprintf("%s (PID %s)", info.name, info.pid) + } + return fmt.Sprintf("PID %s", info.pid) +} + +// formatProcessList formats a list of processes into a human-readable string. +// Returns "unknown" if the list is empty. +func formatProcessList(processes []processInfo) string { + if len(processes) == 0 { + return "unknown" + } + if len(processes) == 1 { + return formatProcessInfo(processes[0]) + } + // Multiple processes - format as comma-separated list + parts := make([]string, len(processes)) + for i, p := range processes { + parts[i] = formatProcessInfo(p) + } + return strings.Join(parts, ", ") +} + +// getProcessNameByPID retrieves the process name for a given PID on Unix systems +func getProcessNameByPID(pid string) string { + cmd := exec.Command("ps", "-p", pid, "-o", "comm=") + output, err := cmd.Output() + if err != nil { + return "" + } + return strings.TrimSpace(string(output)) +} + +// getProcessNameByPIDWindows retrieves the process name for a given PID on Windows +func getProcessNameByPIDWindows(pid string) string { + cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH") + output, err := cmd.Output() + if err != nil { + return "" + } + + // Parse CSV output: "process.exe","1234","Console","1","12,345 K" + csvLine := strings.TrimSpace(string(output)) + if csvLine == "" { + return "" + } + + parts := strings.Split(csvLine, ",") + if len(parts) > 0 { + return strings.Trim(parts[0], "\"") + } + return "" +} + // PortConflict represents a local port that is already in use. type PortConflict struct { Port int // The conflicting port number @@ -102,27 +177,55 @@ func (pc *PortChecker) getProcessUsingPortUnix(port int) string { return "unknown" } - // Get the first PID if multiple are returned + // Handle multiple PIDs (multiple processes on same port) pids := strings.Split(pidStr, "\n") - pid := pids[0] + var validProcesses []processInfo - if !isValidPID(pid) { - return "unknown" + for _, pid := range pids { + pid = strings.TrimSpace(pid) + if pid == "" { + continue + } + + if !isValidPID(pid) { + logger.Debug("Invalid PID format from lsof output", map[string]interface{}{ + "port": port, + "raw_pid": pid, + }) + continue + } + + procName := getProcessNameByPID(pid) + validProcesses = append(validProcesses, processInfo{ + pid: pid, + name: procName, + isValid: true, + }) } - // Get process name using ps - cmd = exec.Command("ps", "-p", pid, "-o", "comm=") - output, err = cmd.Output() - if err != nil { - return fmt.Sprintf("PID %s", pid) + return formatProcessList(validProcesses) +} + +// isListeningState checks if a netstat line indicates a listening state. +// This handles both English and potentially other locales by checking for common patterns. +func isListeningState(line string, fields []string) bool { + upperLine := strings.ToUpper(line) + + // Check for common listening state indicators across locales + // English: LISTENING, German: ABHÖREN, French: ÉCOUTE, etc. + // The most reliable check is the state field position (4th field, 0-indexed = 3) + // and that it's a TCP connection with 0.0.0.0:0 or *:* as foreign address + if len(fields) >= minNetstatFields { + state := strings.ToUpper(fields[3]) + // Common listening state values across Windows locales + if state == "LISTENING" || state == "ABHÖREN" || state == "ÉCOUTE" || + state == "ESCUCHANDO" || state == "ASCOLTO" || state == "NASŁUCHIWANIE" { + return true + } } - procName := strings.TrimSpace(string(output)) - if procName == "" { - return fmt.Sprintf("PID %s", pid) - } - - return fmt.Sprintf("%s (PID %s)", procName, pid) + // Fallback: check if line contains LISTENING (most common case) + return strings.Contains(upperLine, "LISTENING") } // getProcessUsingPortWindows uses netstat to find the process using a port on Windows. @@ -138,6 +241,8 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string { lines := strings.Split(string(output), "\n") portStr := fmt.Sprintf(":%d", port) + var validProcesses []processInfo + for _, line := range lines { if !strings.Contains(line, portStr) { continue @@ -146,44 +251,42 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string { // Parse the line to extract PID // Format: TCP 0.0.0.0:8080 0.0.0.0:0 LISTENING 1234 fields := strings.Fields(line) - if len(fields) < 5 { + if len(fields) < minNetstatFields { continue } - // Check if this is a LISTENING state - if !strings.Contains(strings.ToUpper(line), "LISTENING") { + // Check if this is a LISTENING state (locale-aware) + if !isListeningState(line, fields) { + continue + } + + // Verify the local address field actually contains our port + // (avoid matching port in foreign address) + localAddr := fields[1] + if !strings.HasSuffix(localAddr, portStr) { continue } pid := fields[len(fields)-1] if !isValidPID(pid) { - return "unknown" + logger.Debug("Invalid PID format from netstat output", map[string]interface{}{ + "port": port, + "raw_pid": pid, + "line": line, + }) + continue } - // Get process name using tasklist - cmd = exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH") - output, err = cmd.Output() - if err != nil { - return fmt.Sprintf("PID %s", pid) - } - - // Parse CSV output: "process.exe","1234","Console","1","12,345 K" - csvLine := strings.TrimSpace(string(output)) - if csvLine == "" { - return fmt.Sprintf("PID %s", pid) - } - - parts := strings.Split(csvLine, ",") - if len(parts) > 0 { - procName := strings.Trim(parts[0], "\"") - return fmt.Sprintf("%s (PID %s)", procName, pid) - } - - return fmt.Sprintf("PID %s", pid) + procName := getProcessNameByPIDWindows(pid) + validProcesses = append(validProcesses, processInfo{ + pid: pid, + name: procName, + isValid: true, + }) } - return "unknown" + return formatProcessList(validProcesses) } // FormatConflicts formats port conflicts into a human-readable error message. diff --git a/internal/forward/watchdog.go b/internal/forward/watchdog.go index 3a03dbe..714425a 100644 --- a/internal/forward/watchdog.go +++ b/internal/forward/watchdog.go @@ -123,11 +123,18 @@ func (w *Watchdog) monitorLoop() { } } +// hungWorkerInfo stores information about a hung worker for deferred callback execution +type hungWorkerInfo struct { + forwardID string + callback func(string) +} + // checkWorkers checks all registered workers for hung state func (w *Watchdog) checkWorkers() { - w.mu.Lock() - defer w.mu.Unlock() + // Collect hung workers while holding the lock + var hungWorkers []hungWorkerInfo + w.mu.Lock() now := time.Now() for forwardID, state := range w.workers { timeSinceHeartbeat := now.Sub(state.lastHeartbeat) @@ -145,14 +152,23 @@ func (w *Watchdog) checkWorkers() { "heartbeat_count": state.heartbeatCount, }) - // Trigger callback to handle hung worker (without holding lock) + // Collect callback for deferred execution outside the lock if state.onHungCallback != nil { - callback := state.onHungCallback - w.mu.Unlock() - callback(forwardID) - w.mu.Lock() + hungWorkers = append(hungWorkers, hungWorkerInfo{ + forwardID: forwardID, + callback: state.onHungCallback, + }) } } } } + w.mu.Unlock() + + // Execute callbacks outside the lock to prevent deadlocks and ensure + // consistent state during callback execution. Callbacks are idempotent + // (they trigger reconnection via channels), so concurrent state changes + // between detection and callback execution are safe. + for _, hw := range hungWorkers { + hw.callback(hw.forwardID) + } } diff --git a/internal/healthcheck/checker.go b/internal/healthcheck/checker.go index db0455c..bc44499 100644 --- a/internal/healthcheck/checker.go +++ b/internal/healthcheck/checker.go @@ -7,6 +7,8 @@ import ( "net" "sync" "time" + + "github.com/nvm/kportal/internal/config" ) const ( @@ -77,8 +79,8 @@ func NewChecker(interval, timeout time.Duration) *Checker { Interval: interval, Timeout: timeout, Method: CheckMethodDataTransfer, - MaxConnectionAge: 25 * time.Minute, - MaxIdleTime: 10 * time.Minute, + MaxConnectionAge: config.DefaultMaxConnectionAge, + MaxIdleTime: config.DefaultMaxIdleTime, }) } @@ -150,44 +152,34 @@ func (c *Checker) Unregister(forwardID string) { delete(c.callbacks, forwardID) } -// MarkReconnecting marks a forward as reconnecting (called by worker) -func (c *Checker) MarkReconnecting(forwardID string) { +// markStatus is a helper to set a forward's status and notify on change. +func (c *Checker) markStatus(forwardID string, newStatus Status) { c.mu.Lock() - if health, exists := c.ports[forwardID]; exists { - oldStatus := health.Status - health.Status = StatusReconnect - health.LastCheck = time.Now() - + health, exists := c.ports[forwardID] + if !exists { c.mu.Unlock() - - if oldStatus != StatusReconnect { - c.notifyStatusChange(forwardID, StatusReconnect, "") - } return } + oldStatus := health.Status + health.Status = newStatus + health.LastCheck = time.Now() c.mu.Unlock() + + if oldStatus != newStatus { + c.notifyStatusChange(forwardID, newStatus, "") + } +} + +// MarkReconnecting marks a forward as reconnecting (called by worker) +func (c *Checker) MarkReconnecting(forwardID string) { + c.markStatus(forwardID, StatusReconnect) } // MarkStarting marks a forward as starting (called by worker) func (c *Checker) MarkStarting(forwardID string) { - c.mu.Lock() - - if health, exists := c.ports[forwardID]; exists { - oldStatus := health.Status - health.Status = StatusStarting - health.LastCheck = time.Now() - - c.mu.Unlock() - - if oldStatus != StatusStarting { - c.notifyStatusChange(forwardID, StatusStarting, "") - } - return - } - - c.mu.Unlock() + c.markStatus(forwardID, StatusStarting) } // GetStatus returns the current health status of a forward diff --git a/internal/k8s/portforward.go b/internal/k8s/portforward.go index 452d732..5ffd9d3 100644 --- a/internal/k8s/portforward.go +++ b/internal/k8s/portforward.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/nvm/kportal/internal/config" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/rest" @@ -30,8 +32,8 @@ func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortF return &PortForwarder{ clientPool: clientPool, resolver: resolver, - tcpKeepalive: 30 * time.Second, // Default: 30 second keepalive - dialTimeout: 30 * time.Second, // Default: 30 second dial timeout + tcpKeepalive: config.DefaultTCPKeepalive, + dialTimeout: config.DefaultDialTimeout, } } @@ -140,6 +142,9 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque } // Get pods backing the service using label selector + if len(service.Spec.Selector) == 0 { + return fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", serviceName) + } selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector}) pods, err := client.CoreV1().Pods(req.Namespace).List(ctx, metav1.ListOptions{ LabelSelector: selector, @@ -257,6 +262,9 @@ func (pf *PortForwarder) GetPodForResource(ctx context.Context, contextName, nam return "", fmt.Errorf("failed to get service: %w", err) } + if len(service.Spec.Selector) == 0 { + return "", fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", resourceName) + } selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector}) pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{ LabelSelector: selector, diff --git a/internal/ui/bubbletea_ui.go b/internal/ui/bubbletea_ui.go index c060d0e..a041282 100644 --- a/internal/ui/bubbletea_ui.go +++ b/internal/ui/bubbletea_ui.go @@ -46,6 +46,11 @@ type BubbleTeaUI struct { version string errors map[string]string // Track error messages by forward ID + // Update notification + updateAvailable bool + updateVersion string + updateURL string + // Modal wizard state viewMode ViewMode addWizard *AddWizardState @@ -96,6 +101,16 @@ func (ui *BubbleTeaUI) SetWizardDependencies(discovery *k8s.Discovery, mutator * ui.configPath = configPath } +// SetUpdateAvailable sets the update notification to be displayed +func (ui *BubbleTeaUI) SetUpdateAvailable(version, url string) { + ui.mu.Lock() + defer ui.mu.Unlock() + + ui.updateAvailable = true + ui.updateVersion = version + ui.updateURL = url +} + // Start starts the bubbletea application func (ui *BubbleTeaUI) Start() error { m := model{ui: ui} @@ -169,8 +184,9 @@ func (ui *BubbleTeaUI) UpdateStatus(id string, status string) { if fwd, ok := ui.forwards[id]; ok { fwd.Status = status } - // Clear error if status is not Error - if status != "Error" { + // Only clear error when forward becomes Active again + // This keeps error visible during Reconnecting/Starting states + if status == "Active" { delete(ui.errors, id) } ui.mu.Unlock() @@ -266,7 +282,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.ui.addWizard = nil m.ui.removeWizard = nil m.ui.mu.Unlock() - return m, nil + return m, tea.ClearScreen } return m, nil @@ -356,6 +372,15 @@ func (m model) renderMainView() string { // Title with version title := fmt.Sprintf("kportal v%s - Port Forwarding Status", m.ui.version) b.WriteString(titleStyle.Render(title)) + + // Show update notification if available + if m.ui.updateAvailable { + updateStyle := lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")). // Green + Bold(true) + updateMsg := fmt.Sprintf(" Update available: v%s", m.ui.updateVersion) + b.WriteString(updateStyle.Render(updateMsg)) + } b.WriteString("\n\n") // Header @@ -574,6 +599,15 @@ func (ui *BubbleTeaUI) moveSelection(delta int) { } } +// resetDeleteConfirmation resets the delete confirmation dialog state. +// Caller must hold ui.mu lock. +func (ui *BubbleTeaUI) resetDeleteConfirmation() { + ui.deleteConfirming = false + ui.deleteConfirmID = "" + ui.deleteConfirmAlias = "" + ui.deleteConfirmCursor = 0 +} + // renderDeleteConfirmation renders the delete confirmation dialog func (m model) renderDeleteConfirmation() string { m.ui.mu.RLock() diff --git a/internal/ui/wizard_handlers.go b/internal/ui/wizard_handlers.go index a202924..b268148 100644 --- a/internal/ui/wizard_handlers.go +++ b/internal/ui/wizard_handlers.go @@ -173,12 +173,8 @@ func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) { switch msg.String() { case "ctrl+c", "esc": // Cancel deletion - m.ui.deleteConfirming = false - m.ui.deleteConfirmID = "" - m.ui.deleteConfirmAlias = "" - m.ui.deleteConfirmCursor = 0 // Reset cursor + m.ui.resetDeleteConfirmation() m.ui.mu.Unlock() - // Force a repaint by returning the model return m, tea.ClearScreen case "left", "h", "right", "l": @@ -191,26 +187,18 @@ func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) { // Confirm deletion (either Enter on Yes or pressing 'y') if m.ui.deleteConfirmCursor == 0 || msg.String() == "y" { id := m.ui.deleteConfirmID - m.ui.deleteConfirming = false - m.ui.deleteConfirmID = "" - m.ui.deleteConfirmAlias = "" + m.ui.resetDeleteConfirmation() m.ui.mu.Unlock() return m, removeForwardByIDCmd(m.ui.mutator, id) } // Enter on No = cancel - m.ui.deleteConfirming = false - m.ui.deleteConfirmID = "" - m.ui.deleteConfirmAlias = "" - m.ui.deleteConfirmCursor = 0 // Reset cursor + m.ui.resetDeleteConfirmation() m.ui.mu.Unlock() return m, tea.ClearScreen case "n": // Quick 'n' for no - m.ui.deleteConfirming = false - m.ui.deleteConfirmID = "" - m.ui.deleteConfirmAlias = "" - m.ui.deleteConfirmCursor = 0 // Reset cursor + m.ui.resetDeleteConfirmation() m.ui.mu.Unlock() return m, tea.ClearScreen } @@ -259,10 +247,7 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } else { // Go back one step wizard.step-- - wizard.cursor = 0 - wizard.clearTextInput() - wizard.clearSearchFilter() - wizard.error = nil + wizard.resetInput() // Reset input mode based on the step we're going back to switch wizard.step { @@ -492,7 +477,7 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) { } else { // Text mode - manual entry port, err := strconv.Atoi(wizard.textInput) - if err != nil || port < 1 || port > 65535 { + if err != nil || !config.IsValidPort(port) { wizard.error = fmt.Errorf("invalid port number") } else { wizard.remotePort = port @@ -504,7 +489,7 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) { case StepEnterLocalPort: port, err := strconv.Atoi(wizard.textInput) - if err != nil || port < 1 || port > 65535 { + if err != nil || !config.IsValidPort(port) { wizard.error = fmt.Errorf("invalid port number") } else { // Check port availability before proceeding @@ -559,9 +544,10 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) { return m, saveForwardCmd(m.ui.mutator, wizard.selectedContext, wizard.selectedNamespace, fwd) } else { - // Cancelled + // Cancelled - return to main view with screen clear m.ui.viewMode = ViewModeMain m.ui.addWizard = nil + return m, tea.ClearScreen } case StepSuccess: @@ -571,9 +557,10 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) { m.ui.addWizard.loading = true return m, loadContextsCmd(m.ui.discovery) } else { - // Return to main view + // Return to main view with screen clear m.ui.viewMode = ViewModeMain m.ui.addWizard = nil + return m, tea.ClearScreen } } @@ -828,5 +815,5 @@ func (m model) handleForwardsRemoved(msg ForwardsRemovedMsg) (tea.Model, tea.Cmd // If there was an error, it will be logged but we don't show it in UI for now // The config watcher will either reload (success) or keep old config (failure) - return m, nil + return m, tea.ClearScreen } diff --git a/internal/ui/wizard_state.go b/internal/ui/wizard_state.go index 6473a2f..9163a72 100644 --- a/internal/ui/wizard_state.go +++ b/internal/ui/wizard_state.go @@ -363,3 +363,13 @@ func (w *AddWizardState) clearSearchFilter() { w.cursor = 0 w.scrollOffset = 0 } + +// resetInput clears text input, search filter, and error state. +// Use this when navigating between wizard steps. +func (w *AddWizardState) resetInput() { + w.textInput = "" + w.searchFilter = "" + w.cursor = 0 + w.scrollOffset = 0 + w.error = nil +} diff --git a/internal/version/checker.go b/internal/version/checker.go new file mode 100644 index 0000000..fa42633 --- /dev/null +++ b/internal/version/checker.go @@ -0,0 +1,158 @@ +package version + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +const ( + // GitHubAPIURL is the GitHub API endpoint for releases + githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest" + // requestTimeout is the timeout for HTTP requests + requestTimeout = 5 * time.Second +) + +// ReleaseInfo contains information about a GitHub release +type ReleaseInfo struct { + TagName string `json:"tag_name"` + HTMLURL string `json:"html_url"` + Name string `json:"name"` +} + +// UpdateInfo contains information about an available update +type UpdateInfo struct { + CurrentVersion string + LatestVersion string + ReleaseURL string + ReleaseName string +} + +// Checker checks for new versions on GitHub +type Checker struct { + owner string + repo string + current string + client *http.Client +} + +// NewChecker creates a new version checker +func NewChecker(owner, repo, currentVersion string) *Checker { + return &Checker{ + owner: owner, + repo: repo, + current: normalizeVersion(currentVersion), + client: &http.Client{ + Timeout: requestTimeout, + }, + } +} + +// CheckForUpdate checks if a newer version is available. +// Returns nil if current version is up to date or if check fails. +// This is designed to fail silently - network errors should not impact the user. +func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo { + release, err := c.fetchLatestRelease(ctx) + if err != nil { + return nil + } + + latestVersion := normalizeVersion(release.TagName) + if isNewerVersion(latestVersion, c.current) { + return &UpdateInfo{ + CurrentVersion: c.current, + LatestVersion: latestVersion, + ReleaseURL: release.HTMLURL, + ReleaseName: release.Name, + } + } + + return nil +} + +// fetchLatestRelease fetches the latest release info from GitHub API +func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) { + url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "kportal-version-checker") + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release ReleaseInfo + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, err + } + + return &release, nil +} + +// normalizeVersion removes 'v' or 'V' prefix and trims whitespace +func normalizeVersion(v string) string { + v = strings.TrimSpace(v) + v = strings.TrimPrefix(v, "v") + v = strings.TrimPrefix(v, "V") + return v +} + +// isNewerVersion compares two semver-like versions. +// Returns true if latest is newer than current. +func isNewerVersion(latest, current string) bool { + latestParts := parseVersion(latest) + currentParts := parseVersion(current) + + // Compare each part + for i := 0; i < len(latestParts) && i < len(currentParts); i++ { + if latestParts[i] > currentParts[i] { + return true + } + if latestParts[i] < currentParts[i] { + return false + } + } + + // If all compared parts are equal, longer version is newer + // e.g., 1.0.1 > 1.0 + return len(latestParts) > len(currentParts) +} + +// parseVersion splits a version string into numeric parts +func parseVersion(v string) []int { + // Remove any suffix like -beta, -rc1, etc. + if idx := strings.IndexAny(v, "-+"); idx != -1 { + v = v[:idx] + } + + parts := strings.Split(v, ".") + result := make([]int, 0, len(parts)) + + for _, p := range parts { + var num int + fmt.Sscanf(p, "%d", &num) + result = append(result, num) + } + + return result +} + +// FormatUpdateMessage formats a user-friendly update notification +func (u *UpdateInfo) FormatUpdateMessage() string { + return fmt.Sprintf("New version available: %s (current: %s) - %s", + u.LatestVersion, u.CurrentVersion, u.ReleaseURL) +} diff --git a/internal/version/checker_test.go b/internal/version/checker_test.go new file mode 100644 index 0000000..7d34f5a --- /dev/null +++ b/internal/version/checker_test.go @@ -0,0 +1,90 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeVersion(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"v1.0.0", "1.0.0"}, + {"1.0.0", "1.0.0"}, + {" v2.1.3 ", "2.1.3"}, + {"V1.0.0", "1.0.0"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeVersion(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + expected []int + }{ + {"1.0.0", []int{1, 0, 0}}, + {"2.1.3", []int{2, 1, 3}}, + {"1.0", []int{1, 0}}, + {"10.20.30", []int{10, 20, 30}}, + {"1.0.0-beta", []int{1, 0, 0}}, + {"1.0.0-rc1", []int{1, 0, 0}}, + {"1.0.0+build123", []int{1, 0, 0}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := parseVersion(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsNewerVersion(t *testing.T) { + tests := []struct { + name string + latest string + current string + expected bool + }{ + {"major version bump", "2.0.0", "1.0.0", true}, + {"minor version bump", "1.1.0", "1.0.0", true}, + {"patch version bump", "1.0.1", "1.0.0", true}, + {"same version", "1.0.0", "1.0.0", false}, + {"current is newer major", "1.0.0", "2.0.0", false}, + {"current is newer minor", "1.0.0", "1.1.0", false}, + {"current is newer patch", "1.0.0", "1.0.1", false}, + {"multi-digit versions", "1.10.0", "1.9.0", true}, + {"longer version is newer", "1.0.1", "1.0", true}, + {"shorter version is older", "1.0", "1.0.1", false}, + {"complex comparison", "2.1.3", "2.1.2", true}, + {"real world example", "0.2.0", "0.1.0", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNewerVersion(tt.latest, tt.current) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateInfo_FormatUpdateMessage(t *testing.T) { + info := &UpdateInfo{ + CurrentVersion: "0.1.0", + LatestVersion: "0.2.0", + ReleaseURL: "https://github.com/nvm/kportal/releases/tag/v0.2.0", + } + + msg := info.FormatUpdateMessage() + assert.Contains(t, msg, "0.2.0") + assert.Contains(t, msg, "0.1.0") + assert.Contains(t, msg, "https://github.com/nvm/kportal/releases/tag/v0.2.0") +}