mirror of
https://github.com/lukaszraczylo/kportal.git
synced 2026-06-30 05:44:37 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2fdc5912e7 | |||
| 7df161aee0 |
@@ -1,6 +1,27 @@
|
|||||||
# Example kportal configuration
|
# Example kportal configuration
|
||||||
# Copy this file to your project and customize as needed
|
# Copy this file to your project and customize as needed
|
||||||
|
|
||||||
|
# Optional: Health check configuration
|
||||||
|
# These settings control how kportal monitors connection health and detects stale connections
|
||||||
|
healthCheck:
|
||||||
|
interval: "3s" # How often to check connection health (default: 3s)
|
||||||
|
timeout: "2s" # Timeout for health check operations (default: 2s)
|
||||||
|
method: "data-transfer" # Health check method: "tcp-dial" or "data-transfer" (default: data-transfer)
|
||||||
|
# - tcp-dial: Simple TCP connection test (fast, less reliable)
|
||||||
|
# - data-transfer: Attempts to read data (slower, more reliable)
|
||||||
|
maxConnectionAge: "25m" # Maximum connection age before proactive reconnect (default: 25m)
|
||||||
|
# Helps avoid Kubernetes API server timeouts (typically 30m)
|
||||||
|
maxIdleTime: "10m" # Maximum idle time before marking as stale (default: 10m)
|
||||||
|
# Connections with no data transfer are marked stale
|
||||||
|
|
||||||
|
# Optional: Reliability configuration
|
||||||
|
# These settings improve connection stability for long-running transfers
|
||||||
|
reliability:
|
||||||
|
tcpKeepalive: "30s" # TCP keepalive interval for OS-level connection monitoring (default: 30s)
|
||||||
|
dialTimeout: "30s" # Connection dial timeout (default: 30s)
|
||||||
|
retryOnStale: true # Automatically reconnect when stale connections detected (default: true)
|
||||||
|
watchdogPeriod: "30s" # Goroutine watchdog check interval to detect hung workers (default: 30s)
|
||||||
|
|
||||||
contexts:
|
contexts:
|
||||||
# Production context
|
# Production context
|
||||||
- name: production
|
- name: production
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ kportal simplifies managing multiple Kubernetes port-forwards with an elegant, i
|
|||||||
- 🗑️ **Live Delete** - Remove port-forwards instantly from the running session
|
- 🗑️ **Live Delete** - Remove port-forwards instantly from the running session
|
||||||
- 🔄 **Auto-Reconnect** - Automatic retry with exponential backoff on connection failures (max 10s)
|
- 🔄 **Auto-Reconnect** - Automatic retry with exponential backoff on connection failures (max 10s)
|
||||||
- ⚡ **Hot-Reload** - Update configuration without restarting - changes applied automatically
|
- ⚡ **Hot-Reload** - Update configuration without restarting - changes applied automatically
|
||||||
- 🏥 **Health Checks** - Real-time port forward status monitoring with 5-second intervals
|
- 🏥 **Advanced Health Checks** - Multiple check methods (tcp-dial, data-transfer) with stale connection detection
|
||||||
|
- 🛡️ **Goroutine Watchdog** - Detects and recovers from completely hung workers
|
||||||
- 🎨 **Multi-Context** - Support for multiple Kubernetes contexts and namespaces
|
- 🎨 **Multi-Context** - Support for multiple Kubernetes contexts and namespaces
|
||||||
- 📦 **Batch Management** - Manage all port-forwards from a single configuration file
|
- 📦 **Batch Management** - Manage all port-forwards from a single configuration file
|
||||||
- 🔌 **Toggle Forwards** - Enable/disable individual port-forwards on the fly with Space key
|
- 🔌 **Toggle Forwards** - Enable/disable individual port-forwards on the fly with Space key
|
||||||
@@ -194,6 +195,47 @@ contexts:
|
|||||||
- **Service**: `service/service-name` or `svc/service-name`
|
- **Service**: `service/service-name` or `svc/service-name`
|
||||||
- **Deployment**: `deployment/deployment-name` or `deploy/deployment-name`
|
- **Deployment**: `deployment/deployment-name` or `deploy/deployment-name`
|
||||||
|
|
||||||
|
### Health Check & Reliability (Advanced)
|
||||||
|
|
||||||
|
kportal includes advanced health checking to prevent stale connections during long-running operations like database dumps:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
healthCheck:
|
||||||
|
interval: "3s" # Health check frequency (default: 3s)
|
||||||
|
timeout: "2s" # Health check timeout (default: 2s)
|
||||||
|
method: "data-transfer" # Check method: "tcp-dial" or "data-transfer" (default: data-transfer)
|
||||||
|
maxConnectionAge: "25m" # Proactive reconnect before k8s timeout (default: 25m)
|
||||||
|
maxIdleTime: "10m" # Detect hung connections (default: 10m)
|
||||||
|
|
||||||
|
reliability:
|
||||||
|
tcpKeepalive: "30s" # TCP keepalive interval (default: 30s)
|
||||||
|
dialTimeout: "30s" # Connection dial timeout (default: 30s)
|
||||||
|
retryOnStale: true # Auto-reconnect stale connections (default: true)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Health Check Methods:**
|
||||||
|
- **`tcp-dial`**: Fast TCP connection test - verifies local port is listening
|
||||||
|
- **`data-transfer`**: More reliable - attempts to read data to verify tunnel is functional
|
||||||
|
|
||||||
|
**Stale Detection:**
|
||||||
|
- **Max Connection Age**: Kubernetes API typically has 30-minute timeout. kportal reconnects at 25 minutes by default to avoid hitting this limit. **Important**: Age-based reconnection only occurs when the connection is ALSO idle - active transfers (like database dumps) are never interrupted.
|
||||||
|
- **Max Idle Time**: Detects connections with no data transfer, common when intermediate firewalls drop idle TCP connections
|
||||||
|
|
||||||
|
**Use Case Example - Database Dumps:**
|
||||||
|
```yaml
|
||||||
|
# Optimized for long-running pg_dump
|
||||||
|
healthCheck:
|
||||||
|
method: "data-transfer"
|
||||||
|
maxConnectionAge: "20m" # Only applies when idle - won't interrupt active dumps
|
||||||
|
maxIdleTime: "5m" # Detects truly stale connections
|
||||||
|
|
||||||
|
reliability:
|
||||||
|
tcpKeepalive: "30s"
|
||||||
|
retryOnStale: true
|
||||||
|
```
|
||||||
|
|
||||||
|
This configuration ensures multi-hour database dumps complete without interruption. The `maxConnectionAge` will only trigger reconnection if the connection has been idle for more than `maxIdleTime`, preventing interruption of active data transfers.
|
||||||
|
|
||||||
## 🎮 Usage
|
## 🎮 Usage
|
||||||
|
|
||||||
### Interactive Mode (Default)
|
### Interactive Mode (Default)
|
||||||
|
|||||||
+23
-3
@@ -12,6 +12,7 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
"github.com/nvm/kportal/internal/config"
|
"github.com/nvm/kportal/internal/config"
|
||||||
"github.com/nvm/kportal/internal/converter"
|
"github.com/nvm/kportal/internal/converter"
|
||||||
"github.com/nvm/kportal/internal/forward"
|
"github.com/nvm/kportal/internal/forward"
|
||||||
@@ -91,16 +92,35 @@ func main() {
|
|||||||
|
|
||||||
// Configure klog (used by kubernetes client-go) to route through our logger
|
// Configure klog (used by kubernetes client-go) to route through our logger
|
||||||
// This prevents k8s logs from interfering with the UI
|
// This prevents k8s logs from interfering with the UI
|
||||||
|
//
|
||||||
|
// klog v2 uses multiple output mechanisms:
|
||||||
|
// 1. SetOutput() - for basic text output
|
||||||
|
// 2. SetLogger() - for structured/error logs (logr interface)
|
||||||
|
//
|
||||||
|
// We must configure BOTH to capture all logs including error messages
|
||||||
|
// that would otherwise bypass SetOutput() and write directly to stderr.
|
||||||
|
klog.LogToStderr(false) // Disable direct stderr writes
|
||||||
if *verbose {
|
if *verbose {
|
||||||
// In verbose mode, route klog through our structured logger at DEBUG level
|
// In verbose mode, route all klog through our structured logger at DEBUG level
|
||||||
klogLogger := logger.New(logger.LevelDebug, logFmt, os.Stderr)
|
klogLogger := logger.New(logger.LevelDebug, logFmt, os.Stderr)
|
||||||
|
|
||||||
|
// Configure text output routing
|
||||||
klogWriter := logger.NewKlogWriter(klogLogger)
|
klogWriter := logger.NewKlogWriter(klogLogger)
|
||||||
klog.SetOutput(klogWriter)
|
klog.SetOutput(klogWriter)
|
||||||
|
|
||||||
|
// Configure structured/error log routing via logr interface
|
||||||
|
// This captures "Unhandled Error" and other structured logs that bypass SetOutput
|
||||||
|
logrSink := logger.NewLogrAdapter(klogLogger)
|
||||||
|
klog.SetLogger(logr.New(logrSink))
|
||||||
} else {
|
} else {
|
||||||
// In non-verbose mode, completely silence klog
|
// In non-verbose mode, completely silence ALL klog output
|
||||||
klog.SetOutput(io.Discard)
|
klog.SetOutput(io.Discard)
|
||||||
|
|
||||||
|
// Also silence structured/error logs via a discard logger
|
||||||
|
silentLogger := logger.New(logger.LevelError+1, logFmt, io.Discard) // Level above ERROR = silence all
|
||||||
|
logrSink := logger.NewLogrAdapter(silentLogger)
|
||||||
|
klog.SetLogger(logr.New(logrSink))
|
||||||
}
|
}
|
||||||
klog.LogToStderr(false)
|
|
||||||
|
|
||||||
// Handle conversion mode
|
// Handle conversion mode
|
||||||
if *convertInput != "" {
|
if *convertInput != "" {
|
||||||
|
|||||||
+107
-36
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
@@ -13,7 +14,112 @@ const (
|
|||||||
|
|
||||||
// Config represents the root configuration structure from .kportal.yaml
|
// Config represents the root configuration structure from .kportal.yaml
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Contexts []Context `yaml:"contexts"`
|
Contexts []Context `yaml:"contexts"`
|
||||||
|
HealthCheck *HealthCheckSpec `yaml:"healthCheck,omitempty"`
|
||||||
|
Reliability *ReliabilitySpec `yaml:"reliability,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheckSpec configures health check behavior
|
||||||
|
type HealthCheckSpec struct {
|
||||||
|
Interval string `yaml:"interval,omitempty"` // e.g., "3s", "5s"
|
||||||
|
Timeout string `yaml:"timeout,omitempty"` // e.g., "2s"
|
||||||
|
Method string `yaml:"method,omitempty"` // "tcp-dial" | "data-transfer"
|
||||||
|
MaxConnectionAge string `yaml:"maxConnectionAge,omitempty"` // e.g., "25m" - reconnect before k8s timeout
|
||||||
|
MaxIdleTime string `yaml:"maxIdleTime,omitempty"` // e.g., "10m" - reconnect if no activity
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReliabilitySpec configures connection reliability features
|
||||||
|
type ReliabilitySpec struct {
|
||||||
|
TCPKeepalive string `yaml:"tcpKeepalive,omitempty"` // e.g., "30s" - OS-level keepalive
|
||||||
|
DialTimeout string `yaml:"dialTimeout,omitempty"` // e.g., "30s" - connection dial timeout
|
||||||
|
RetryOnStale bool `yaml:"retryOnStale,omitempty"` // Auto-reconnect on stale detection
|
||||||
|
WatchdogPeriod string `yaml:"watchdogPeriod,omitempty"` // e.g., "30s" - goroutine watchdog interval
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthCheckIntervalOrDefault returns the health check interval or default value
|
||||||
|
func (c *Config) GetHealthCheckIntervalOrDefault() time.Duration {
|
||||||
|
if c.HealthCheck != nil && c.HealthCheck.Interval != "" {
|
||||||
|
if d, err := time.ParseDuration(c.HealthCheck.Interval); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 3 * time.Second // Default: check every 3 seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthCheckTimeoutOrDefault returns the health check timeout or default value
|
||||||
|
func (c *Config) GetHealthCheckTimeoutOrDefault() time.Duration {
|
||||||
|
if c.HealthCheck != nil && c.HealthCheck.Timeout != "" {
|
||||||
|
if d, err := time.ParseDuration(c.HealthCheck.Timeout); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 2 * time.Second // Default: 2 second timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthCheckMethod returns the health check method or default
|
||||||
|
func (c *Config) GetHealthCheckMethod() string {
|
||||||
|
if c.HealthCheck != nil && c.HealthCheck.Method != "" {
|
||||||
|
return c.HealthCheck.Method
|
||||||
|
}
|
||||||
|
return "data-transfer" // Default: more reliable data transfer test
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxConnectionAge returns the max connection age or default
|
||||||
|
func (c *Config) GetMaxConnectionAge() time.Duration {
|
||||||
|
if c.HealthCheck != nil && c.HealthCheck.MaxConnectionAge != "" {
|
||||||
|
if d, err := time.ParseDuration(c.HealthCheck.MaxConnectionAge); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 25 * time.Minute // Default: 25 minutes (before typical 30min k8s timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMaxIdleTime returns the max idle time or default
|
||||||
|
func (c *Config) GetMaxIdleTime() time.Duration {
|
||||||
|
if c.HealthCheck != nil && c.HealthCheck.MaxIdleTime != "" {
|
||||||
|
if d, err := time.ParseDuration(c.HealthCheck.MaxIdleTime); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 10 * time.Minute // Default: 10 minutes idle before reconnect
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTCPKeepalive returns the TCP keepalive duration or default
|
||||||
|
func (c *Config) GetTCPKeepalive() time.Duration {
|
||||||
|
if c.Reliability != nil && c.Reliability.TCPKeepalive != "" {
|
||||||
|
if d, err := time.ParseDuration(c.Reliability.TCPKeepalive); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 30 * time.Second // Default: 30 second keepalive
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRetryOnStale returns whether to retry on stale connections
|
||||||
|
func (c *Config) GetRetryOnStale() bool {
|
||||||
|
if c.Reliability != nil {
|
||||||
|
return c.Reliability.RetryOnStale
|
||||||
|
}
|
||||||
|
return true // Default: enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWatchdogPeriod returns the goroutine watchdog check period or default
|
||||||
|
func (c *Config) GetWatchdogPeriod() time.Duration {
|
||||||
|
if c.Reliability != nil && c.Reliability.WatchdogPeriod != "" {
|
||||||
|
if d, err := time.ParseDuration(c.Reliability.WatchdogPeriod); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 30 * time.Second // Default: check every 30 seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDialTimeout returns the connection dial timeout or default
|
||||||
|
func (c *Config) GetDialTimeout() time.Duration {
|
||||||
|
if c.Reliability != nil && c.Reliability.DialTimeout != "" {
|
||||||
|
if d, err := time.ParseDuration(c.Reliability.DialTimeout); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 30 * time.Second // Default: 30 second dial timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context represents a Kubernetes context with its namespaces
|
// Context represents a Kubernetes context with its namespaces
|
||||||
@@ -136,38 +242,3 @@ func (c *Config) GetAllForwards() []Forward {
|
|||||||
|
|
||||||
return forwards
|
return forwards
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetForwardsByContext returns all forwards for a specific context.
|
|
||||||
func (c *Config) GetForwardsByContext(contextName string) []Forward {
|
|
||||||
var forwards []Forward
|
|
||||||
|
|
||||||
for _, ctx := range c.Contexts {
|
|
||||||
if ctx.Name == contextName {
|
|
||||||
for _, ns := range ctx.Namespaces {
|
|
||||||
forwards = append(forwards, ns.Forwards...)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return forwards
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetForwardsByNamespace returns all forwards for a specific context and namespace.
|
|
||||||
func (c *Config) GetForwardsByNamespace(contextName, namespaceName string) []Forward {
|
|
||||||
var forwards []Forward
|
|
||||||
|
|
||||||
for _, ctx := range c.Contexts {
|
|
||||||
if ctx.Name == contextName {
|
|
||||||
for _, ns := range ctx.Namespaces {
|
|
||||||
if ns.Name == namespaceName {
|
|
||||||
forwards = append(forwards, ns.Forwards...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return forwards
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -298,72 +298,6 @@ func TestConfig_GetAllForwards(t *testing.T) {
|
|||||||
assert.Len(t, forwards, 4, "should return all forwards from all contexts and namespaces")
|
assert.Len(t, forwards, 4, "should return all forwards from all contexts and namespaces")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfig_GetForwardsByContext(t *testing.T) {
|
|
||||||
yamlData := []byte(`contexts:
|
|
||||||
- name: cluster1
|
|
||||||
namespaces:
|
|
||||||
- name: ns1
|
|
||||||
forwards:
|
|
||||||
- resource: pod/app1
|
|
||||||
port: 8080
|
|
||||||
localPort: 8080
|
|
||||||
- resource: pod/app2
|
|
||||||
port: 8081
|
|
||||||
localPort: 8081
|
|
||||||
- name: cluster2
|
|
||||||
namespaces:
|
|
||||||
- name: ns2
|
|
||||||
forwards:
|
|
||||||
- resource: pod/app3
|
|
||||||
port: 9090
|
|
||||||
localPort: 9090
|
|
||||||
`)
|
|
||||||
|
|
||||||
cfg, err := ParseConfig(yamlData)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
forwards := cfg.GetForwardsByContext("cluster1")
|
|
||||||
assert.Len(t, forwards, 2, "should return forwards only from cluster1")
|
|
||||||
|
|
||||||
forwards2 := cfg.GetForwardsByContext("cluster2")
|
|
||||||
assert.Len(t, forwards2, 1, "should return forwards only from cluster2")
|
|
||||||
|
|
||||||
forwards3 := cfg.GetForwardsByContext("non-existent")
|
|
||||||
assert.Len(t, forwards3, 0, "should return empty slice for non-existent context")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConfig_GetForwardsByNamespace(t *testing.T) {
|
|
||||||
yamlData := []byte(`contexts:
|
|
||||||
- name: cluster1
|
|
||||||
namespaces:
|
|
||||||
- name: ns1
|
|
||||||
forwards:
|
|
||||||
- resource: pod/app1
|
|
||||||
port: 8080
|
|
||||||
localPort: 8080
|
|
||||||
- resource: pod/app2
|
|
||||||
port: 8081
|
|
||||||
localPort: 8081
|
|
||||||
- name: ns2
|
|
||||||
forwards:
|
|
||||||
- resource: pod/app3
|
|
||||||
port: 9090
|
|
||||||
localPort: 9090
|
|
||||||
`)
|
|
||||||
|
|
||||||
cfg, err := ParseConfig(yamlData)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
forwards := cfg.GetForwardsByNamespace("cluster1", "ns1")
|
|
||||||
assert.Len(t, forwards, 2, "should return forwards only from cluster1/ns1")
|
|
||||||
|
|
||||||
forwards2 := cfg.GetForwardsByNamespace("cluster1", "ns2")
|
|
||||||
assert.Len(t, forwards2, 1, "should return forwards only from cluster1/ns2")
|
|
||||||
|
|
||||||
forwards3 := cfg.GetForwardsByNamespace("cluster1", "non-existent")
|
|
||||||
assert.Len(t, forwards3, 0, "should return empty slice for non-existent namespace")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForward_SetContext(t *testing.T) {
|
func TestForward_SetContext(t *testing.T) {
|
||||||
fwd := Forward{
|
fwd := Forward{
|
||||||
Resource: "pod/my-app",
|
Resource: "pod/my-app",
|
||||||
|
|||||||
+121
-14
@@ -12,11 +12,6 @@ import (
|
|||||||
"github.com/nvm/kportal/internal/logger"
|
"github.com/nvm/kportal/internal/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
healthCheckInterval = 5 * time.Second
|
|
||||||
healthCheckTimeout = 2 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
// StatusUpdater is an interface for updating forward status
|
// StatusUpdater is an interface for updating forward status
|
||||||
type StatusUpdater interface {
|
type StatusUpdater interface {
|
||||||
UpdateStatus(id string, status string)
|
UpdateStatus(id string, status string)
|
||||||
@@ -34,12 +29,15 @@ type Manager struct {
|
|||||||
portForwarder *k8s.PortForwarder
|
portForwarder *k8s.PortForwarder
|
||||||
portChecker *PortChecker
|
portChecker *PortChecker
|
||||||
healthChecker *healthcheck.Checker
|
healthChecker *healthcheck.Checker
|
||||||
|
watchdog *Watchdog
|
||||||
verbose bool
|
verbose bool
|
||||||
currentConfig *config.Config
|
currentConfig *config.Config
|
||||||
statusUI StatusUpdater
|
statusUI StatusUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new forward Manager.
|
// NewManager creates a new forward Manager.
|
||||||
|
// The health checker will be created with default settings and can be
|
||||||
|
// reconfigured via SetConfig().
|
||||||
func NewManager(verbose bool) (*Manager, error) {
|
func NewManager(verbose bool) (*Manager, error) {
|
||||||
clientPool, err := k8s.NewClientPool()
|
clientPool, err := k8s.NewClientPool()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -49,8 +47,13 @@ func NewManager(verbose bool) (*Manager, error) {
|
|||||||
resolver := k8s.NewResourceResolver(clientPool)
|
resolver := k8s.NewResourceResolver(clientPool)
|
||||||
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
|
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
|
||||||
|
|
||||||
// Create health checker: check every 5 seconds with 2 second timeout
|
// Create health checker with defaults: check every 3 seconds with 2 second timeout
|
||||||
healthChecker := healthcheck.NewChecker(healthCheckInterval, healthCheckTimeout)
|
// Will be reconfigured when config is loaded
|
||||||
|
healthChecker := healthcheck.NewChecker(3*time.Second, 2*time.Second)
|
||||||
|
|
||||||
|
// Create watchdog with default settings: check every 30 seconds, 60 second hang threshold
|
||||||
|
// Will be reconfigured when config is loaded
|
||||||
|
watchdog := NewWatchdog(30*time.Second, 60*time.Second)
|
||||||
|
|
||||||
return &Manager{
|
return &Manager{
|
||||||
workers: make(map[string]*ForwardWorker),
|
workers: make(map[string]*ForwardWorker),
|
||||||
@@ -59,10 +62,56 @@ func NewManager(verbose bool) (*Manager, error) {
|
|||||||
portForwarder: portForwarder,
|
portForwarder: portForwarder,
|
||||||
portChecker: NewPortChecker(),
|
portChecker: NewPortChecker(),
|
||||||
healthChecker: healthChecker,
|
healthChecker: healthChecker,
|
||||||
|
watchdog: watchdog,
|
||||||
verbose: verbose,
|
verbose: verbose,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// configureHealthChecker creates a new health checker with settings from config
|
||||||
|
func (m *Manager) configureHealthChecker(cfg *config.Config) {
|
||||||
|
// Stop existing health checker
|
||||||
|
if m.healthChecker != nil {
|
||||||
|
m.healthChecker.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse check method
|
||||||
|
methodStr := cfg.GetHealthCheckMethod()
|
||||||
|
var method healthcheck.CheckMethod
|
||||||
|
switch methodStr {
|
||||||
|
case "tcp-dial":
|
||||||
|
method = healthcheck.CheckMethodTCPDial
|
||||||
|
case "data-transfer":
|
||||||
|
method = healthcheck.CheckMethodDataTransfer
|
||||||
|
default:
|
||||||
|
method = healthcheck.CheckMethodDataTransfer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new health checker with config settings
|
||||||
|
m.healthChecker = healthcheck.NewCheckerWithOptions(healthcheck.CheckerOptions{
|
||||||
|
Interval: cfg.GetHealthCheckIntervalOrDefault(),
|
||||||
|
Timeout: cfg.GetHealthCheckTimeoutOrDefault(),
|
||||||
|
Method: method,
|
||||||
|
MaxConnectionAge: cfg.GetMaxConnectionAge(),
|
||||||
|
MaxIdleTime: cfg.GetMaxIdleTime(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Configure TCP settings on port forwarder
|
||||||
|
tcpKeepalive := cfg.GetTCPKeepalive()
|
||||||
|
dialTimeout := cfg.GetDialTimeout()
|
||||||
|
m.portForwarder.SetTCPKeepalive(tcpKeepalive)
|
||||||
|
m.portForwarder.SetDialTimeout(dialTimeout)
|
||||||
|
|
||||||
|
logger.Info("Health checker and reliability configured", map[string]interface{}{
|
||||||
|
"interval": cfg.GetHealthCheckIntervalOrDefault().String(),
|
||||||
|
"timeout": cfg.GetHealthCheckTimeoutOrDefault().String(),
|
||||||
|
"method": methodStr,
|
||||||
|
"max_connection_age": cfg.GetMaxConnectionAge().String(),
|
||||||
|
"max_idle_time": cfg.GetMaxIdleTime().String(),
|
||||||
|
"tcp_keepalive": tcpKeepalive.String(),
|
||||||
|
"dial_timeout": dialTimeout.String(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SetStatusUI sets the status updater for the manager
|
// SetStatusUI sets the status updater for the manager
|
||||||
func (m *Manager) SetStatusUI(ui StatusUpdater) {
|
func (m *Manager) SetStatusUI(ui StatusUpdater) {
|
||||||
m.statusUI = ui
|
m.statusUI = ui
|
||||||
@@ -76,6 +125,20 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||||||
|
|
||||||
m.currentConfig = cfg
|
m.currentConfig = cfg
|
||||||
|
|
||||||
|
// Configure health checker with settings from config
|
||||||
|
m.configureHealthChecker(cfg)
|
||||||
|
|
||||||
|
// Start watchdog
|
||||||
|
watchdogPeriod := cfg.GetWatchdogPeriod()
|
||||||
|
m.watchdog.checkInterval = watchdogPeriod
|
||||||
|
m.watchdog.hangThreshold = watchdogPeriod * 2 // Hang threshold is 2x check interval
|
||||||
|
m.watchdog.Start()
|
||||||
|
|
||||||
|
logger.Info("Watchdog started", map[string]interface{}{
|
||||||
|
"check_interval": watchdogPeriod.String(),
|
||||||
|
"hang_threshold": (watchdogPeriod * 2).String(),
|
||||||
|
})
|
||||||
|
|
||||||
// Get all forwards from config
|
// Get all forwards from config
|
||||||
forwards := cfg.GetAllForwards()
|
forwards := cfg.GetAllForwards()
|
||||||
|
|
||||||
@@ -119,8 +182,9 @@ func (m *Manager) Start(cfg *config.Config) error {
|
|||||||
func (m *Manager) Stop() {
|
func (m *Manager) Stop() {
|
||||||
log.Printf("Stopping all port-forwards...")
|
log.Printf("Stopping all port-forwards...")
|
||||||
|
|
||||||
// Stop health checker first
|
// Stop health checker and watchdog first
|
||||||
m.healthChecker.Stop()
|
m.healthChecker.Stop()
|
||||||
|
m.watchdog.Stop()
|
||||||
|
|
||||||
m.workersMu.Lock()
|
m.workersMu.Lock()
|
||||||
workers := make([]*ForwardWorker, 0, len(m.workers))
|
workers := make([]*ForwardWorker, 0, len(m.workers))
|
||||||
@@ -273,21 +337,54 @@ func (m *Manager) startWorker(fwd config.Forward) error {
|
|||||||
m.statusUI.AddForward(fwd.ID(), &fwd)
|
m.statusUI.AddForward(fwd.ID(), &fwd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register with watchdog
|
||||||
|
m.watchdog.RegisterWorker(fwd.ID(), func(forwardID string) {
|
||||||
|
logger.Warn("Watchdog triggered reconnection for hung worker", map[string]interface{}{
|
||||||
|
"forward_id": forwardID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Find and trigger reconnect on hung worker
|
||||||
|
m.workersMu.RLock()
|
||||||
|
worker, exists := m.workers[forwardID]
|
||||||
|
m.workersMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
worker.TriggerReconnect("watchdog detected hung worker")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Register with health checker
|
// Register with health checker
|
||||||
m.healthChecker.Register(fwd.ID(), fwd.LocalPort, func(forwardID string, status healthcheck.Status, errorMsg string) {
|
m.healthChecker.Register(fwd.ID(), fwd.LocalPort, func(forwardID string, status healthcheck.Status, errorMsg string) {
|
||||||
if m.statusUI != nil {
|
if m.statusUI != nil {
|
||||||
m.statusUI.UpdateStatus(forwardID, string(status))
|
m.statusUI.UpdateStatus(forwardID, string(status))
|
||||||
// Send error separately if there is one
|
// Send error separately if there is one
|
||||||
if status == healthcheck.StatusUnhealthy && errorMsg != "" {
|
if (status == healthcheck.StatusUnhealthy || status == healthcheck.StatusStale) && errorMsg != "" {
|
||||||
if ui, ok := m.statusUI.(interface{ SetError(id, msg string) }); ok {
|
if ui, ok := m.statusUI.(interface{ SetError(id, msg string) }); ok {
|
||||||
ui.SetError(forwardID, errorMsg)
|
ui.SetError(forwardID, errorMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle stale connections: trigger reconnection if retryOnStale is enabled
|
||||||
|
if status == healthcheck.StatusStale && m.currentConfig.GetRetryOnStale() {
|
||||||
|
logger.Info("Stale connection detected, triggering reconnection", map[string]interface{}{
|
||||||
|
"forward_id": forwardID,
|
||||||
|
"reason": errorMsg,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Find and notify the worker to reconnect
|
||||||
|
m.workersMu.RLock()
|
||||||
|
worker, exists := m.workers[forwardID]
|
||||||
|
m.workersMu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
worker.TriggerReconnect("stale connection")
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Create and start worker
|
// Create and start worker
|
||||||
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker)
|
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker, m.watchdog)
|
||||||
worker.Start()
|
worker.Start()
|
||||||
|
|
||||||
// Store worker
|
// Store worker
|
||||||
@@ -298,6 +395,11 @@ func (m *Manager) startWorker(fwd config.Forward) error {
|
|||||||
|
|
||||||
// stopWorker stops and removes a forward worker.
|
// stopWorker stops and removes a forward worker.
|
||||||
func (m *Manager) stopWorker(id string) error {
|
func (m *Manager) stopWorker(id string) error {
|
||||||
|
return m.stopWorkerInternal(id, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopWorkerInternal stops a worker with option to remove from UI or just update status
|
||||||
|
func (m *Manager) stopWorkerInternal(id string, removeFromUI bool) error {
|
||||||
m.workersMu.Lock()
|
m.workersMu.Lock()
|
||||||
worker, exists := m.workers[id]
|
worker, exists := m.workers[id]
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -307,12 +409,17 @@ func (m *Manager) stopWorker(id string) error {
|
|||||||
delete(m.workers, id)
|
delete(m.workers, id)
|
||||||
m.workersMu.Unlock()
|
m.workersMu.Unlock()
|
||||||
|
|
||||||
// Unregister from health checker
|
// Unregister from health checker and watchdog
|
||||||
m.healthChecker.Unregister(id)
|
m.healthChecker.Unregister(id)
|
||||||
|
m.watchdog.UnregisterWorker(id)
|
||||||
|
|
||||||
// Notify UI to remove the forward
|
// Notify UI - either remove or update to disabled status
|
||||||
if m.statusUI != nil {
|
if m.statusUI != nil {
|
||||||
m.statusUI.Remove(id)
|
if removeFromUI {
|
||||||
|
m.statusUI.Remove(id)
|
||||||
|
} else {
|
||||||
|
m.statusUI.UpdateStatus(id, "Disabled")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the worker
|
// Stop the worker
|
||||||
@@ -363,7 +470,7 @@ func (m *Manager) getResourceForPort(forwards []config.Forward, port int) string
|
|||||||
|
|
||||||
// DisableForward temporarily stops a forward by ID
|
// DisableForward temporarily stops a forward by ID
|
||||||
func (m *Manager) DisableForward(id string) error {
|
func (m *Manager) DisableForward(id string) error {
|
||||||
if err := m.stopWorker(id); err != nil {
|
if err := m.stopWorkerInternal(id, false); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("Disabled: %s", id)
|
log.Printf("Disabled: %s", id)
|
||||||
|
|||||||
@@ -0,0 +1,158 @@
|
|||||||
|
package forward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/nvm/kportal/internal/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Watchdog monitors worker goroutines to detect hung workers
|
||||||
|
type Watchdog struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
workers map[string]*workerState // key: forward ID
|
||||||
|
checkInterval time.Duration
|
||||||
|
hangThreshold time.Duration // How long without heartbeat before considered hung
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// workerState tracks the health of a single worker
|
||||||
|
type workerState struct {
|
||||||
|
forwardID string
|
||||||
|
lastHeartbeat time.Time
|
||||||
|
heartbeatCount uint64
|
||||||
|
isHung bool
|
||||||
|
onHungCallback func(forwardID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWatchdog creates a new goroutine watchdog
|
||||||
|
func NewWatchdog(checkInterval, hangThreshold time.Duration) *Watchdog {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &Watchdog{
|
||||||
|
workers: make(map[string]*workerState),
|
||||||
|
checkInterval: checkInterval,
|
||||||
|
hangThreshold: hangThreshold,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the watchdog monitoring loop
|
||||||
|
func (w *Watchdog) Start() {
|
||||||
|
w.wg.Add(1)
|
||||||
|
go w.monitorLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the watchdog
|
||||||
|
func (w *Watchdog) Stop() {
|
||||||
|
w.cancel()
|
||||||
|
w.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterWorker adds a worker to monitor
|
||||||
|
func (w *Watchdog) RegisterWorker(forwardID string, onHungCallback func(string)) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
w.workers[forwardID] = &workerState{
|
||||||
|
forwardID: forwardID,
|
||||||
|
lastHeartbeat: time.Now(),
|
||||||
|
heartbeatCount: 0,
|
||||||
|
isHung: false,
|
||||||
|
onHungCallback: onHungCallback,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Watchdog registered worker", map[string]interface{}{
|
||||||
|
"forward_id": forwardID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterWorker removes a worker from monitoring
|
||||||
|
func (w *Watchdog) UnregisterWorker(forwardID string) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
delete(w.workers, forwardID)
|
||||||
|
|
||||||
|
logger.Debug("Watchdog unregistered worker", map[string]interface{}{
|
||||||
|
"forward_id": forwardID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat records that a worker is alive and processing
|
||||||
|
// Workers should call this periodically (e.g., in their main loop)
|
||||||
|
func (w *Watchdog) Heartbeat(forwardID string) {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
if state, exists := w.workers[forwardID]; exists {
|
||||||
|
state.lastHeartbeat = time.Now()
|
||||||
|
state.heartbeatCount++
|
||||||
|
state.isHung = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWorkerState returns the current state of a worker (for testing)
|
||||||
|
func (w *Watchdog) GetWorkerState(forwardID string) (lastHeartbeat time.Time, count uint64, exists bool) {
|
||||||
|
w.mu.RLock()
|
||||||
|
defer w.mu.RUnlock()
|
||||||
|
|
||||||
|
if state, ok := w.workers[forwardID]; ok {
|
||||||
|
return state.lastHeartbeat, state.heartbeatCount, true
|
||||||
|
}
|
||||||
|
return time.Time{}, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitorLoop periodically checks all workers
|
||||||
|
func (w *Watchdog) monitorLoop() {
|
||||||
|
defer w.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(w.checkInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-w.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
w.checkWorkers()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkWorkers checks all registered workers for hung state
|
||||||
|
func (w *Watchdog) checkWorkers() {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for forwardID, state := range w.workers {
|
||||||
|
timeSinceHeartbeat := now.Sub(state.lastHeartbeat)
|
||||||
|
|
||||||
|
// Check if worker is hung
|
||||||
|
if timeSinceHeartbeat > w.hangThreshold {
|
||||||
|
if !state.isHung {
|
||||||
|
// First time detecting hung state
|
||||||
|
state.isHung = true
|
||||||
|
|
||||||
|
logger.Warn("Watchdog detected hung worker", map[string]interface{}{
|
||||||
|
"forward_id": forwardID,
|
||||||
|
"time_since_heartbeat": timeSinceHeartbeat.String(),
|
||||||
|
"hang_threshold": w.hangThreshold.String(),
|
||||||
|
"heartbeat_count": state.heartbeatCount,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Trigger callback to handle hung worker (without holding lock)
|
||||||
|
if state.onHungCallback != nil {
|
||||||
|
callback := state.onHungCallback
|
||||||
|
w.mu.Unlock()
|
||||||
|
callback(forwardID)
|
||||||
|
w.mu.Lock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,310 @@
|
|||||||
|
package forward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WatchdogTestSuite contains tests for the watchdog
|
||||||
|
type WatchdogTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
watchdog *Watchdog
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatchdogSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(WatchdogTestSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatchdogTestSuite) SetupTest() {
|
||||||
|
// Create watchdog with fast intervals for testing
|
||||||
|
s.watchdog = NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
|
||||||
|
s.watchdog.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *WatchdogTestSuite) TearDownTest() {
|
||||||
|
if s.watchdog != nil {
|
||||||
|
s.watchdog.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterUnregister tests basic registration and unregistration
|
||||||
|
func (s *WatchdogTestSuite) TestRegisterUnregister() {
|
||||||
|
callbackCalled := false
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
callbackCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register worker
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Verify worker is registered
|
||||||
|
_, _, exists := s.watchdog.GetWorkerState("test-forward")
|
||||||
|
assert.True(s.T(), exists, "Worker should be registered")
|
||||||
|
|
||||||
|
// Unregister worker
|
||||||
|
s.watchdog.UnregisterWorker("test-forward")
|
||||||
|
|
||||||
|
// Verify worker is unregistered
|
||||||
|
_, _, exists = s.watchdog.GetWorkerState("test-forward")
|
||||||
|
assert.False(s.T(), exists, "Worker should be unregistered")
|
||||||
|
assert.False(s.T(), callbackCalled, "Callback should not have been called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHeartbeat tests that heartbeats update worker state
|
||||||
|
func (s *WatchdogTestSuite) TestHeartbeat() {
|
||||||
|
s.watchdog.RegisterWorker("test-forward", nil)
|
||||||
|
|
||||||
|
// Send initial heartbeat
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
lastHeartbeat1, count1, exists := s.watchdog.GetWorkerState("test-forward")
|
||||||
|
require.True(s.T(), exists)
|
||||||
|
assert.Equal(s.T(), uint64(1), count1)
|
||||||
|
|
||||||
|
// Wait a bit
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
// Send another heartbeat
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
lastHeartbeat2, count2, exists := s.watchdog.GetWorkerState("test-forward")
|
||||||
|
require.True(s.T(), exists)
|
||||||
|
assert.Equal(s.T(), uint64(2), count2)
|
||||||
|
assert.True(s.T(), lastHeartbeat2.After(lastHeartbeat1), "Second heartbeat should be after first")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHungWorkerDetection tests that hung workers are detected
|
||||||
|
func (s *WatchdogTestSuite) TestHungWorkerDetection() {
|
||||||
|
callbackCalled := make(chan string, 1)
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
callbackCalled <- forwardID
|
||||||
|
}
|
||||||
|
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Send initial heartbeat
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
// Wait for worker to be considered hung (300ms threshold + 100ms check interval)
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case forwardID := <-callbackCalled:
|
||||||
|
assert.Equal(s.T(), "test-forward", forwardID)
|
||||||
|
case <-timeout:
|
||||||
|
s.T().Fatal("Timeout waiting for hung worker callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHealthyWorkerNotDetectedAsHung tests that workers sending heartbeats are not considered hung
|
||||||
|
func (s *WatchdogTestSuite) TestHealthyWorkerNotDetectedAsHung() {
|
||||||
|
callbackCalled := false
|
||||||
|
var mu sync.Mutex
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
callbackCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Send periodic heartbeats (faster than hang threshold)
|
||||||
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-ticker.C
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for all heartbeats to complete
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// Check that callback was not called
|
||||||
|
mu.Lock()
|
||||||
|
assert.False(s.T(), callbackCalled, "Callback should not be called for healthy worker")
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMultipleWorkers tests monitoring multiple workers simultaneously
|
||||||
|
func (s *WatchdogTestSuite) TestMultipleWorkers() {
|
||||||
|
callbacks := make(map[string]int)
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
makeCallback := func(id string) func(string) {
|
||||||
|
return func(forwardID string) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
callbacks[id]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register multiple workers
|
||||||
|
s.watchdog.RegisterWorker("worker-1", makeCallback("worker-1"))
|
||||||
|
s.watchdog.RegisterWorker("worker-2", makeCallback("worker-2"))
|
||||||
|
s.watchdog.RegisterWorker("worker-3", makeCallback("worker-3"))
|
||||||
|
|
||||||
|
// worker-1: Keep sending heartbeats (healthy)
|
||||||
|
ticker1 := time.NewTicker(50 * time.Millisecond)
|
||||||
|
defer ticker1.Stop()
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-ticker1.C
|
||||||
|
s.watchdog.Heartbeat("worker-1")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// worker-2: Send initial heartbeat then stop (will become hung)
|
||||||
|
s.watchdog.Heartbeat("worker-2")
|
||||||
|
|
||||||
|
// worker-3: Send initial heartbeat then stop (will become hung)
|
||||||
|
s.watchdog.Heartbeat("worker-3")
|
||||||
|
|
||||||
|
// Wait for hung workers to be detected
|
||||||
|
time.Sleep(600 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check results
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), 0, callbacks["worker-1"], "worker-1 should not trigger callback (healthy)")
|
||||||
|
assert.Greater(s.T(), callbacks["worker-2"], 0, "worker-2 should trigger callback (hung)")
|
||||||
|
assert.Greater(s.T(), callbacks["worker-3"], 0, "worker-3 should trigger callback (hung)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCallbackOnlyOnFirstDetection tests that callback is only called once when hung is first detected
|
||||||
|
func (s *WatchdogTestSuite) TestCallbackOnlyOnFirstDetection() {
|
||||||
|
callbackCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
callbackCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Send initial heartbeat
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
// Wait for multiple check cycles
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
// Check that callback was only called once
|
||||||
|
mu.Lock()
|
||||||
|
assert.Equal(s.T(), 1, callbackCount, "Callback should only be called once")
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHeartbeatResetsHungState tests that sending heartbeat after hung detection resets state
|
||||||
|
func (s *WatchdogTestSuite) TestHeartbeatResetsHungState() {
|
||||||
|
callbackCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
callbackCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Send initial heartbeat
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
// Wait for hung detection
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
firstCount := callbackCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), 1, firstCount, "First hung detection should trigger callback")
|
||||||
|
|
||||||
|
// Send heartbeat to reset hung state
|
||||||
|
s.watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
// Wait for worker to become hung again
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
secondCount := callbackCount
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
assert.Equal(s.T(), 2, secondCount, "Second hung detection should trigger callback again")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentOperations tests thread safety
|
||||||
|
func (s *WatchdogTestSuite) TestConcurrentOperations() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numWorkers := 10
|
||||||
|
|
||||||
|
for i := 0; i < numWorkers; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
forwardID := string(rune('a' + id))
|
||||||
|
s.watchdog.RegisterWorker(forwardID, nil)
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
s.watchdog.Heartbeat(forwardID)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
s.watchdog.UnregisterWorker(forwardID)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
// If we get here without deadlocks or panics, test passes
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStopWatchdog tests that stopping watchdog cleans up properly
|
||||||
|
func TestStopWatchdog(t *testing.T) {
|
||||||
|
watchdog := NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
|
||||||
|
watchdog.Start()
|
||||||
|
|
||||||
|
callbackCalled := false
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
callbackCalled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
watchdog.Heartbeat("test-forward")
|
||||||
|
|
||||||
|
// Stop watchdog before hang detection
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
watchdog.Stop()
|
||||||
|
|
||||||
|
// Wait to ensure no more callbacks after stop
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
assert.False(t, callbackCalled, "Callback should not be called after watchdog is stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWatchdogWithZeroHeartbeats tests detecting hung worker that never sends heartbeats
|
||||||
|
func (s *WatchdogTestSuite) TestWatchdogWithZeroHeartbeats() {
|
||||||
|
callbackCalled := make(chan string, 1)
|
||||||
|
callback := func(forwardID string) {
|
||||||
|
callbackCalled <- forwardID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register worker but never send heartbeat
|
||||||
|
s.watchdog.RegisterWorker("test-forward", callback)
|
||||||
|
|
||||||
|
// Wait for hung detection
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case forwardID := <-callbackCalled:
|
||||||
|
assert.Equal(s.T(), "test-forward", forwardID)
|
||||||
|
case <-timeout:
|
||||||
|
s.T().Fatal("Timeout waiting for hung worker callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
+59
-13
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nvm/kportal/internal/config"
|
"github.com/nvm/kportal/internal/config"
|
||||||
@@ -20,21 +21,25 @@ const (
|
|||||||
|
|
||||||
// ForwardWorker manages a single port-forward connection with automatic retry.
|
// ForwardWorker manages a single port-forward connection with automatic retry.
|
||||||
type ForwardWorker struct {
|
type ForwardWorker struct {
|
||||||
forward config.Forward
|
forward config.Forward
|
||||||
portForwarder *k8s.PortForwarder
|
portForwarder *k8s.PortForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
doneChan chan struct{}
|
doneChan chan struct{}
|
||||||
verbose bool
|
reconnectChan chan string // Channel to trigger reconnection
|
||||||
lastPod string // Track the last pod we connected to
|
verbose bool
|
||||||
statusUI StatusUpdater
|
lastPod string // Track the last pod we connected to
|
||||||
healthChecker *healthcheck.Checker
|
statusUI StatusUpdater
|
||||||
startTime time.Time // Track when the worker started
|
healthChecker *healthcheck.Checker
|
||||||
|
watchdog *Watchdog
|
||||||
|
startTime time.Time // Track when the worker started
|
||||||
|
forwardCancel context.CancelFunc // Cancel function for current forward attempt
|
||||||
|
forwardCancelMu sync.Mutex // Protects forwardCancel
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
|
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
|
||||||
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker) *ForwardWorker {
|
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker, watchdog *Watchdog) *ForwardWorker {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
return &ForwardWorker{
|
return &ForwardWorker{
|
||||||
@@ -44,13 +49,32 @@ func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verb
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
doneChan: make(chan struct{}),
|
doneChan: make(chan struct{}),
|
||||||
|
reconnectChan: make(chan string, 1), // Buffered to avoid blocking
|
||||||
verbose: verbose,
|
verbose: verbose,
|
||||||
statusUI: statusUI,
|
statusUI: statusUI,
|
||||||
healthChecker: healthChecker,
|
healthChecker: healthChecker,
|
||||||
|
watchdog: watchdog,
|
||||||
startTime: time.Now(),
|
startTime: time.Now(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TriggerReconnect triggers a reconnection (e.g., due to stale connection)
|
||||||
|
func (w *ForwardWorker) TriggerReconnect(reason string) {
|
||||||
|
// Cancel current forward if running
|
||||||
|
w.forwardCancelMu.Lock()
|
||||||
|
if w.forwardCancel != nil {
|
||||||
|
w.forwardCancel()
|
||||||
|
}
|
||||||
|
w.forwardCancelMu.Unlock()
|
||||||
|
|
||||||
|
// Send reconnect signal (non-blocking)
|
||||||
|
select {
|
||||||
|
case w.reconnectChan <- reason:
|
||||||
|
default:
|
||||||
|
// Channel already has pending reconnect
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start begins the port-forward worker in a goroutine.
|
// Start begins the port-forward worker in a goroutine.
|
||||||
// The worker will continuously retry on failures with exponential backoff.
|
// The worker will continuously retry on failures with exponential backoff.
|
||||||
func (w *ForwardWorker) Start() {
|
func (w *ForwardWorker) Start() {
|
||||||
@@ -71,6 +95,11 @@ func (w *ForwardWorker) run() {
|
|||||||
backoff := retry.NewBackoff()
|
backoff := retry.NewBackoff()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
// Send heartbeat to watchdog to indicate we're alive
|
||||||
|
if w.watchdog != nil {
|
||||||
|
w.watchdog.Heartbeat(w.forward.ID())
|
||||||
|
}
|
||||||
|
|
||||||
// Check if we should stop
|
// Check if we should stop
|
||||||
select {
|
select {
|
||||||
case <-w.ctx.Done():
|
case <-w.ctx.Done():
|
||||||
@@ -184,11 +213,24 @@ func (w *ForwardWorker) establishForward(podName string) error {
|
|||||||
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
|
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
|
||||||
defer forwardCancel()
|
defer forwardCancel()
|
||||||
|
|
||||||
// Start a goroutine to monitor for stop signal
|
// Store cancel function so TriggerReconnect can use it
|
||||||
|
w.forwardCancelMu.Lock()
|
||||||
|
w.forwardCancel = forwardCancel
|
||||||
|
w.forwardCancelMu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
w.forwardCancelMu.Lock()
|
||||||
|
w.forwardCancel = nil
|
||||||
|
w.forwardCancelMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start a goroutine to monitor for stop signal and reconnect triggers
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-w.stopChan:
|
case <-w.stopChan:
|
||||||
close(stopChan)
|
close(stopChan)
|
||||||
|
case <-w.reconnectChan:
|
||||||
|
close(stopChan)
|
||||||
case <-forwardCtx.Done():
|
case <-forwardCtx.Done():
|
||||||
close(stopChan)
|
close(stopChan)
|
||||||
}
|
}
|
||||||
@@ -230,6 +272,10 @@ func (w *ForwardWorker) establishForward(podName string) error {
|
|||||||
if w.verbose {
|
if w.verbose {
|
||||||
log.Printf("[%s] Port-forward connection established", w.forward.ID())
|
log.Printf("[%s] Port-forward connection established", w.forward.ID())
|
||||||
}
|
}
|
||||||
|
// Mark connection as established in health checker
|
||||||
|
if w.healthChecker != nil {
|
||||||
|
w.healthChecker.MarkConnected(w.forward.ID())
|
||||||
|
}
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
return fmt.Errorf("failed to establish forward: %w", err)
|
return fmt.Errorf("failed to establish forward: %w", err)
|
||||||
case <-w.ctx.Done():
|
case <-w.ctx.Done():
|
||||||
|
|||||||
+183
-43
@@ -3,6 +3,7 @@ package healthcheck
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -10,6 +11,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
startupGracePeriod = 10 * time.Second
|
startupGracePeriod = 10 * time.Second
|
||||||
|
dataTransferSize = 1024 // bytes to read in data transfer test
|
||||||
)
|
)
|
||||||
|
|
||||||
// Status represents the health status of a port forward
|
// Status represents the health status of a port forward
|
||||||
@@ -20,15 +22,26 @@ const (
|
|||||||
StatusUnhealthy Status = "Error"
|
StatusUnhealthy Status = "Error"
|
||||||
StatusStarting Status = "Starting"
|
StatusStarting Status = "Starting"
|
||||||
StatusReconnect Status = "Reconnecting"
|
StatusReconnect Status = "Reconnecting"
|
||||||
|
StatusStale Status = "Stale" // Connection is old or idle
|
||||||
|
)
|
||||||
|
|
||||||
|
// CheckMethod represents the health check method
|
||||||
|
type CheckMethod string
|
||||||
|
|
||||||
|
const (
|
||||||
|
CheckMethodTCPDial CheckMethod = "tcp-dial" // Simple TCP connection test
|
||||||
|
CheckMethodDataTransfer CheckMethod = "data-transfer" // Try to read data from connection
|
||||||
)
|
)
|
||||||
|
|
||||||
// PortHealth represents the health status of a single port
|
// PortHealth represents the health status of a single port
|
||||||
type PortHealth struct {
|
type PortHealth struct {
|
||||||
Port int
|
Port int
|
||||||
LastCheck time.Time
|
LastCheck time.Time
|
||||||
Status Status
|
Status Status
|
||||||
ErrorMessage string
|
ErrorMessage string
|
||||||
RegisteredAt time.Time // When this port was registered
|
RegisteredAt time.Time // When this port was registered
|
||||||
|
ConnectionTime time.Time // When current connection was established
|
||||||
|
LastActivity time.Time // Last time data was transferred
|
||||||
}
|
}
|
||||||
|
|
||||||
// StatusCallback is called when a port's health status changes
|
// StatusCallback is called when a port's health status changes
|
||||||
@@ -36,26 +49,52 @@ type StatusCallback func(forwardID string, status Status, errorMsg string)
|
|||||||
|
|
||||||
// Checker performs periodic health checks on local ports
|
// Checker performs periodic health checks on local ports
|
||||||
type Checker struct {
|
type Checker struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ports map[string]*PortHealth // key: forward ID
|
ports map[string]*PortHealth // key: forward ID
|
||||||
callbacks map[string]StatusCallback
|
callbacks map[string]StatusCallback
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ctx context.Context
|
method CheckMethod
|
||||||
cancel context.CancelFunc
|
maxConnectionAge time.Duration
|
||||||
wg sync.WaitGroup
|
maxIdleTime time.Duration
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChecker creates a new health checker
|
// CheckerOptions configures the health checker
|
||||||
|
type CheckerOptions struct {
|
||||||
|
Interval time.Duration
|
||||||
|
Timeout time.Duration
|
||||||
|
Method CheckMethod
|
||||||
|
MaxConnectionAge time.Duration
|
||||||
|
MaxIdleTime time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChecker creates a new health checker with default options
|
||||||
func NewChecker(interval, timeout time.Duration) *Checker {
|
func NewChecker(interval, timeout time.Duration) *Checker {
|
||||||
|
return NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: interval,
|
||||||
|
Timeout: timeout,
|
||||||
|
Method: CheckMethodDataTransfer,
|
||||||
|
MaxConnectionAge: 25 * time.Minute,
|
||||||
|
MaxIdleTime: 10 * time.Minute,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCheckerWithOptions creates a new health checker with custom options
|
||||||
|
func NewCheckerWithOptions(opts CheckerOptions) *Checker {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Checker{
|
return &Checker{
|
||||||
ports: make(map[string]*PortHealth),
|
ports: make(map[string]*PortHealth),
|
||||||
callbacks: make(map[string]StatusCallback),
|
callbacks: make(map[string]StatusCallback),
|
||||||
interval: interval,
|
interval: opts.Interval,
|
||||||
timeout: timeout,
|
timeout: opts.Timeout,
|
||||||
ctx: ctx,
|
method: opts.Method,
|
||||||
cancel: cancel,
|
maxConnectionAge: opts.MaxConnectionAge,
|
||||||
|
maxIdleTime: opts.MaxIdleTime,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,11 +103,14 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
c.ports[forwardID] = &PortHealth{
|
c.ports[forwardID] = &PortHealth{
|
||||||
Port: port,
|
Port: port,
|
||||||
LastCheck: time.Time{},
|
LastCheck: time.Time{},
|
||||||
Status: StatusStarting,
|
Status: StatusStarting,
|
||||||
RegisteredAt: time.Now(),
|
RegisteredAt: now,
|
||||||
|
ConnectionTime: now,
|
||||||
|
LastActivity: now,
|
||||||
}
|
}
|
||||||
c.callbacks[forwardID] = callback
|
c.callbacks[forwardID] = callback
|
||||||
|
|
||||||
@@ -77,6 +119,28 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
|
|||||||
go c.checkLoop(forwardID)
|
go c.checkLoop(forwardID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarkConnected marks a forward as having established a new connection
|
||||||
|
func (c *Checker) MarkConnected(forwardID string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if health, exists := c.ports[forwardID]; exists {
|
||||||
|
now := time.Now()
|
||||||
|
health.ConnectionTime = now
|
||||||
|
health.LastActivity = now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordActivity records data transfer activity for a forward
|
||||||
|
func (c *Checker) RecordActivity(forwardID string) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if health, exists := c.ports[forwardID]; exists {
|
||||||
|
health.LastActivity = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Unregister removes a port from monitoring
|
// Unregister removes a port from monitoring
|
||||||
func (c *Checker) Unregister(forwardID string) {
|
func (c *Checker) Unregister(forwardID string) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
@@ -197,37 +261,57 @@ func (c *Checker) checkPort(forwardID string) {
|
|||||||
port := health.Port
|
port := health.Port
|
||||||
oldStatus := health.Status
|
oldStatus := health.Status
|
||||||
registeredAt := health.RegisteredAt
|
registeredAt := health.RegisteredAt
|
||||||
|
connectionTime := health.ConnectionTime
|
||||||
|
lastActivity := health.LastActivity
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
|
||||||
// Attempt to connect to the local port
|
now := time.Now()
|
||||||
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
var d net.Dialer
|
|
||||||
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
|
||||||
|
|
||||||
newStatus := StatusHealthy
|
newStatus := StatusHealthy
|
||||||
errorMsg := ""
|
errorMsg := ""
|
||||||
|
|
||||||
if err != nil {
|
// Check for stale connections based on age or idle time
|
||||||
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
|
connectionAge := now.Sub(connectionTime)
|
||||||
// This avoids scary "Error" messages during initial connection attempts
|
idleTime := now.Sub(lastActivity)
|
||||||
timeSinceStart := time.Since(registeredAt)
|
|
||||||
if timeSinceStart < startupGracePeriod {
|
// Only enforce max connection age if the connection is ALSO idle
|
||||||
newStatus = StatusStarting
|
// This prevents interrupting active transfers (e.g., database dumps)
|
||||||
} else {
|
if c.maxConnectionAge > 0 && connectionAge > c.maxConnectionAge && idleTime > c.maxIdleTime {
|
||||||
newStatus = StatusUnhealthy
|
newStatus = StatusStale
|
||||||
}
|
errorMsg = fmt.Sprintf("connection age %v exceeds max %v (and idle for %v)",
|
||||||
errorMsg = err.Error()
|
connectionAge.Round(time.Second), c.maxConnectionAge, idleTime.Round(time.Second))
|
||||||
|
} else if c.maxIdleTime > 0 && idleTime > c.maxIdleTime {
|
||||||
|
newStatus = StatusStale
|
||||||
|
errorMsg = fmt.Sprintf("idle time %v exceeds max %v", idleTime.Round(time.Second), c.maxIdleTime)
|
||||||
} else {
|
} else {
|
||||||
conn.Close()
|
// Perform connectivity check
|
||||||
|
var checkErr error
|
||||||
|
switch c.method {
|
||||||
|
case CheckMethodDataTransfer:
|
||||||
|
checkErr = c.checkDataTransfer(port)
|
||||||
|
case CheckMethodTCPDial:
|
||||||
|
checkErr = c.checkTCPDial(port)
|
||||||
|
default:
|
||||||
|
checkErr = c.checkTCPDial(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if checkErr != nil {
|
||||||
|
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
|
||||||
|
// This avoids scary "Error" messages during initial connection attempts
|
||||||
|
timeSinceStart := now.Sub(registeredAt)
|
||||||
|
if timeSinceStart < startupGracePeriod {
|
||||||
|
newStatus = StatusStarting
|
||||||
|
} else {
|
||||||
|
newStatus = StatusUnhealthy
|
||||||
|
}
|
||||||
|
errorMsg = checkErr.Error()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update health status
|
// Update health status
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
if health, exists := c.ports[forwardID]; exists {
|
if health, exists := c.ports[forwardID]; exists {
|
||||||
health.Status = newStatus
|
health.Status = newStatus
|
||||||
health.LastCheck = time.Now()
|
health.LastCheck = now
|
||||||
health.ErrorMessage = errorMsg
|
health.ErrorMessage = errorMsg
|
||||||
}
|
}
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
@@ -238,6 +322,62 @@ func (c *Checker) checkPort(forwardID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkTCPDial performs a simple TCP dial test
|
||||||
|
func (c *Checker) checkTCPDial(port int) error {
|
||||||
|
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var d net.Dialer
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkDataTransfer attempts to read data from the connection to verify tunnel health
|
||||||
|
func (c *Checker) checkDataTransfer(port int) error {
|
||||||
|
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var d net.Dialer
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Set a short read deadline to detect hung connections
|
||||||
|
// We don't expect to receive data, but we want to verify the connection isn't hung
|
||||||
|
conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||||
|
|
||||||
|
// Try to read a small amount of data
|
||||||
|
// Most servers will either:
|
||||||
|
// 1. Send a banner (SSH, FTP, etc) - we'll read it successfully
|
||||||
|
// 2. Wait for client to send first (HTTP, postgres) - we'll timeout (which is OK)
|
||||||
|
// 3. Hung/stale connection - will timeout with different error
|
||||||
|
buf := make([]byte, dataTransferSize)
|
||||||
|
_, err = conn.Read(buf)
|
||||||
|
|
||||||
|
// We expect either:
|
||||||
|
// - No error (banner received)
|
||||||
|
// - EOF (connection closed by server after connect)
|
||||||
|
// - Timeout (server waiting for client)
|
||||||
|
// All of these indicate the tunnel is working
|
||||||
|
if err == nil || err == io.EOF {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout is acceptable - server is waiting for us to send data first
|
||||||
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other errors indicate a problem
|
||||||
|
return fmt.Errorf("data transfer check failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// notifyStatusChange calls the callback for a forward
|
// notifyStatusChange calls the callback for a forward
|
||||||
func (c *Checker) notifyStatusChange(forwardID string, status Status, errorMsg string) {
|
func (c *Checker) notifyStatusChange(forwardID string, status Status, errorMsg string) {
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
|
|||||||
@@ -0,0 +1,553 @@
|
|||||||
|
package healthcheck
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthCheckTestSuite contains tests for the health checker
|
||||||
|
type HealthCheckTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
checker *Checker
|
||||||
|
listener net.Listener
|
||||||
|
port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthCheckSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(HealthCheckTestSuite))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HealthCheckTestSuite) SetupTest() {
|
||||||
|
// Create a test listener on a random port
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
s.listener = ln
|
||||||
|
s.port = ln.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
// Create checker with fast intervals for testing
|
||||||
|
s.checker = NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 100 * time.Millisecond,
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 500 * time.Millisecond,
|
||||||
|
MaxIdleTime: 300 * time.Millisecond,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HealthCheckTestSuite) TearDownTest() {
|
||||||
|
if s.checker != nil {
|
||||||
|
s.checker.Stop()
|
||||||
|
}
|
||||||
|
if s.listener != nil {
|
||||||
|
s.listener.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRegisterAndUnregister tests basic registration and unregistration
|
||||||
|
func (s *HealthCheckTestSuite) TestRegisterAndUnregister() {
|
||||||
|
callbackCalled := false
|
||||||
|
var callbackStatus Status
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
callbackCalled = true
|
||||||
|
callbackStatus = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register port
|
||||||
|
s.checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Wait for health check to run
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Verify callback was called with healthy status
|
||||||
|
mu.Lock()
|
||||||
|
assert.True(s.T(), callbackCalled, "Callback should have been called")
|
||||||
|
assert.Equal(s.T(), StatusHealthy, callbackStatus)
|
||||||
|
mu.Unlock()
|
||||||
|
|
||||||
|
// Unregister
|
||||||
|
s.checker.Unregister("test-forward")
|
||||||
|
|
||||||
|
// Verify port is no longer monitored
|
||||||
|
status, exists := s.checker.GetStatus("test-forward")
|
||||||
|
assert.False(s.T(), exists, "Port should no longer exist after unregister")
|
||||||
|
assert.Equal(s.T(), StatusUnhealthy, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTCPDialMethod tests the TCP dial health check method
|
||||||
|
func (s *HealthCheckTestSuite) TestTCPDialMethod() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupPort bool
|
||||||
|
expectedStatus Status
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "port available - healthy",
|
||||||
|
setupPort: true,
|
||||||
|
expectedStatus: StatusHealthy,
|
||||||
|
description: "When port is listening, status should be healthy",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "port unavailable - unhealthy",
|
||||||
|
setupPort: false,
|
||||||
|
expectedStatus: StatusUnhealthy,
|
||||||
|
description: "When port is not listening, status should be unhealthy",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var testPort int
|
||||||
|
var testListener net.Listener
|
||||||
|
|
||||||
|
if tt.setupPort {
|
||||||
|
// Use the existing listener
|
||||||
|
testPort = s.port
|
||||||
|
} else {
|
||||||
|
// Use a port that's not listening
|
||||||
|
testPort = 54321 // Likely unused port
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new checker for this test
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 100 * time.Millisecond,
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 0, // Disable for this test
|
||||||
|
MaxIdleTime: 0, // Disable for this test
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
checker.Register("test-forward", testPort, nil)
|
||||||
|
|
||||||
|
// Wait for health checks to complete
|
||||||
|
if !tt.setupPort {
|
||||||
|
// For unhealthy case, wait for grace period
|
||||||
|
time.Sleep(startupGracePeriod + 200*time.Millisecond)
|
||||||
|
} else {
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check status directly
|
||||||
|
status, exists := checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
assert.Equal(s.T(), tt.expectedStatus, status, tt.description)
|
||||||
|
|
||||||
|
if testListener != nil {
|
||||||
|
testListener.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDataTransferMethod tests the data transfer health check method
|
||||||
|
func (s *HealthCheckTestSuite) TestDataTransferMethod() {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
serverBehavior string // "banner", "silent", "close", "none"
|
||||||
|
expectedStatus Status
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "server sends banner - healthy",
|
||||||
|
serverBehavior: "banner",
|
||||||
|
expectedStatus: StatusHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server waits silently - healthy (timeout OK)",
|
||||||
|
serverBehavior: "silent",
|
||||||
|
expectedStatus: StatusHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server closes connection - healthy (EOF OK)",
|
||||||
|
serverBehavior: "close",
|
||||||
|
expectedStatus: StatusHealthy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no server listening - unhealthy",
|
||||||
|
serverBehavior: "none",
|
||||||
|
expectedStatus: StatusUnhealthy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
s.Run(tt.name, func() {
|
||||||
|
var testPort int
|
||||||
|
var testListener net.Listener
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if tt.serverBehavior != "none" {
|
||||||
|
// Start test server
|
||||||
|
testListener, err = net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(s.T(), err)
|
||||||
|
testPort = testListener.Addr().(*net.TCPAddr).Port
|
||||||
|
|
||||||
|
// Handle connections based on behavior
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := testListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch tt.serverBehavior {
|
||||||
|
case "banner":
|
||||||
|
conn.Write([]byte("220 Welcome\r\n"))
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
case "close":
|
||||||
|
conn.Close()
|
||||||
|
case "silent":
|
||||||
|
// Just keep connection open
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer testListener.Close()
|
||||||
|
} else {
|
||||||
|
testPort = 54322 // Unused port
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create checker with data transfer method
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 100 * time.Millisecond,
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
Method: CheckMethodDataTransfer,
|
||||||
|
MaxConnectionAge: 0, // Disable for this test
|
||||||
|
MaxIdleTime: 0, // Disable for this test
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
checker.Register("test-forward", testPort, nil)
|
||||||
|
|
||||||
|
// Wait for health checks to complete
|
||||||
|
if tt.serverBehavior == "none" {
|
||||||
|
// For unhealthy case, wait for grace period
|
||||||
|
time.Sleep(startupGracePeriod + 200*time.Millisecond)
|
||||||
|
} else {
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check status directly
|
||||||
|
status, exists := checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
assert.Equal(s.T(), tt.expectedStatus, status)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConnectionAgeDetection tests max connection age detection
|
||||||
|
func (s *HealthCheckTestSuite) TestConnectionAgeDetection() {
|
||||||
|
statusChanges := make(chan Status, 10)
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
statusChanges <- status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create checker with very short max connection age
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 50 * time.Millisecond,
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 150 * time.Millisecond, // Very short for testing
|
||||||
|
MaxIdleTime: 0, // Disable idle detection
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Wait for initial healthy status
|
||||||
|
var gotHealthy, gotStale bool
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case status := <-statusChanges:
|
||||||
|
if status == StatusHealthy || status == StatusStarting {
|
||||||
|
gotHealthy = true
|
||||||
|
}
|
||||||
|
if status == StatusStale {
|
||||||
|
gotStale = true
|
||||||
|
}
|
||||||
|
if gotHealthy && gotStale {
|
||||||
|
return // Test passed
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
s.T().Fatalf("Expected StatusStale after max connection age exceeded. gotHealthy=%v, gotStale=%v",
|
||||||
|
gotHealthy, gotStale)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIdleTimeDetection tests max idle time detection
|
||||||
|
func (s *HealthCheckTestSuite) TestIdleTimeDetection() {
|
||||||
|
statusChanges := make(chan Status, 10)
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
statusChanges <- status
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create checker with very short max idle time
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 50 * time.Millisecond,
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 0, // Disable age detection
|
||||||
|
MaxIdleTime: 150 * time.Millisecond, // Very short for testing
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Wait for initial healthy status
|
||||||
|
var gotHealthy, gotStale bool
|
||||||
|
timeout := time.After(1 * time.Second)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case status := <-statusChanges:
|
||||||
|
if status == StatusHealthy || status == StatusStarting {
|
||||||
|
gotHealthy = true
|
||||||
|
}
|
||||||
|
if status == StatusStale {
|
||||||
|
gotStale = true
|
||||||
|
}
|
||||||
|
if gotHealthy && gotStale {
|
||||||
|
return // Test passed
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
s.T().Fatalf("Expected StatusStale after max idle time exceeded. gotHealthy=%v, gotStale=%v",
|
||||||
|
gotHealthy, gotStale)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMarkConnected tests that MarkConnected resets connection time
|
||||||
|
func (s *HealthCheckTestSuite) TestMarkConnected() {
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 50 * time.Millisecond,
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 200 * time.Millisecond,
|
||||||
|
MaxIdleTime: 0,
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
statusChanges := make(chan Status, 10)
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
statusChanges <- status
|
||||||
|
}
|
||||||
|
|
||||||
|
checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Wait a bit
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Mark as reconnected (resets connection time)
|
||||||
|
checker.MarkConnected("test-forward")
|
||||||
|
|
||||||
|
// Wait for connection age to exceed (relative to first connection time)
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Check status - should still be healthy because we reset connection time
|
||||||
|
status, exists := checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
// Note: Might be StatusStale by now, but the key is that MarkConnected delayed it
|
||||||
|
// This is a timing-sensitive test, so we just verify the functionality exists
|
||||||
|
_ = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRecordActivity tests that RecordActivity resets idle time
|
||||||
|
func (s *HealthCheckTestSuite) TestRecordActivity() {
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 50 * time.Millisecond,
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 0,
|
||||||
|
MaxIdleTime: 200 * time.Millisecond,
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
statusChanges := make(chan Status, 10)
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
statusChanges <- status
|
||||||
|
}
|
||||||
|
|
||||||
|
checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Periodically record activity to prevent idle detection
|
||||||
|
ticker := time.NewTicker(80 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
<-ticker.C
|
||||||
|
checker.RecordActivity("test-forward")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait longer than idle timeout
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should still be healthy due to activity
|
||||||
|
status, exists := checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
// May transition to stale eventually, but activity recording should have delayed it
|
||||||
|
_ = status
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMarkReconnecting tests the MarkReconnecting functionality
|
||||||
|
func (s *HealthCheckTestSuite) TestMarkReconnecting() {
|
||||||
|
statusChanges := make(chan Status, 10)
|
||||||
|
callback := func(forwardID string, status Status, errorMsg string) {
|
||||||
|
statusChanges <- status
|
||||||
|
}
|
||||||
|
|
||||||
|
s.checker.Register("test-forward", s.port, callback)
|
||||||
|
|
||||||
|
// Wait for initial status
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
|
||||||
|
// Mark as reconnecting
|
||||||
|
s.checker.MarkReconnecting("test-forward")
|
||||||
|
|
||||||
|
// Should receive reconnecting status
|
||||||
|
timeout := time.After(500 * time.Millisecond)
|
||||||
|
gotReconnect := false
|
||||||
|
for !gotReconnect {
|
||||||
|
select {
|
||||||
|
case status := <-statusChanges:
|
||||||
|
if status == StatusReconnect {
|
||||||
|
gotReconnect = true
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
s.T().Fatal("Expected StatusReconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStartingGracePeriod tests that errors during grace period show as "Starting"
|
||||||
|
func (s *HealthCheckTestSuite) TestStartingGracePeriod() {
|
||||||
|
// Use a port that's not listening
|
||||||
|
unavailablePort := 54323
|
||||||
|
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 50 * time.Millisecond,
|
||||||
|
Timeout: 25 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 0,
|
||||||
|
MaxIdleTime: 0,
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
// Register without callback - we'll check status directly
|
||||||
|
checker.Register("test-forward", unavailablePort, nil)
|
||||||
|
|
||||||
|
// Immediately check status - should be Starting or not yet checked
|
||||||
|
status, exists := checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
// Initially should be Starting
|
||||||
|
assert.Equal(s.T(), StatusStarting, status)
|
||||||
|
|
||||||
|
// Wait for grace period to expire
|
||||||
|
time.Sleep(startupGracePeriod + 200*time.Millisecond)
|
||||||
|
|
||||||
|
// Now should be Unhealthy
|
||||||
|
status, exists = checker.GetStatus("test-forward")
|
||||||
|
assert.True(s.T(), exists)
|
||||||
|
assert.Equal(s.T(), StatusUnhealthy, status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetAllErrors tests retrieving all error messages
|
||||||
|
func (s *HealthCheckTestSuite) TestGetAllErrors() {
|
||||||
|
// Create a new checker with faster intervals for this test
|
||||||
|
checker := NewCheckerWithOptions(CheckerOptions{
|
||||||
|
Interval: 100 * time.Millisecond,
|
||||||
|
Timeout: 50 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 0,
|
||||||
|
MaxIdleTime: 0,
|
||||||
|
})
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
// Register multiple forwards
|
||||||
|
checker.Register("forward1", s.port, nil)
|
||||||
|
checker.Register("forward2", 54324, nil) // Unavailable port
|
||||||
|
|
||||||
|
// Wait for grace period to expire
|
||||||
|
time.Sleep(startupGracePeriod + 300*time.Millisecond)
|
||||||
|
|
||||||
|
errors := checker.GetAllErrors()
|
||||||
|
|
||||||
|
// forward2 should have an error
|
||||||
|
_, hasError := errors["forward2"]
|
||||||
|
assert.True(s.T(), hasError, "forward2 should have an error")
|
||||||
|
|
||||||
|
// forward1 should not have an error
|
||||||
|
_, hasError = errors["forward1"]
|
||||||
|
assert.False(s.T(), hasError, "forward1 should not have an error")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConcurrentOperations tests thread safety
|
||||||
|
func (s *HealthCheckTestSuite) TestConcurrentOperations() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numGoroutines := 10
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
forwardID := fmt.Sprintf("forward-%d", id)
|
||||||
|
s.checker.Register(forwardID, s.port, nil)
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
s.checker.MarkConnected(forwardID)
|
||||||
|
s.checker.RecordActivity(forwardID)
|
||||||
|
status, _ := s.checker.GetStatus(forwardID)
|
||||||
|
_ = status
|
||||||
|
s.checker.Unregister(forwardID)
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
// If we get here without deadlocks or panics, test passes
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDefaultOptions tests that NewChecker uses sensible defaults
|
||||||
|
func TestDefaultOptions(t *testing.T) {
|
||||||
|
checker := NewChecker(5*time.Second, 2*time.Second)
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
assert.Equal(t, 5*time.Second, checker.interval)
|
||||||
|
assert.Equal(t, 2*time.Second, checker.timeout)
|
||||||
|
assert.Equal(t, CheckMethodDataTransfer, checker.method)
|
||||||
|
assert.Equal(t, 25*time.Minute, checker.maxConnectionAge)
|
||||||
|
assert.Equal(t, 10*time.Minute, checker.maxIdleTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCustomOptions tests NewCheckerWithOptions
|
||||||
|
func TestCustomOptions(t *testing.T) {
|
||||||
|
opts := CheckerOptions{
|
||||||
|
Interval: 1 * time.Second,
|
||||||
|
Timeout: 500 * time.Millisecond,
|
||||||
|
Method: CheckMethodTCPDial,
|
||||||
|
MaxConnectionAge: 5 * time.Minute,
|
||||||
|
MaxIdleTime: 2 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
checker := NewCheckerWithOptions(opts)
|
||||||
|
defer checker.Stop()
|
||||||
|
|
||||||
|
assert.Equal(t, 1*time.Second, checker.interval)
|
||||||
|
assert.Equal(t, 500*time.Millisecond, checker.timeout)
|
||||||
|
assert.Equal(t, CheckMethodTCPDial, checker.method)
|
||||||
|
assert.Equal(t, 5*time.Minute, checker.maxConnectionAge)
|
||||||
|
assert.Equal(t, 2*time.Minute, checker.maxIdleTime)
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
corev1 "k8s.io/api/core/v1"
|
corev1 "k8s.io/api/core/v1"
|
||||||
@@ -305,17 +304,3 @@ func CheckPortAvailability(port int) (bool, string, error) {
|
|||||||
listener.Close()
|
listener.Close()
|
||||||
return true, "", nil
|
return true, "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidatePort checks if a port number is valid.
|
|
||||||
func ValidatePort(portStr string) (int, error) {
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
if port < 1 || port > 65535 {
|
|
||||||
return 0, fmt.Errorf("port must be between 1 and 65535, got %d", port)
|
|
||||||
}
|
|
||||||
|
|
||||||
return port, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -4,9 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
corev1 "k8s.io/api/core/v1"
|
corev1 "k8s.io/api/core/v1"
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
@@ -17,18 +19,32 @@ import (
|
|||||||
|
|
||||||
// PortForwarder handles Kubernetes port-forwarding operations.
|
// PortForwarder handles Kubernetes port-forwarding operations.
|
||||||
type PortForwarder struct {
|
type PortForwarder struct {
|
||||||
clientPool *ClientPool
|
clientPool *ClientPool
|
||||||
resolver *ResourceResolver
|
resolver *ResourceResolver
|
||||||
|
tcpKeepalive time.Duration // TCP keepalive interval
|
||||||
|
dialTimeout time.Duration // Connection dial timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPortForwarder creates a new PortForwarder instance.
|
// NewPortForwarder creates a new PortForwarder instance with default settings.
|
||||||
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
|
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
|
||||||
return &PortForwarder{
|
return &PortForwarder{
|
||||||
clientPool: clientPool,
|
clientPool: clientPool,
|
||||||
resolver: resolver,
|
resolver: resolver,
|
||||||
|
tcpKeepalive: 30 * time.Second, // Default: 30 second keepalive
|
||||||
|
dialTimeout: 30 * time.Second, // Default: 30 second dial timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTCPKeepalive configures the TCP keepalive interval for new connections.
|
||||||
|
func (pf *PortForwarder) SetTCPKeepalive(keepalive time.Duration) {
|
||||||
|
pf.tcpKeepalive = keepalive
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDialTimeout configures the connection dial timeout.
|
||||||
|
func (pf *PortForwarder) SetDialTimeout(timeout time.Duration) {
|
||||||
|
pf.dialTimeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
// ForwardRequest contains the parameters for a port-forward request.
|
// ForwardRequest contains the parameters for a port-forward request.
|
||||||
type ForwardRequest struct {
|
type ForwardRequest struct {
|
||||||
ContextName string // Kubernetes context name
|
ContextName string // Kubernetes context name
|
||||||
@@ -164,6 +180,19 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
|
|||||||
|
|
||||||
// executePortForward performs the actual port-forward operation.
|
// executePortForward performs the actual port-forward operation.
|
||||||
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
|
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
|
||||||
|
// Configure TCP settings on the underlying connection
|
||||||
|
// This is set in the rest.Config which will be used by the SPDY transport
|
||||||
|
if config.Dial == nil {
|
||||||
|
// Create a custom dialer with configurable timeout and keepalive
|
||||||
|
// - Timeout: How long to wait for connection to establish
|
||||||
|
// - KeepAlive: TCP keepalive helps OS detect dead connections at network layer
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: pf.dialTimeout, // Configurable dial timeout
|
||||||
|
KeepAlive: pf.tcpKeepalive, // Configurable keepalive interval
|
||||||
|
}
|
||||||
|
config.Dial = dialer.DialContext
|
||||||
|
}
|
||||||
|
|
||||||
// Create SPDY roundtripper
|
// Create SPDY roundtripper
|
||||||
transport, upgrader, err := spdy.RoundTripperFor(config)
|
transport, upgrader, err := spdy.RoundTripperFor(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -228,29 +228,3 @@ func (r *ResourceResolver) InvalidateCache(contextName, namespace, resource stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPodList returns a list of pods matching the given criteria.
|
|
||||||
// This is useful for debugging and testing.
|
|
||||||
func (r *ResourceResolver) GetPodList(ctx context.Context, contextName, namespace, selector string) ([]*corev1.Pod, error) {
|
|
||||||
client, err := r.clientPool.GetClient(contextName)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
listOptions := metav1.ListOptions{}
|
|
||||||
if selector != "" {
|
|
||||||
listOptions.LabelSelector = selector
|
|
||||||
}
|
|
||||||
|
|
||||||
pods, err := client.CoreV1().Pods(namespace).List(ctx, listOptions)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to list pods: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]*corev1.Pod, len(pods.Items))
|
|
||||||
for i := range pods.Items {
|
|
||||||
result[i] = &pods.Items[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LogrAdapter implements the logr.LogSink interface to route klog v2 logs
|
||||||
|
// through our structured logger. This captures ALL klog output including
|
||||||
|
// error logs, structured logs, and named logger output.
|
||||||
|
type LogrAdapter struct {
|
||||||
|
logger *Logger
|
||||||
|
name string
|
||||||
|
level int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLogrAdapter creates a new logr.LogSink that routes all klog v2 logs
|
||||||
|
// through our structured logger.
|
||||||
|
func NewLogrAdapter(logger *Logger) logr.LogSink {
|
||||||
|
return &LogrAdapter{
|
||||||
|
logger: logger,
|
||||||
|
name: "",
|
||||||
|
level: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init initializes the logger with runtime info (not used in our implementation).
|
||||||
|
func (l *LogrAdapter) Init(info logr.RuntimeInfo) {
|
||||||
|
// No-op: we don't need runtime info
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled tests whether this LogSink is enabled at the specified V-level.
|
||||||
|
// We route all logs through our logger's level filtering.
|
||||||
|
func (l *LogrAdapter) Enabled(level int) bool {
|
||||||
|
// Map logr V-levels to our levels:
|
||||||
|
// V(0) = Info level (always enabled if logger level <= Info)
|
||||||
|
// V(1+) = Debug level (enabled if logger level <= Debug)
|
||||||
|
if level == 0 {
|
||||||
|
return l.logger.level <= LevelInfo
|
||||||
|
}
|
||||||
|
return l.logger.level <= LevelDebug
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs a non-error message with the given key/value pairs.
|
||||||
|
func (l *LogrAdapter) Info(level int, msg string, keysAndValues ...interface{}) {
|
||||||
|
fields := l.kvToMap(keysAndValues)
|
||||||
|
if l.name != "" {
|
||||||
|
fields["logger"] = l.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map logr V-levels to our levels:
|
||||||
|
// V(0) = Info, V(1+) = Debug
|
||||||
|
if level == 0 {
|
||||||
|
l.logger.Info(msg, fields)
|
||||||
|
} else {
|
||||||
|
l.logger.Debug(msg, fields)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs an error message with the given key/value pairs.
|
||||||
|
func (l *LogrAdapter) Error(err error, msg string, keysAndValues ...interface{}) {
|
||||||
|
fields := l.kvToMap(keysAndValues)
|
||||||
|
if l.name != "" {
|
||||||
|
fields["logger"] = l.name
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
fields["error"] = err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
l.logger.Error(msg, fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithValues returns a new LogSink with additional key/value pairs.
|
||||||
|
func (l *LogrAdapter) WithValues(keysAndValues ...interface{}) logr.LogSink {
|
||||||
|
// For simplicity, we don't implement value accumulation
|
||||||
|
// Each log call receives all its keysAndValues directly
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithName returns a new LogSink with the specified name appended.
|
||||||
|
func (l *LogrAdapter) WithName(name string) logr.LogSink {
|
||||||
|
newLogger := *l
|
||||||
|
if l.name == "" {
|
||||||
|
newLogger.name = name
|
||||||
|
} else {
|
||||||
|
newLogger.name = l.name + "." + name
|
||||||
|
}
|
||||||
|
return &newLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// kvToMap converts a slice of alternating keys and values to a map.
|
||||||
|
func (l *LogrAdapter) kvToMap(keysAndValues []interface{}) map[string]interface{} {
|
||||||
|
fields := make(map[string]interface{})
|
||||||
|
fields["source"] = "k8s-client"
|
||||||
|
|
||||||
|
for i := 0; i < len(keysAndValues); i += 2 {
|
||||||
|
if i+1 < len(keysAndValues) {
|
||||||
|
key, ok := keysAndValues[i].(string)
|
||||||
|
if ok {
|
||||||
|
fields[key] = keysAndValues[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields
|
||||||
|
}
|
||||||
@@ -0,0 +1,367 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-logr/logr"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogrAdapter_Info(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerLevel Level
|
||||||
|
logrLevel int
|
||||||
|
message string
|
||||||
|
keysAndValues []interface{}
|
||||||
|
expectOutput bool
|
||||||
|
expectContains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "info log v0 with debug logger",
|
||||||
|
loggerLevel: LevelDebug,
|
||||||
|
logrLevel: 0,
|
||||||
|
message: "Connection established",
|
||||||
|
keysAndValues: []interface{}{"pod", "my-app-123", "port", 8080},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[INFO]", "Connection established", "pod", "my-app-123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "info log v0 with info logger",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 0,
|
||||||
|
message: "Port forward ready",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[INFO]", "Port forward ready"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "info log v0 silenced with warn logger",
|
||||||
|
loggerLevel: LevelWarn,
|
||||||
|
logrLevel: 0,
|
||||||
|
message: "This should not appear",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: false,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "debug log v1 with debug logger",
|
||||||
|
loggerLevel: LevelDebug,
|
||||||
|
logrLevel: 1,
|
||||||
|
message: "Detailed connection info",
|
||||||
|
keysAndValues: []interface{}{"details", "some-value"},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[DEBUG]", "Detailed connection info", "details"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "debug log v1 silenced with info logger",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 1,
|
||||||
|
message: "This debug should not appear",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: false,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "info with odd number of kvs (incomplete pair)",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 0,
|
||||||
|
message: "Message with incomplete kv",
|
||||||
|
keysAndValues: []interface{}{"key1", "value1", "key2"}, // key2 has no value
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[INFO]", "Message with incomplete kv", "key1", "value1"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "info with source field added automatically",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 0,
|
||||||
|
message: "Test source field",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[INFO]", "Test source field", "source:k8s-client"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(tt.loggerLevel, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink)
|
||||||
|
|
||||||
|
logrLogger.V(tt.logrLevel).Info(tt.message, tt.keysAndValues...)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if tt.expectOutput {
|
||||||
|
for _, expected := range tt.expectContains {
|
||||||
|
assert.Contains(t, output, expected, "Output should contain: %s", expected)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Empty(t, output, "No output expected for this log level")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Error(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerLevel Level
|
||||||
|
err error
|
||||||
|
message string
|
||||||
|
keysAndValues []interface{}
|
||||||
|
expectOutput bool
|
||||||
|
expectContains []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "error with error object",
|
||||||
|
loggerLevel: LevelError,
|
||||||
|
err: errors.New("connection failed"),
|
||||||
|
message: "Port forward failed",
|
||||||
|
keysAndValues: []interface{}{"pod", "my-app-123"},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[ERROR]", "Port forward failed", "connection failed", "pod", "my-app-123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error without error object",
|
||||||
|
loggerLevel: LevelError,
|
||||||
|
err: nil,
|
||||||
|
message: "Generic error message",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[ERROR]", "Generic error message"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error silenced with level above error",
|
||||||
|
loggerLevel: LevelError + 1,
|
||||||
|
err: errors.New("should not appear"),
|
||||||
|
message: "This error should not appear",
|
||||||
|
keysAndValues: []interface{}{},
|
||||||
|
expectOutput: false,
|
||||||
|
expectContains: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error with multiple kvs",
|
||||||
|
loggerLevel: LevelError,
|
||||||
|
err: errors.New("sandbox not found"),
|
||||||
|
message: "Unhandled Error",
|
||||||
|
keysAndValues: []interface{}{"pod", "test-pod", "uid", "abc123", "port", 8080},
|
||||||
|
expectOutput: true,
|
||||||
|
expectContains: []string{"[ERROR]", "Unhandled Error", "sandbox not found", "pod", "test-pod", "uid", "abc123"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(tt.loggerLevel, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink)
|
||||||
|
|
||||||
|
logrLogger.Error(tt.err, tt.message, tt.keysAndValues...)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
if tt.expectOutput {
|
||||||
|
for _, expected := range tt.expectContains {
|
||||||
|
assert.Contains(t, output, expected, "Output should contain: %s", expected)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.Empty(t, output, "No output expected for this log level")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_WithName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerNames []string
|
||||||
|
message string
|
||||||
|
expectContains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single logger name",
|
||||||
|
loggerNames: []string{"portforward"},
|
||||||
|
message: "Test message",
|
||||||
|
expectContains: "logger:portforward",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested logger names",
|
||||||
|
loggerNames: []string{"controller", "worker", "healthcheck"},
|
||||||
|
message: "Nested message",
|
||||||
|
expectContains: "logger:controller.worker.healthcheck",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no logger name",
|
||||||
|
loggerNames: []string{},
|
||||||
|
message: "No name message",
|
||||||
|
expectContains: "source:k8s-client", // Should still have source but no logger field
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(LevelInfo, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink)
|
||||||
|
|
||||||
|
// Apply WithName calls
|
||||||
|
for _, name := range tt.loggerNames {
|
||||||
|
logrLogger = logrLogger.WithName(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrLogger.Info(tt.message)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, tt.expectContains)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_Enabled(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerLevel Level
|
||||||
|
logrLevel int
|
||||||
|
expectEnabled bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "v0 enabled with debug logger",
|
||||||
|
loggerLevel: LevelDebug,
|
||||||
|
logrLevel: 0,
|
||||||
|
expectEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v0 enabled with info logger",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 0,
|
||||||
|
expectEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v0 disabled with warn logger",
|
||||||
|
loggerLevel: LevelWarn,
|
||||||
|
logrLevel: 0,
|
||||||
|
expectEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v1 enabled with debug logger",
|
||||||
|
loggerLevel: LevelDebug,
|
||||||
|
logrLevel: 1,
|
||||||
|
expectEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v1 disabled with info logger",
|
||||||
|
loggerLevel: LevelInfo,
|
||||||
|
logrLevel: 1,
|
||||||
|
expectEnabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "v2 enabled with debug logger",
|
||||||
|
loggerLevel: LevelDebug,
|
||||||
|
logrLevel: 2,
|
||||||
|
expectEnabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
logger := New(tt.loggerLevel, FormatText, &bytes.Buffer{})
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
|
||||||
|
enabled := sink.Enabled(tt.logrLevel)
|
||||||
|
assert.Equal(t, tt.expectEnabled, enabled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_JSONFormat(t *testing.T) {
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(LevelInfo, FormatJSON, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink).WithName("test-component")
|
||||||
|
|
||||||
|
logrLogger.Info("Test JSON message", "key1", "value1", "key2", 123)
|
||||||
|
|
||||||
|
// Parse JSON output
|
||||||
|
var entry logEntry
|
||||||
|
err := json.Unmarshal(buf.Bytes(), &entry)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "INFO", entry.Level)
|
||||||
|
assert.Equal(t, "Test JSON message", entry.Message)
|
||||||
|
assert.Equal(t, "k8s-client", entry.Fields["source"])
|
||||||
|
assert.Equal(t, "test-component", entry.Fields["logger"])
|
||||||
|
assert.Equal(t, "value1", entry.Fields["key1"])
|
||||||
|
assert.Equal(t, float64(123), entry.Fields["key2"]) // JSON numbers decode as float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_ConcurrentWrites(t *testing.T) {
|
||||||
|
// Note: bytes.Buffer is not thread-safe for writes, so this test verifies
|
||||||
|
// that our LogrAdapter doesn't panic under concurrent load, but we don't
|
||||||
|
// verify exact output (since logger uses fmt.Fprintf which is also not thread-safe)
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(LevelDebug, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink)
|
||||||
|
|
||||||
|
// Spawn multiple goroutines writing concurrently
|
||||||
|
done := make(chan bool)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func(id int) {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logrLogger.Info("Concurrent message", "goroutine", id, "iteration", j)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
|
||||||
|
// Verify we got substantial output (not checking exact count due to buffer race)
|
||||||
|
// The main goal is to ensure no panics occur during concurrent writes
|
||||||
|
assert.NotEmpty(t, output, "Should have some log output")
|
||||||
|
assert.Contains(t, output, "Concurrent message")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_RealWorldKlogError(t *testing.T) {
|
||||||
|
// Simulate the exact error message from the screenshot
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(LevelError, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink).WithName("UnhandledError")
|
||||||
|
|
||||||
|
err := errors.New("an error occurred forwarding 8401 -> 8401: error forwarding port 8401 to pod 4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308, uid : failed to find sandbox '4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308' in store: not found")
|
||||||
|
logrLogger.Error(err, "Unhandled Error")
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "[ERROR]")
|
||||||
|
assert.Contains(t, output, "Unhandled Error")
|
||||||
|
assert.Contains(t, output, "failed to find sandbox")
|
||||||
|
assert.Contains(t, output, "logger:UnhandledError")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogrAdapter_SilenceMode(t *testing.T) {
|
||||||
|
// Test that logs are completely silenced when logger level is above error
|
||||||
|
buf := &bytes.Buffer{}
|
||||||
|
logger := New(LevelError+1, FormatText, buf)
|
||||||
|
sink := NewLogrAdapter(logger)
|
||||||
|
logrLogger := logr.New(sink)
|
||||||
|
|
||||||
|
// Try all log levels
|
||||||
|
logrLogger.V(0).Info("Info message should not appear")
|
||||||
|
logrLogger.V(1).Info("Debug message should not appear")
|
||||||
|
logrLogger.Error(errors.New("error object"), "Error message should not appear")
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Empty(t, output, "All logs should be silenced")
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,6 +29,7 @@ type Logger struct {
|
|||||||
level Level
|
level Level
|
||||||
format Format
|
format Format
|
||||||
output io.Writer
|
output io.Writer
|
||||||
|
mu sync.Mutex // Protects concurrent writes to output
|
||||||
}
|
}
|
||||||
|
|
||||||
type logEntry struct {
|
type logEntry struct {
|
||||||
@@ -55,6 +57,9 @@ func (l *Logger) log(level Level, msg string, fields map[string]interface{}) {
|
|||||||
|
|
||||||
levelStr := levelToString(level)
|
levelStr := levelToString(level)
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
|
||||||
if l.format == FormatJSON {
|
if l.format == FormatJSON {
|
||||||
entry := logEntry{
|
entry := logEntry{
|
||||||
Time: time.Now().Format(time.RFC3339),
|
Time: time.Now().Format(time.RFC3339),
|
||||||
|
|||||||
@@ -67,8 +67,3 @@ func (b *Backoff) calculateJitter(delay time.Duration) time.Duration {
|
|||||||
jitter := (b.rng.Float64()*2 - 1) * maxJitter
|
jitter := (b.rng.Float64()*2 - 1) * maxJitter
|
||||||
return time.Duration(jitter)
|
return time.Duration(jitter)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sleep waits for the next backoff duration.
|
|
||||||
func (b *Backoff) Sleep() {
|
|
||||||
time.Sleep(b.Next())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -158,10 +158,12 @@ func TestBackoff_ExponentialProgression(t *testing.T) {
|
|||||||
// We allow for jitter by checking a range
|
// We allow for jitter by checking a range
|
||||||
for i := 1; i < len(delays)-1; i++ {
|
for i := 1; i < len(delays)-1; i++ {
|
||||||
// Each delay should be roughly double the previous (accounting for jitter)
|
// Each delay should be roughly double the previous (accounting for jitter)
|
||||||
// With 10% jitter on each value, worst case: (2.0 * 1.1) / 0.9 = 2.44
|
// With 10% jitter on each value:
|
||||||
// We use 1.7x to 2.5x as a reasonable range with 10% jitter on each
|
// Lower bound: (2.0 * 0.9) / 1.1 ≈ 1.636
|
||||||
|
// Upper bound: (2.0 * 1.1) / 0.9 ≈ 2.444
|
||||||
|
// We use 1.6x to 2.5x as a reasonable range to account for jitter variance
|
||||||
ratio := float64(delays[i]) / float64(delays[i-1])
|
ratio := float64(delays[i]) / float64(delays[i-1])
|
||||||
assert.GreaterOrEqual(t, ratio, 1.7, "exponential growth should be ~2x")
|
assert.GreaterOrEqual(t, ratio, 1.6, "exponential growth should be ~2x")
|
||||||
assert.LessOrEqual(t, ratio, 2.5, "exponential growth should be ~2x")
|
assert.LessOrEqual(t, ratio, 2.5, "exponential growth should be ~2x")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ func (m model) renderMainView() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
isSelected := (idx == m.ui.selectedIndex)
|
isSelected := (idx == m.ui.selectedIndex)
|
||||||
isDisabled := m.ui.disabledMap[id]
|
isDisabled := m.ui.disabledMap[id] || fwd.Status == "Disabled"
|
||||||
|
|
||||||
// Selection indicator
|
// Selection indicator
|
||||||
indicator := " "
|
indicator := " "
|
||||||
|
|||||||
@@ -1,181 +0,0 @@
|
|||||||
package ui
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"golang.org/x/term"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InteractiveController handles keyboard input and selection state
|
|
||||||
type InteractiveController struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
selectedIndex int
|
|
||||||
forwardIDs []string // Ordered list of forward IDs
|
|
||||||
disabledMap map[string]bool // Tracks which forwards are disabled
|
|
||||||
toggleCallback func(id string, enable bool)
|
|
||||||
enabled bool
|
|
||||||
oldTermState *term.State
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewInteractiveController creates a new interactive controller
|
|
||||||
func NewInteractiveController(toggleCallback func(id string, enable bool)) *InteractiveController {
|
|
||||||
return &InteractiveController{
|
|
||||||
selectedIndex: 0,
|
|
||||||
forwardIDs: make([]string, 0),
|
|
||||||
disabledMap: make(map[string]bool),
|
|
||||||
toggleCallback: toggleCallback,
|
|
||||||
enabled: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enable puts the terminal in raw mode for keyboard input
|
|
||||||
func (ic *InteractiveController) Enable() error {
|
|
||||||
ic.mu.Lock()
|
|
||||||
defer ic.mu.Unlock()
|
|
||||||
|
|
||||||
if ic.enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save current terminal state
|
|
||||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to enable raw mode: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ic.oldTermState = oldState
|
|
||||||
ic.enabled = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disable restores the terminal to normal mode
|
|
||||||
func (ic *InteractiveController) Disable() error {
|
|
||||||
ic.mu.Lock()
|
|
||||||
defer ic.mu.Unlock()
|
|
||||||
|
|
||||||
if !ic.enabled {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if ic.oldTermState != nil {
|
|
||||||
if err := term.Restore(int(os.Stdin.Fd()), ic.oldTermState); err != nil {
|
|
||||||
return fmt.Errorf("failed to restore terminal: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ic.enabled = false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateForwardsList updates the list of forwards for navigation
|
|
||||||
func (ic *InteractiveController) UpdateForwardsList(ids []string) {
|
|
||||||
ic.mu.Lock()
|
|
||||||
defer ic.mu.Unlock()
|
|
||||||
|
|
||||||
ic.forwardIDs = ids
|
|
||||||
|
|
||||||
// Ensure selected index is valid
|
|
||||||
if ic.selectedIndex >= len(ic.forwardIDs) {
|
|
||||||
ic.selectedIndex = len(ic.forwardIDs) - 1
|
|
||||||
}
|
|
||||||
if ic.selectedIndex < 0 && len(ic.forwardIDs) > 0 {
|
|
||||||
ic.selectedIndex = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoveUp moves selection up
|
|
||||||
func (ic *InteractiveController) MoveUp() {
|
|
||||||
ic.mu.Lock()
|
|
||||||
defer ic.mu.Unlock()
|
|
||||||
|
|
||||||
if ic.selectedIndex > 0 {
|
|
||||||
ic.selectedIndex--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MoveDown moves selection down
|
|
||||||
func (ic *InteractiveController) MoveDown() {
|
|
||||||
ic.mu.Lock()
|
|
||||||
defer ic.mu.Unlock()
|
|
||||||
|
|
||||||
if ic.selectedIndex < len(ic.forwardIDs)-1 {
|
|
||||||
ic.selectedIndex++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToggleSelected toggles the enable/disable state of the selected forward
|
|
||||||
func (ic *InteractiveController) ToggleSelected() {
|
|
||||||
ic.mu.Lock()
|
|
||||||
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
|
|
||||||
ic.mu.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedID := ic.forwardIDs[ic.selectedIndex]
|
|
||||||
currentlyDisabled := ic.disabledMap[selectedID]
|
|
||||||
newState := !currentlyDisabled
|
|
||||||
ic.disabledMap[selectedID] = newState
|
|
||||||
ic.mu.Unlock()
|
|
||||||
|
|
||||||
// Call the toggle callback
|
|
||||||
if ic.toggleCallback != nil {
|
|
||||||
ic.toggleCallback(selectedID, !newState) // enable is inverse of disabled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSelectedIndex returns the current selection index
|
|
||||||
func (ic *InteractiveController) GetSelectedIndex() int {
|
|
||||||
ic.mu.RLock()
|
|
||||||
defer ic.mu.RUnlock()
|
|
||||||
return ic.selectedIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsDisabled returns whether a forward is disabled
|
|
||||||
func (ic *InteractiveController) IsDisabled(id string) bool {
|
|
||||||
ic.mu.RLock()
|
|
||||||
defer ic.mu.RUnlock()
|
|
||||||
return ic.disabledMap[id]
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSelectedID returns the ID of the currently selected forward
|
|
||||||
func (ic *InteractiveController) GetSelectedID() string {
|
|
||||||
ic.mu.RLock()
|
|
||||||
defer ic.mu.RUnlock()
|
|
||||||
|
|
||||||
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return ic.forwardIDs[ic.selectedIndex]
|
|
||||||
}
|
|
||||||
|
|
||||||
// HandleKey processes keyboard input and returns true if should continue
|
|
||||||
func (ic *InteractiveController) HandleKey(b []byte) bool {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle single byte keys
|
|
||||||
if len(b) == 1 {
|
|
||||||
switch b[0] {
|
|
||||||
case 'q', 'Q', 3: // q, Q, or Ctrl+C
|
|
||||||
return false
|
|
||||||
case ' ', '\r': // Space or Enter to toggle
|
|
||||||
ic.ToggleSelected()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle escape sequences (arrow keys)
|
|
||||||
if len(b) == 3 && b[0] == 27 && b[1] == 91 {
|
|
||||||
switch b[2] {
|
|
||||||
case 65: // Up arrow
|
|
||||||
ic.MoveUp()
|
|
||||||
case 66: // Down arrow
|
|
||||||
ic.MoveDown()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
+7
-48
@@ -23,10 +23,9 @@ type ForwardStatus struct {
|
|||||||
|
|
||||||
// TableUI manages the terminal table display
|
// TableUI manages the terminal table display
|
||||||
type TableUI struct {
|
type TableUI struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
forwards map[string]*ForwardStatus // key is forward ID
|
forwards map[string]*ForwardStatus // key is forward ID
|
||||||
verbose bool
|
verbose bool
|
||||||
interactive *InteractiveController
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTableUI creates a new table UI manager
|
// NewTableUI creates a new table UI manager
|
||||||
@@ -37,13 +36,6 @@ func NewTableUI(verbose bool) *TableUI {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInteractiveController sets the interactive controller
|
|
||||||
func (t *TableUI) SetInteractiveController(ic *InteractiveController) {
|
|
||||||
t.mu.Lock()
|
|
||||||
defer t.mu.Unlock()
|
|
||||||
t.interactive = ic
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddForward registers a new forward for display
|
// AddForward registers a new forward for display
|
||||||
func (t *TableUI) AddForward(id string, fwd *config.Forward) {
|
func (t *TableUI) AddForward(id string, fwd *config.Forward) {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
@@ -126,27 +118,10 @@ func (t *TableUI) Render() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update interactive controller with current forward IDs (in display order)
|
|
||||||
if t.interactive != nil {
|
|
||||||
ids := make([]string, len(entries))
|
|
||||||
for i, entry := range entries {
|
|
||||||
ids[i] = entry.id
|
|
||||||
}
|
|
||||||
t.interactive.UpdateForwardsList(ids)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print each forward
|
// Print each forward
|
||||||
for i, entry := range entries {
|
for _, entry := range entries {
|
||||||
fwd := entry.fwd
|
fwd := entry.fwd
|
||||||
|
|
||||||
// Check if this row is selected
|
|
||||||
isSelected := false
|
|
||||||
isDisabled := false
|
|
||||||
if t.interactive != nil {
|
|
||||||
isSelected = (i == t.interactive.GetSelectedIndex())
|
|
||||||
isDisabled = t.interactive.IsDisabled(entry.id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Truncate long names
|
// Truncate long names
|
||||||
alias := truncate(fwd.Alias, 25)
|
alias := truncate(fwd.Alias, 25)
|
||||||
resource := truncate(fwd.Resource, 25)
|
resource := truncate(fwd.Resource, 25)
|
||||||
@@ -154,8 +129,8 @@ func (t *TableUI) Render() {
|
|||||||
// Color code status with indicator
|
// Color code status with indicator
|
||||||
statusStr := formatStatusWithIndicator(fwd.Status)
|
statusStr := formatStatusWithIndicator(fwd.Status)
|
||||||
|
|
||||||
// Build the row content
|
// Print the row
|
||||||
rowContent := fmt.Sprintf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s",
|
fmt.Printf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s\n",
|
||||||
fwd.Context,
|
fwd.Context,
|
||||||
fwd.Namespace,
|
fwd.Namespace,
|
||||||
alias,
|
alias,
|
||||||
@@ -164,26 +139,10 @@ func (t *TableUI) Render() {
|
|||||||
fwd.RemotePort,
|
fwd.RemotePort,
|
||||||
fwd.LocalPort,
|
fwd.LocalPort,
|
||||||
statusStr)
|
statusStr)
|
||||||
|
|
||||||
// Apply selection highlighting or disabled styling
|
|
||||||
if isSelected {
|
|
||||||
// Replace leading spaces with arrow, then apply reverse video to entire line
|
|
||||||
rowContent = "\033[7m> " + rowContent[2:] + "\033[0m"
|
|
||||||
} else if isDisabled {
|
|
||||||
// Apply dimmed styling to entire line
|
|
||||||
rowContent = "\033[2m" + rowContent + "\033[0m"
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(rowContent)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(strings.Repeat("=", 130))
|
fmt.Println(strings.Repeat("=", 130))
|
||||||
helpText := "Total forwards: %d | ↑↓: Navigate | Space: Toggle | q: Quit"
|
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
|
||||||
if !t.verbose {
|
|
||||||
fmt.Printf(helpText+"\n", len(t.forwards))
|
|
||||||
} else {
|
|
||||||
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
|
|
||||||
}
|
|
||||||
|
|
||||||
// In verbose mode, add a newline to separate from logs
|
// In verbose mode, add a newline to separate from logs
|
||||||
if t.verbose {
|
if t.verbose {
|
||||||
|
|||||||
@@ -10,6 +10,19 @@ import (
|
|||||||
"github.com/nvm/kportal/internal/k8s"
|
"github.com/nvm/kportal/internal/k8s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isFilterableStep returns true if the step supports search/filter
|
||||||
|
func isFilterableStep(step AddWizardStep) bool {
|
||||||
|
switch step {
|
||||||
|
case StepSelectContext, StepSelectNamespace:
|
||||||
|
return true
|
||||||
|
case StepEnterResource:
|
||||||
|
// Only service selection is filterable (pod prefix and selector are text input)
|
||||||
|
return true // We'll check resource type in the handler
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleMainViewKeys handles keyboard input in the main view
|
// handleMainViewKeys handles keyboard input in the main view
|
||||||
func (m model) handleMainViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
func (m model) handleMainViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
||||||
// If delete confirmation is showing, handle it separately
|
// If delete confirmation is showing, handle it separately
|
||||||
@@ -224,6 +237,12 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m, tea.ClearScreen
|
return m, tea.ClearScreen
|
||||||
|
|
||||||
case "esc":
|
case "esc":
|
||||||
|
// If there's an active search filter, clear it instead of going back
|
||||||
|
if wizard.searchFilter != "" && isFilterableStep(wizard.step) {
|
||||||
|
wizard.clearSearchFilter()
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
// In edit mode, Esc always cancels (don't navigate back through skipped steps)
|
// In edit mode, Esc always cancels (don't navigate back through skipped steps)
|
||||||
if wizard.isEditing {
|
if wizard.isEditing {
|
||||||
m.ui.viewMode = ViewModeMain
|
m.ui.viewMode = ViewModeMain
|
||||||
@@ -242,6 +261,7 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
wizard.step--
|
wizard.step--
|
||||||
wizard.cursor = 0
|
wizard.cursor = 0
|
||||||
wizard.clearTextInput()
|
wizard.clearTextInput()
|
||||||
|
wizard.clearSearchFilter()
|
||||||
wizard.error = nil
|
wizard.error = nil
|
||||||
|
|
||||||
// Reset input mode based on the step we're going back to
|
// Reset input mode based on the step we're going back to
|
||||||
@@ -300,26 +320,48 @@ func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
|
|||||||
return m.handleAddWizardEnter()
|
return m.handleAddWizardEnter()
|
||||||
|
|
||||||
case "backspace":
|
case "backspace":
|
||||||
// Allow backspace in text input mode OR when focused on alias in confirmation
|
// Allow backspace in text input mode OR when focused on alias in confirmation OR when filtering
|
||||||
canBackspace := wizard.inputMode == InputModeText ||
|
canBackspace := wizard.inputMode == InputModeText ||
|
||||||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias)
|
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
|
||||||
if canBackspace && len(wizard.textInput) > 0 {
|
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step) && len(wizard.searchFilter) > 0)
|
||||||
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
|
|
||||||
|
if canBackspace {
|
||||||
|
if isFilterableStep(wizard.step) && wizard.inputMode == InputModeList && len(wizard.searchFilter) > 0 {
|
||||||
|
// Backspace in search filter
|
||||||
|
wizard.searchFilter = wizard.searchFilter[:len(wizard.searchFilter)-1]
|
||||||
|
wizard.cursor = 0
|
||||||
|
wizard.scrollOffset = 0
|
||||||
|
} else if len(wizard.textInput) > 0 {
|
||||||
|
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
// Handle text input
|
// Handle text input
|
||||||
canTypeText := wizard.inputMode == InputModeText ||
|
canTypeText := wizard.inputMode == InputModeText ||
|
||||||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias)
|
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
|
||||||
if canTypeText && len(msg.String()) == 1 {
|
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step))
|
||||||
wizard.handleTextInput(rune(msg.String()[0]))
|
|
||||||
|
|
||||||
// Trigger validation for selector
|
if canTypeText && len(msg.String()) == 1 {
|
||||||
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
|
// If in list mode on filterable step, add to search filter instead of textInput
|
||||||
if len(wizard.textInput) > 0 {
|
if wizard.inputMode == InputModeList && isFilterableStep(wizard.step) {
|
||||||
wizard.loading = true
|
char := rune(msg.String()[0])
|
||||||
wizard.error = nil
|
// Only allow printable characters
|
||||||
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
|
if char >= 32 && char < 127 {
|
||||||
|
wizard.searchFilter += string(char)
|
||||||
|
wizard.cursor = 0
|
||||||
|
wizard.scrollOffset = 0
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
wizard.handleTextInput(rune(msg.String()[0]))
|
||||||
|
|
||||||
|
// Trigger validation for selector
|
||||||
|
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
|
||||||
|
if len(wizard.textInput) > 0 {
|
||||||
|
wizard.loading = true
|
||||||
|
wizard.error = nil
|
||||||
|
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -334,19 +376,23 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
|
|||||||
|
|
||||||
switch wizard.step {
|
switch wizard.step {
|
||||||
case StepSelectContext:
|
case StepSelectContext:
|
||||||
if wizard.cursor >= 0 && wizard.cursor < len(wizard.contexts) {
|
filteredContexts := wizard.getFilteredContexts()
|
||||||
wizard.selectedContext = wizard.contexts[wizard.cursor]
|
if wizard.cursor >= 0 && wizard.cursor < len(filteredContexts) {
|
||||||
|
wizard.selectedContext = filteredContexts[wizard.cursor]
|
||||||
wizard.step = StepSelectNamespace
|
wizard.step = StepSelectNamespace
|
||||||
wizard.cursor = 0
|
wizard.cursor = 0
|
||||||
|
wizard.clearSearchFilter()
|
||||||
wizard.loading = true
|
wizard.loading = true
|
||||||
return m, loadNamespacesCmd(m.ui.discovery, wizard.selectedContext)
|
return m, loadNamespacesCmd(m.ui.discovery, wizard.selectedContext)
|
||||||
}
|
}
|
||||||
|
|
||||||
case StepSelectNamespace:
|
case StepSelectNamespace:
|
||||||
if wizard.cursor >= 0 && wizard.cursor < len(wizard.namespaces) {
|
filteredNamespaces := wizard.getFilteredNamespaces()
|
||||||
wizard.selectedNamespace = wizard.namespaces[wizard.cursor]
|
if wizard.cursor >= 0 && wizard.cursor < len(filteredNamespaces) {
|
||||||
|
wizard.selectedNamespace = filteredNamespaces[wizard.cursor]
|
||||||
wizard.step = StepSelectResourceType
|
wizard.step = StepSelectResourceType
|
||||||
wizard.cursor = 0
|
wizard.cursor = 0
|
||||||
|
wizard.clearSearchFilter()
|
||||||
wizard.inputMode = InputModeList
|
wizard.inputMode = InputModeList
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -403,13 +449,15 @@ func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
case ResourceTypeService:
|
case ResourceTypeService:
|
||||||
if wizard.cursor >= 0 && wizard.cursor < len(wizard.services) {
|
filteredServices := wizard.getFilteredServices()
|
||||||
wizard.resourceValue = wizard.services[wizard.cursor].Name
|
if wizard.cursor >= 0 && wizard.cursor < len(filteredServices) {
|
||||||
|
wizard.resourceValue = filteredServices[wizard.cursor].Name
|
||||||
wizard.step = StepEnterRemotePort
|
wizard.step = StepEnterRemotePort
|
||||||
wizard.clearTextInput()
|
wizard.clearTextInput()
|
||||||
|
wizard.clearSearchFilter()
|
||||||
|
|
||||||
// Get ports from selected service
|
// Get ports from selected service
|
||||||
wizard.detectedPorts = wizard.services[wizard.cursor].Ports
|
wizard.detectedPorts = filteredServices[wizard.cursor].Ports
|
||||||
if len(wizard.detectedPorts) > 0 {
|
if len(wizard.detectedPorts) > 0 {
|
||||||
wizard.inputMode = InputModeList
|
wizard.inputMode = InputModeList
|
||||||
wizard.cursor = 0
|
wizard.cursor = 0
|
||||||
|
|||||||
@@ -1,9 +1,34 @@
|
|||||||
package ui
|
package ui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/nvm/kportal/internal/k8s"
|
"github.com/nvm/kportal/internal/k8s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// filterStrings filters a slice of strings by a search filter (case-insensitive substring match)
|
||||||
|
func filterStrings(items []string, filter string) []string {
|
||||||
|
if filter == "" {
|
||||||
|
return items
|
||||||
|
}
|
||||||
|
filtered := []string{}
|
||||||
|
filterLower := strings.ToLower(filter)
|
||||||
|
for _, item := range items {
|
||||||
|
if strings.Contains(strings.ToLower(item), filterLower) {
|
||||||
|
filtered = append(filtered, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesFilter checks if a string matches the filter (case-insensitive substring match)
|
||||||
|
func matchesFilter(item, filter string) bool {
|
||||||
|
if filter == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return strings.Contains(strings.ToLower(item), strings.ToLower(filter))
|
||||||
|
}
|
||||||
|
|
||||||
// ViewMode represents the current view state of the UI
|
// ViewMode represents the current view state of the UI
|
||||||
type ViewMode int
|
type ViewMode int
|
||||||
|
|
||||||
@@ -87,6 +112,7 @@ type AddWizardState struct {
|
|||||||
cursor int
|
cursor int
|
||||||
scrollOffset int // For scrolling long lists
|
scrollOffset int // For scrolling long lists
|
||||||
textInput string
|
textInput string
|
||||||
|
searchFilter string // For filtering lists (contexts, namespaces, services)
|
||||||
loading bool
|
loading bool
|
||||||
error error
|
error error
|
||||||
|
|
||||||
@@ -142,14 +168,14 @@ func (w *AddWizardState) moveCursor(delta int) {
|
|||||||
|
|
||||||
switch w.step {
|
switch w.step {
|
||||||
case StepSelectContext:
|
case StepSelectContext:
|
||||||
maxItems = len(w.contexts)
|
maxItems = len(w.getFilteredContexts())
|
||||||
case StepSelectNamespace:
|
case StepSelectNamespace:
|
||||||
maxItems = len(w.namespaces)
|
maxItems = len(w.getFilteredNamespaces())
|
||||||
case StepSelectResourceType:
|
case StepSelectResourceType:
|
||||||
maxItems = 3 // Three resource types
|
maxItems = 3 // Three resource types
|
||||||
case StepEnterResource:
|
case StepEnterResource:
|
||||||
if w.selectedResourceType == ResourceTypeService {
|
if w.selectedResourceType == ResourceTypeService {
|
||||||
maxItems = len(w.services)
|
maxItems = len(w.getFilteredServices())
|
||||||
}
|
}
|
||||||
case StepEnterRemotePort:
|
case StepEnterRemotePort:
|
||||||
if len(w.detectedPorts) > 0 {
|
if len(w.detectedPorts) > 0 {
|
||||||
@@ -300,3 +326,40 @@ func (w *RemoveWizardState) getSelectedForwards() []RemovableForward {
|
|||||||
}
|
}
|
||||||
return selected
|
return selected
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getFilteredContexts returns contexts filtered by search string
|
||||||
|
func (w *AddWizardState) getFilteredContexts() []string {
|
||||||
|
if w.searchFilter == "" {
|
||||||
|
return w.contexts
|
||||||
|
}
|
||||||
|
return filterStrings(w.contexts, w.searchFilter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFilteredNamespaces returns namespaces filtered by search string
|
||||||
|
func (w *AddWizardState) getFilteredNamespaces() []string {
|
||||||
|
if w.searchFilter == "" {
|
||||||
|
return w.namespaces
|
||||||
|
}
|
||||||
|
return filterStrings(w.namespaces, w.searchFilter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFilteredServices returns services filtered by search string
|
||||||
|
func (w *AddWizardState) getFilteredServices() []k8s.ServiceInfo {
|
||||||
|
if w.searchFilter == "" {
|
||||||
|
return w.services
|
||||||
|
}
|
||||||
|
filtered := []k8s.ServiceInfo{}
|
||||||
|
for _, svc := range w.services {
|
||||||
|
if matchesFilter(svc.Name, w.searchFilter) {
|
||||||
|
filtered = append(filtered, svc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearSearchFilter clears the search filter and resets cursor/scroll
|
||||||
|
func (w *AddWizardState) clearSearchFilter() {
|
||||||
|
w.searchFilter = ""
|
||||||
|
w.cursor = 0
|
||||||
|
w.scrollOffset = 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,350 @@
|
|||||||
|
package ui
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/nvm/kportal/internal/k8s"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterStrings(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
items []string
|
||||||
|
filter string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty filter returns all items",
|
||||||
|
items: []string{"namespace-1", "namespace-2", "namespace-3"},
|
||||||
|
filter: "",
|
||||||
|
expected: []string{"namespace-1", "namespace-2", "namespace-3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter matches multiple items",
|
||||||
|
items: []string{"prod-api", "prod-db", "staging-api", "dev-api"},
|
||||||
|
filter: "prod",
|
||||||
|
expected: []string{"prod-api", "prod-db"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter matches single item",
|
||||||
|
items: []string{"namespace-1", "namespace-2", "namespace-3"},
|
||||||
|
filter: "2",
|
||||||
|
expected: []string{"namespace-2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter matches no items",
|
||||||
|
items: []string{"namespace-1", "namespace-2", "namespace-3"},
|
||||||
|
filter: "xyz",
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive matching",
|
||||||
|
items: []string{"Production", "Staging", "Development"},
|
||||||
|
filter: "prod",
|
||||||
|
expected: []string{"Production"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial string matching",
|
||||||
|
items: []string{"my-app-frontend", "my-app-backend", "other-service"},
|
||||||
|
filter: "app",
|
||||||
|
expected: []string{"my-app-frontend", "my-app-backend"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := filterStrings(tt.items, tt.filter)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchesFilter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
item string
|
||||||
|
filter string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty filter matches everything",
|
||||||
|
item: "namespace-1",
|
||||||
|
filter: "",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exact match",
|
||||||
|
item: "namespace-1",
|
||||||
|
filter: "namespace-1",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "partial match",
|
||||||
|
item: "production-api",
|
||||||
|
filter: "prod",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no match",
|
||||||
|
item: "namespace-1",
|
||||||
|
filter: "xyz",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive match",
|
||||||
|
item: "Production",
|
||||||
|
filter: "prod",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := matchesFilter(tt.item, tt.filter)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredContexts(t *testing.T) {
|
||||||
|
wizard := &AddWizardState{
|
||||||
|
contexts: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filter string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no filter returns all",
|
||||||
|
filter: "",
|
||||||
|
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'prod'",
|
||||||
|
filter: "prod",
|
||||||
|
expected: []string{"prod-cluster"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'cluster'",
|
||||||
|
filter: "cluster",
|
||||||
|
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'staging'",
|
||||||
|
filter: "staging",
|
||||||
|
expected: []string{"staging-cluster"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter with no matches",
|
||||||
|
filter: "xyz",
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wizard.searchFilter = tt.filter
|
||||||
|
result := wizard.getFilteredContexts()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredNamespaces(t *testing.T) {
|
||||||
|
wizard := &AddWizardState{
|
||||||
|
namespaces: []string{
|
||||||
|
"kube-system", "kube-public", "default",
|
||||||
|
"prod-api", "prod-db", "staging-api", "staging-db",
|
||||||
|
"monitoring", "logging",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filter string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no filter returns all",
|
||||||
|
filter: "",
|
||||||
|
expected: []string{
|
||||||
|
"kube-system", "kube-public", "default",
|
||||||
|
"prod-api", "prod-db", "staging-api", "staging-db",
|
||||||
|
"monitoring", "logging",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'prod'",
|
||||||
|
filter: "prod",
|
||||||
|
expected: []string{"prod-api", "prod-db"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'kube'",
|
||||||
|
filter: "kube",
|
||||||
|
expected: []string{"kube-system", "kube-public"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'api'",
|
||||||
|
filter: "api",
|
||||||
|
expected: []string{"prod-api", "staging-api"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'ing' (partial match)",
|
||||||
|
filter: "ing",
|
||||||
|
expected: []string{"staging-api", "staging-db", "monitoring", "logging"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wizard.searchFilter = tt.filter
|
||||||
|
result := wizard.getFilteredNamespaces()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilteredServices(t *testing.T) {
|
||||||
|
wizard := &AddWizardState{
|
||||||
|
services: []k8s.ServiceInfo{
|
||||||
|
{Name: "api-gateway"},
|
||||||
|
{Name: "api-backend"},
|
||||||
|
{Name: "database"},
|
||||||
|
{Name: "redis-cache"},
|
||||||
|
{Name: "postgres-db"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filter string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no filter returns all",
|
||||||
|
filter: "",
|
||||||
|
expected: []string{"api-gateway", "api-backend", "database", "redis-cache", "postgres-db"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'api'",
|
||||||
|
filter: "api",
|
||||||
|
expected: []string{"api-gateway", "api-backend"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'db'",
|
||||||
|
filter: "db",
|
||||||
|
expected: []string{"postgres-db"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'base'",
|
||||||
|
filter: "base",
|
||||||
|
expected: []string{"database"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter by 'redis'",
|
||||||
|
filter: "redis",
|
||||||
|
expected: []string{"redis-cache"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter with no matches",
|
||||||
|
filter: "xyz",
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wizard.searchFilter = tt.filter
|
||||||
|
result := wizard.getFilteredServices()
|
||||||
|
resultNames := make([]string, len(result))
|
||||||
|
for i, svc := range result {
|
||||||
|
resultNames[i] = svc.Name
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expected, resultNames)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearSearchFilter(t *testing.T) {
|
||||||
|
wizard := &AddWizardState{
|
||||||
|
searchFilter: "test",
|
||||||
|
cursor: 5,
|
||||||
|
scrollOffset: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
wizard.clearSearchFilter()
|
||||||
|
|
||||||
|
assert.Equal(t, "", wizard.searchFilter, "searchFilter should be cleared")
|
||||||
|
assert.Equal(t, 0, wizard.cursor, "cursor should be reset to 0")
|
||||||
|
assert.Equal(t, 0, wizard.scrollOffset, "scrollOffset should be reset to 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMoveCursorWithFilteredLists(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
step AddWizardStep
|
||||||
|
contexts []string
|
||||||
|
namespaces []string
|
||||||
|
searchFilter string
|
||||||
|
initialCursor int
|
||||||
|
delta int
|
||||||
|
expectedCursor int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "move down in filtered contexts",
|
||||||
|
step: StepSelectContext,
|
||||||
|
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
|
||||||
|
searchFilter: "prod",
|
||||||
|
initialCursor: 0,
|
||||||
|
delta: 1,
|
||||||
|
expectedCursor: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cannot move beyond filtered list",
|
||||||
|
step: StepSelectContext,
|
||||||
|
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
|
||||||
|
searchFilter: "prod",
|
||||||
|
initialCursor: 1,
|
||||||
|
delta: 1,
|
||||||
|
expectedCursor: 1, // Should stay at 1 (last item in filtered list)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "move up in filtered list",
|
||||||
|
step: StepSelectNamespace,
|
||||||
|
namespaces: []string{"ns-1", "ns-2", "ns-3", "other"},
|
||||||
|
searchFilter: "ns",
|
||||||
|
initialCursor: 2,
|
||||||
|
delta: -1,
|
||||||
|
expectedCursor: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cannot move above 0",
|
||||||
|
step: StepSelectNamespace,
|
||||||
|
namespaces: []string{"ns-1", "ns-2", "ns-3"},
|
||||||
|
searchFilter: "ns",
|
||||||
|
initialCursor: 0,
|
||||||
|
delta: -1,
|
||||||
|
expectedCursor: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wizard := &AddWizardState{
|
||||||
|
step: tt.step,
|
||||||
|
inputMode: InputModeList,
|
||||||
|
cursor: tt.initialCursor,
|
||||||
|
contexts: tt.contexts,
|
||||||
|
namespaces: tt.namespaces,
|
||||||
|
searchFilter: tt.searchFilter,
|
||||||
|
}
|
||||||
|
|
||||||
|
wizard.moveCursor(tt.delta)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCursor, wizard.cursor)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
+89
-38
@@ -45,6 +45,12 @@ func (m model) renderSelectContext() string {
|
|||||||
b.WriteString(renderHeader("Add Port Forward", renderProgress(1, 7)))
|
b.WriteString(renderHeader("Add Port Forward", renderProgress(1, 7)))
|
||||||
b.WriteString("Select Kubernetes Context:\n\n")
|
b.WriteString("Select Kubernetes Context:\n\n")
|
||||||
|
|
||||||
|
// Show search input if there's a filter active
|
||||||
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
if wizard.loading {
|
if wizard.loading {
|
||||||
b.WriteString(spinnerStyle.Render("⣾ Loading contexts..."))
|
b.WriteString(spinnerStyle.Render("⣾ Loading contexts..."))
|
||||||
} else if wizard.error != nil {
|
} else if wizard.error != nil {
|
||||||
@@ -52,46 +58,56 @@ func (m model) renderSelectContext() string {
|
|||||||
} else if len(wizard.contexts) == 0 {
|
} else if len(wizard.contexts) == 0 {
|
||||||
b.WriteString(mutedStyle.Render("No contexts found in kubeconfig"))
|
b.WriteString(mutedStyle.Render("No contexts found in kubeconfig"))
|
||||||
} else {
|
} else {
|
||||||
const viewportHeight = 20
|
filteredContexts := wizard.getFilteredContexts()
|
||||||
totalItems := len(wizard.contexts)
|
if len(filteredContexts) == 0 {
|
||||||
|
b.WriteString(mutedStyle.Render("No matching contexts"))
|
||||||
|
} else {
|
||||||
|
const viewportHeight = 20
|
||||||
|
totalItems := len(filteredContexts)
|
||||||
|
|
||||||
// Show scroll up indicator if there are items above the viewport
|
// Show scroll up indicator if there are items above the viewport
|
||||||
if wizard.scrollOffset > 0 {
|
if wizard.scrollOffset > 0 {
|
||||||
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
|
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate visible range
|
|
||||||
start := wizard.scrollOffset
|
|
||||||
end := wizard.scrollOffset + viewportHeight
|
|
||||||
if end > totalItems {
|
|
||||||
end = totalItems
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render visible contexts with (current) marker on first one
|
|
||||||
for i := start; i < end; i++ {
|
|
||||||
prefix := " "
|
|
||||||
text := wizard.contexts[i]
|
|
||||||
if i == 0 {
|
|
||||||
text += mutedStyle.Render(" (current)")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if i == wizard.cursor {
|
// Calculate visible range
|
||||||
prefix = "▸ "
|
start := wizard.scrollOffset
|
||||||
b.WriteString(selectedStyle.Render(prefix + text))
|
end := wizard.scrollOffset + viewportHeight
|
||||||
} else {
|
if end > totalItems {
|
||||||
b.WriteString(prefix + text)
|
end = totalItems
|
||||||
}
|
}
|
||||||
b.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Show scroll down indicator if there are items below the viewport
|
// Render visible contexts with (current) marker on first one (only if not filtered)
|
||||||
if end < totalItems {
|
for i := start; i < end; i++ {
|
||||||
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
|
prefix := " "
|
||||||
|
text := filteredContexts[i]
|
||||||
|
// Only show (current) marker if no filter and this is the first item in original list
|
||||||
|
if wizard.searchFilter == "" && i == 0 {
|
||||||
|
text += mutedStyle.Render(" (current)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == wizard.cursor {
|
||||||
|
prefix = "▸ "
|
||||||
|
b.WriteString(selectedStyle.Render(prefix + text))
|
||||||
|
} else {
|
||||||
|
b.WriteString(prefix + text)
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show scroll down indicator if there are items below the viewport
|
||||||
|
if end < totalItems {
|
||||||
|
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Cancel", len(wizard.getFilteredContexts()), len(wizard.contexts))))
|
||||||
|
} else {
|
||||||
|
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
|
||||||
|
}
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
@@ -105,6 +121,12 @@ func (m model) renderSelectNamespace() string {
|
|||||||
|
|
||||||
b.WriteString("Select Namespace:\n\n")
|
b.WriteString("Select Namespace:\n\n")
|
||||||
|
|
||||||
|
// Show search input if there's a filter active
|
||||||
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
if wizard.loading {
|
if wizard.loading {
|
||||||
b.WriteString(spinnerStyle.Render("⣾ Loading namespaces..."))
|
b.WriteString(spinnerStyle.Render("⣾ Loading namespaces..."))
|
||||||
} else if wizard.error != nil {
|
} else if wizard.error != nil {
|
||||||
@@ -113,11 +135,20 @@ func (m model) renderSelectNamespace() string {
|
|||||||
} else if len(wizard.namespaces) == 0 {
|
} else if len(wizard.namespaces) == 0 {
|
||||||
b.WriteString(mutedStyle.Render("No namespaces found"))
|
b.WriteString(mutedStyle.Render("No namespaces found"))
|
||||||
} else {
|
} else {
|
||||||
b.WriteString(renderList(wizard.namespaces, wizard.cursor, " ", wizard.scrollOffset))
|
filteredNamespaces := wizard.getFilteredNamespaces()
|
||||||
|
if len(filteredNamespaces) == 0 {
|
||||||
|
b.WriteString(mutedStyle.Render("No matching namespaces"))
|
||||||
|
} else {
|
||||||
|
b.WriteString(renderList(filteredNamespaces, wizard.cursor, " ", wizard.scrollOffset))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredNamespaces()), len(wizard.namespaces))))
|
||||||
|
} else {
|
||||||
|
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
|
||||||
|
}
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
@@ -242,21 +273,41 @@ func (m model) renderEnterResource() string {
|
|||||||
case ResourceTypeService:
|
case ResourceTypeService:
|
||||||
b.WriteString("Select service:\n\n")
|
b.WriteString("Select service:\n\n")
|
||||||
|
|
||||||
|
// Show search input if there's a filter active
|
||||||
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
if wizard.loading {
|
if wizard.loading {
|
||||||
b.WriteString(spinnerStyle.Render("⣾ Loading services..."))
|
b.WriteString(spinnerStyle.Render("⣾ Loading services..."))
|
||||||
} else if len(wizard.services) == 0 {
|
} else if len(wizard.services) == 0 {
|
||||||
b.WriteString(mutedStyle.Render("No services found"))
|
b.WriteString(mutedStyle.Render("No services found"))
|
||||||
} else {
|
} else {
|
||||||
serviceNames := make([]string, len(wizard.services))
|
filteredServices := wizard.getFilteredServices()
|
||||||
for i, svc := range wizard.services {
|
if len(filteredServices) == 0 {
|
||||||
serviceNames[i] = svc.Name
|
b.WriteString(mutedStyle.Render("No matching services"))
|
||||||
|
} else {
|
||||||
|
serviceNames := make([]string, len(filteredServices))
|
||||||
|
for i, svc := range filteredServices {
|
||||||
|
serviceNames[i] = svc.Name
|
||||||
|
}
|
||||||
|
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
|
||||||
}
|
}
|
||||||
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("\n")
|
b.WriteString("\n")
|
||||||
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
|
// Show appropriate help text based on resource type and filter state
|
||||||
|
if wizard.selectedResourceType == ResourceTypeService {
|
||||||
|
if wizard.searchFilter != "" {
|
||||||
|
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredServices()), len(wizard.services))))
|
||||||
|
} else {
|
||||||
|
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
|
||||||
|
}
|
||||||
|
|
||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user