Compare commits

...

5 Commits

Author SHA1 Message Date
lukaszraczylo 3a7cc6f502 bugfixes nov2025 pt3 (#6)
* Minor improvements.
* DRY the codebase.
* Add version checker / updater.
2025-11-25 01:28:23 +00:00
lukaszraczylo 49acba5679 Bugfixes nov2025 pt2 (#5)
* UI bugfixes.
* Fix open port check during new fwd setup wizard
2025-11-25 00:09:32 +00:00
lukaszraczylo 39fe4286b4 Fix the watchdog being too aggressive. 2025-11-24 13:19:44 +00:00
lukaszraczylo 2fdc5912e7 healtcheck improvements (#4)
* Advanced healtchecks.
* Add watchdog for stale connections handling.
2025-11-24 13:00:19 +00:00
lukaszraczylo 7df161aee0 bugfixes nov2025 (#3)
* Fix enter misbehaving.
* Cleanup after previous tui implementation.
* Fix race condition and improve logging
* Add filtering of the namespaces by text input in the wizard UI
2025-11-24 11:09:23 +00:00
35 changed files with 3351 additions and 647 deletions
+1 -1
View File
@@ -19,7 +19,7 @@ builds:
- arm64
ldflags:
- -s -w
- -X main.version={{.Version}}
- -X main.appVersion={{.Version}}
archives:
- id: kportal
+21
View File
@@ -1,6 +1,27 @@
# Example kportal configuration
# Copy this file to your project and customize as needed
# Optional: Health check configuration
# These settings control how kportal monitors connection health and detects stale connections
healthCheck:
interval: "3s" # How often to check connection health (default: 3s)
timeout: "2s" # Timeout for health check operations (default: 2s)
method: "data-transfer" # Health check method: "tcp-dial" or "data-transfer" (default: data-transfer)
# - tcp-dial: Simple TCP connection test (fast, less reliable)
# - data-transfer: Attempts to read data (slower, more reliable)
maxConnectionAge: "25m" # Maximum connection age before proactive reconnect (default: 25m)
# Helps avoid Kubernetes API server timeouts (typically 30m)
maxIdleTime: "10m" # Maximum idle time before marking as stale (default: 10m)
# Connections with no data transfer are marked stale
# Optional: Reliability configuration
# These settings improve connection stability for long-running transfers
reliability:
tcpKeepalive: "30s" # TCP keepalive interval for OS-level connection monitoring (default: 30s)
dialTimeout: "30s" # Connection dial timeout (default: 30s)
retryOnStale: true # Automatically reconnect when stale connections detected (default: true)
watchdogPeriod: "30s" # Goroutine watchdog check interval to detect hung workers (default: 30s)
contexts:
# Production context
- name: production
+1 -1
View File
@@ -37,7 +37,7 @@ GOFMT=$(GOCMD) fmt
# Build flags
BUILD_FLAGS=-buildvcs=false
LDFLAGS=-ldflags="-s -w -X main.version=$(VERSION)"
LDFLAGS=-ldflags="-s -w -X main.appVersion=$(VERSION)"
all: fmt vet staticcheck test build
+43 -1
View File
@@ -24,7 +24,8 @@ kportal simplifies managing multiple Kubernetes port-forwards with an elegant, i
- 🗑️ **Live Delete** - Remove port-forwards instantly from the running session
- 🔄 **Auto-Reconnect** - Automatic retry with exponential backoff on connection failures (max 10s)
-**Hot-Reload** - Update configuration without restarting - changes applied automatically
- 🏥 **Health Checks** - Real-time port forward status monitoring with 5-second intervals
- 🏥 **Advanced Health Checks** - Multiple check methods (tcp-dial, data-transfer) with stale connection detection
- 🛡️ **Goroutine Watchdog** - Detects and recovers from completely hung workers
- 🎨 **Multi-Context** - Support for multiple Kubernetes contexts and namespaces
- 📦 **Batch Management** - Manage all port-forwards from a single configuration file
- 🔌 **Toggle Forwards** - Enable/disable individual port-forwards on the fly with Space key
@@ -194,6 +195,47 @@ contexts:
- **Service**: `service/service-name` or `svc/service-name`
- **Deployment**: `deployment/deployment-name` or `deploy/deployment-name`
### Health Check & Reliability (Advanced)
kportal includes advanced health checking to prevent stale connections during long-running operations like database dumps:
```yaml
healthCheck:
interval: "3s" # Health check frequency (default: 3s)
timeout: "2s" # Health check timeout (default: 2s)
method: "data-transfer" # Check method: "tcp-dial" or "data-transfer" (default: data-transfer)
maxConnectionAge: "25m" # Proactive reconnect before k8s timeout (default: 25m)
maxIdleTime: "10m" # Detect hung connections (default: 10m)
reliability:
tcpKeepalive: "30s" # TCP keepalive interval (default: 30s)
dialTimeout: "30s" # Connection dial timeout (default: 30s)
retryOnStale: true # Auto-reconnect stale connections (default: true)
```
**Health Check Methods:**
- **`tcp-dial`**: Fast TCP connection test - verifies local port is listening
- **`data-transfer`**: More reliable - attempts to read data to verify tunnel is functional
**Stale Detection:**
- **Max Connection Age**: Kubernetes API typically has 30-minute timeout. kportal reconnects at 25 minutes by default to avoid hitting this limit. **Important**: Age-based reconnection only occurs when the connection is ALSO idle - active transfers (like database dumps) are never interrupted.
- **Max Idle Time**: Detects connections with no data transfer, common when intermediate firewalls drop idle TCP connections
**Use Case Example - Database Dumps:**
```yaml
# Optimized for long-running pg_dump
healthCheck:
method: "data-transfer"
maxConnectionAge: "20m" # Only applies when idle - won't interrupt active dumps
maxIdleTime: "5m" # Detects truly stale connections
reliability:
tcpKeepalive: "30s"
retryOnStale: true
```
This configuration ensures multi-hour database dumps complete without interruption. The `maxConnectionAge` will only trigger reconnection if the connection has been idle for more than `maxIdleTime`, preventing interruption of active data transfers.
## 🎮 Usage
### Interactive Mode (Default)
+83 -7
View File
@@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"io"
@@ -12,12 +13,14 @@ import (
"syscall"
"time"
"github.com/go-logr/logr"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/converter"
"github.com/nvm/kportal/internal/forward"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/ui"
"github.com/nvm/kportal/internal/version"
"k8s.io/klog/v2"
)
@@ -25,6 +28,10 @@ const (
defaultConfigFile = ".kportal.yaml"
initialForwardSettleTime = 100 * time.Millisecond
tableUpdateInterval = 2 * time.Second
// GitHub repository info for update checks
githubOwner = "lukaszraczylo"
githubRepo = "kportal"
)
var (
@@ -33,16 +40,22 @@ var (
logFormat = flag.String("log-format", "text", "Log format: text or json")
check = flag.Bool("check", false, "Validate configuration and exit")
showVersion = flag.Bool("version", false, "Show version and exit")
checkUpdate = flag.Bool("update", false, "Check for updates and exit")
convertInput = flag.String("convert", "", "Convert kftray JSON config to kportal YAML (provide input file path)")
convertOutput = flag.String("convert-output", ".kportal.yaml", "Output file for converted configuration")
version = "0.1.0" // Set via ldflags during build
appVersion = "0.1.0" // Set via ldflags during build
)
func main() {
flag.Parse()
if *showVersion {
fmt.Printf("kportal version %s\n", version)
fmt.Printf("kportal version %s\n", appVersion)
os.Exit(0)
}
if *checkUpdate {
checkForUpdates()
os.Exit(0)
}
@@ -91,16 +104,35 @@ func main() {
// Configure klog (used by kubernetes client-go) to route through our logger
// This prevents k8s logs from interfering with the UI
//
// klog v2 uses multiple output mechanisms:
// 1. SetOutput() - for basic text output
// 2. SetLogger() - for structured/error logs (logr interface)
//
// We must configure BOTH to capture all logs including error messages
// that would otherwise bypass SetOutput() and write directly to stderr.
klog.LogToStderr(false) // Disable direct stderr writes
if *verbose {
// In verbose mode, route klog through our structured logger at DEBUG level
// In verbose mode, route all klog through our structured logger at DEBUG level
klogLogger := logger.New(logger.LevelDebug, logFmt, os.Stderr)
// Configure text output routing
klogWriter := logger.NewKlogWriter(klogLogger)
klog.SetOutput(klogWriter)
// Configure structured/error log routing via logr interface
// This captures "Unhandled Error" and other structured logs that bypass SetOutput
logrSink := logger.NewLogrAdapter(klogLogger)
klog.SetLogger(logr.New(logrSink))
} else {
// In non-verbose mode, completely silence klog
// In non-verbose mode, completely silence ALL klog output
klog.SetOutput(io.Discard)
// Also silence structured/error logs via a discard logger
silentLogger := logger.New(logger.LevelError+1, logFmt, io.Discard) // Level above ERROR = silence all
logrSink := logger.NewLogrAdapter(silentLogger)
klog.SetLogger(logr.New(logrSink))
}
klog.LogToStderr(false)
// Handle conversion mode
if *convertInput != "" {
@@ -157,7 +189,7 @@ func main() {
// Only log startup messages in verbose mode
if *verbose {
log.Printf("kportal v%s", version)
log.Printf("kportal v%s", appVersion)
log.Printf("Loading configuration from: %s", *configFile)
}
@@ -189,17 +221,40 @@ func main() {
} else {
manager.DisableForward(id)
}
}, version)
}, appVersion)
// Set wizard dependencies
// Note: mutator is always available (for delete/edit), discovery requires valid kubeconfig (for add)
bubbleTeaUI.SetWizardDependencies(discovery, mutator, *configFile)
// Check for updates in background (non-blocking)
go func() {
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if update := checker.CheckForUpdate(ctx); update != nil {
bubbleTeaUI.SetUpdateAvailable(update.LatestVersion, update.ReleaseURL)
}
}()
manager.SetStatusUI(bubbleTeaUI)
} else {
// Verbose mode with simple table
tableUI = ui.NewTableUI(*verbose)
manager.SetStatusUI(tableUI)
// Check for updates and print to log
go func() {
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if update := checker.CheckForUpdate(ctx); update != nil {
log.Printf("Update available: v%s (current: v%s) - %s",
update.LatestVersion, update.CurrentVersion, update.ReleaseURL)
}
}()
}
// Start forwards
@@ -302,3 +357,24 @@ func main() {
manager.Stop()
}
}
// checkForUpdates checks for available updates and prints the result
func checkForUpdates() {
fmt.Printf("kportal version %s\n", appVersion)
fmt.Println("Checking for updates...")
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
update := checker.CheckForUpdate(ctx)
if update == nil {
fmt.Println("You are running the latest version.")
return
}
fmt.Printf("\nUpdate available: v%s\n", update.LatestVersion)
fmt.Printf("Download: %s\n", update.ReleaseURL)
fmt.Println("\nTo update, download the latest release from the URL above")
fmt.Println("or use your package manager (e.g., 'brew upgrade kportal').")
}
+1
View File
@@ -25,6 +25,7 @@ require (
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emicklei/go-restful/v3 v3.13.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
+2
View File
@@ -23,6 +23,8 @@ github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsV
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes=
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
+126 -38
View File
@@ -1,19 +1,136 @@
package config
import (
"bytes"
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
const (
maxConfigSize = 10 * 1024 * 1024 // 10MB
// maxConfigSize is the maximum allowed configuration file size (10MB)
maxConfigSize = 10 * 1024 * 1024
// Default health check settings
DefaultHealthCheckInterval = 3 * time.Second // How often to check connection health
DefaultHealthCheckTimeout = 2 * time.Second // Timeout for health check probes
DefaultHealthCheckMethod = "data-transfer" // More reliable than tcp-dial
DefaultMaxConnectionAge = 25 * time.Minute // Reconnect before k8s 30min timeout
DefaultMaxIdleTime = 10 * time.Minute // Reconnect if no activity
// Default reliability settings
DefaultTCPKeepalive = 30 * time.Second // OS-level TCP keepalive interval
DefaultDialTimeout = 30 * time.Second // Connection establishment timeout
DefaultWatchdogPeriod = 30 * time.Second // Goroutine health check interval
)
// Config represents the root configuration structure from .kportal.yaml
type Config struct {
Contexts []Context `yaml:"contexts"`
Contexts []Context `yaml:"contexts"`
HealthCheck *HealthCheckSpec `yaml:"healthCheck,omitempty"`
Reliability *ReliabilitySpec `yaml:"reliability,omitempty"`
}
// HealthCheckSpec configures health check behavior
type HealthCheckSpec struct {
Interval string `yaml:"interval,omitempty"` // e.g., "3s", "5s"
Timeout string `yaml:"timeout,omitempty"` // e.g., "2s"
Method string `yaml:"method,omitempty"` // "tcp-dial" | "data-transfer"
MaxConnectionAge string `yaml:"maxConnectionAge,omitempty"` // e.g., "25m" - reconnect before k8s timeout
MaxIdleTime string `yaml:"maxIdleTime,omitempty"` // e.g., "10m" - reconnect if no activity
}
// ReliabilitySpec configures connection reliability features
type ReliabilitySpec struct {
TCPKeepalive string `yaml:"tcpKeepalive,omitempty"` // e.g., "30s" - OS-level keepalive
DialTimeout string `yaml:"dialTimeout,omitempty"` // e.g., "30s" - connection dial timeout
RetryOnStale bool `yaml:"retryOnStale,omitempty"` // Auto-reconnect on stale detection
WatchdogPeriod string `yaml:"watchdogPeriod,omitempty"` // e.g., "30s" - goroutine watchdog interval
}
// parseDurationOrDefault parses a duration string and returns the default if empty or invalid.
func parseDurationOrDefault(value string, defaultDur time.Duration) time.Duration {
if value == "" {
return defaultDur
}
if d, err := time.ParseDuration(value); err == nil {
return d
}
return defaultDur
}
// GetHealthCheckIntervalOrDefault returns the health check interval or default value
func (c *Config) GetHealthCheckIntervalOrDefault() time.Duration {
if c.HealthCheck == nil {
return DefaultHealthCheckInterval
}
return parseDurationOrDefault(c.HealthCheck.Interval, DefaultHealthCheckInterval)
}
// GetHealthCheckTimeoutOrDefault returns the health check timeout or default value
func (c *Config) GetHealthCheckTimeoutOrDefault() time.Duration {
if c.HealthCheck == nil {
return DefaultHealthCheckTimeout
}
return parseDurationOrDefault(c.HealthCheck.Timeout, DefaultHealthCheckTimeout)
}
// GetHealthCheckMethod returns the health check method or default
func (c *Config) GetHealthCheckMethod() string {
if c.HealthCheck != nil && c.HealthCheck.Method != "" {
return c.HealthCheck.Method
}
return DefaultHealthCheckMethod
}
// GetMaxConnectionAge returns the max connection age or default
func (c *Config) GetMaxConnectionAge() time.Duration {
if c.HealthCheck == nil {
return DefaultMaxConnectionAge
}
return parseDurationOrDefault(c.HealthCheck.MaxConnectionAge, DefaultMaxConnectionAge)
}
// GetMaxIdleTime returns the max idle time or default
func (c *Config) GetMaxIdleTime() time.Duration {
if c.HealthCheck == nil {
return DefaultMaxIdleTime
}
return parseDurationOrDefault(c.HealthCheck.MaxIdleTime, DefaultMaxIdleTime)
}
// GetTCPKeepalive returns the TCP keepalive duration or default
func (c *Config) GetTCPKeepalive() time.Duration {
if c.Reliability == nil {
return DefaultTCPKeepalive
}
return parseDurationOrDefault(c.Reliability.TCPKeepalive, DefaultTCPKeepalive)
}
// GetRetryOnStale returns whether to retry on stale connections
func (c *Config) GetRetryOnStale() bool {
if c.Reliability != nil {
return c.Reliability.RetryOnStale
}
return true // Default: enabled
}
// GetWatchdogPeriod returns the goroutine watchdog check period or default
func (c *Config) GetWatchdogPeriod() time.Duration {
if c.Reliability == nil {
return DefaultWatchdogPeriod
}
return parseDurationOrDefault(c.Reliability.WatchdogPeriod, DefaultWatchdogPeriod)
}
// GetDialTimeout returns the connection dial timeout or default
func (c *Config) GetDialTimeout() time.Duration {
if c.Reliability == nil {
return DefaultDialTimeout
}
return parseDurationOrDefault(c.Reliability.DialTimeout, DefaultDialTimeout)
}
// Context represents a Kubernetes context with its namespaces
@@ -103,9 +220,15 @@ func LoadConfig(path string) (*Config, error) {
}
// ParseConfig parses YAML configuration data into a Config struct.
// It uses strict parsing that rejects unknown keys to catch typos.
func ParseConfig(data []byte) (*Config, error) {
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
// Use decoder with KnownFields to reject unknown keys (catches typos)
decoder := yaml.NewDecoder(bytes.NewReader(data))
decoder.KnownFields(true)
if err := decoder.Decode(&cfg); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
@@ -136,38 +259,3 @@ func (c *Config) GetAllForwards() []Forward {
return forwards
}
// GetForwardsByContext returns all forwards for a specific context.
func (c *Config) GetForwardsByContext(contextName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
forwards = append(forwards, ns.Forwards...)
}
break
}
}
return forwards
}
// GetForwardsByNamespace returns all forwards for a specific context and namespace.
func (c *Config) GetForwardsByNamespace(contextName, namespaceName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
if ns.Name == namespaceName {
forwards = append(forwards, ns.Forwards...)
break
}
}
break
}
}
return forwards
}
-66
View File
@@ -298,72 +298,6 @@ func TestConfig_GetAllForwards(t *testing.T) {
assert.Len(t, forwards, 4, "should return all forwards from all contexts and namespaces")
}
func TestConfig_GetForwardsByContext(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: cluster2
namespaces:
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByContext("cluster1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1")
forwards2 := cfg.GetForwardsByContext("cluster2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster2")
forwards3 := cfg.GetForwardsByContext("non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent context")
}
func TestConfig_GetForwardsByNamespace(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByNamespace("cluster1", "ns1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1/ns1")
forwards2 := cfg.GetForwardsByNamespace("cluster1", "ns2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster1/ns2")
forwards3 := cfg.GetForwardsByNamespace("cluster1", "non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent namespace")
}
func TestForward_SetContext(t *testing.T) {
fwd := Forward{
Resource: "pod/my-app",
+12 -7
View File
@@ -6,10 +6,15 @@ import (
)
const (
minPort = 1
maxPort = 65535
MinPort = 1
MaxPort = 65535
)
// IsValidPort returns true if the port number is within the valid range (1-65535).
func IsValidPort(port int) bool {
return port >= MinPort && port <= MaxPort
}
// ValidationError represents a configuration validation error with context.
type ValidationError struct {
Field string // The field that failed validation
@@ -84,7 +89,7 @@ func (v *Validator) validateStructure(cfg *Config) []ValidationError {
Field: fmt.Sprintf("contexts[%d].namespaces", i),
Message: fmt.Sprintf("Context '%s' must have at least one namespace", ctx.Name),
})
continue
// Don't continue - still validate other aspects of the context if any
}
for j, ns := range ctx.Namespaces {
@@ -130,17 +135,17 @@ func (v *Validator) validateForward(fwd *Forward) []ValidationError {
}
// Validate ports
if fwd.Port < minPort || fwd.Port > maxPort {
if fwd.Port < MinPort || fwd.Port > MaxPort {
errs = append(errs, ValidationError{
Field: "port",
Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), minPort, maxPort),
Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), MinPort, MaxPort),
})
}
if fwd.LocalPort < minPort || fwd.LocalPort > maxPort {
if fwd.LocalPort < MinPort || fwd.LocalPort > MaxPort {
errs = append(errs, ValidationError{
Field: "localPort",
Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), minPort, maxPort),
Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), MinPort, MaxPort),
})
}
+122 -14
View File
@@ -12,11 +12,6 @@ import (
"github.com/nvm/kportal/internal/logger"
)
const (
healthCheckInterval = 5 * time.Second
healthCheckTimeout = 2 * time.Second
)
// StatusUpdater is an interface for updating forward status
type StatusUpdater interface {
UpdateStatus(id string, status string)
@@ -34,12 +29,15 @@ type Manager struct {
portForwarder *k8s.PortForwarder
portChecker *PortChecker
healthChecker *healthcheck.Checker
watchdog *Watchdog
verbose bool
currentConfig *config.Config
statusUI StatusUpdater
}
// NewManager creates a new forward Manager.
// The health checker will be created with default settings and can be
// reconfigured via SetConfig().
func NewManager(verbose bool) (*Manager, error) {
clientPool, err := k8s.NewClientPool()
if err != nil {
@@ -49,8 +47,13 @@ func NewManager(verbose bool) (*Manager, error) {
resolver := k8s.NewResourceResolver(clientPool)
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
// Create health checker: check every 5 seconds with 2 second timeout
healthChecker := healthcheck.NewChecker(healthCheckInterval, healthCheckTimeout)
// Create health checker with defaults: check every 3 seconds with 2 second timeout
// Will be reconfigured when config is loaded
healthChecker := healthcheck.NewChecker(3*time.Second, 2*time.Second)
// Create watchdog with default settings: check every 30 seconds, 60 second hang threshold
// Will be reconfigured when config is loaded
watchdog := NewWatchdog(30*time.Second, 60*time.Second)
return &Manager{
workers: make(map[string]*ForwardWorker),
@@ -59,10 +62,56 @@ func NewManager(verbose bool) (*Manager, error) {
portForwarder: portForwarder,
portChecker: NewPortChecker(),
healthChecker: healthChecker,
watchdog: watchdog,
verbose: verbose,
}, nil
}
// configureHealthChecker creates a new health checker with settings from config
func (m *Manager) configureHealthChecker(cfg *config.Config) {
// Stop existing health checker
if m.healthChecker != nil {
m.healthChecker.Stop()
}
// Parse check method
methodStr := cfg.GetHealthCheckMethod()
var method healthcheck.CheckMethod
switch methodStr {
case "tcp-dial":
method = healthcheck.CheckMethodTCPDial
case "data-transfer":
method = healthcheck.CheckMethodDataTransfer
default:
method = healthcheck.CheckMethodDataTransfer
}
// Create new health checker with config settings
m.healthChecker = healthcheck.NewCheckerWithOptions(healthcheck.CheckerOptions{
Interval: cfg.GetHealthCheckIntervalOrDefault(),
Timeout: cfg.GetHealthCheckTimeoutOrDefault(),
Method: method,
MaxConnectionAge: cfg.GetMaxConnectionAge(),
MaxIdleTime: cfg.GetMaxIdleTime(),
})
// Configure TCP settings on port forwarder
tcpKeepalive := cfg.GetTCPKeepalive()
dialTimeout := cfg.GetDialTimeout()
m.portForwarder.SetTCPKeepalive(tcpKeepalive)
m.portForwarder.SetDialTimeout(dialTimeout)
logger.Info("Health checker and reliability configured", map[string]interface{}{
"interval": cfg.GetHealthCheckIntervalOrDefault().String(),
"timeout": cfg.GetHealthCheckTimeoutOrDefault().String(),
"method": methodStr,
"max_connection_age": cfg.GetMaxConnectionAge().String(),
"max_idle_time": cfg.GetMaxIdleTime().String(),
"tcp_keepalive": tcpKeepalive.String(),
"dial_timeout": dialTimeout.String(),
})
}
// SetStatusUI sets the status updater for the manager
func (m *Manager) SetStatusUI(ui StatusUpdater) {
m.statusUI = ui
@@ -76,6 +125,20 @@ func (m *Manager) Start(cfg *config.Config) error {
m.currentConfig = cfg
// Configure health checker with settings from config
m.configureHealthChecker(cfg)
// Start watchdog
watchdogPeriod := cfg.GetWatchdogPeriod()
m.watchdog.checkInterval = watchdogPeriod
m.watchdog.hangThreshold = watchdogPeriod * 2 // Hang threshold is 2x check interval
m.watchdog.Start()
logger.Info("Watchdog started", map[string]interface{}{
"check_interval": watchdogPeriod.String(),
"hang_threshold": (watchdogPeriod * 2).String(),
})
// Get all forwards from config
forwards := cfg.GetAllForwards()
@@ -119,8 +182,9 @@ func (m *Manager) Start(cfg *config.Config) error {
func (m *Manager) Stop() {
log.Printf("Stopping all port-forwards...")
// Stop health checker first
// Stop health checker and watchdog first
m.healthChecker.Stop()
m.watchdog.Stop()
m.workersMu.Lock()
workers := make([]*ForwardWorker, 0, len(m.workers))
@@ -273,21 +337,55 @@ func (m *Manager) startWorker(fwd config.Forward) error {
m.statusUI.AddForward(fwd.ID(), &fwd)
}
// Register with watchdog
m.watchdog.RegisterWorker(fwd.ID(), func(forwardID string) {
logger.Warn("Watchdog triggered reconnection for hung worker", map[string]interface{}{
"forward_id": forwardID,
})
// Find and trigger reconnect on hung worker
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("watchdog detected hung worker")
}
})
// Register with health checker
m.healthChecker.Register(fwd.ID(), fwd.LocalPort, func(forwardID string, status healthcheck.Status, errorMsg string) {
if m.statusUI != nil {
m.statusUI.UpdateStatus(forwardID, string(status))
// Send error separately if there is one
if status == healthcheck.StatusUnhealthy && errorMsg != "" {
if (status == healthcheck.StatusUnhealthy || status == healthcheck.StatusStale) && errorMsg != "" {
if ui, ok := m.statusUI.(interface{ SetError(id, msg string) }); ok {
ui.SetError(forwardID, errorMsg)
}
}
}
// Handle stale connections: trigger reconnection if retryOnStale is enabled
if status == healthcheck.StatusStale && m.currentConfig.GetRetryOnStale() {
logger.Info("Stale connection detected, triggering reconnection", map[string]interface{}{
"forward_id": forwardID,
"reason": errorMsg,
})
// Find and notify the worker to reconnect
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("stale connection")
}
}
})
// Create and start worker
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker)
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker, m.watchdog)
worker.Start()
// Store worker
@@ -298,6 +396,11 @@ func (m *Manager) startWorker(fwd config.Forward) error {
// stopWorker stops and removes a forward worker.
func (m *Manager) stopWorker(id string) error {
return m.stopWorkerInternal(id, true)
}
// stopWorkerInternal stops a worker with option to remove from UI or just update status
func (m *Manager) stopWorkerInternal(id string, removeFromUI bool) error {
m.workersMu.Lock()
worker, exists := m.workers[id]
if !exists {
@@ -307,12 +410,17 @@ func (m *Manager) stopWorker(id string) error {
delete(m.workers, id)
m.workersMu.Unlock()
// Unregister from health checker
// Unregister from health checker and watchdog
m.healthChecker.Unregister(id)
m.watchdog.UnregisterWorker(id)
// Notify UI to remove the forward
// Notify UI - either remove or update to disabled status
if m.statusUI != nil {
m.statusUI.Remove(id)
if removeFromUI {
m.statusUI.Remove(id)
} else {
m.statusUI.UpdateStatus(id, "Disabled")
}
}
// Stop the worker
@@ -363,7 +471,7 @@ func (m *Manager) getResourceForPort(forwards []config.Forward, port int) string
// DisableForward temporarily stops a forward by ID
func (m *Manager) DisableForward(id string) error {
if err := m.stopWorker(id); err != nil {
if err := m.stopWorkerInternal(id, false); err != nil {
return err
}
log.Printf("Disabled: %s", id)
+144 -41
View File
@@ -6,11 +6,20 @@ import (
"os/exec"
"runtime"
"strings"
"github.com/nvm/kportal/internal/logger"
)
const (
// maxPIDLength is the maximum length of a valid PID string (9 digits covers PIDs up to 999,999,999)
maxPIDLength = 9
// minNetstatFields is the minimum number of fields expected in netstat output
minNetstatFields = 5
)
// isValidPID validates that a PID string contains only digits
func isValidPID(pid string) bool {
if len(pid) == 0 || len(pid) > 9 {
if len(pid) == 0 || len(pid) > maxPIDLength {
return false
}
for _, c := range pid {
@@ -21,6 +30,72 @@ func isValidPID(pid string) bool {
return true
}
// processInfo holds information about a process using a port
type processInfo struct {
pid string
name string
isValid bool
}
// formatProcessInfo formats process information for display
func formatProcessInfo(info processInfo) string {
if !info.isValid {
return "unknown"
}
if info.name != "" {
return fmt.Sprintf("%s (PID %s)", info.name, info.pid)
}
return fmt.Sprintf("PID %s", info.pid)
}
// formatProcessList formats a list of processes into a human-readable string.
// Returns "unknown" if the list is empty.
func formatProcessList(processes []processInfo) string {
if len(processes) == 0 {
return "unknown"
}
if len(processes) == 1 {
return formatProcessInfo(processes[0])
}
// Multiple processes - format as comma-separated list
parts := make([]string, len(processes))
for i, p := range processes {
parts[i] = formatProcessInfo(p)
}
return strings.Join(parts, ", ")
}
// getProcessNameByPID retrieves the process name for a given PID on Unix systems
func getProcessNameByPID(pid string) string {
cmd := exec.Command("ps", "-p", pid, "-o", "comm=")
output, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(output))
}
// getProcessNameByPIDWindows retrieves the process name for a given PID on Windows
func getProcessNameByPIDWindows(pid string) string {
cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
output, err := cmd.Output()
if err != nil {
return ""
}
// Parse CSV output: "process.exe","1234","Console","1","12,345 K"
csvLine := strings.TrimSpace(string(output))
if csvLine == "" {
return ""
}
parts := strings.Split(csvLine, ",")
if len(parts) > 0 {
return strings.Trim(parts[0], "\"")
}
return ""
}
// PortConflict represents a local port that is already in use.
type PortConflict struct {
Port int // The conflicting port number
@@ -102,27 +177,55 @@ func (pc *PortChecker) getProcessUsingPortUnix(port int) string {
return "unknown"
}
// Get the first PID if multiple are returned
// Handle multiple PIDs (multiple processes on same port)
pids := strings.Split(pidStr, "\n")
pid := pids[0]
var validProcesses []processInfo
if !isValidPID(pid) {
return "unknown"
for _, pid := range pids {
pid = strings.TrimSpace(pid)
if pid == "" {
continue
}
if !isValidPID(pid) {
logger.Debug("Invalid PID format from lsof output", map[string]interface{}{
"port": port,
"raw_pid": pid,
})
continue
}
procName := getProcessNameByPID(pid)
validProcesses = append(validProcesses, processInfo{
pid: pid,
name: procName,
isValid: true,
})
}
// Get process name using ps
cmd = exec.Command("ps", "-p", pid, "-o", "comm=")
output, err = cmd.Output()
if err != nil {
return fmt.Sprintf("PID %s", pid)
return formatProcessList(validProcesses)
}
// isListeningState checks if a netstat line indicates a listening state.
// This handles both English and potentially other locales by checking for common patterns.
func isListeningState(line string, fields []string) bool {
upperLine := strings.ToUpper(line)
// Check for common listening state indicators across locales
// English: LISTENING, German: ABHÖREN, French: ÉCOUTE, etc.
// The most reliable check is the state field position (4th field, 0-indexed = 3)
// and that it's a TCP connection with 0.0.0.0:0 or *:* as foreign address
if len(fields) >= minNetstatFields {
state := strings.ToUpper(fields[3])
// Common listening state values across Windows locales
if state == "LISTENING" || state == "ABHÖREN" || state == "ÉCOUTE" ||
state == "ESCUCHANDO" || state == "ASCOLTO" || state == "NASŁUCHIWANIE" {
return true
}
}
procName := strings.TrimSpace(string(output))
if procName == "" {
return fmt.Sprintf("PID %s", pid)
}
return fmt.Sprintf("%s (PID %s)", procName, pid)
// Fallback: check if line contains LISTENING (most common case)
return strings.Contains(upperLine, "LISTENING")
}
// getProcessUsingPortWindows uses netstat to find the process using a port on Windows.
@@ -138,6 +241,8 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
var validProcesses []processInfo
for _, line := range lines {
if !strings.Contains(line, portStr) {
continue
@@ -146,44 +251,42 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
// Parse the line to extract PID
// Format: TCP 0.0.0.0:8080 0.0.0.0:0 LISTENING 1234
fields := strings.Fields(line)
if len(fields) < 5 {
if len(fields) < minNetstatFields {
continue
}
// Check if this is a LISTENING state
if !strings.Contains(strings.ToUpper(line), "LISTENING") {
// Check if this is a LISTENING state (locale-aware)
if !isListeningState(line, fields) {
continue
}
// Verify the local address field actually contains our port
// (avoid matching port in foreign address)
localAddr := fields[1]
if !strings.HasSuffix(localAddr, portStr) {
continue
}
pid := fields[len(fields)-1]
if !isValidPID(pid) {
return "unknown"
logger.Debug("Invalid PID format from netstat output", map[string]interface{}{
"port": port,
"raw_pid": pid,
"line": line,
})
continue
}
// Get process name using tasklist
cmd = exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
output, err = cmd.Output()
if err != nil {
return fmt.Sprintf("PID %s", pid)
}
// Parse CSV output: "process.exe","1234","Console","1","12,345 K"
csvLine := strings.TrimSpace(string(output))
if csvLine == "" {
return fmt.Sprintf("PID %s", pid)
}
parts := strings.Split(csvLine, ",")
if len(parts) > 0 {
procName := strings.Trim(parts[0], "\"")
return fmt.Sprintf("%s (PID %s)", procName, pid)
}
return fmt.Sprintf("PID %s", pid)
procName := getProcessNameByPIDWindows(pid)
validProcesses = append(validProcesses, processInfo{
pid: pid,
name: procName,
isValid: true,
})
}
return "unknown"
return formatProcessList(validProcesses)
}
// FormatConflicts formats port conflicts into a human-readable error message.
+174
View File
@@ -0,0 +1,174 @@
package forward
import (
"context"
"sync"
"time"
"github.com/nvm/kportal/internal/logger"
)
// Watchdog monitors worker goroutines to detect hung workers
type Watchdog struct {
mu sync.RWMutex
workers map[string]*workerState // key: forward ID
checkInterval time.Duration
hangThreshold time.Duration // How long without heartbeat before considered hung
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// workerState tracks the health of a single worker
type workerState struct {
forwardID string
lastHeartbeat time.Time
heartbeatCount uint64
isHung bool
onHungCallback func(forwardID string)
}
// NewWatchdog creates a new goroutine watchdog
func NewWatchdog(checkInterval, hangThreshold time.Duration) *Watchdog {
ctx, cancel := context.WithCancel(context.Background())
return &Watchdog{
workers: make(map[string]*workerState),
checkInterval: checkInterval,
hangThreshold: hangThreshold,
ctx: ctx,
cancel: cancel,
}
}
// Start begins the watchdog monitoring loop
func (w *Watchdog) Start() {
w.wg.Add(1)
go w.monitorLoop()
}
// Stop stops the watchdog
func (w *Watchdog) Stop() {
w.cancel()
w.wg.Wait()
}
// RegisterWorker adds a worker to monitor
func (w *Watchdog) RegisterWorker(forwardID string, onHungCallback func(string)) {
w.mu.Lock()
defer w.mu.Unlock()
w.workers[forwardID] = &workerState{
forwardID: forwardID,
lastHeartbeat: time.Now(),
heartbeatCount: 0,
isHung: false,
onHungCallback: onHungCallback,
}
logger.Debug("Watchdog registered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// UnregisterWorker removes a worker from monitoring
func (w *Watchdog) UnregisterWorker(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.workers, forwardID)
logger.Debug("Watchdog unregistered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// Heartbeat records that a worker is alive and processing
// Workers should call this periodically (e.g., in their main loop)
func (w *Watchdog) Heartbeat(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
if state, exists := w.workers[forwardID]; exists {
state.lastHeartbeat = time.Now()
state.heartbeatCount++
state.isHung = false
}
}
// GetWorkerState returns the current state of a worker (for testing)
func (w *Watchdog) GetWorkerState(forwardID string) (lastHeartbeat time.Time, count uint64, exists bool) {
w.mu.RLock()
defer w.mu.RUnlock()
if state, ok := w.workers[forwardID]; ok {
return state.lastHeartbeat, state.heartbeatCount, true
}
return time.Time{}, 0, false
}
// monitorLoop periodically checks all workers
func (w *Watchdog) monitorLoop() {
defer w.wg.Done()
ticker := time.NewTicker(w.checkInterval)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
w.checkWorkers()
}
}
}
// hungWorkerInfo stores information about a hung worker for deferred callback execution
type hungWorkerInfo struct {
forwardID string
callback func(string)
}
// checkWorkers checks all registered workers for hung state
func (w *Watchdog) checkWorkers() {
// Collect hung workers while holding the lock
var hungWorkers []hungWorkerInfo
w.mu.Lock()
now := time.Now()
for forwardID, state := range w.workers {
timeSinceHeartbeat := now.Sub(state.lastHeartbeat)
// Check if worker is hung
if timeSinceHeartbeat > w.hangThreshold {
if !state.isHung {
// First time detecting hung state
state.isHung = true
logger.Warn("Watchdog detected hung worker", map[string]interface{}{
"forward_id": forwardID,
"time_since_heartbeat": timeSinceHeartbeat.String(),
"hang_threshold": w.hangThreshold.String(),
"heartbeat_count": state.heartbeatCount,
})
// Collect callback for deferred execution outside the lock
if state.onHungCallback != nil {
hungWorkers = append(hungWorkers, hungWorkerInfo{
forwardID: forwardID,
callback: state.onHungCallback,
})
}
}
}
}
w.mu.Unlock()
// Execute callbacks outside the lock to prevent deadlocks and ensure
// consistent state during callback execution. Callbacks are idempotent
// (they trigger reconnection via channels), so concurrent state changes
// between detection and callback execution are safe.
for _, hw := range hungWorkers {
hw.callback(hw.forwardID)
}
}
+310
View File
@@ -0,0 +1,310 @@
package forward
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// WatchdogTestSuite contains tests for the watchdog
type WatchdogTestSuite struct {
suite.Suite
watchdog *Watchdog
}
func TestWatchdogSuite(t *testing.T) {
suite.Run(t, new(WatchdogTestSuite))
}
func (s *WatchdogTestSuite) SetupTest() {
// Create watchdog with fast intervals for testing
s.watchdog = NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
s.watchdog.Start()
}
func (s *WatchdogTestSuite) TearDownTest() {
if s.watchdog != nil {
s.watchdog.Stop()
}
}
// TestRegisterUnregister tests basic registration and unregistration
func (s *WatchdogTestSuite) TestRegisterUnregister() {
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
// Register worker
s.watchdog.RegisterWorker("test-forward", callback)
// Verify worker is registered
_, _, exists := s.watchdog.GetWorkerState("test-forward")
assert.True(s.T(), exists, "Worker should be registered")
// Unregister worker
s.watchdog.UnregisterWorker("test-forward")
// Verify worker is unregistered
_, _, exists = s.watchdog.GetWorkerState("test-forward")
assert.False(s.T(), exists, "Worker should be unregistered")
assert.False(s.T(), callbackCalled, "Callback should not have been called")
}
// TestHeartbeat tests that heartbeats update worker state
func (s *WatchdogTestSuite) TestHeartbeat() {
s.watchdog.RegisterWorker("test-forward", nil)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat1, count1, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(1), count1)
// Wait a bit
time.Sleep(50 * time.Millisecond)
// Send another heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat2, count2, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(2), count2)
assert.True(s.T(), lastHeartbeat2.After(lastHeartbeat1), "Second heartbeat should be after first")
}
// TestHungWorkerDetection tests that hung workers are detected
func (s *WatchdogTestSuite) TestHungWorkerDetection() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for worker to be considered hung (300ms threshold + 100ms check interval)
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
// TestHealthyWorkerNotDetectedAsHung tests that workers sending heartbeats are not considered hung
func (s *WatchdogTestSuite) TestHealthyWorkerNotDetectedAsHung() {
callbackCalled := false
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send periodic heartbeats (faster than hang threshold)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
done := make(chan bool)
go func() {
for i := 0; i < 10; i++ {
<-ticker.C
s.watchdog.Heartbeat("test-forward")
}
done <- true
}()
// Wait for all heartbeats to complete
<-done
// Check that callback was not called
mu.Lock()
assert.False(s.T(), callbackCalled, "Callback should not be called for healthy worker")
mu.Unlock()
}
// TestMultipleWorkers tests monitoring multiple workers simultaneously
func (s *WatchdogTestSuite) TestMultipleWorkers() {
callbacks := make(map[string]int)
var mu sync.Mutex
makeCallback := func(id string) func(string) {
return func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbacks[id]++
}
}
// Register multiple workers
s.watchdog.RegisterWorker("worker-1", makeCallback("worker-1"))
s.watchdog.RegisterWorker("worker-2", makeCallback("worker-2"))
s.watchdog.RegisterWorker("worker-3", makeCallback("worker-3"))
// worker-1: Keep sending heartbeats (healthy)
ticker1 := time.NewTicker(50 * time.Millisecond)
defer ticker1.Stop()
go func() {
for i := 0; i < 10; i++ {
<-ticker1.C
s.watchdog.Heartbeat("worker-1")
}
}()
// worker-2: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-2")
// worker-3: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-3")
// Wait for hung workers to be detected
time.Sleep(600 * time.Millisecond)
// Check results
mu.Lock()
defer mu.Unlock()
assert.Equal(s.T(), 0, callbacks["worker-1"], "worker-1 should not trigger callback (healthy)")
assert.Greater(s.T(), callbacks["worker-2"], 0, "worker-2 should trigger callback (hung)")
assert.Greater(s.T(), callbacks["worker-3"], 0, "worker-3 should trigger callback (hung)")
}
// TestCallbackOnlyOnFirstDetection tests that callback is only called once when hung is first detected
func (s *WatchdogTestSuite) TestCallbackOnlyOnFirstDetection() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for multiple check cycles
time.Sleep(1 * time.Second)
// Check that callback was only called once
mu.Lock()
assert.Equal(s.T(), 1, callbackCount, "Callback should only be called once")
mu.Unlock()
}
// TestHeartbeatResetsHungState tests that sending heartbeat after hung detection resets state
func (s *WatchdogTestSuite) TestHeartbeatResetsHungState() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for hung detection
time.Sleep(500 * time.Millisecond)
mu.Lock()
firstCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 1, firstCount, "First hung detection should trigger callback")
// Send heartbeat to reset hung state
s.watchdog.Heartbeat("test-forward")
// Wait for worker to become hung again
time.Sleep(500 * time.Millisecond)
mu.Lock()
secondCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 2, secondCount, "Second hung detection should trigger callback again")
}
// TestConcurrentOperations tests thread safety
func (s *WatchdogTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numWorkers := 10
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := string(rune('a' + id))
s.watchdog.RegisterWorker(forwardID, nil)
for j := 0; j < 10; j++ {
s.watchdog.Heartbeat(forwardID)
time.Sleep(10 * time.Millisecond)
}
s.watchdog.UnregisterWorker(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestStopWatchdog tests that stopping watchdog cleans up properly
func TestStopWatchdog(t *testing.T) {
watchdog := NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
watchdog.Start()
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
watchdog.RegisterWorker("test-forward", callback)
watchdog.Heartbeat("test-forward")
// Stop watchdog before hang detection
time.Sleep(100 * time.Millisecond)
watchdog.Stop()
// Wait to ensure no more callbacks after stop
time.Sleep(500 * time.Millisecond)
assert.False(t, callbackCalled, "Callback should not be called after watchdog is stopped")
}
// TestWatchdogWithZeroHeartbeats tests detecting hung worker that never sends heartbeats
func (s *WatchdogTestSuite) TestWatchdogWithZeroHeartbeats() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
// Register worker but never send heartbeat
s.watchdog.RegisterWorker("test-forward", callback)
// Wait for hung detection
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
+80 -13
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"log"
"sync"
"time"
"github.com/nvm/kportal/internal/config"
@@ -20,21 +21,25 @@ const (
// ForwardWorker manages a single port-forward connection with automatic retry.
type ForwardWorker struct {
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
startTime time.Time // Track when the worker started
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
reconnectChan chan string // Channel to trigger reconnection
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
watchdog *Watchdog
startTime time.Time // Track when the worker started
forwardCancel context.CancelFunc // Cancel function for current forward attempt
forwardCancelMu sync.Mutex // Protects forwardCancel
}
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker) *ForwardWorker {
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker, watchdog *Watchdog) *ForwardWorker {
ctx, cancel := context.WithCancel(context.Background())
return &ForwardWorker{
@@ -44,13 +49,32 @@ func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verb
cancel: cancel,
stopChan: make(chan struct{}),
doneChan: make(chan struct{}),
reconnectChan: make(chan string, 1), // Buffered to avoid blocking
verbose: verbose,
statusUI: statusUI,
healthChecker: healthChecker,
watchdog: watchdog,
startTime: time.Now(),
}
}
// TriggerReconnect triggers a reconnection (e.g., due to stale connection)
func (w *ForwardWorker) TriggerReconnect(reason string) {
// Cancel current forward if running
w.forwardCancelMu.Lock()
if w.forwardCancel != nil {
w.forwardCancel()
}
w.forwardCancelMu.Unlock()
// Send reconnect signal (non-blocking)
select {
case w.reconnectChan <- reason:
default:
// Channel already has pending reconnect
}
}
// Start begins the port-forward worker in a goroutine.
// The worker will continuously retry on failures with exponential backoff.
func (w *ForwardWorker) Start() {
@@ -68,6 +92,12 @@ func (w *ForwardWorker) Stop() {
func (w *ForwardWorker) run() {
defer close(w.doneChan)
// Start heartbeat goroutine to continuously send heartbeats to watchdog
// This prevents false "hung worker" detection when connections are long-lived
if w.watchdog != nil {
go w.heartbeatLoop()
}
backoff := retry.NewBackoff()
for {
@@ -173,6 +203,26 @@ func (w *ForwardWorker) run() {
}
}
// heartbeatLoop sends periodic heartbeats to the watchdog to prove the worker is alive
// This runs in a separate goroutine and continues throughout the worker's lifetime
func (w *ForwardWorker) heartbeatLoop() {
// Send heartbeats every 15 seconds (well within typical 60s watchdog timeout)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
// Send immediate heartbeat
w.watchdog.Heartbeat(w.forward.ID())
for {
select {
case <-ticker.C:
w.watchdog.Heartbeat(w.forward.ID())
case <-w.ctx.Done():
return
}
}
}
// establishForward establishes a port-forward connection.
// This blocks until the connection is closed or an error occurs.
func (w *ForwardWorker) establishForward(podName string) error {
@@ -184,11 +234,24 @@ func (w *ForwardWorker) establishForward(podName string) error {
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
defer forwardCancel()
// Start a goroutine to monitor for stop signal
// Store cancel function so TriggerReconnect can use it
w.forwardCancelMu.Lock()
w.forwardCancel = forwardCancel
w.forwardCancelMu.Unlock()
defer func() {
w.forwardCancelMu.Lock()
w.forwardCancel = nil
w.forwardCancelMu.Unlock()
}()
// Start a goroutine to monitor for stop signal and reconnect triggers
go func() {
select {
case <-w.stopChan:
close(stopChan)
case <-w.reconnectChan:
close(stopChan)
case <-forwardCtx.Done():
close(stopChan)
}
@@ -230,6 +293,10 @@ func (w *ForwardWorker) establishForward(podName string) error {
if w.verbose {
log.Printf("[%s] Port-forward connection established", w.forward.ID())
}
// Mark connection as established in health checker
if w.healthChecker != nil {
w.healthChecker.MarkConnected(w.forward.ID())
}
case err := <-errChan:
return fmt.Errorf("failed to establish forward: %w", err)
case <-w.ctx.Done():
+219 -70
View File
@@ -3,13 +3,17 @@ package healthcheck
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/nvm/kportal/internal/config"
)
const (
startupGracePeriod = 10 * time.Second
dataTransferSize = 1024 // bytes to read in data transfer test
)
// Status represents the health status of a port forward
@@ -20,15 +24,26 @@ const (
StatusUnhealthy Status = "Error"
StatusStarting Status = "Starting"
StatusReconnect Status = "Reconnecting"
StatusStale Status = "Stale" // Connection is old or idle
)
// CheckMethod represents the health check method
type CheckMethod string
const (
CheckMethodTCPDial CheckMethod = "tcp-dial" // Simple TCP connection test
CheckMethodDataTransfer CheckMethod = "data-transfer" // Try to read data from connection
)
// PortHealth represents the health status of a single port
type PortHealth struct {
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
ConnectionTime time.Time // When current connection was established
LastActivity time.Time // Last time data was transferred
}
// StatusCallback is called when a port's health status changes
@@ -36,26 +51,52 @@ type StatusCallback func(forwardID string, status Status, errorMsg string)
// Checker performs periodic health checks on local ports
type Checker struct {
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
method CheckMethod
maxConnectionAge time.Duration
maxIdleTime time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewChecker creates a new health checker
// CheckerOptions configures the health checker
type CheckerOptions struct {
Interval time.Duration
Timeout time.Duration
Method CheckMethod
MaxConnectionAge time.Duration
MaxIdleTime time.Duration
}
// NewChecker creates a new health checker with default options
func NewChecker(interval, timeout time.Duration) *Checker {
return NewCheckerWithOptions(CheckerOptions{
Interval: interval,
Timeout: timeout,
Method: CheckMethodDataTransfer,
MaxConnectionAge: config.DefaultMaxConnectionAge,
MaxIdleTime: config.DefaultMaxIdleTime,
})
}
// NewCheckerWithOptions creates a new health checker with custom options
func NewCheckerWithOptions(opts CheckerOptions) *Checker {
ctx, cancel := context.WithCancel(context.Background())
return &Checker{
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: interval,
timeout: timeout,
ctx: ctx,
cancel: cancel,
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: opts.Interval,
timeout: opts.Timeout,
method: opts.Method,
maxConnectionAge: opts.MaxConnectionAge,
maxIdleTime: opts.MaxIdleTime,
ctx: ctx,
cancel: cancel,
}
}
@@ -64,11 +105,14 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
c.ports[forwardID] = &PortHealth{
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: time.Now(),
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: now,
ConnectionTime: now,
LastActivity: now,
}
c.callbacks[forwardID] = callback
@@ -77,6 +121,28 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
go c.checkLoop(forwardID)
}
// MarkConnected marks a forward as having established a new connection
func (c *Checker) MarkConnected(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
now := time.Now()
health.ConnectionTime = now
health.LastActivity = now
}
}
// RecordActivity records data transfer activity for a forward
func (c *Checker) RecordActivity(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
health.LastActivity = time.Now()
}
}
// Unregister removes a port from monitoring
func (c *Checker) Unregister(forwardID string) {
c.mu.Lock()
@@ -86,44 +152,34 @@ func (c *Checker) Unregister(forwardID string) {
delete(c.callbacks, forwardID)
}
// MarkReconnecting marks a forward as reconnecting (called by worker)
func (c *Checker) MarkReconnecting(forwardID string) {
// markStatus is a helper to set a forward's status and notify on change.
func (c *Checker) markStatus(forwardID string, newStatus Status) {
c.mu.Lock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusReconnect
health.LastCheck = time.Now()
health, exists := c.ports[forwardID]
if !exists {
c.mu.Unlock()
if oldStatus != StatusReconnect {
c.notifyStatusChange(forwardID, StatusReconnect, "")
}
return
}
oldStatus := health.Status
health.Status = newStatus
health.LastCheck = time.Now()
c.mu.Unlock()
if oldStatus != newStatus {
c.notifyStatusChange(forwardID, newStatus, "")
}
}
// MarkReconnecting marks a forward as reconnecting (called by worker)
func (c *Checker) MarkReconnecting(forwardID string) {
c.markStatus(forwardID, StatusReconnect)
}
// MarkStarting marks a forward as starting (called by worker)
func (c *Checker) MarkStarting(forwardID string) {
c.mu.Lock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusStarting
health.LastCheck = time.Now()
c.mu.Unlock()
if oldStatus != StatusStarting {
c.notifyStatusChange(forwardID, StatusStarting, "")
}
return
}
c.mu.Unlock()
c.markStatus(forwardID, StatusStarting)
}
// GetStatus returns the current health status of a forward
@@ -137,6 +193,17 @@ func (c *Checker) GetStatus(forwardID string) (Status, bool) {
return StatusUnhealthy, false
}
// GetLastCheckTime returns the last health check time for a forward
func (c *Checker) GetLastCheckTime(forwardID string) (time.Time, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if health, exists := c.ports[forwardID]; exists {
return health.LastCheck, true
}
return time.Time{}, false
}
// GetAllErrors returns all forwards with errors and their error messages
func (c *Checker) GetAllErrors() map[string]string {
c.mu.RLock()
@@ -197,38 +264,64 @@ func (c *Checker) checkPort(forwardID string) {
port := health.Port
oldStatus := health.Status
registeredAt := health.RegisteredAt
connectionTime := health.ConnectionTime
lastActivity := health.LastActivity
c.mu.RUnlock()
// Attempt to connect to the local port
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
now := time.Now()
newStatus := StatusHealthy
errorMsg := ""
if err != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := time.Since(registeredAt)
if timeSinceStart < startupGracePeriod {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = err.Error()
// Check for stale connections based on age or idle time
connectionAge := now.Sub(connectionTime)
idleTime := now.Sub(lastActivity)
// Only enforce max connection age if the connection is ALSO idle
// This prevents interrupting active transfers (e.g., database dumps)
if c.maxConnectionAge > 0 && connectionAge > c.maxConnectionAge && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("connection age %v exceeds max %v (and idle for %v)",
connectionAge.Round(time.Second), c.maxConnectionAge, idleTime.Round(time.Second))
} else if c.maxIdleTime > 0 && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("idle time %v exceeds max %v", idleTime.Round(time.Second), c.maxIdleTime)
} else {
conn.Close()
// Perform connectivity check
var checkErr error
switch c.method {
case CheckMethodDataTransfer:
checkErr = c.checkDataTransfer(port)
case CheckMethodTCPDial:
checkErr = c.checkTCPDial(port)
default:
checkErr = c.checkTCPDial(port)
}
if checkErr != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := now.Sub(registeredAt)
if timeSinceStart < startupGracePeriod {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = checkErr.Error()
}
}
// Update health status
c.mu.Lock()
if health, exists := c.ports[forwardID]; exists {
health.Status = newStatus
health.LastCheck = time.Now()
health.LastCheck = now
health.ErrorMessage = errorMsg
// Successful health check indicates connection is active
// This prevents false positives where healthy connections are marked as idle
if newStatus == StatusHealthy {
health.LastActivity = now
}
}
c.mu.Unlock()
@@ -238,6 +331,62 @@ func (c *Checker) checkPort(forwardID string) {
}
}
// checkTCPDial performs a simple TCP dial test
func (c *Checker) checkTCPDial(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
conn.Close()
return nil
}
// checkDataTransfer attempts to read data from the connection to verify tunnel health
func (c *Checker) checkDataTransfer(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
defer conn.Close()
// Set a short read deadline to detect hung connections
// We don't expect to receive data, but we want to verify the connection isn't hung
conn.SetReadDeadline(time.Now().Add(c.timeout))
// Try to read a small amount of data
// Most servers will either:
// 1. Send a banner (SSH, FTP, etc) - we'll read it successfully
// 2. Wait for client to send first (HTTP, postgres) - we'll timeout (which is OK)
// 3. Hung/stale connection - will timeout with different error
buf := make([]byte, dataTransferSize)
_, err = conn.Read(buf)
// We expect either:
// - No error (banner received)
// - EOF (connection closed by server after connect)
// - Timeout (server waiting for client)
// All of these indicate the tunnel is working
if err == nil || err == io.EOF {
return nil
}
// Timeout is acceptable - server is waiting for us to send data first
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil
}
// Other errors indicate a problem
return fmt.Errorf("data transfer check failed: %w", err)
}
// notifyStatusChange calls the callback for a forward
func (c *Checker) notifyStatusChange(forwardID string, status Status, errorMsg string) {
c.mu.RLock()
+551
View File
@@ -0,0 +1,551 @@
package healthcheck
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// HealthCheckTestSuite contains tests for the health checker
type HealthCheckTestSuite struct {
suite.Suite
checker *Checker
listener net.Listener
port int
}
func TestHealthCheckSuite(t *testing.T) {
suite.Run(t, new(HealthCheckTestSuite))
}
func (s *HealthCheckTestSuite) SetupTest() {
// Create a test listener on a random port
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
s.listener = ln
s.port = ln.Addr().(*net.TCPAddr).Port
// Create checker with fast intervals for testing
s.checker = NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 500 * time.Millisecond,
MaxIdleTime: 300 * time.Millisecond,
})
}
func (s *HealthCheckTestSuite) TearDownTest() {
if s.checker != nil {
s.checker.Stop()
}
if s.listener != nil {
s.listener.Close()
}
}
// TestRegisterAndUnregister tests basic registration and unregistration
func (s *HealthCheckTestSuite) TestRegisterAndUnregister() {
callbackCalled := false
var callbackStatus Status
var mu sync.Mutex
callback := func(forwardID string, status Status, errorMsg string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
callbackStatus = status
}
// Register port
s.checker.Register("test-forward", s.port, callback)
// Wait for health check to run
time.Sleep(200 * time.Millisecond)
// Verify callback was called with healthy status
mu.Lock()
assert.True(s.T(), callbackCalled, "Callback should have been called")
assert.Equal(s.T(), StatusHealthy, callbackStatus)
mu.Unlock()
// Unregister
s.checker.Unregister("test-forward")
// Verify port is no longer monitored
status, exists := s.checker.GetStatus("test-forward")
assert.False(s.T(), exists, "Port should no longer exist after unregister")
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestTCPDialMethod tests the TCP dial health check method
func (s *HealthCheckTestSuite) TestTCPDialMethod() {
tests := []struct {
name string
setupPort bool
expectedStatus Status
description string
}{
{
name: "port available - healthy",
setupPort: true,
expectedStatus: StatusHealthy,
description: "When port is listening, status should be healthy",
},
{
name: "port unavailable - unhealthy",
setupPort: false,
expectedStatus: StatusUnhealthy,
description: "When port is not listening, status should be unhealthy",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
if tt.setupPort {
// Use the existing listener
testPort = s.port
} else {
// Use a port that's not listening
testPort = 54321 // Likely unused port
}
// Create a new checker for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if !tt.setupPort {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(200 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status, tt.description)
if testListener != nil {
testListener.Close()
}
})
}
}
// TestDataTransferMethod tests the data transfer health check method
func (s *HealthCheckTestSuite) TestDataTransferMethod() {
tests := []struct {
name string
serverBehavior string // "banner", "silent", "close", "none"
expectedStatus Status
}{
{
name: "server sends banner - healthy",
serverBehavior: "banner",
expectedStatus: StatusHealthy,
},
{
name: "server waits silently - healthy (timeout OK)",
serverBehavior: "silent",
expectedStatus: StatusHealthy,
},
{
name: "server closes connection - healthy (EOF OK)",
serverBehavior: "close",
expectedStatus: StatusHealthy,
},
{
name: "no server listening - unhealthy",
serverBehavior: "none",
expectedStatus: StatusUnhealthy,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
var err error
if tt.serverBehavior != "none" {
// Start test server
testListener, err = net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
testPort = testListener.Addr().(*net.TCPAddr).Port
// Handle connections based on behavior
go func() {
for {
conn, err := testListener.Accept()
if err != nil {
return
}
switch tt.serverBehavior {
case "banner":
conn.Write([]byte("220 Welcome\r\n"))
time.Sleep(50 * time.Millisecond)
conn.Close()
case "close":
conn.Close()
case "silent":
// Just keep connection open
time.Sleep(200 * time.Millisecond)
conn.Close()
}
}
}()
defer testListener.Close()
} else {
testPort = 54322 // Unused port
}
// Create checker with data transfer method
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodDataTransfer,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if tt.serverBehavior == "none" {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(300 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status)
})
}
}
// TestConnectionAgeDetection tests max connection age detection
func (s *HealthCheckTestSuite) TestConnectionAgeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max connection age
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 150 * time.Millisecond, // Very short for testing
MaxIdleTime: 0, // Disable idle detection
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait for initial healthy status
var gotHealthy, gotStale bool
timeout := time.After(1 * time.Second)
for {
select {
case status := <-statusChanges:
if status == StatusHealthy || status == StatusStarting {
gotHealthy = true
}
if status == StatusStale {
gotStale = true
}
if gotHealthy && gotStale {
return // Test passed
}
case <-timeout:
s.T().Fatalf("Expected StatusStale after max connection age exceeded. gotHealthy=%v, gotStale=%v",
gotHealthy, gotStale)
}
}
}
// TestIdleTimeDetection tests that connections with passing health checks are NOT marked as stale
// This verifies that successful health checks update LastActivity, preventing false idle detection
func (s *HealthCheckTestSuite) TestIdleTimeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max idle time
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable age detection
MaxIdleTime: 150 * time.Millisecond, // Very short for testing
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait long enough that idle time WOULD be exceeded if health checks didn't update LastActivity
time.Sleep(500 * time.Millisecond)
// Verify connection is still healthy, not stale
// This proves that successful health checks are updating LastActivity
status, exists := checker.GetStatus("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), StatusHealthy, status, "Connection with passing health checks should NOT be marked as stale")
// Verify we never received a StatusStale callback
select {
case status := <-statusChanges:
if status == StatusStale {
s.T().Fatal("Connection should NOT be marked as stale when health checks are passing")
}
default:
// No stale status - this is correct
}
}
// TestMarkConnected tests that MarkConnected resets connection time
func (s *HealthCheckTestSuite) TestMarkConnected() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 200 * time.Millisecond,
MaxIdleTime: 0,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Mark as reconnected (resets connection time)
checker.MarkConnected("test-forward")
// Wait for connection age to exceed (relative to first connection time)
time.Sleep(200 * time.Millisecond)
// Check status - should still be healthy because we reset connection time
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Note: Might be StatusStale by now, but the key is that MarkConnected delayed it
// This is a timing-sensitive test, so we just verify the functionality exists
_ = status
}
// TestRecordActivity tests that RecordActivity resets idle time
func (s *HealthCheckTestSuite) TestRecordActivity() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 200 * time.Millisecond,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Periodically record activity to prevent idle detection
ticker := time.NewTicker(80 * time.Millisecond)
defer ticker.Stop()
go func() {
for i := 0; i < 5; i++ {
<-ticker.C
checker.RecordActivity("test-forward")
}
}()
// Wait longer than idle timeout
time.Sleep(500 * time.Millisecond)
// Should still be healthy due to activity
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// May transition to stale eventually, but activity recording should have delayed it
_ = status
}
// TestMarkReconnecting tests the MarkReconnecting functionality
func (s *HealthCheckTestSuite) TestMarkReconnecting() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
s.checker.Register("test-forward", s.port, callback)
// Wait for initial status
time.Sleep(150 * time.Millisecond)
// Mark as reconnecting
s.checker.MarkReconnecting("test-forward")
// Should receive reconnecting status
timeout := time.After(500 * time.Millisecond)
gotReconnect := false
for !gotReconnect {
select {
case status := <-statusChanges:
if status == StatusReconnect {
gotReconnect = true
}
case <-timeout:
s.T().Fatal("Expected StatusReconnect")
}
}
}
// TestStartingGracePeriod tests that errors during grace period show as "Starting"
func (s *HealthCheckTestSuite) TestStartingGracePeriod() {
// Use a port that's not listening
unavailablePort := 54323
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register without callback - we'll check status directly
checker.Register("test-forward", unavailablePort, nil)
// Immediately check status - should be Starting or not yet checked
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Initially should be Starting
assert.Equal(s.T(), StatusStarting, status)
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 200*time.Millisecond)
// Now should be Unhealthy
status, exists = checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestGetAllErrors tests retrieving all error messages
func (s *HealthCheckTestSuite) TestGetAllErrors() {
// Create a new checker with faster intervals for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register multiple forwards
checker.Register("forward1", s.port, nil)
checker.Register("forward2", 54324, nil) // Unavailable port
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 300*time.Millisecond)
errors := checker.GetAllErrors()
// forward2 should have an error
_, hasError := errors["forward2"]
assert.True(s.T(), hasError, "forward2 should have an error")
// forward1 should not have an error
_, hasError = errors["forward1"]
assert.False(s.T(), hasError, "forward1 should not have an error")
}
// TestConcurrentOperations tests thread safety
func (s *HealthCheckTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := fmt.Sprintf("forward-%d", id)
s.checker.Register(forwardID, s.port, nil)
time.Sleep(50 * time.Millisecond)
s.checker.MarkConnected(forwardID)
s.checker.RecordActivity(forwardID)
status, _ := s.checker.GetStatus(forwardID)
_ = status
s.checker.Unregister(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestDefaultOptions tests that NewChecker uses sensible defaults
func TestDefaultOptions(t *testing.T) {
checker := NewChecker(5*time.Second, 2*time.Second)
defer checker.Stop()
assert.Equal(t, 5*time.Second, checker.interval)
assert.Equal(t, 2*time.Second, checker.timeout)
assert.Equal(t, CheckMethodDataTransfer, checker.method)
assert.Equal(t, 25*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 10*time.Minute, checker.maxIdleTime)
}
// TestCustomOptions tests NewCheckerWithOptions
func TestCustomOptions(t *testing.T) {
opts := CheckerOptions{
Interval: 1 * time.Second,
Timeout: 500 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 5 * time.Minute,
MaxIdleTime: 2 * time.Minute,
}
checker := NewCheckerWithOptions(opts)
defer checker.Stop()
assert.Equal(t, 1*time.Second, checker.interval)
assert.Equal(t, 500*time.Millisecond, checker.timeout)
assert.Equal(t, CheckMethodTCPDial, checker.method)
assert.Equal(t, 5*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 2*time.Minute, checker.maxIdleTime)
}
+2 -21
View File
@@ -5,7 +5,6 @@ import (
"fmt"
"net"
"sort"
"strconv"
"strings"
corev1 "k8s.io/api/core/v1"
@@ -293,29 +292,11 @@ func CheckPortAvailability(port int) (bool, string, error) {
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
// Port is in use
// Try to get process info (best-effort)
processInfo := "unknown process"
// Note: Getting process info requires platform-specific code
// For now, just return a generic message
return false, processInfo, nil
// Port is in use - return error details
return false, err.Error(), nil
}
// Port is available, close the listener
listener.Close()
return true, "", nil
}
// ValidatePort checks if a port number is valid.
func ValidatePort(portStr string) (int, error) {
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, fmt.Errorf("invalid port number: %s", portStr)
}
if port < 1 || port > 65535 {
return 0, fmt.Errorf("port must be between 1 and 65535, got %d", port)
}
return port, nil
}
+42 -5
View File
@@ -4,9 +4,13 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/nvm/kportal/internal/config"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -17,18 +21,32 @@ import (
// PortForwarder handles Kubernetes port-forwarding operations.
type PortForwarder struct {
clientPool *ClientPool
resolver *ResourceResolver
clientPool *ClientPool
resolver *ResourceResolver
tcpKeepalive time.Duration // TCP keepalive interval
dialTimeout time.Duration // Connection dial timeout
}
// NewPortForwarder creates a new PortForwarder instance.
// NewPortForwarder creates a new PortForwarder instance with default settings.
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
return &PortForwarder{
clientPool: clientPool,
resolver: resolver,
clientPool: clientPool,
resolver: resolver,
tcpKeepalive: config.DefaultTCPKeepalive,
dialTimeout: config.DefaultDialTimeout,
}
}
// SetTCPKeepalive configures the TCP keepalive interval for new connections.
func (pf *PortForwarder) SetTCPKeepalive(keepalive time.Duration) {
pf.tcpKeepalive = keepalive
}
// SetDialTimeout configures the connection dial timeout.
func (pf *PortForwarder) SetDialTimeout(timeout time.Duration) {
pf.dialTimeout = timeout
}
// ForwardRequest contains the parameters for a port-forward request.
type ForwardRequest struct {
ContextName string // Kubernetes context name
@@ -124,6 +142,9 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
}
// Get pods backing the service using label selector
if len(service.Spec.Selector) == 0 {
return fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", serviceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(req.Namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
@@ -164,6 +185,19 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
// executePortForward performs the actual port-forward operation.
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
// Configure TCP settings on the underlying connection
// This is set in the rest.Config which will be used by the SPDY transport
if config.Dial == nil {
// Create a custom dialer with configurable timeout and keepalive
// - Timeout: How long to wait for connection to establish
// - KeepAlive: TCP keepalive helps OS detect dead connections at network layer
dialer := &net.Dialer{
Timeout: pf.dialTimeout, // Configurable dial timeout
KeepAlive: pf.tcpKeepalive, // Configurable keepalive interval
}
config.Dial = dialer.DialContext
}
// Create SPDY roundtripper
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
@@ -228,6 +262,9 @@ func (pf *PortForwarder) GetPodForResource(ctx context.Context, contextName, nam
return "", fmt.Errorf("failed to get service: %w", err)
}
if len(service.Spec.Selector) == 0 {
return "", fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", resourceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
-26
View File
@@ -228,29 +228,3 @@ func (r *ResourceResolver) InvalidateCache(contextName, namespace, resource stri
}
}
}
// GetPodList returns a list of pods matching the given criteria.
// This is useful for debugging and testing.
func (r *ResourceResolver) GetPodList(ctx context.Context, contextName, namespace, selector string) ([]*corev1.Pod, error) {
client, err := r.clientPool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
listOptions := metav1.ListOptions{}
if selector != "" {
listOptions.LabelSelector = selector
}
pods, err := client.CoreV1().Pods(namespace).List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
result := make([]*corev1.Pod, len(pods.Items))
for i := range pods.Items {
result[i] = &pods.Items[i]
}
return result, nil
}
+105
View File
@@ -0,0 +1,105 @@
package logger
import (
"github.com/go-logr/logr"
)
// LogrAdapter implements the logr.LogSink interface to route klog v2 logs
// through our structured logger. This captures ALL klog output including
// error logs, structured logs, and named logger output.
type LogrAdapter struct {
logger *Logger
name string
level int
}
// NewLogrAdapter creates a new logr.LogSink that routes all klog v2 logs
// through our structured logger.
func NewLogrAdapter(logger *Logger) logr.LogSink {
return &LogrAdapter{
logger: logger,
name: "",
level: 0,
}
}
// Init initializes the logger with runtime info (not used in our implementation).
func (l *LogrAdapter) Init(info logr.RuntimeInfo) {
// No-op: we don't need runtime info
}
// Enabled tests whether this LogSink is enabled at the specified V-level.
// We route all logs through our logger's level filtering.
func (l *LogrAdapter) Enabled(level int) bool {
// Map logr V-levels to our levels:
// V(0) = Info level (always enabled if logger level <= Info)
// V(1+) = Debug level (enabled if logger level <= Debug)
if level == 0 {
return l.logger.level <= LevelInfo
}
return l.logger.level <= LevelDebug
}
// Info logs a non-error message with the given key/value pairs.
func (l *LogrAdapter) Info(level int, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
// Map logr V-levels to our levels:
// V(0) = Info, V(1+) = Debug
if level == 0 {
l.logger.Info(msg, fields)
} else {
l.logger.Debug(msg, fields)
}
}
// Error logs an error message with the given key/value pairs.
func (l *LogrAdapter) Error(err error, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
if err != nil {
fields["error"] = err.Error()
}
l.logger.Error(msg, fields)
}
// WithValues returns a new LogSink with additional key/value pairs.
func (l *LogrAdapter) WithValues(keysAndValues ...interface{}) logr.LogSink {
// For simplicity, we don't implement value accumulation
// Each log call receives all its keysAndValues directly
return l
}
// WithName returns a new LogSink with the specified name appended.
func (l *LogrAdapter) WithName(name string) logr.LogSink {
newLogger := *l
if l.name == "" {
newLogger.name = name
} else {
newLogger.name = l.name + "." + name
}
return &newLogger
}
// kvToMap converts a slice of alternating keys and values to a map.
func (l *LogrAdapter) kvToMap(keysAndValues []interface{}) map[string]interface{} {
fields := make(map[string]interface{})
fields["source"] = "k8s-client"
for i := 0; i < len(keysAndValues); i += 2 {
if i+1 < len(keysAndValues) {
key, ok := keysAndValues[i].(string)
if ok {
fields[key] = keysAndValues[i+1]
}
}
}
return fields
}
+367
View File
@@ -0,0 +1,367 @@
package logger
import (
"bytes"
"encoding/json"
"errors"
"testing"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogrAdapter_Info(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "info log v0 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
message: "Connection established",
keysAndValues: []interface{}{"pod", "my-app-123", "port", 8080},
expectOutput: true,
expectContains: []string{"[INFO]", "Connection established", "pod", "my-app-123"},
},
{
name: "info log v0 with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Port forward ready",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Port forward ready"},
},
{
name: "info log v0 silenced with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
message: "This should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "debug log v1 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
message: "Detailed connection info",
keysAndValues: []interface{}{"details", "some-value"},
expectOutput: true,
expectContains: []string{"[DEBUG]", "Detailed connection info", "details"},
},
{
name: "debug log v1 silenced with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
message: "This debug should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "info with odd number of kvs (incomplete pair)",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Message with incomplete kv",
keysAndValues: []interface{}{"key1", "value1", "key2"}, // key2 has no value
expectOutput: true,
expectContains: []string{"[INFO]", "Message with incomplete kv", "key1", "value1"},
},
{
name: "info with source field added automatically",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Test source field",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Test source field", "source:k8s-client"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.V(tt.logrLevel).Info(tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_Error(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
err error
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "error with error object",
loggerLevel: LevelError,
err: errors.New("connection failed"),
message: "Port forward failed",
keysAndValues: []interface{}{"pod", "my-app-123"},
expectOutput: true,
expectContains: []string{"[ERROR]", "Port forward failed", "connection failed", "pod", "my-app-123"},
},
{
name: "error without error object",
loggerLevel: LevelError,
err: nil,
message: "Generic error message",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[ERROR]", "Generic error message"},
},
{
name: "error silenced with level above error",
loggerLevel: LevelError + 1,
err: errors.New("should not appear"),
message: "This error should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "error with multiple kvs",
loggerLevel: LevelError,
err: errors.New("sandbox not found"),
message: "Unhandled Error",
keysAndValues: []interface{}{"pod", "test-pod", "uid", "abc123", "port", 8080},
expectOutput: true,
expectContains: []string{"[ERROR]", "Unhandled Error", "sandbox not found", "pod", "test-pod", "uid", "abc123"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.Error(tt.err, tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_WithName(t *testing.T) {
tests := []struct {
name string
loggerNames []string
message string
expectContains string
}{
{
name: "single logger name",
loggerNames: []string{"portforward"},
message: "Test message",
expectContains: "logger:portforward",
},
{
name: "nested logger names",
loggerNames: []string{"controller", "worker", "healthcheck"},
message: "Nested message",
expectContains: "logger:controller.worker.healthcheck",
},
{
name: "no logger name",
loggerNames: []string{},
message: "No name message",
expectContains: "source:k8s-client", // Should still have source but no logger field
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Apply WithName calls
for _, name := range tt.loggerNames {
logrLogger = logrLogger.WithName(name)
}
logrLogger.Info(tt.message)
output := buf.String()
assert.Contains(t, output, tt.expectContains)
})
}
}
func TestLogrAdapter_Enabled(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
expectEnabled bool
}{
{
name: "v0 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 enabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 disabled with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
expectEnabled: false,
},
{
name: "v1 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
expectEnabled: true,
},
{
name: "v1 disabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
expectEnabled: false,
},
{
name: "v2 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 2,
expectEnabled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := New(tt.loggerLevel, FormatText, &bytes.Buffer{})
sink := NewLogrAdapter(logger)
enabled := sink.Enabled(tt.logrLevel)
assert.Equal(t, tt.expectEnabled, enabled)
})
}
}
func TestLogrAdapter_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatJSON, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("test-component")
logrLogger.Info("Test JSON message", "key1", "value1", "key2", 123)
// Parse JSON output
var entry logEntry
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "INFO", entry.Level)
assert.Equal(t, "Test JSON message", entry.Message)
assert.Equal(t, "k8s-client", entry.Fields["source"])
assert.Equal(t, "test-component", entry.Fields["logger"])
assert.Equal(t, "value1", entry.Fields["key1"])
assert.Equal(t, float64(123), entry.Fields["key2"]) // JSON numbers decode as float64
}
func TestLogrAdapter_ConcurrentWrites(t *testing.T) {
// Note: bytes.Buffer is not thread-safe for writes, so this test verifies
// that our LogrAdapter doesn't panic under concurrent load, but we don't
// verify exact output (since logger uses fmt.Fprintf which is also not thread-safe)
buf := &bytes.Buffer{}
logger := New(LevelDebug, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Spawn multiple goroutines writing concurrently
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
logrLogger.Info("Concurrent message", "goroutine", id, "iteration", j)
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
output := buf.String()
// Verify we got substantial output (not checking exact count due to buffer race)
// The main goal is to ensure no panics occur during concurrent writes
assert.NotEmpty(t, output, "Should have some log output")
assert.Contains(t, output, "Concurrent message")
}
func TestLogrAdapter_RealWorldKlogError(t *testing.T) {
// Simulate the exact error message from the screenshot
buf := &bytes.Buffer{}
logger := New(LevelError, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("UnhandledError")
err := errors.New("an error occurred forwarding 8401 -> 8401: error forwarding port 8401 to pod 4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308, uid : failed to find sandbox '4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308' in store: not found")
logrLogger.Error(err, "Unhandled Error")
output := buf.String()
assert.Contains(t, output, "[ERROR]")
assert.Contains(t, output, "Unhandled Error")
assert.Contains(t, output, "failed to find sandbox")
assert.Contains(t, output, "logger:UnhandledError")
}
func TestLogrAdapter_SilenceMode(t *testing.T) {
// Test that logs are completely silenced when logger level is above error
buf := &bytes.Buffer{}
logger := New(LevelError+1, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Try all log levels
logrLogger.V(0).Info("Info message should not appear")
logrLogger.V(1).Info("Debug message should not appear")
logrLogger.Error(errors.New("error object"), "Error message should not appear")
output := buf.String()
assert.Empty(t, output, "All logs should be silenced")
}
+5
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"io"
"os"
"sync"
"time"
)
@@ -28,6 +29,7 @@ type Logger struct {
level Level
format Format
output io.Writer
mu sync.Mutex // Protects concurrent writes to output
}
type logEntry struct {
@@ -55,6 +57,9 @@ func (l *Logger) log(level Level, msg string, fields map[string]interface{}) {
levelStr := levelToString(level)
l.mu.Lock()
defer l.mu.Unlock()
if l.format == FormatJSON {
entry := logEntry{
Time: time.Now().Format(time.RFC3339),
-5
View File
@@ -67,8 +67,3 @@ func (b *Backoff) calculateJitter(delay time.Duration) time.Duration {
jitter := (b.rng.Float64()*2 - 1) * maxJitter
return time.Duration(jitter)
}
// Sleep waits for the next backoff duration.
func (b *Backoff) Sleep() {
time.Sleep(b.Next())
}
+5 -3
View File
@@ -158,10 +158,12 @@ func TestBackoff_ExponentialProgression(t *testing.T) {
// We allow for jitter by checking a range
for i := 1; i < len(delays)-1; i++ {
// Each delay should be roughly double the previous (accounting for jitter)
// With 10% jitter on each value, worst case: (2.0 * 1.1) / 0.9 = 2.44
// We use 1.7x to 2.5x as a reasonable range with 10% jitter on each
// With 10% jitter on each value:
// Lower bound: (2.0 * 0.9) / 1.1 ≈ 1.636
// Upper bound: (2.0 * 1.1) / 0.9 ≈ 2.444
// We use 1.6x to 2.5x as a reasonable range to account for jitter variance
ratio := float64(delays[i]) / float64(delays[i-1])
assert.GreaterOrEqual(t, ratio, 1.7, "exponential growth should be ~2x")
assert.GreaterOrEqual(t, ratio, 1.6, "exponential growth should be ~2x")
assert.LessOrEqual(t, ratio, 2.5, "exponential growth should be ~2x")
}
}
+38 -4
View File
@@ -46,6 +46,11 @@ type BubbleTeaUI struct {
version string
errors map[string]string // Track error messages by forward ID
// Update notification
updateAvailable bool
updateVersion string
updateURL string
// Modal wizard state
viewMode ViewMode
addWizard *AddWizardState
@@ -96,6 +101,16 @@ func (ui *BubbleTeaUI) SetWizardDependencies(discovery *k8s.Discovery, mutator *
ui.configPath = configPath
}
// SetUpdateAvailable sets the update notification to be displayed
func (ui *BubbleTeaUI) SetUpdateAvailable(version, url string) {
ui.mu.Lock()
defer ui.mu.Unlock()
ui.updateAvailable = true
ui.updateVersion = version
ui.updateURL = url
}
// Start starts the bubbletea application
func (ui *BubbleTeaUI) Start() error {
m := model{ui: ui}
@@ -169,8 +184,9 @@ func (ui *BubbleTeaUI) UpdateStatus(id string, status string) {
if fwd, ok := ui.forwards[id]; ok {
fwd.Status = status
}
// Clear error if status is not Error
if status != "Error" {
// Only clear error when forward becomes Active again
// This keeps error visible during Reconnecting/Starting states
if status == "Active" {
delete(ui.errors, id)
}
ui.mu.Unlock()
@@ -266,7 +282,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.ui.addWizard = nil
m.ui.removeWizard = nil
m.ui.mu.Unlock()
return m, nil
return m, tea.ClearScreen
}
return m, nil
@@ -356,6 +372,15 @@ func (m model) renderMainView() string {
// Title with version
title := fmt.Sprintf("kportal v%s - Port Forwarding Status", m.ui.version)
b.WriteString(titleStyle.Render(title))
// Show update notification if available
if m.ui.updateAvailable {
updateStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("42")). // Green
Bold(true)
updateMsg := fmt.Sprintf(" Update available: v%s", m.ui.updateVersion)
b.WriteString(updateStyle.Render(updateMsg))
}
b.WriteString("\n\n")
// Header
@@ -378,7 +403,7 @@ func (m model) renderMainView() string {
}
isSelected := (idx == m.ui.selectedIndex)
isDisabled := m.ui.disabledMap[id]
isDisabled := m.ui.disabledMap[id] || fwd.Status == "Disabled"
// Selection indicator
indicator := " "
@@ -574,6 +599,15 @@ func (ui *BubbleTeaUI) moveSelection(delta int) {
}
}
// resetDeleteConfirmation resets the delete confirmation dialog state.
// Caller must hold ui.mu lock.
func (ui *BubbleTeaUI) resetDeleteConfirmation() {
ui.deleteConfirming = false
ui.deleteConfirmID = ""
ui.deleteConfirmAlias = ""
ui.deleteConfirmCursor = 0
}
// renderDeleteConfirmation renders the delete confirmation dialog
func (m model) renderDeleteConfirmation() string {
m.ui.mu.RLock()
-181
View File
@@ -1,181 +0,0 @@
package ui
import (
"fmt"
"os"
"sync"
"golang.org/x/term"
)
// InteractiveController handles keyboard input and selection state
type InteractiveController struct {
mu sync.RWMutex
selectedIndex int
forwardIDs []string // Ordered list of forward IDs
disabledMap map[string]bool // Tracks which forwards are disabled
toggleCallback func(id string, enable bool)
enabled bool
oldTermState *term.State
}
// NewInteractiveController creates a new interactive controller
func NewInteractiveController(toggleCallback func(id string, enable bool)) *InteractiveController {
return &InteractiveController{
selectedIndex: 0,
forwardIDs: make([]string, 0),
disabledMap: make(map[string]bool),
toggleCallback: toggleCallback,
enabled: false,
}
}
// Enable puts the terminal in raw mode for keyboard input
func (ic *InteractiveController) Enable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.enabled {
return nil
}
// Save current terminal state
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return fmt.Errorf("failed to enable raw mode: %w", err)
}
ic.oldTermState = oldState
ic.enabled = true
return nil
}
// Disable restores the terminal to normal mode
func (ic *InteractiveController) Disable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if !ic.enabled {
return nil
}
if ic.oldTermState != nil {
if err := term.Restore(int(os.Stdin.Fd()), ic.oldTermState); err != nil {
return fmt.Errorf("failed to restore terminal: %w", err)
}
}
ic.enabled = false
return nil
}
// UpdateForwardsList updates the list of forwards for navigation
func (ic *InteractiveController) UpdateForwardsList(ids []string) {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.forwardIDs = ids
// Ensure selected index is valid
if ic.selectedIndex >= len(ic.forwardIDs) {
ic.selectedIndex = len(ic.forwardIDs) - 1
}
if ic.selectedIndex < 0 && len(ic.forwardIDs) > 0 {
ic.selectedIndex = 0
}
}
// MoveUp moves selection up
func (ic *InteractiveController) MoveUp() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex > 0 {
ic.selectedIndex--
}
}
// MoveDown moves selection down
func (ic *InteractiveController) MoveDown() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex < len(ic.forwardIDs)-1 {
ic.selectedIndex++
}
}
// ToggleSelected toggles the enable/disable state of the selected forward
func (ic *InteractiveController) ToggleSelected() {
ic.mu.Lock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
ic.mu.Unlock()
return
}
selectedID := ic.forwardIDs[ic.selectedIndex]
currentlyDisabled := ic.disabledMap[selectedID]
newState := !currentlyDisabled
ic.disabledMap[selectedID] = newState
ic.mu.Unlock()
// Call the toggle callback
if ic.toggleCallback != nil {
ic.toggleCallback(selectedID, !newState) // enable is inverse of disabled
}
}
// GetSelectedIndex returns the current selection index
func (ic *InteractiveController) GetSelectedIndex() int {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.selectedIndex
}
// IsDisabled returns whether a forward is disabled
func (ic *InteractiveController) IsDisabled(id string) bool {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.disabledMap[id]
}
// GetSelectedID returns the ID of the currently selected forward
func (ic *InteractiveController) GetSelectedID() string {
ic.mu.RLock()
defer ic.mu.RUnlock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
return ""
}
return ic.forwardIDs[ic.selectedIndex]
}
// HandleKey processes keyboard input and returns true if should continue
func (ic *InteractiveController) HandleKey(b []byte) bool {
if len(b) == 0 {
return true
}
// Handle single byte keys
if len(b) == 1 {
switch b[0] {
case 'q', 'Q', 3: // q, Q, or Ctrl+C
return false
case ' ', '\r': // Space or Enter to toggle
ic.ToggleSelected()
return true
}
}
// Handle escape sequences (arrow keys)
if len(b) == 3 && b[0] == 27 && b[1] == 91 {
switch b[2] {
case 65: // Up arrow
ic.MoveUp()
case 66: // Down arrow
ic.MoveDown()
}
}
return true
}
+7 -48
View File
@@ -23,10 +23,9 @@ type ForwardStatus struct {
// TableUI manages the terminal table display
type TableUI struct {
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
interactive *InteractiveController
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
}
// NewTableUI creates a new table UI manager
@@ -37,13 +36,6 @@ func NewTableUI(verbose bool) *TableUI {
}
}
// SetInteractiveController sets the interactive controller
func (t *TableUI) SetInteractiveController(ic *InteractiveController) {
t.mu.Lock()
defer t.mu.Unlock()
t.interactive = ic
}
// AddForward registers a new forward for display
func (t *TableUI) AddForward(id string, fwd *config.Forward) {
t.mu.Lock()
@@ -126,27 +118,10 @@ func (t *TableUI) Render() {
}
}
// Update interactive controller with current forward IDs (in display order)
if t.interactive != nil {
ids := make([]string, len(entries))
for i, entry := range entries {
ids[i] = entry.id
}
t.interactive.UpdateForwardsList(ids)
}
// Print each forward
for i, entry := range entries {
for _, entry := range entries {
fwd := entry.fwd
// Check if this row is selected
isSelected := false
isDisabled := false
if t.interactive != nil {
isSelected = (i == t.interactive.GetSelectedIndex())
isDisabled = t.interactive.IsDisabled(entry.id)
}
// Truncate long names
alias := truncate(fwd.Alias, 25)
resource := truncate(fwd.Resource, 25)
@@ -154,8 +129,8 @@ func (t *TableUI) Render() {
// Color code status with indicator
statusStr := formatStatusWithIndicator(fwd.Status)
// Build the row content
rowContent := fmt.Sprintf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s",
// Print the row
fmt.Printf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s\n",
fwd.Context,
fwd.Namespace,
alias,
@@ -164,26 +139,10 @@ func (t *TableUI) Render() {
fwd.RemotePort,
fwd.LocalPort,
statusStr)
// Apply selection highlighting or disabled styling
if isSelected {
// Replace leading spaces with arrow, then apply reverse video to entire line
rowContent = "\033[7m> " + rowContent[2:] + "\033[0m"
} else if isDisabled {
// Apply dimmed styling to entire line
rowContent = "\033[2m" + rowContent + "\033[0m"
}
fmt.Println(rowContent)
}
fmt.Println(strings.Repeat("=", 130))
helpText := "Total forwards: %d | ↑↓: Navigate | Space: Toggle | q: Quit"
if !t.verbose {
fmt.Printf(helpText+"\n", len(t.forwards))
} else {
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
}
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
// In verbose mode, add a newline to separate from logs
if t.verbose {
+18 -1
View File
@@ -144,8 +144,25 @@ func validateSelectorCmd(discovery *k8s.Discovery, contextName, namespace, selec
}
// checkPortCmd checks if a local port is available
func checkPortCmd(port int) tea.Cmd {
func checkPortCmd(port int, configPath string) tea.Cmd {
return func() tea.Msg {
// First check if port is already in the configuration
cfg, err := config.LoadConfig(configPath)
if err == nil {
// Check all forwards in config for this port
allForwards := cfg.GetAllForwards()
for _, fwd := range allForwards {
if fwd.LocalPort == port {
return PortCheckedMsg{
port: port,
available: false,
message: fmt.Sprintf("✗ Port %d already assigned to %s", port, fwd.ID()),
}
}
}
}
// Then check if port is available at OS level
available, processInfo, err := k8s.CheckPortAvailability(port)
msg := ""
+107 -51
View File
@@ -10,6 +10,19 @@ import (
"github.com/nvm/kportal/internal/k8s"
)
// isFilterableStep returns true if the step supports search/filter
func isFilterableStep(step AddWizardStep) bool {
switch step {
case StepSelectContext, StepSelectNamespace:
return true
case StepEnterResource:
// Only service selection is filterable (pod prefix and selector are text input)
return true // We'll check resource type in the handler
default:
return false
}
}
// handleMainViewKeys handles keyboard input in the main view
func (m model) handleMainViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// If delete confirmation is showing, handle it separately
@@ -160,12 +173,8 @@ func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
switch msg.String() {
case "ctrl+c", "esc":
// Cancel deletion
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
// Force a repaint by returning the model
return m, tea.ClearScreen
case "left", "h", "right", "l":
@@ -178,26 +187,18 @@ func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// Confirm deletion (either Enter on Yes or pressing 'y')
if m.ui.deleteConfirmCursor == 0 || msg.String() == "y" {
id := m.ui.deleteConfirmID
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, removeForwardByIDCmd(m.ui.mutator, id)
}
// Enter on No = cancel
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, tea.ClearScreen
case "n":
// Quick 'n' for no
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, tea.ClearScreen
}
@@ -224,6 +225,12 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m, tea.ClearScreen
case "esc":
// If there's an active search filter, clear it instead of going back
if wizard.searchFilter != "" && isFilterableStep(wizard.step) {
wizard.clearSearchFilter()
return m, nil
}
// In edit mode, Esc always cancels (don't navigate back through skipped steps)
if wizard.isEditing {
m.ui.viewMode = ViewModeMain
@@ -240,9 +247,7 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
} else {
// Go back one step
wizard.step--
wizard.cursor = 0
wizard.clearTextInput()
wizard.error = nil
wizard.resetInput()
// Reset input mode based on the step we're going back to
switch wizard.step {
@@ -300,26 +305,48 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
return m.handleAddWizardEnter()
case "backspace":
// Allow backspace in text input mode OR when focused on alias in confirmation
// Allow backspace in text input mode OR when focused on alias in confirmation OR when filtering
canBackspace := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias)
if canBackspace && len(wizard.textInput) > 0 {
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step) && len(wizard.searchFilter) > 0)
if canBackspace {
if isFilterableStep(wizard.step) && wizard.inputMode == InputModeList && len(wizard.searchFilter) > 0 {
// Backspace in search filter
wizard.searchFilter = wizard.searchFilter[:len(wizard.searchFilter)-1]
wizard.cursor = 0
wizard.scrollOffset = 0
} else if len(wizard.textInput) > 0 {
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
}
}
default:
// Handle text input
canTypeText := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias)
if canTypeText && len(msg.String()) == 1 {
wizard.handleTextInput(rune(msg.String()[0]))
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step))
// Trigger validation for selector
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
if len(wizard.textInput) > 0 {
wizard.loading = true
wizard.error = nil
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
if canTypeText && len(msg.String()) == 1 {
// If in list mode on filterable step, add to search filter instead of textInput
if wizard.inputMode == InputModeList && isFilterableStep(wizard.step) {
char := rune(msg.String()[0])
// Only allow printable characters
if char >= 32 && char < 127 {
wizard.searchFilter += string(char)
wizard.cursor = 0
wizard.scrollOffset = 0
}
} else {
wizard.handleTextInput(rune(msg.String()[0]))
// Trigger validation for selector
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
if len(wizard.textInput) > 0 {
wizard.loading = true
wizard.error = nil
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
}
}
}
}
@@ -332,21 +359,30 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
wizard := m.ui.addWizard
// Don't process Enter if we're currently loading
if wizard.loading {
return m, nil
}
switch wizard.step {
case StepSelectContext:
if wizard.cursor >= 0 && wizard.cursor < len(wizard.contexts) {
wizard.selectedContext = wizard.contexts[wizard.cursor]
filteredContexts := wizard.getFilteredContexts()
if wizard.cursor >= 0 && wizard.cursor < len(filteredContexts) {
wizard.selectedContext = filteredContexts[wizard.cursor]
wizard.step = StepSelectNamespace
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.loading = true
return m, loadNamespacesCmd(m.ui.discovery, wizard.selectedContext)
}
case StepSelectNamespace:
if wizard.cursor >= 0 && wizard.cursor < len(wizard.namespaces) {
wizard.selectedNamespace = wizard.namespaces[wizard.cursor]
filteredNamespaces := wizard.getFilteredNamespaces()
if wizard.cursor >= 0 && wizard.cursor < len(filteredNamespaces) {
wizard.selectedNamespace = filteredNamespaces[wizard.cursor]
wizard.step = StepSelectResourceType
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.inputMode = InputModeList
}
@@ -403,13 +439,17 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
}
case ResourceTypeService:
if wizard.cursor >= 0 && wizard.cursor < len(wizard.services) {
wizard.resourceValue = wizard.services[wizard.cursor].Name
filteredServices := wizard.getFilteredServices()
if wizard.cursor >= 0 && wizard.cursor < len(filteredServices) {
wizard.resourceValue = filteredServices[wizard.cursor].Name
// Get ports from selected service (must do this BEFORE clearing search filter)
wizard.detectedPorts = filteredServices[wizard.cursor].Ports
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
wizard.clearSearchFilter()
// Get ports from selected service
wizard.detectedPorts = wizard.services[wizard.cursor].Ports
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
@@ -437,7 +477,7 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
} else {
// Text mode - manual entry
port, err := strconv.Atoi(wizard.textInput)
if err != nil || port < 1 || port > 65535 {
if err != nil || !config.IsValidPort(port) {
wizard.error = fmt.Errorf("invalid port number")
} else {
wizard.remotePort = port
@@ -449,17 +489,14 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
case StepEnterLocalPort:
port, err := strconv.Atoi(wizard.textInput)
if err != nil || port < 1 || port > 65535 {
if err != nil || !config.IsValidPort(port) {
wizard.error = fmt.Errorf("invalid port number")
} else {
// Check port availability before proceeding
wizard.localPort = port
wizard.step = StepConfirmation
wizard.clearTextInput()
wizard.cursor = 0
wizard.inputMode = InputModeList
wizard.error = nil
wizard.loading = true
return m, checkPortCmd(port)
wizard.error = nil
return m, checkPortCmd(port, m.ui.configPath)
}
case StepConfirmation:
@@ -472,6 +509,12 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
// Handle button selection
if wizard.cursor == 0 {
// Check if port is available before saving
if !wizard.portAvailable {
wizard.error = fmt.Errorf("port %d is not available. Please choose a different port", wizard.localPort)
return m, nil
}
// Confirmed - save the forward
wizard.alias = wizard.textInput
@@ -501,9 +544,10 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
return m, saveForwardCmd(m.ui.mutator, wizard.selectedContext, wizard.selectedNamespace, fwd)
} else {
// Cancelled
// Cancelled - return to main view with screen clear
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
case StepSuccess:
@@ -513,9 +557,10 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
m.ui.addWizard.loading = true
return m, loadContextsCmd(m.ui.discovery)
} else {
// Return to main view
// Return to main view with screen clear
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
}
@@ -723,6 +768,17 @@ func (m model) handlePortChecked(msg PortCheckedMsg) (tea.Model, tea.Cmd) {
m.ui.addWizard.loading = false
m.ui.addWizard.portAvailable = msg.available
m.ui.addWizard.portCheckMsg = msg.message
// Only proceed to confirmation if port is available
if msg.available {
m.ui.addWizard.step = StepConfirmation
m.ui.addWizard.clearTextInput()
m.ui.addWizard.cursor = 0
m.ui.addWizard.inputMode = InputModeList
} else {
// Port is not available - show error and stay on local port step
m.ui.addWizard.error = fmt.Errorf("port %d is in use, please choose another port", msg.port)
}
}
return m, nil
@@ -759,5 +815,5 @@ func (m model) handleForwardsRemoved(msg ForwardsRemovedMsg) (tea.Model, tea.Cmd
// If there was an error, it will be logged but we don't show it in UI for now
// The config watcher will either reload (success) or keep old config (failure)
return m, nil
return m, tea.ClearScreen
}
+76 -3
View File
@@ -1,9 +1,34 @@
package ui
import (
"strings"
"github.com/nvm/kportal/internal/k8s"
)
// filterStrings filters a slice of strings by a search filter (case-insensitive substring match)
func filterStrings(items []string, filter string) []string {
if filter == "" {
return items
}
filtered := []string{}
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item), filterLower) {
filtered = append(filtered, item)
}
}
return filtered
}
// matchesFilter checks if a string matches the filter (case-insensitive substring match)
func matchesFilter(item, filter string) bool {
if filter == "" {
return true
}
return strings.Contains(strings.ToLower(item), strings.ToLower(filter))
}
// ViewMode represents the current view state of the UI
type ViewMode int
@@ -87,6 +112,7 @@ type AddWizardState struct {
cursor int
scrollOffset int // For scrolling long lists
textInput string
searchFilter string // For filtering lists (contexts, namespaces, services)
loading bool
error error
@@ -142,14 +168,14 @@ func (w *AddWizardState) moveCursor(delta int) {
switch w.step {
case StepSelectContext:
maxItems = len(w.contexts)
maxItems = len(w.getFilteredContexts())
case StepSelectNamespace:
maxItems = len(w.namespaces)
maxItems = len(w.getFilteredNamespaces())
case StepSelectResourceType:
maxItems = 3 // Three resource types
case StepEnterResource:
if w.selectedResourceType == ResourceTypeService {
maxItems = len(w.services)
maxItems = len(w.getFilteredServices())
}
case StepEnterRemotePort:
if len(w.detectedPorts) > 0 {
@@ -300,3 +326,50 @@ func (w *RemoveWizardState) getSelectedForwards() []RemovableForward {
}
return selected
}
// getFilteredContexts returns contexts filtered by search string
func (w *AddWizardState) getFilteredContexts() []string {
if w.searchFilter == "" {
return w.contexts
}
return filterStrings(w.contexts, w.searchFilter)
}
// getFilteredNamespaces returns namespaces filtered by search string
func (w *AddWizardState) getFilteredNamespaces() []string {
if w.searchFilter == "" {
return w.namespaces
}
return filterStrings(w.namespaces, w.searchFilter)
}
// getFilteredServices returns services filtered by search string
func (w *AddWizardState) getFilteredServices() []k8s.ServiceInfo {
if w.searchFilter == "" {
return w.services
}
filtered := []k8s.ServiceInfo{}
for _, svc := range w.services {
if matchesFilter(svc.Name, w.searchFilter) {
filtered = append(filtered, svc)
}
}
return filtered
}
// clearSearchFilter clears the search filter and resets cursor/scroll
func (w *AddWizardState) clearSearchFilter() {
w.searchFilter = ""
w.cursor = 0
w.scrollOffset = 0
}
// resetInput clears text input, search filter, and error state.
// Use this when navigating between wizard steps.
func (w *AddWizardState) resetInput() {
w.textInput = ""
w.searchFilter = ""
w.cursor = 0
w.scrollOffset = 0
w.error = nil
}
+350
View File
@@ -0,0 +1,350 @@
package ui
import (
"testing"
"github.com/nvm/kportal/internal/k8s"
"github.com/stretchr/testify/assert"
)
func TestFilterStrings(t *testing.T) {
tests := []struct {
name string
items []string
filter string
expected []string
}{
{
name: "empty filter returns all items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "",
expected: []string{"namespace-1", "namespace-2", "namespace-3"},
},
{
name: "filter matches multiple items",
items: []string{"prod-api", "prod-db", "staging-api", "dev-api"},
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter matches single item",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "2",
expected: []string{"namespace-2"},
},
{
name: "filter matches no items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "xyz",
expected: []string{},
},
{
name: "case insensitive matching",
items: []string{"Production", "Staging", "Development"},
filter: "prod",
expected: []string{"Production"},
},
{
name: "partial string matching",
items: []string{"my-app-frontend", "my-app-backend", "other-service"},
filter: "app",
expected: []string{"my-app-frontend", "my-app-backend"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterStrings(tt.items, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMatchesFilter(t *testing.T) {
tests := []struct {
name string
item string
filter string
expected bool
}{
{
name: "empty filter matches everything",
item: "namespace-1",
filter: "",
expected: true,
},
{
name: "exact match",
item: "namespace-1",
filter: "namespace-1",
expected: true,
},
{
name: "partial match",
item: "production-api",
filter: "prod",
expected: true,
},
{
name: "no match",
item: "namespace-1",
filter: "xyz",
expected: false,
},
{
name: "case insensitive match",
item: "Production",
filter: "prod",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchesFilter(tt.item, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredContexts(t *testing.T) {
wizard := &AddWizardState{
contexts: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-cluster"},
},
{
name: "filter by 'cluster'",
filter: "cluster",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'staging'",
filter: "staging",
expected: []string{"staging-cluster"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredContexts()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredNamespaces(t *testing.T) {
wizard := &AddWizardState{
namespaces: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter by 'kube'",
filter: "kube",
expected: []string{"kube-system", "kube-public"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"prod-api", "staging-api"},
},
{
name: "filter by 'ing' (partial match)",
filter: "ing",
expected: []string{"staging-api", "staging-db", "monitoring", "logging"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredNamespaces()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredServices(t *testing.T) {
wizard := &AddWizardState{
services: []k8s.ServiceInfo{
{Name: "api-gateway"},
{Name: "api-backend"},
{Name: "database"},
{Name: "redis-cache"},
{Name: "postgres-db"},
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"api-gateway", "api-backend", "database", "redis-cache", "postgres-db"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"api-gateway", "api-backend"},
},
{
name: "filter by 'db'",
filter: "db",
expected: []string{"postgres-db"},
},
{
name: "filter by 'base'",
filter: "base",
expected: []string{"database"},
},
{
name: "filter by 'redis'",
filter: "redis",
expected: []string{"redis-cache"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredServices()
resultNames := make([]string, len(result))
for i, svc := range result {
resultNames[i] = svc.Name
}
assert.Equal(t, tt.expected, resultNames)
})
}
}
func TestClearSearchFilter(t *testing.T) {
wizard := &AddWizardState{
searchFilter: "test",
cursor: 5,
scrollOffset: 10,
}
wizard.clearSearchFilter()
assert.Equal(t, "", wizard.searchFilter, "searchFilter should be cleared")
assert.Equal(t, 0, wizard.cursor, "cursor should be reset to 0")
assert.Equal(t, 0, wizard.scrollOffset, "scrollOffset should be reset to 0")
}
func TestMoveCursorWithFilteredLists(t *testing.T) {
tests := []struct {
name string
step AddWizardStep
contexts []string
namespaces []string
searchFilter string
initialCursor int
delta int
expectedCursor int
}{
{
name: "move down in filtered contexts",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 0,
delta: 1,
expectedCursor: 1,
},
{
name: "cannot move beyond filtered list",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 1,
delta: 1,
expectedCursor: 1, // Should stay at 1 (last item in filtered list)
},
{
name: "move up in filtered list",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3", "other"},
searchFilter: "ns",
initialCursor: 2,
delta: -1,
expectedCursor: 1,
},
{
name: "cannot move above 0",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3"},
searchFilter: "ns",
initialCursor: 0,
delta: -1,
expectedCursor: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard := &AddWizardState{
step: tt.step,
inputMode: InputModeList,
cursor: tt.initialCursor,
contexts: tt.contexts,
namespaces: tt.namespaces,
searchFilter: tt.searchFilter,
}
wizard.moveCursor(tt.delta)
assert.Equal(t, tt.expectedCursor, wizard.cursor)
})
}
}
+91 -40
View File
@@ -45,6 +45,12 @@ func (m model) renderSelectContext() string {
b.WriteString(renderHeader("Add Port Forward", renderProgress(1, 7)))
b.WriteString("Select Kubernetes Context:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading contexts..."))
} else if wizard.error != nil {
@@ -52,46 +58,56 @@ func (m model) renderSelectContext() string {
} else if len(wizard.contexts) == 0 {
b.WriteString(mutedStyle.Render("No contexts found in kubeconfig"))
} else {
const viewportHeight = 20
totalItems := len(wizard.contexts)
filteredContexts := wizard.getFilteredContexts()
if len(filteredContexts) == 0 {
b.WriteString(mutedStyle.Render("No matching contexts"))
} else {
const viewportHeight = 20
totalItems := len(filteredContexts)
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render visible contexts with (current) marker on first one
for i := start; i < end; i++ {
prefix := " "
text := wizard.contexts[i]
if i == 0 {
text += mutedStyle.Render(" (current)")
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + text))
} else {
b.WriteString(prefix + text)
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
// Render visible contexts with (current) marker on first one (only if not filtered)
for i := start; i < end; i++ {
prefix := " "
text := filteredContexts[i]
// Only show (current) marker if no filter and this is the first item in original list
if wizard.searchFilter == "" && i == 0 {
text += mutedStyle.Render(" (current)")
}
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + text))
} else {
b.WriteString(prefix + text)
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
}
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Cancel", len(wizard.getFilteredContexts()), len(wizard.contexts))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
}
return b.String()
}
@@ -105,6 +121,12 @@ func (m model) renderSelectNamespace() string {
b.WriteString("Select Namespace:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading namespaces..."))
} else if wizard.error != nil {
@@ -113,11 +135,20 @@ func (m model) renderSelectNamespace() string {
} else if len(wizard.namespaces) == 0 {
b.WriteString(mutedStyle.Render("No namespaces found"))
} else {
b.WriteString(renderList(wizard.namespaces, wizard.cursor, " ", wizard.scrollOffset))
filteredNamespaces := wizard.getFilteredNamespaces()
if len(filteredNamespaces) == 0 {
b.WriteString(mutedStyle.Render("No matching namespaces"))
} else {
b.WriteString(renderList(filteredNamespaces, wizard.cursor, " ", wizard.scrollOffset))
}
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredNamespaces()), len(wizard.namespaces))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
@@ -242,21 +273,41 @@ func (m model) renderEnterResource() string {
case ResourceTypeService:
b.WriteString("Select service:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading services..."))
} else if len(wizard.services) == 0 {
b.WriteString(mutedStyle.Render("No services found"))
} else {
serviceNames := make([]string, len(wizard.services))
for i, svc := range wizard.services {
serviceNames[i] = svc.Name
filteredServices := wizard.getFilteredServices()
if len(filteredServices) == 0 {
b.WriteString(mutedStyle.Render("No matching services"))
} else {
serviceNames := make([]string, len(filteredServices))
for i, svc := range filteredServices {
serviceNames[i] = svc.Name
}
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
}
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
}
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
// Show appropriate help text based on resource type and filter state
if wizard.selectedResourceType == ResourceTypeService {
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredServices()), len(wizard.services))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
} else {
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
@@ -322,7 +373,7 @@ func (m model) renderEnterRemotePort() string {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + manualOption))
} else {
b.WriteString(mutedStyle.Render(prefix + manualOption))
b.WriteString(prefix + mutedStyle.Render(manualOption))
}
b.WriteString("\n")
}
@@ -392,7 +443,7 @@ func (m model) renderEnterLocalPort() string {
} else {
b.WriteString(errorStyle.Render(wizard.portCheckMsg))
}
} else if wizard.textInput != "" && wizard.localPort > 0 {
} else if wizard.textInput != "" {
b.WriteString(mutedStyle.Render("Press Enter to check availability"))
}
+158
View File
@@ -0,0 +1,158 @@
package version
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
const (
// GitHubAPIURL is the GitHub API endpoint for releases
githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest"
// requestTimeout is the timeout for HTTP requests
requestTimeout = 5 * time.Second
)
// ReleaseInfo contains information about a GitHub release
type ReleaseInfo struct {
TagName string `json:"tag_name"`
HTMLURL string `json:"html_url"`
Name string `json:"name"`
}
// UpdateInfo contains information about an available update
type UpdateInfo struct {
CurrentVersion string
LatestVersion string
ReleaseURL string
ReleaseName string
}
// Checker checks for new versions on GitHub
type Checker struct {
owner string
repo string
current string
client *http.Client
}
// NewChecker creates a new version checker
func NewChecker(owner, repo, currentVersion string) *Checker {
return &Checker{
owner: owner,
repo: repo,
current: normalizeVersion(currentVersion),
client: &http.Client{
Timeout: requestTimeout,
},
}
}
// CheckForUpdate checks if a newer version is available.
// Returns nil if current version is up to date or if check fails.
// This is designed to fail silently - network errors should not impact the user.
func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo {
release, err := c.fetchLatestRelease(ctx)
if err != nil {
return nil
}
latestVersion := normalizeVersion(release.TagName)
if isNewerVersion(latestVersion, c.current) {
return &UpdateInfo{
CurrentVersion: c.current,
LatestVersion: latestVersion,
ReleaseURL: release.HTMLURL,
ReleaseName: release.Name,
}
}
return nil
}
// fetchLatestRelease fetches the latest release info from GitHub API
func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) {
url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "kportal-version-checker")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
}
var release ReleaseInfo
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
// normalizeVersion removes 'v' or 'V' prefix and trims whitespace
func normalizeVersion(v string) string {
v = strings.TrimSpace(v)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
return v
}
// isNewerVersion compares two semver-like versions.
// Returns true if latest is newer than current.
func isNewerVersion(latest, current string) bool {
latestParts := parseVersion(latest)
currentParts := parseVersion(current)
// Compare each part
for i := 0; i < len(latestParts) && i < len(currentParts); i++ {
if latestParts[i] > currentParts[i] {
return true
}
if latestParts[i] < currentParts[i] {
return false
}
}
// If all compared parts are equal, longer version is newer
// e.g., 1.0.1 > 1.0
return len(latestParts) > len(currentParts)
}
// parseVersion splits a version string into numeric parts
func parseVersion(v string) []int {
// Remove any suffix like -beta, -rc1, etc.
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.Split(v, ".")
result := make([]int, 0, len(parts))
for _, p := range parts {
var num int
fmt.Sscanf(p, "%d", &num)
result = append(result, num)
}
return result
}
// FormatUpdateMessage formats a user-friendly update notification
func (u *UpdateInfo) FormatUpdateMessage() string {
return fmt.Sprintf("New version available: %s (current: %s) - %s",
u.LatestVersion, u.CurrentVersion, u.ReleaseURL)
}
+90
View File
@@ -0,0 +1,90 @@
package version
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNormalizeVersion(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"v1.0.0", "1.0.0"},
{"1.0.0", "1.0.0"},
{" v2.1.3 ", "2.1.3"},
{"V1.0.0", "1.0.0"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseVersion(t *testing.T) {
tests := []struct {
input string
expected []int
}{
{"1.0.0", []int{1, 0, 0}},
{"2.1.3", []int{2, 1, 3}},
{"1.0", []int{1, 0}},
{"10.20.30", []int{10, 20, 30}},
{"1.0.0-beta", []int{1, 0, 0}},
{"1.0.0-rc1", []int{1, 0, 0}},
{"1.0.0+build123", []int{1, 0, 0}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := parseVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
latest string
current string
expected bool
}{
{"major version bump", "2.0.0", "1.0.0", true},
{"minor version bump", "1.1.0", "1.0.0", true},
{"patch version bump", "1.0.1", "1.0.0", true},
{"same version", "1.0.0", "1.0.0", false},
{"current is newer major", "1.0.0", "2.0.0", false},
{"current is newer minor", "1.0.0", "1.1.0", false},
{"current is newer patch", "1.0.0", "1.0.1", false},
{"multi-digit versions", "1.10.0", "1.9.0", true},
{"longer version is newer", "1.0.1", "1.0", true},
{"shorter version is older", "1.0", "1.0.1", false},
{"complex comparison", "2.1.3", "2.1.2", true},
{"real world example", "0.2.0", "0.1.0", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNewerVersion(tt.latest, tt.current)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUpdateInfo_FormatUpdateMessage(t *testing.T) {
info := &UpdateInfo{
CurrentVersion: "0.1.0",
LatestVersion: "0.2.0",
ReleaseURL: "https://github.com/nvm/kportal/releases/tag/v0.2.0",
}
msg := info.FormatUpdateMessage()
assert.Contains(t, msg, "0.2.0")
assert.Contains(t, msg, "0.1.0")
assert.Contains(t, msg, "https://github.com/nvm/kportal/releases/tag/v0.2.0")
}