Compare commits

...

4 Commits

Author SHA1 Message Date
lukaszraczylo 39fe4286b4 Fix the watchdog being too aggressive. 2025-11-24 13:19:44 +00:00
lukaszraczylo 2fdc5912e7 healtcheck improvements (#4)
* Advanced healtchecks.
* Add watchdog for stale connections handling.
2025-11-24 13:00:19 +00:00
lukaszraczylo 7df161aee0 bugfixes nov2025 (#3)
* Fix enter misbehaving.
* Cleanup after previous tui implementation.
* Fix race condition and improve logging
* Add filtering of the namespaces by text input in the wizard UI
2025-11-24 11:09:23 +00:00
lukaszraczylo f41c316b2b Add configuration wizard. (#2)
* Add configuration wizard.
2025-11-24 02:28:08 +00:00
41 changed files with 8588 additions and 928 deletions
+21
View File
@@ -1,6 +1,27 @@
# Example kportal configuration
# Copy this file to your project and customize as needed
# Optional: Health check configuration
# These settings control how kportal monitors connection health and detects stale connections
healthCheck:
interval: "3s" # How often to check connection health (default: 3s)
timeout: "2s" # Timeout for health check operations (default: 2s)
method: "data-transfer" # Health check method: "tcp-dial" or "data-transfer" (default: data-transfer)
# - tcp-dial: Simple TCP connection test (fast, less reliable)
# - data-transfer: Attempts to read data (slower, more reliable)
maxConnectionAge: "25m" # Maximum connection age before proactive reconnect (default: 25m)
# Helps avoid Kubernetes API server timeouts (typically 30m)
maxIdleTime: "10m" # Maximum idle time before marking as stale (default: 10m)
# Connections with no data transfer are marked stale
# Optional: Reliability configuration
# These settings improve connection stability for long-running transfers
reliability:
tcpKeepalive: "30s" # TCP keepalive interval for OS-level connection monitoring (default: 30s)
dialTimeout: "30s" # Connection dial timeout (default: 30s)
retryOnStale: true # Automatically reconnect when stale connections detected (default: true)
watchdogPeriod: "30s" # Goroutine watchdog check interval to detect hung workers (default: 30s)
contexts:
# Production context
- name: production
+60 -6
View File
@@ -1,10 +1,16 @@
# kportal
<p align="center">
<img src="docs/kportal-logo-dark.svg" alt="kportal logo" width="400">
</p>
[![Release](https://img.shields.io/github/v/release/lukaszraczylo/kportal)](https://github.com/lukaszraczylo/kportal/releases)
[![License](https://img.shields.io/github/license/lukaszraczylo/kportal)](LICENSE)
[![Go Report Card](https://goreportcard.com/badge/github.com/lukaszraczylo/kportal)](https://goreportcard.com/report/github.com/lukaszraczylo/kportal)
<p align="center">
<a href="https://github.com/lukaszraczylo/kportal/releases"><img src="https://img.shields.io/github/v/release/lukaszraczylo/kportal" alt="Release"></a>
<a href="LICENSE"><img src="https://img.shields.io/github/license/lukaszraczylo/kportal" alt="License"></a>
<a href="https://goreportcard.com/report/github.com/lukaszraczylo/kportal"><img src="https://goreportcard.com/badge/github.com/lukaszraczylo/kportal" alt="Go Report Card"></a>
</p>
**Modern Kubernetes port-forward manager with interactive terminal UI**
<p align="center">
<strong>Modern Kubernetes port-forward manager with interactive terminal UI</strong>
</p>
kportal simplifies managing multiple Kubernetes port-forwards with an elegant, interactive terminal interface. Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea), it provides real-time status updates, automatic reconnection, and hot-reload configuration support.
@@ -13,9 +19,13 @@ kportal simplifies managing multiple Kubernetes port-forwards with an elegant, i
## ✨ Features
- 🎯 **Interactive TUI** - Beautiful terminal interface with keyboard navigation (↑↓/jk, Space to toggle, q to quit)
- **Live Add** - Add new port-forwards on-the-fly without editing config files or restarting
- ✏️ **Live Edit** - Modify existing port-forwards (ports, resources, aliases) in real-time
- 🗑️ **Live Delete** - Remove port-forwards instantly from the running session
- 🔄 **Auto-Reconnect** - Automatic retry with exponential backoff on connection failures (max 10s)
-**Hot-Reload** - Update configuration without restarting - changes applied automatically
- 🏥 **Health Checks** - Real-time port forward status monitoring with 5-second intervals
- 🏥 **Advanced Health Checks** - Multiple check methods (tcp-dial, data-transfer) with stale connection detection
- 🛡️ **Goroutine Watchdog** - Detects and recovers from completely hung workers
- 🎨 **Multi-Context** - Support for multiple Kubernetes contexts and namespaces
- 📦 **Batch Management** - Manage all port-forwards from a single configuration file
- 🔌 **Toggle Forwards** - Enable/disable individual port-forwards on the fly with Space key
@@ -93,6 +103,9 @@ kportal
3. **Navigate the interface**:
- `↑↓` or `j/k` - Navigate through forwards
- `Space` or `Enter` - Toggle forward on/off
- `a` - Add new port-forward interactively
- `e` - Edit selected port-forward
- `d` - Delete selected port-forward
- `q` - Quit application
## 📖 Configuration
@@ -182,6 +195,47 @@ contexts:
- **Service**: `service/service-name` or `svc/service-name`
- **Deployment**: `deployment/deployment-name` or `deploy/deployment-name`
### Health Check & Reliability (Advanced)
kportal includes advanced health checking to prevent stale connections during long-running operations like database dumps:
```yaml
healthCheck:
interval: "3s" # Health check frequency (default: 3s)
timeout: "2s" # Health check timeout (default: 2s)
method: "data-transfer" # Check method: "tcp-dial" or "data-transfer" (default: data-transfer)
maxConnectionAge: "25m" # Proactive reconnect before k8s timeout (default: 25m)
maxIdleTime: "10m" # Detect hung connections (default: 10m)
reliability:
tcpKeepalive: "30s" # TCP keepalive interval (default: 30s)
dialTimeout: "30s" # Connection dial timeout (default: 30s)
retryOnStale: true # Auto-reconnect stale connections (default: true)
```
**Health Check Methods:**
- **`tcp-dial`**: Fast TCP connection test - verifies local port is listening
- **`data-transfer`**: More reliable - attempts to read data to verify tunnel is functional
**Stale Detection:**
- **Max Connection Age**: Kubernetes API typically has 30-minute timeout. kportal reconnects at 25 minutes by default to avoid hitting this limit. **Important**: Age-based reconnection only occurs when the connection is ALSO idle - active transfers (like database dumps) are never interrupted.
- **Max Idle Time**: Detects connections with no data transfer, common when intermediate firewalls drop idle TCP connections
**Use Case Example - Database Dumps:**
```yaml
# Optimized for long-running pg_dump
healthCheck:
method: "data-transfer"
maxConnectionAge: "20m" # Only applies when idle - won't interrupt active dumps
maxIdleTime: "5m" # Detects truly stale connections
reliability:
tcpKeepalive: "30s"
retryOnStale: true
```
This configuration ensures multi-hour database dumps complete without interruption. The `maxConnectionAge` will only trigger reconnection if the connection has been idle for more than `maxIdleTime`, preventing interruption of active data transfers.
## 🎮 Usage
### Interactive Mode (Default)
+171
View File
@@ -0,0 +1,171 @@
# Interactive Add/Remove Wizards
kportal now includes interactive wizards for adding and removing port forwards directly from the running UI!
## Quick Start
Run kportal normally:
```bash
./kportal
```
From the main view:
- Press **`n`** to add a new port forward
- Press **`d`** to delete existing port forwards
## Add Forward Wizard (`n` key)
The wizard guides you through 7 steps to add a new forward:
### Step 1: Select Context
Choose from available Kubernetes contexts in your kubeconfig.
### Step 2: Select Namespace
Pick the namespace where your resource lives.
### Step 3: Select Resource Type
Three options:
- **Pod (by name prefix)** - Forward to a specific pod by prefix matching
- **Pod (by label selector)** - Forward to pods matching labels (survives restarts)
- **Service** - Most stable, load-balanced option
### Step 4: Enter Resource
- **Pod prefix**: Type a prefix like `nginx-` to match pods
- **Label selector**: Enter labels like `app=nginx,env=prod`
- **Service**: Select from a list of services
The wizard shows real-time validation and matching resources!
### Step 5: Remote Port
Enter the port number on the remote resource. The wizard displays detected ports from running containers.
### Step 6: Local Port
Enter the local port to bind to. The wizard checks availability in real-time.
### Step 7: Confirmation
Review your configuration and optionally add an alias (friendly name). Confirm to save!
### Navigation Keys
- **`↑`/`↓`** or **`j`/`k`** - Navigate options
- **`Enter`** - Confirm and proceed to next step
- **`Esc`** - Go back one step (or cancel on first step)
- **`Ctrl+C`** - Hard cancel and return to main view
- **`Backspace`** - Delete characters in text fields
## Remove Forward Wizard (`d` key)
Multi-select interface for removing forwards:
1. **Select forwards**: Use arrow keys to navigate, `Space` to toggle selection
2. **Confirm removal**: Press `Enter` and confirm your choice
### Navigation Keys
- **`↑`/`↓`** or **`j`/`k`** - Navigate forwards
- **`Space`** - Toggle selection of current forward
- **`a`** - Select all forwards
- **`n`** - Deselect all forwards
- **`Enter`** - Proceed to confirmation
- **`Esc`** - Cancel and return to main view
- **`Ctrl+C`** - Hard cancel
## Auto Hot-Reload
When you save a forward via the wizard:
1. The wizard writes to `.kportal.yaml` atomically
2. The file watcher detects the change (~100ms)
3. The manager reloads and starts the new forward
4. The UI updates automatically
No restart needed!
## Error Handling
The wizards handle errors gracefully:
- **Cluster unreachable**: Shows error but allows manual entry
- **Port conflicts**: Displays which process is using the port
- **Invalid selectors**: Shows validation errors in real-time
- **Duplicate ports**: Prevents adding forwards with conflicting ports
## Tips
### Pod Prefix Matching
When using pod prefix, you can type just the app name:
- `nginx` matches `nginx-deployment-abc123`
- `postgres` matches `postgres-statefulset-0`
### Label Selectors
Use standard Kubernetes label syntax:
- `app=nginx` - Single label
- `app=nginx,env=prod` - Multiple labels (comma-separated)
- Real-time validation shows matching pods as you type!
### Aliases
Use aliases for cleaner UI display:
- Instead of: `production/default/pod/nginx-deployment-abc123:80→8080`
- Shows as: `my-nginx:80→8080`
### Quick Selection
In list views, you can use `j`/`k` (Vim-style) or arrow keys for navigation.
## Example Workflow
Adding a forward for a PostgreSQL database:
1. Press `n` in main view
2. Select context: `production` (arrow keys + Enter)
3. Select namespace: `default` (arrow keys + Enter)
4. Select type: `Service` (arrow keys + Enter)
5. Select service: `postgres` (arrow keys + Enter)
6. Enter remote port: `5432` (type + Enter)
7. Enter local port: `5432` (type + Enter)
8. Add alias: `prod-db` (optional, type + Enter)
9. Confirm: Select "Add to .kportal.yaml" (Enter)
Done! The forward starts automatically within seconds.
## Architecture
The wizards use:
- **Config Mutator**: Safe, atomic YAML writes (temp file + rename)
- **K8s Discovery**: Lists contexts, namespaces, pods, services
- **Modal Overlays**: Wizards appear centered over the main view
- **Async Validation**: Port checks and selector validation run in background
- **Hot-Reload Integration**: File watcher picks up changes automatically
## Troubleshooting
### Wizards not appearing?
Check that kportal can connect to your Kubernetes cluster:
```bash
kubectl cluster-info
```
### Port check showing wrong status?
The port check happens asynchronously. Wait a moment after typing for validation.
### Changes not appearing?
The file watcher triggers within 100ms. If changes aren't visible, check:
1. `.kportal.yaml` was written correctly
2. No validation errors in the file
3. kportal process is still running
---
**Navigation Summary**
Main View:
- `n` - New forward wizard
- `d` - Delete forward wizard
- `Space` - Toggle forward on/off
- `↑↓/jk` - Navigate forwards
- `q` - Quit
Wizards:
- `Enter` - Next step / Confirm
- `Esc` - Previous step / Cancel
- `Ctrl+C` - Hard cancel
- `↑↓/jk` - Navigate
- `Space` - Toggle (in delete wizard)
+106 -8
View File
@@ -7,23 +7,31 @@ import (
"log"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/go-logr/logr"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/converter"
"github.com/nvm/kportal/internal/forward"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/ui"
"k8s.io/klog/v2"
)
const (
defaultConfigFile = ".kportal.yaml"
defaultConfigFile = ".kportal.yaml"
initialForwardSettleTime = 100 * time.Millisecond
tableUpdateInterval = 2 * time.Second
)
var (
configFile = flag.String("c", defaultConfigFile, "Path to configuration file")
verbose = flag.Bool("v", false, "Enable verbose logging")
logFormat = flag.String("log-format", "text", "Log format: text or json")
check = flag.Bool("check", false, "Validate configuration and exit")
showVersion = flag.Bool("version", false, "Show version and exit")
convertInput = flag.String("convert", "", "Convert kftray JSON config to kportal YAML (provide input file path)")
@@ -39,6 +47,81 @@ func main() {
os.Exit(0)
}
// Validate config path security
if *configFile != "" {
absConfigPath, err := filepath.Abs(*configFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid config path: %v\n", err)
os.Exit(1)
}
absConfigPath = filepath.Clean(absConfigPath)
// Block system directories
systemDirs := []string{"/etc", "/sys", "/proc", "/dev"}
for _, sysDir := range systemDirs {
if strings.HasPrefix(absConfigPath, sysDir) {
fmt.Fprintf(os.Stderr, "Error: Config file cannot be in system directory: %s\n", sysDir)
os.Exit(1)
}
}
*configFile = absConfigPath
}
// Initialize structured logger
var logLevel logger.Level
var logFmt logger.Format
var logOutput io.Writer
if *verbose {
logLevel = logger.LevelDebug
logOutput = os.Stderr
} else {
logLevel = logger.LevelInfo
logOutput = io.Discard // Silence logger in non-verbose mode to prevent UI corruption
}
switch *logFormat {
case "json":
logFmt = logger.FormatJSON
default:
logFmt = logger.FormatText
}
logger.Init(logLevel, logFmt, logOutput)
// Configure klog (used by kubernetes client-go) to route through our logger
// This prevents k8s logs from interfering with the UI
//
// klog v2 uses multiple output mechanisms:
// 1. SetOutput() - for basic text output
// 2. SetLogger() - for structured/error logs (logr interface)
//
// We must configure BOTH to capture all logs including error messages
// that would otherwise bypass SetOutput() and write directly to stderr.
klog.LogToStderr(false) // Disable direct stderr writes
if *verbose {
// In verbose mode, route all klog through our structured logger at DEBUG level
klogLogger := logger.New(logger.LevelDebug, logFmt, os.Stderr)
// Configure text output routing
klogWriter := logger.NewKlogWriter(klogLogger)
klog.SetOutput(klogWriter)
// Configure structured/error log routing via logr interface
// This captures "Unhandled Error" and other structured logs that bypass SetOutput
logrSink := logger.NewLogrAdapter(klogLogger)
klog.SetLogger(logr.New(logrSink))
} else {
// In non-verbose mode, completely silence ALL klog output
klog.SetOutput(io.Discard)
// Also silence structured/error logs via a discard logger
silentLogger := logger.New(logger.LevelError+1, logFmt, io.Discard) // Level above ERROR = silence all
logrSink := logger.NewLogrAdapter(silentLogger)
klog.SetLogger(logr.New(logrSink))
}
// Handle conversion mode
if *convertInput != "" {
if err := converter.ConvertKFTrayToKPortal(*convertInput, *convertOutput); err != nil {
@@ -68,11 +151,8 @@ func main() {
log.SetOutput(io.Discard)
log.SetPrefix("")
log.SetFlags(0)
// Disable klog (used by kubernetes client-go)
klog.SetOutput(io.Discard)
klog.LogToStderr(false)
} else {
// Verbose mode - enable standard log formatting
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
@@ -101,8 +181,21 @@ func main() {
log.Printf("Loading configuration from: %s", *configFile)
}
// Create Kubernetes client pool and discovery for wizards
pool, err := k8s.NewClientPool()
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to create k8s client pool: %v\n", err)
fmt.Fprintf(os.Stderr, "Add/remove wizards will not be available\n")
}
discovery := k8s.NewDiscovery(pool)
mutator := config.NewMutator(*configFile)
// Create forward manager
manager := forward.NewManager(*verbose)
manager, err := forward.NewManager(*verbose)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating forward manager: %v\n", err)
os.Exit(1)
}
// Create UI (bubbletea for interactive, simple table for verbose)
var bubbleTeaUI *ui.BubbleTeaUI
@@ -117,6 +210,11 @@ func main() {
manager.DisableForward(id)
}
}, version)
// Set wizard dependencies
// Note: mutator is always available (for delete/edit), discovery requires valid kubeconfig (for add)
bubbleTeaUI.SetWizardDependencies(discovery, mutator, *configFile)
manager.SetStatusUI(bubbleTeaUI)
} else {
// Verbose mode with simple table
@@ -140,7 +238,7 @@ func main() {
// Start table update loop
go func() {
ticker := time.NewTicker(2 * time.Second)
ticker := time.NewTicker(tableUpdateInterval)
defer ticker.Stop()
for range ticker.C {
tableUI.Render()
@@ -211,7 +309,7 @@ func main() {
}()
// Give a moment for initial forwards to be added
time.Sleep(100 * time.Millisecond)
time.Sleep(initialForwardSettleTime)
// Start the bubbletea app (blocks until quit)
if err := bubbleTeaUI.Start(); err != nil {
+939 -395
View File
File diff suppressed because it is too large Load Diff
+132
View File
@@ -0,0 +1,132 @@
<svg width="310" height="150" viewBox="0 0 310 150" xmlns="http://www.w3.org/2000/svg" id="darkLogo">
<defs>
<!-- Simple turbulence for portal edges -->
<filter id="portalTurbulence" x="-50%" y="-50%" width="200%" height="200%">
<feTurbulence type="fractalNoise" baseFrequency="0.02 0.03" numOctaves="2" result="turbulence" seed="5">
<animate attributeName="seed" values="5;10;5" dur="8s" repeatCount="indefinite"/>
</feTurbulence>
<feDisplacementMap in2="turbulence" in="SourceGraphic" scale="2" xChannelSelector="R" yChannelSelector="G"/>
</filter>
<!-- Blue glow -->
<filter id="blueGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="5" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Orange glow -->
<filter id="orangeGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="5" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Text glow -->
<filter id="textGlow" x="-20%" y="-20%" width="140%" height="140%">
<feGaussianBlur stdDeviation="1" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Gradients -->
<radialGradient id="bluePortal" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#000814;stop-opacity:0.9"/>
<stop offset="20%" style="stop-color:#001845;stop-opacity:0.8"/>
<stop offset="50%" style="stop-color:#0077B6;stop-opacity:0.95"/>
<stop offset="80%" style="stop-color:#00B4D8;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#90E0EF;stop-opacity:1"/>
</radialGradient>
<radialGradient id="orangePortal" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#1A0E00;stop-opacity:0.9"/>
<stop offset="20%" style="stop-color:#3D2314;stop-opacity:0.8"/>
<stop offset="50%" style="stop-color:#F77F00;stop-opacity:0.95"/>
<stop offset="80%" style="stop-color:#FCBF49;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FFD6A5;stop-opacity:1"/>
</radialGradient>
</defs>
<!-- Blue Portal (LEFT) -->
<g id="bluePortalGroup">
<!-- Outer rings -->
<ellipse cx="50" cy="75" rx="35" ry="50" fill="none" stroke="#90E0EF" stroke-width="0.5" opacity="0.2"/>
<ellipse cx="50" cy="75" rx="30" ry="44" fill="none" stroke="#00B4D8" stroke-width="1" opacity="0.3"/>
<!-- Main portal -->
<ellipse cx="50" cy="75" rx="26" ry="40" fill="url(#bluePortal)" filter="url(#blueGlow)" opacity="0.95"/>
<!-- Inner energy rings -->
<ellipse cx="50" cy="75" rx="20" ry="32" fill="none" stroke="#00B4D8" stroke-width="2" opacity="0.7">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="50" cy="75" rx="14" ry="24" fill="none" stroke="#90E0EF" stroke-width="1.5" opacity="0.5">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="50" cy="75" rx="7" ry="12" fill="#000814" opacity="0.95"/>
</g>
<!-- Text: "kportal" -->
<!-- Orange K -->
<text x="76" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="#FCBF49" filter="url(#textGlow)">
k
<animate attributeName="x" values="76;79;76" dur="4s" repeatCount="indefinite"/>
</text>
<!-- White "porta" -->
<text x="105" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="white" filter="url(#textGlow)">
porta
</text>
<!-- Blue L -->
<text x="220" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="#00B4D8" filter="url(#textGlow)">
l
<animate attributeName="x" values="220;223;220" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Orange Portal (RIGHT) at x=260 -->
<g id="orangePortalGroup">
<!-- Outer rings -->
<ellipse cx="260" cy="75" rx="35" ry="50" fill="none" stroke="#FFD6A5" stroke-width="0.5" opacity="0.2"/>
<ellipse cx="260" cy="75" rx="30" ry="44" fill="none" stroke="#FCBF49" stroke-width="1" opacity="0.3"/>
<!-- Main portal -->
<ellipse cx="260" cy="75" rx="26" ry="40" fill="url(#orangePortal)" filter="url(#orangeGlow)" opacity="0.95"/>
<!-- Inner energy rings -->
<ellipse cx="260" cy="75" rx="20" ry="32" fill="none" stroke="#FCBF49" stroke-width="2" opacity="0.7">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="260" cy="75" rx="14" ry="24" fill="none" stroke="#FFD6A5" stroke-width="1.5" opacity="0.5">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="260" cy="75" rx="7" ry="12" fill="#1A0E00" opacity="0.95"/>
</g>
<!-- Energy connection between portals -->
<path d="M 76 75 Q 180 70 222 75" stroke="url(#energyGradient)" stroke-width="0.5" fill="none" opacity="0.3">
<animate attributeName="opacity" values="0.1;0.3;0.1" dur="4s" repeatCount="indefinite"/>
</path>
<defs>
<linearGradient id="energyGradient" x1="0%" y1="0%" x2="100%" y2="0%">
<stop offset="0%" style="stop-color:#00B4D8;stop-opacity:1"/>
<stop offset="50%" style="stop-color:#667eea;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FCBF49;stop-opacity:1"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 7.7 KiB

+128
View File
@@ -0,0 +1,128 @@
<svg width="310" height="150" viewBox="0 0 310 150" xmlns="http://www.w3.org/2000/svg" id="lightLogo">
<defs>
<!-- Simple turbulence for portal edges -->
<filter id="portalTurbulenceLight" x="-50%" y="-50%" width="200%" height="200%">
<feTurbulence type="fractalNoise" baseFrequency="0.02 0.03" numOctaves="2" result="turbulence" seed="5">
<animate attributeName="seed" values="5;10;5" dur="8s" repeatCount="indefinite"/>
</feTurbulence>
<feDisplacementMap in2="turbulence" in="SourceGraphic" scale="2" xChannelSelector="R" yChannelSelector="G"/>
</filter>
<!-- Blue glow for light background -->
<filter id="blueGlowLight" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="4" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Orange glow for light background -->
<filter id="orangeGlowLight" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="4" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Text shadow for light background -->
<filter id="textShadowLight" x="-20%" y="-20%" width="140%" height="140%">
<feDropShadow dx="0" dy="1" stdDeviation="0.5" flood-opacity="0.15"/>
</filter>
<!-- Enhanced gradients for light background -->
<radialGradient id="bluePortalLight" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#001529;stop-opacity:1"/>
<stop offset="20%" style="stop-color:#002766;stop-opacity:0.95"/>
<stop offset="50%" style="stop-color:#0066CC;stop-opacity:0.98"/>
<stop offset="80%" style="stop-color:#0099FF;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#66CCFF;stop-opacity:1"/>
</radialGradient>
<radialGradient id="orangePortalLight" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#2E1A00;stop-opacity:1"/>
<stop offset="20%" style="stop-color:#5C3317;stop-opacity:0.95"/>
<stop offset="50%" style="stop-color:#E66100;stop-opacity:0.98"/>
<stop offset="80%" style="stop-color:#FF9933;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FFBB66;stop-opacity:1"/>
</radialGradient>
</defs>
<!-- Blue Portal (LEFT) -->
<g id="bluePortalGroupLight">
<!-- Outer rings -->
<ellipse cx="50" cy="75" rx="35" ry="50" fill="none" stroke="#0099FF" stroke-width="0.8" opacity="0.3"/>
<ellipse cx="50" cy="75" rx="30" ry="44" fill="none" stroke="#0066CC" stroke-width="1.2" opacity="0.4"/>
<!-- Main portal -->
<ellipse cx="50" cy="75" rx="26" ry="40" fill="url(#bluePortalLight)" filter="url(#blueGlowLight)" opacity="1"/>
<!-- Inner energy rings -->
<ellipse cx="50" cy="75" rx="20" ry="32" fill="none" stroke="#0099FF" stroke-width="2" opacity="0.8">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="50" cy="75" rx="14" ry="24" fill="none" stroke="#66CCFF" stroke-width="1.5" opacity="0.6">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="50" cy="75" rx="7" ry="12" fill="#001529" opacity="1"/>
</g>
<!-- Text: "kportal" with dark colors for light background -->
<!-- Orange K -->
<text x="76" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#E66100" filter="url(#textShadowLight)">
k
<animate attributeName="x" values="76;79;76" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Dark "porta" for light background -->
<text x="105" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#2C3E50" filter="url(#textShadowLight)">
porta
</text>
<!-- Blue L -->
<text x="220" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#0066CC" filter="url(#textShadowLight)">
l
<animate attributeName="x" values="220;223;220" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Orange Portal (RIGHT) at x=260 -->
<g id="orangePortalGroupLight">
<!-- Outer rings -->
<ellipse cx="260" cy="75" rx="35" ry="50" fill="none" stroke="#FFBB66" stroke-width="0.8" opacity="0.3"/>
<ellipse cx="260" cy="75" rx="30" ry="44" fill="none" stroke="#FF9933" stroke-width="1.2" opacity="0.4"/>
<!-- Main portal -->
<ellipse cx="260" cy="75" rx="26" ry="40" fill="url(#orangePortalLight)" filter="url(#orangeGlowLight)" opacity="1"/>
<!-- Inner energy rings -->
<ellipse cx="260" cy="75" rx="20" ry="32" fill="none" stroke="#FF9933" stroke-width="2" opacity="0.8">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="260" cy="75" rx="14" ry="24" fill="none" stroke="#FFBB66" stroke-width="1.5" opacity="0.6">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="260" cy="75" rx="7" ry="12" fill="#2E1A00" opacity="1"/>
</g>
<!-- Energy connection between portals -->
<path d="M 76 75 Q 180 70 222 75" stroke="url(#energyGradientLight)" stroke-width="0.7" fill="none" opacity="0.4">
<animate attributeName="opacity" values="0.2;0.4;0.2" dur="4s" repeatCount="indefinite"/>
</path>
<defs>
<linearGradient id="energyGradientLight" x1="0%" y1="0%" x2="100%" y2="0%">
<stop offset="0%" style="stop-color:#0066CC;stop-opacity:1"/>
<stop offset="50%" style="stop-color:#8B7CC6;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#E66100;stop-opacity:1"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 199 KiB

After

Width:  |  Height:  |  Size: 184 KiB

+121 -36
View File
@@ -3,13 +3,123 @@ package config
import (
"fmt"
"os"
"time"
"gopkg.in/yaml.v3"
)
const (
maxConfigSize = 10 * 1024 * 1024 // 10MB
)
// Config represents the root configuration structure from .kportal.yaml
type Config struct {
Contexts []Context `yaml:"contexts"`
Contexts []Context `yaml:"contexts"`
HealthCheck *HealthCheckSpec `yaml:"healthCheck,omitempty"`
Reliability *ReliabilitySpec `yaml:"reliability,omitempty"`
}
// HealthCheckSpec configures health check behavior
type HealthCheckSpec struct {
Interval string `yaml:"interval,omitempty"` // e.g., "3s", "5s"
Timeout string `yaml:"timeout,omitempty"` // e.g., "2s"
Method string `yaml:"method,omitempty"` // "tcp-dial" | "data-transfer"
MaxConnectionAge string `yaml:"maxConnectionAge,omitempty"` // e.g., "25m" - reconnect before k8s timeout
MaxIdleTime string `yaml:"maxIdleTime,omitempty"` // e.g., "10m" - reconnect if no activity
}
// ReliabilitySpec configures connection reliability features
type ReliabilitySpec struct {
TCPKeepalive string `yaml:"tcpKeepalive,omitempty"` // e.g., "30s" - OS-level keepalive
DialTimeout string `yaml:"dialTimeout,omitempty"` // e.g., "30s" - connection dial timeout
RetryOnStale bool `yaml:"retryOnStale,omitempty"` // Auto-reconnect on stale detection
WatchdogPeriod string `yaml:"watchdogPeriod,omitempty"` // e.g., "30s" - goroutine watchdog interval
}
// GetHealthCheckIntervalOrDefault returns the health check interval or default value
func (c *Config) GetHealthCheckIntervalOrDefault() time.Duration {
if c.HealthCheck != nil && c.HealthCheck.Interval != "" {
if d, err := time.ParseDuration(c.HealthCheck.Interval); err == nil {
return d
}
}
return 3 * time.Second // Default: check every 3 seconds
}
// GetHealthCheckTimeoutOrDefault returns the health check timeout or default value
func (c *Config) GetHealthCheckTimeoutOrDefault() time.Duration {
if c.HealthCheck != nil && c.HealthCheck.Timeout != "" {
if d, err := time.ParseDuration(c.HealthCheck.Timeout); err == nil {
return d
}
}
return 2 * time.Second // Default: 2 second timeout
}
// GetHealthCheckMethod returns the health check method or default
func (c *Config) GetHealthCheckMethod() string {
if c.HealthCheck != nil && c.HealthCheck.Method != "" {
return c.HealthCheck.Method
}
return "data-transfer" // Default: more reliable data transfer test
}
// GetMaxConnectionAge returns the max connection age or default
func (c *Config) GetMaxConnectionAge() time.Duration {
if c.HealthCheck != nil && c.HealthCheck.MaxConnectionAge != "" {
if d, err := time.ParseDuration(c.HealthCheck.MaxConnectionAge); err == nil {
return d
}
}
return 25 * time.Minute // Default: 25 minutes (before typical 30min k8s timeout)
}
// GetMaxIdleTime returns the max idle time or default
func (c *Config) GetMaxIdleTime() time.Duration {
if c.HealthCheck != nil && c.HealthCheck.MaxIdleTime != "" {
if d, err := time.ParseDuration(c.HealthCheck.MaxIdleTime); err == nil {
return d
}
}
return 10 * time.Minute // Default: 10 minutes idle before reconnect
}
// GetTCPKeepalive returns the TCP keepalive duration or default
func (c *Config) GetTCPKeepalive() time.Duration {
if c.Reliability != nil && c.Reliability.TCPKeepalive != "" {
if d, err := time.ParseDuration(c.Reliability.TCPKeepalive); err == nil {
return d
}
}
return 30 * time.Second // Default: 30 second keepalive
}
// GetRetryOnStale returns whether to retry on stale connections
func (c *Config) GetRetryOnStale() bool {
if c.Reliability != nil {
return c.Reliability.RetryOnStale
}
return true // Default: enabled
}
// GetWatchdogPeriod returns the goroutine watchdog check period or default
func (c *Config) GetWatchdogPeriod() time.Duration {
if c.Reliability != nil && c.Reliability.WatchdogPeriod != "" {
if d, err := time.ParseDuration(c.Reliability.WatchdogPeriod); err == nil {
return d
}
}
return 30 * time.Second // Default: check every 30 seconds
}
// GetDialTimeout returns the connection dial timeout or default
func (c *Config) GetDialTimeout() time.Duration {
if c.Reliability != nil && c.Reliability.DialTimeout != "" {
if d, err := time.ParseDuration(c.Reliability.DialTimeout); err == nil {
return d
}
}
return 30 * time.Second // Default: 30 second dial timeout
}
// Context represents a Kubernetes context with its namespaces
@@ -80,6 +190,16 @@ func (f *Forward) GetNamespace() string {
// LoadConfig loads and parses the configuration file from the given path.
func LoadConfig(path string) (*Config, error) {
// Validate file size before reading
fileInfo, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("failed to stat config file: %w", err)
}
if fileInfo.Size() > maxConfigSize {
return nil, fmt.Errorf("config file too large: %d bytes (max %d)", fileInfo.Size(), maxConfigSize)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
@@ -122,38 +242,3 @@ func (c *Config) GetAllForwards() []Forward {
return forwards
}
// GetForwardsByContext returns all forwards for a specific context.
func (c *Config) GetForwardsByContext(contextName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
forwards = append(forwards, ns.Forwards...)
}
break
}
}
return forwards
}
// GetForwardsByNamespace returns all forwards for a specific context and namespace.
func (c *Config) GetForwardsByNamespace(contextName, namespaceName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
if ns.Name == namespaceName {
forwards = append(forwards, ns.Forwards...)
break
}
}
break
}
}
return forwards
}
+1 -67
View File
@@ -97,7 +97,7 @@ func TestLoadConfig_FileNotFound(t *testing.T) {
cfg, err := LoadConfig("/non/existent/path/.kportal.yaml")
assert.Error(t, err, "LoadConfig should fail with non-existent file")
assert.Nil(t, cfg, "config should be nil on error")
assert.Contains(t, err.Error(), "failed to read config file", "error should mention read failure")
assert.Contains(t, err.Error(), "failed to stat config file", "error should mention stat failure")
}
func TestForward_ID(t *testing.T) {
@@ -298,72 +298,6 @@ func TestConfig_GetAllForwards(t *testing.T) {
assert.Len(t, forwards, 4, "should return all forwards from all contexts and namespaces")
}
func TestConfig_GetForwardsByContext(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: cluster2
namespaces:
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByContext("cluster1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1")
forwards2 := cfg.GetForwardsByContext("cluster2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster2")
forwards3 := cfg.GetForwardsByContext("non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent context")
}
func TestConfig_GetForwardsByNamespace(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByNamespace("cluster1", "ns1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1/ns1")
forwards2 := cfg.GetForwardsByNamespace("cluster1", "ns2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster1/ns2")
forwards3 := cfg.GetForwardsByNamespace("cluster1", "non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent namespace")
}
func TestForward_SetContext(t *testing.T) {
fwd := Forward{
Resource: "pod/my-app",
+273
View File
@@ -0,0 +1,273 @@
package config
import (
"fmt"
"os"
"path/filepath"
"sync"
"gopkg.in/yaml.v3"
)
// Mutator provides safe, atomic mutations to the kportal configuration file.
// All operations use atomic file writes (write to temp, then rename) to prevent
// corruption and ensure the file watcher picks up changes.
type Mutator struct {
configPath string
mu sync.Mutex // Ensure only one mutation at a time
}
// NewMutator creates a new configuration mutator for the given config file path.
func NewMutator(configPath string) *Mutator {
return &Mutator{
configPath: configPath,
}
}
// findOrCreateContext finds an existing context or creates a new one
func (m *Mutator) findOrCreateContext(cfg *Config, contextName string) *Context {
for i := range cfg.Contexts {
if cfg.Contexts[i].Name == contextName {
return &cfg.Contexts[i]
}
}
// Create new context
cfg.Contexts = append(cfg.Contexts, Context{
Name: contextName,
Namespaces: []Namespace{},
})
return &cfg.Contexts[len(cfg.Contexts)-1]
}
// findOrCreateNamespace finds an existing namespace or creates a new one
func (m *Mutator) findOrCreateNamespace(ctx *Context, namespaceName string) *Namespace {
for i := range ctx.Namespaces {
if ctx.Namespaces[i].Name == namespaceName {
return &ctx.Namespaces[i]
}
}
// Create new namespace
ctx.Namespaces = append(ctx.Namespaces, Namespace{
Name: namespaceName,
Forwards: []Forward{},
})
return &ctx.Namespaces[len(ctx.Namespaces)-1]
}
// AddForward adds a new port forward to the configuration.
// If the context or namespace doesn't exist, they will be created.
// The new configuration is validated before writing.
// Returns an error if the port is already in use or validation fails.
func (m *Mutator) AddForward(contextName, namespaceName string, fwd Forward) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
// If file doesn't exist, create empty config
if os.IsNotExist(err) {
cfg = &Config{Contexts: []Context{}}
} else {
return fmt.Errorf("failed to load config: %w", err)
}
}
// Find or create context and namespace
targetContext := m.findOrCreateContext(cfg, contextName)
targetNamespace := m.findOrCreateNamespace(targetContext, namespaceName)
// Set context/namespace on the forward for validation
fwd.SetContext(contextName, namespaceName)
// Check for duplicate local port
allForwards := cfg.GetAllForwards()
for _, existing := range allForwards {
if existing.LocalPort == fwd.LocalPort {
return fmt.Errorf("port %d is already in use by %s", fwd.LocalPort, existing.String())
}
}
// Add the forward
targetNamespace.Forwards = append(targetNamespace.Forwards, fwd)
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// RemoveForwards removes forwards matching the predicate function.
// The predicate receives the context, namespace, and forward, and should return true
// to remove that forward.
// Empty namespaces and contexts are preserved (not automatically removed).
func (m *Mutator) RemoveForwards(predicate func(ctx, ns string, fwd Forward) bool) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// Iterate and filter
for i := range cfg.Contexts {
ctx := &cfg.Contexts[i]
filteredNamespaces := []Namespace{}
for j := range ctx.Namespaces {
ns := &ctx.Namespaces[j]
// Filter forwards
filtered := []Forward{}
for _, fwd := range ns.Forwards {
// CRITICAL: Set context/namespace so fwd.ID() generates correct ID
fwd.SetContext(ctx.Name, ns.Name)
if !predicate(ctx.Name, ns.Name, fwd) {
// Keep this forward
filtered = append(filtered, fwd)
}
}
ns.Forwards = filtered
// Only keep namespaces that have at least one forward
if len(ns.Forwards) > 0 {
filteredNamespaces = append(filteredNamespaces, *ns)
}
}
ctx.Namespaces = filteredNamespaces
}
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// RemoveForwardByID removes a specific forward by its ID.
func (m *Mutator) RemoveForwardByID(id string) error {
return m.RemoveForwards(func(ctx, ns string, fwd Forward) bool {
return fwd.ID() == id
})
}
// UpdateForward atomically replaces an existing forward with a new one.
// This is used for editing - it removes the old forward and adds the new one in a single transaction.
// If the old forward doesn't exist, returns an error.
// If the new forward validation fails, the operation is rolled back (old forward remains).
func (m *Mutator) UpdateForward(oldID, newContextName, newNamespaceName string, newFwd Forward) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// First, verify the old forward exists and remove it
oldForwardFound := false
for i := range cfg.Contexts {
ctx := &cfg.Contexts[i]
for j := range ctx.Namespaces {
ns := &ctx.Namespaces[j]
// Filter forwards, removing the old one
filtered := []Forward{}
for _, fwd := range ns.Forwards {
// CRITICAL: Set context/namespace so fwd.ID() generates correct ID
fwd.SetContext(ctx.Name, ns.Name)
if fwd.ID() == oldID {
oldForwardFound = true
// Skip this forward (remove it)
continue
}
// Keep this forward
filtered = append(filtered, fwd)
}
ns.Forwards = filtered
}
}
if !oldForwardFound {
return fmt.Errorf("forward with ID %s not found", oldID)
}
// Now add the new forward
// Find or create context and namespace
targetContext := m.findOrCreateContext(cfg, newContextName)
targetNamespace := m.findOrCreateNamespace(targetContext, newNamespaceName)
// Set context/namespace on the forward for validation
newFwd.SetContext(newContextName, newNamespaceName)
// Check for duplicate local port (excluding the one we just removed)
allForwards := cfg.GetAllForwards()
for _, existing := range allForwards {
if existing.LocalPort == newFwd.LocalPort && existing.ID() != oldID {
return fmt.Errorf("port %d is already in use by %s", newFwd.LocalPort, existing.String())
}
}
// Add the new forward
targetNamespace.Forwards = append(targetNamespace.Forwards, newFwd)
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// writeAtomic writes the configuration atomically to prevent corruption.
// Steps:
// 1. Marshal config to YAML
// 2. Write to temporary file (.kportal.yaml.tmp)
// 3. Atomic rename to actual config file
//
// This ensures the file watcher picks up a complete, valid file.
func (m *Mutator) writeAtomic(cfg *Config) error {
// Marshal to YAML
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create temporary file in same directory as config
dir := filepath.Dir(m.configPath)
tmpFile := filepath.Join(dir, ".kportal.yaml.tmp")
// Write to temp file
if err := os.WriteFile(tmpFile, data, 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename
if err := os.Rename(tmpFile, m.configPath); err != nil {
// Clean up temp file on failure
os.Remove(tmpFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
+20 -10
View File
@@ -6,6 +6,7 @@ import (
"path/filepath"
"github.com/fsnotify/fsnotify"
"github.com/nvm/kportal/internal/logger"
)
// ReloadCallback is called when the configuration file changes.
@@ -113,28 +114,37 @@ func (w *Watcher) handleReload() {
// Load new configuration
newCfg, err := LoadConfig(w.configPath)
if err != nil {
log.Printf("Failed to load configuration: %v", err)
log.Printf("Keeping previous configuration active")
logger.Error("Failed to load configuration during hot-reload", map[string]interface{}{
"config_path": w.configPath,
"error": err.Error(),
})
logger.Info("Keeping previous configuration active", nil)
return
}
// Validate new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(newCfg); len(errs) > 0 {
log.Printf("Configuration validation failed:")
log.Print(FormatValidationErrors(errs))
log.Printf("Keeping previous configuration active")
logger.Error("Configuration validation failed during hot-reload", map[string]interface{}{
"config_path": w.configPath,
"validation_errors": len(errs),
})
logger.Info("Keeping previous configuration active", nil)
return
}
// Call reload callback
if err := w.callback(newCfg); err != nil {
log.Printf("Failed to apply new configuration: %v", err)
log.Printf("Keeping previous configuration active")
logger.Error("Failed to apply new configuration", map[string]interface{}{
"config_path": w.configPath,
"error": err.Error(),
})
logger.Info("Keeping previous configuration active", nil)
return
}
if w.verbose {
log.Printf("Configuration reloaded successfully")
}
logger.Info("Configuration reloaded successfully", map[string]interface{}{
"config_path": w.configPath,
"forwards_count": len(newCfg.GetAllForwards()),
})
}
+137 -13
View File
@@ -9,6 +9,7 @@ import (
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/healthcheck"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
)
// StatusUpdater is an interface for updating forward status
@@ -28,23 +29,31 @@ type Manager struct {
portForwarder *k8s.PortForwarder
portChecker *PortChecker
healthChecker *healthcheck.Checker
watchdog *Watchdog
verbose bool
currentConfig *config.Config
statusUI StatusUpdater
}
// NewManager creates a new forward Manager.
func NewManager(verbose bool) *Manager {
// The health checker will be created with default settings and can be
// reconfigured via SetConfig().
func NewManager(verbose bool) (*Manager, error) {
clientPool, err := k8s.NewClientPool()
if err != nil {
log.Fatalf("Failed to create client pool: %v", err)
return nil, fmt.Errorf("failed to create client pool: %w", err)
}
resolver := k8s.NewResourceResolver(clientPool)
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
// Create health checker: check every 5 seconds with 2 second timeout
healthChecker := healthcheck.NewChecker(5*time.Second, 2*time.Second)
// Create health checker with defaults: check every 3 seconds with 2 second timeout
// Will be reconfigured when config is loaded
healthChecker := healthcheck.NewChecker(3*time.Second, 2*time.Second)
// Create watchdog with default settings: check every 30 seconds, 60 second hang threshold
// Will be reconfigured when config is loaded
watchdog := NewWatchdog(30*time.Second, 60*time.Second)
return &Manager{
workers: make(map[string]*ForwardWorker),
@@ -53,8 +62,54 @@ func NewManager(verbose bool) *Manager {
portForwarder: portForwarder,
portChecker: NewPortChecker(),
healthChecker: healthChecker,
watchdog: watchdog,
verbose: verbose,
}, nil
}
// configureHealthChecker creates a new health checker with settings from config
func (m *Manager) configureHealthChecker(cfg *config.Config) {
// Stop existing health checker
if m.healthChecker != nil {
m.healthChecker.Stop()
}
// Parse check method
methodStr := cfg.GetHealthCheckMethod()
var method healthcheck.CheckMethod
switch methodStr {
case "tcp-dial":
method = healthcheck.CheckMethodTCPDial
case "data-transfer":
method = healthcheck.CheckMethodDataTransfer
default:
method = healthcheck.CheckMethodDataTransfer
}
// Create new health checker with config settings
m.healthChecker = healthcheck.NewCheckerWithOptions(healthcheck.CheckerOptions{
Interval: cfg.GetHealthCheckIntervalOrDefault(),
Timeout: cfg.GetHealthCheckTimeoutOrDefault(),
Method: method,
MaxConnectionAge: cfg.GetMaxConnectionAge(),
MaxIdleTime: cfg.GetMaxIdleTime(),
})
// Configure TCP settings on port forwarder
tcpKeepalive := cfg.GetTCPKeepalive()
dialTimeout := cfg.GetDialTimeout()
m.portForwarder.SetTCPKeepalive(tcpKeepalive)
m.portForwarder.SetDialTimeout(dialTimeout)
logger.Info("Health checker and reliability configured", map[string]interface{}{
"interval": cfg.GetHealthCheckIntervalOrDefault().String(),
"timeout": cfg.GetHealthCheckTimeoutOrDefault().String(),
"method": methodStr,
"max_connection_age": cfg.GetMaxConnectionAge().String(),
"max_idle_time": cfg.GetMaxIdleTime().String(),
"tcp_keepalive": tcpKeepalive.String(),
"dial_timeout": dialTimeout.String(),
})
}
// SetStatusUI sets the status updater for the manager
@@ -70,6 +125,20 @@ func (m *Manager) Start(cfg *config.Config) error {
m.currentConfig = cfg
// Configure health checker with settings from config
m.configureHealthChecker(cfg)
// Start watchdog
watchdogPeriod := cfg.GetWatchdogPeriod()
m.watchdog.checkInterval = watchdogPeriod
m.watchdog.hangThreshold = watchdogPeriod * 2 // Hang threshold is 2x check interval
m.watchdog.Start()
logger.Info("Watchdog started", map[string]interface{}{
"check_interval": watchdogPeriod.String(),
"hang_threshold": (watchdogPeriod * 2).String(),
})
// Get all forwards from config
forwards := cfg.GetAllForwards()
@@ -93,7 +162,14 @@ func (m *Manager) Start(cfg *config.Config) error {
for _, fwd := range forwards {
if err := m.startWorker(fwd); err != nil {
log.Printf("Failed to start worker for %s: %v", fwd.ID(), err)
logger.Error("Failed to start worker", map[string]interface{}{
"forward_id": fwd.ID(),
"context": fwd.GetContext(),
"namespace": fwd.GetNamespace(),
"resource": fwd.Resource,
"local_port": fwd.LocalPort,
"error": err.Error(),
})
// Continue with other workers
}
}
@@ -106,8 +182,9 @@ func (m *Manager) Start(cfg *config.Config) error {
func (m *Manager) Stop() {
log.Printf("Stopping all port-forwards...")
// Stop health checker first
// Stop health checker and watchdog first
m.healthChecker.Stop()
m.watchdog.Stop()
m.workersMu.Lock()
workers := make([]*ForwardWorker, 0, len(m.workers))
@@ -146,7 +223,9 @@ func (m *Manager) Reload(newCfg *config.Config) error {
return fmt.Errorf("new configuration is nil")
}
log.Printf("Reloading configuration...")
logger.Info("Reloading configuration", map[string]interface{}{
"new_forwards_count": len(newCfg.GetAllForwards()),
})
// Get all forwards from new config
newForwards := newCfg.GetAllForwards()
@@ -258,21 +337,54 @@ func (m *Manager) startWorker(fwd config.Forward) error {
m.statusUI.AddForward(fwd.ID(), &fwd)
}
// Register with watchdog
m.watchdog.RegisterWorker(fwd.ID(), func(forwardID string) {
logger.Warn("Watchdog triggered reconnection for hung worker", map[string]interface{}{
"forward_id": forwardID,
})
// Find and trigger reconnect on hung worker
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("watchdog detected hung worker")
}
})
// Register with health checker
m.healthChecker.Register(fwd.ID(), fwd.LocalPort, func(forwardID string, status healthcheck.Status, errorMsg string) {
if m.statusUI != nil {
m.statusUI.UpdateStatus(forwardID, string(status))
// Send error separately if there is one
if status == healthcheck.StatusUnhealthy && errorMsg != "" {
if (status == healthcheck.StatusUnhealthy || status == healthcheck.StatusStale) && errorMsg != "" {
if ui, ok := m.statusUI.(interface{ SetError(id, msg string) }); ok {
ui.SetError(forwardID, errorMsg)
}
}
}
// Handle stale connections: trigger reconnection if retryOnStale is enabled
if status == healthcheck.StatusStale && m.currentConfig.GetRetryOnStale() {
logger.Info("Stale connection detected, triggering reconnection", map[string]interface{}{
"forward_id": forwardID,
"reason": errorMsg,
})
// Find and notify the worker to reconnect
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("stale connection")
}
}
})
// Create and start worker
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker)
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker, m.watchdog)
worker.Start()
// Store worker
@@ -283,6 +395,11 @@ func (m *Manager) startWorker(fwd config.Forward) error {
// stopWorker stops and removes a forward worker.
func (m *Manager) stopWorker(id string) error {
return m.stopWorkerInternal(id, true)
}
// stopWorkerInternal stops a worker with option to remove from UI or just update status
func (m *Manager) stopWorkerInternal(id string, removeFromUI bool) error {
m.workersMu.Lock()
worker, exists := m.workers[id]
if !exists {
@@ -292,11 +409,18 @@ func (m *Manager) stopWorker(id string) error {
delete(m.workers, id)
m.workersMu.Unlock()
// Unregister from health checker
// Unregister from health checker and watchdog
m.healthChecker.Unregister(id)
m.watchdog.UnregisterWorker(id)
// Note: We DON'T call Remove() here anymore - keep it in the UI
// The UI will show it as disabled instead
// Notify UI - either remove or update to disabled status
if m.statusUI != nil {
if removeFromUI {
m.statusUI.Remove(id)
} else {
m.statusUI.UpdateStatus(id, "Disabled")
}
}
// Stop the worker
worker.Stop()
@@ -346,7 +470,7 @@ func (m *Manager) getResourceForPort(forwards []config.Forward, port int) string
// DisableForward temporarily stops a forward by ID
func (m *Manager) DisableForward(id string) error {
if err := m.stopWorker(id); err != nil {
if err := m.stopWorkerInternal(id, false); err != nil {
return err
}
log.Printf("Disabled: %s", id)
+21 -13
View File
@@ -8,6 +8,19 @@ import (
"strings"
)
// isValidPID validates that a PID string contains only digits
func isValidPID(pid string) bool {
if len(pid) == 0 || len(pid) > 9 {
return false
}
for _, c := range pid {
if c < '0' || c > '9' {
return false
}
}
return true
}
// PortConflict represents a local port that is already in use.
type PortConflict struct {
Port int // The conflicting port number
@@ -93,6 +106,10 @@ func (pc *PortChecker) getProcessUsingPortUnix(port int) string {
pids := strings.Split(pidStr, "\n")
pid := pids[0]
if !isValidPID(pid) {
return "unknown"
}
// Get process name using ps
cmd = exec.Command("ps", "-p", pid, "-o", "comm=")
output, err = cmd.Output()
@@ -140,6 +157,10 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
pid := fields[len(fields)-1]
if !isValidPID(pid) {
return "unknown"
}
// Get process name using tasklist
cmd = exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
output, err = cmd.Output()
@@ -188,16 +209,3 @@ func FormatConflicts(conflicts []PortConflict) string {
return sb.String()
}
// GetPortsFromForwards extracts all local ports from a list of forward configurations.
func GetPortsFromForwards(forwards []interface{}) []int {
ports := make([]int, 0, len(forwards))
for _, fwd := range forwards {
// This function expects a generic interface to work with different forward types
// The actual implementation should use the Forward struct from config package
if f, ok := fwd.(interface{ GetLocalPort() int }); ok {
ports = append(ports, f.GetLocalPort())
}
}
return ports
}
+158
View File
@@ -0,0 +1,158 @@
package forward
import (
"context"
"sync"
"time"
"github.com/nvm/kportal/internal/logger"
)
// Watchdog monitors worker goroutines to detect hung workers
type Watchdog struct {
mu sync.RWMutex
workers map[string]*workerState // key: forward ID
checkInterval time.Duration
hangThreshold time.Duration // How long without heartbeat before considered hung
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// workerState tracks the health of a single worker
type workerState struct {
forwardID string
lastHeartbeat time.Time
heartbeatCount uint64
isHung bool
onHungCallback func(forwardID string)
}
// NewWatchdog creates a new goroutine watchdog
func NewWatchdog(checkInterval, hangThreshold time.Duration) *Watchdog {
ctx, cancel := context.WithCancel(context.Background())
return &Watchdog{
workers: make(map[string]*workerState),
checkInterval: checkInterval,
hangThreshold: hangThreshold,
ctx: ctx,
cancel: cancel,
}
}
// Start begins the watchdog monitoring loop
func (w *Watchdog) Start() {
w.wg.Add(1)
go w.monitorLoop()
}
// Stop stops the watchdog
func (w *Watchdog) Stop() {
w.cancel()
w.wg.Wait()
}
// RegisterWorker adds a worker to monitor
func (w *Watchdog) RegisterWorker(forwardID string, onHungCallback func(string)) {
w.mu.Lock()
defer w.mu.Unlock()
w.workers[forwardID] = &workerState{
forwardID: forwardID,
lastHeartbeat: time.Now(),
heartbeatCount: 0,
isHung: false,
onHungCallback: onHungCallback,
}
logger.Debug("Watchdog registered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// UnregisterWorker removes a worker from monitoring
func (w *Watchdog) UnregisterWorker(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.workers, forwardID)
logger.Debug("Watchdog unregistered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// Heartbeat records that a worker is alive and processing
// Workers should call this periodically (e.g., in their main loop)
func (w *Watchdog) Heartbeat(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
if state, exists := w.workers[forwardID]; exists {
state.lastHeartbeat = time.Now()
state.heartbeatCount++
state.isHung = false
}
}
// GetWorkerState returns the current state of a worker (for testing)
func (w *Watchdog) GetWorkerState(forwardID string) (lastHeartbeat time.Time, count uint64, exists bool) {
w.mu.RLock()
defer w.mu.RUnlock()
if state, ok := w.workers[forwardID]; ok {
return state.lastHeartbeat, state.heartbeatCount, true
}
return time.Time{}, 0, false
}
// monitorLoop periodically checks all workers
func (w *Watchdog) monitorLoop() {
defer w.wg.Done()
ticker := time.NewTicker(w.checkInterval)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
w.checkWorkers()
}
}
}
// checkWorkers checks all registered workers for hung state
func (w *Watchdog) checkWorkers() {
w.mu.Lock()
defer w.mu.Unlock()
now := time.Now()
for forwardID, state := range w.workers {
timeSinceHeartbeat := now.Sub(state.lastHeartbeat)
// Check if worker is hung
if timeSinceHeartbeat > w.hangThreshold {
if !state.isHung {
// First time detecting hung state
state.isHung = true
logger.Warn("Watchdog detected hung worker", map[string]interface{}{
"forward_id": forwardID,
"time_since_heartbeat": timeSinceHeartbeat.String(),
"hang_threshold": w.hangThreshold.String(),
"heartbeat_count": state.heartbeatCount,
})
// Trigger callback to handle hung worker (without holding lock)
if state.onHungCallback != nil {
callback := state.onHungCallback
w.mu.Unlock()
callback(forwardID)
w.mu.Lock()
}
}
}
}
}
+310
View File
@@ -0,0 +1,310 @@
package forward
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// WatchdogTestSuite contains tests for the watchdog
type WatchdogTestSuite struct {
suite.Suite
watchdog *Watchdog
}
func TestWatchdogSuite(t *testing.T) {
suite.Run(t, new(WatchdogTestSuite))
}
func (s *WatchdogTestSuite) SetupTest() {
// Create watchdog with fast intervals for testing
s.watchdog = NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
s.watchdog.Start()
}
func (s *WatchdogTestSuite) TearDownTest() {
if s.watchdog != nil {
s.watchdog.Stop()
}
}
// TestRegisterUnregister tests basic registration and unregistration
func (s *WatchdogTestSuite) TestRegisterUnregister() {
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
// Register worker
s.watchdog.RegisterWorker("test-forward", callback)
// Verify worker is registered
_, _, exists := s.watchdog.GetWorkerState("test-forward")
assert.True(s.T(), exists, "Worker should be registered")
// Unregister worker
s.watchdog.UnregisterWorker("test-forward")
// Verify worker is unregistered
_, _, exists = s.watchdog.GetWorkerState("test-forward")
assert.False(s.T(), exists, "Worker should be unregistered")
assert.False(s.T(), callbackCalled, "Callback should not have been called")
}
// TestHeartbeat tests that heartbeats update worker state
func (s *WatchdogTestSuite) TestHeartbeat() {
s.watchdog.RegisterWorker("test-forward", nil)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat1, count1, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(1), count1)
// Wait a bit
time.Sleep(50 * time.Millisecond)
// Send another heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat2, count2, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(2), count2)
assert.True(s.T(), lastHeartbeat2.After(lastHeartbeat1), "Second heartbeat should be after first")
}
// TestHungWorkerDetection tests that hung workers are detected
func (s *WatchdogTestSuite) TestHungWorkerDetection() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for worker to be considered hung (300ms threshold + 100ms check interval)
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
// TestHealthyWorkerNotDetectedAsHung tests that workers sending heartbeats are not considered hung
func (s *WatchdogTestSuite) TestHealthyWorkerNotDetectedAsHung() {
callbackCalled := false
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send periodic heartbeats (faster than hang threshold)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
done := make(chan bool)
go func() {
for i := 0; i < 10; i++ {
<-ticker.C
s.watchdog.Heartbeat("test-forward")
}
done <- true
}()
// Wait for all heartbeats to complete
<-done
// Check that callback was not called
mu.Lock()
assert.False(s.T(), callbackCalled, "Callback should not be called for healthy worker")
mu.Unlock()
}
// TestMultipleWorkers tests monitoring multiple workers simultaneously
func (s *WatchdogTestSuite) TestMultipleWorkers() {
callbacks := make(map[string]int)
var mu sync.Mutex
makeCallback := func(id string) func(string) {
return func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbacks[id]++
}
}
// Register multiple workers
s.watchdog.RegisterWorker("worker-1", makeCallback("worker-1"))
s.watchdog.RegisterWorker("worker-2", makeCallback("worker-2"))
s.watchdog.RegisterWorker("worker-3", makeCallback("worker-3"))
// worker-1: Keep sending heartbeats (healthy)
ticker1 := time.NewTicker(50 * time.Millisecond)
defer ticker1.Stop()
go func() {
for i := 0; i < 10; i++ {
<-ticker1.C
s.watchdog.Heartbeat("worker-1")
}
}()
// worker-2: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-2")
// worker-3: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-3")
// Wait for hung workers to be detected
time.Sleep(600 * time.Millisecond)
// Check results
mu.Lock()
defer mu.Unlock()
assert.Equal(s.T(), 0, callbacks["worker-1"], "worker-1 should not trigger callback (healthy)")
assert.Greater(s.T(), callbacks["worker-2"], 0, "worker-2 should trigger callback (hung)")
assert.Greater(s.T(), callbacks["worker-3"], 0, "worker-3 should trigger callback (hung)")
}
// TestCallbackOnlyOnFirstDetection tests that callback is only called once when hung is first detected
func (s *WatchdogTestSuite) TestCallbackOnlyOnFirstDetection() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for multiple check cycles
time.Sleep(1 * time.Second)
// Check that callback was only called once
mu.Lock()
assert.Equal(s.T(), 1, callbackCount, "Callback should only be called once")
mu.Unlock()
}
// TestHeartbeatResetsHungState tests that sending heartbeat after hung detection resets state
func (s *WatchdogTestSuite) TestHeartbeatResetsHungState() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for hung detection
time.Sleep(500 * time.Millisecond)
mu.Lock()
firstCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 1, firstCount, "First hung detection should trigger callback")
// Send heartbeat to reset hung state
s.watchdog.Heartbeat("test-forward")
// Wait for worker to become hung again
time.Sleep(500 * time.Millisecond)
mu.Lock()
secondCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 2, secondCount, "Second hung detection should trigger callback again")
}
// TestConcurrentOperations tests thread safety
func (s *WatchdogTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numWorkers := 10
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := string(rune('a' + id))
s.watchdog.RegisterWorker(forwardID, nil)
for j := 0; j < 10; j++ {
s.watchdog.Heartbeat(forwardID)
time.Sleep(10 * time.Millisecond)
}
s.watchdog.UnregisterWorker(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestStopWatchdog tests that stopping watchdog cleans up properly
func TestStopWatchdog(t *testing.T) {
watchdog := NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
watchdog.Start()
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
watchdog.RegisterWorker("test-forward", callback)
watchdog.Heartbeat("test-forward")
// Stop watchdog before hang detection
time.Sleep(100 * time.Millisecond)
watchdog.Stop()
// Wait to ensure no more callbacks after stop
time.Sleep(500 * time.Millisecond)
assert.False(t, callbackCalled, "Callback should not be called after watchdog is stopped")
}
// TestWatchdogWithZeroHeartbeats tests detecting hung worker that never sends heartbeats
func (s *WatchdogTestSuite) TestWatchdogWithZeroHeartbeats() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
// Register worker but never send heartbeat
s.watchdog.RegisterWorker("test-forward", callback)
// Wait for hung detection
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
+114 -19
View File
@@ -5,31 +5,41 @@ import (
"fmt"
"io"
"log"
"sync"
"time"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/healthcheck"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/retry"
)
const (
portForwardReadyTimeout = 30 * time.Second
)
// ForwardWorker manages a single port-forward connection with automatic retry.
type ForwardWorker struct {
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
startTime time.Time // Track when the worker started
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
reconnectChan chan string // Channel to trigger reconnection
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
watchdog *Watchdog
startTime time.Time // Track when the worker started
forwardCancel context.CancelFunc // Cancel function for current forward attempt
forwardCancelMu sync.Mutex // Protects forwardCancel
}
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker) *ForwardWorker {
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker, watchdog *Watchdog) *ForwardWorker {
ctx, cancel := context.WithCancel(context.Background())
return &ForwardWorker{
@@ -39,13 +49,32 @@ func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verb
cancel: cancel,
stopChan: make(chan struct{}),
doneChan: make(chan struct{}),
reconnectChan: make(chan string, 1), // Buffered to avoid blocking
verbose: verbose,
statusUI: statusUI,
healthChecker: healthChecker,
watchdog: watchdog,
startTime: time.Now(),
}
}
// TriggerReconnect triggers a reconnection (e.g., due to stale connection)
func (w *ForwardWorker) TriggerReconnect(reason string) {
// Cancel current forward if running
w.forwardCancelMu.Lock()
if w.forwardCancel != nil {
w.forwardCancel()
}
w.forwardCancelMu.Unlock()
// Send reconnect signal (non-blocking)
select {
case w.reconnectChan <- reason:
default:
// Channel already has pending reconnect
}
}
// Start begins the port-forward worker in a goroutine.
// The worker will continuously retry on failures with exponential backoff.
func (w *ForwardWorker) Start() {
@@ -63,6 +92,12 @@ func (w *ForwardWorker) Stop() {
func (w *ForwardWorker) run() {
defer close(w.doneChan)
// Start heartbeat goroutine to continuously send heartbeats to watchdog
// This prevents false "hung worker" detection when connections are long-lived
if w.watchdog != nil {
go w.heartbeatLoop()
}
backoff := retry.NewBackoff()
for {
@@ -86,7 +121,13 @@ func (w *ForwardWorker) run() {
)
if err != nil {
log.Printf("[%s] Failed to resolve resource: %v", w.forward.ID(), err)
logger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": w.forward.ID(),
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
"resource": w.forward.Resource,
"error": err.Error(),
})
w.sleepWithBackoff(backoff)
continue
}
@@ -96,10 +137,20 @@ func (w *ForwardWorker) run() {
if w.healthChecker != nil {
w.healthChecker.MarkReconnecting(w.forward.ID())
}
log.Printf("[%s] Switched to new pod: %s → %s", w.forward.ID(), w.lastPod, podName)
logger.Info("Pod restart detected, switching to new pod", map[string]interface{}{
"forward_id": w.forward.ID(),
"old_pod": w.lastPod,
"new_pod": podName,
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
})
} else if w.lastPod == "" {
log.Printf("[%s] Forwarding %s → localhost:%d",
w.forward.ID(), w.forward.String(), w.forward.LocalPort)
logger.Info("Starting port forward", map[string]interface{}{
"forward_id": w.forward.ID(),
"target": w.forward.String(),
"local_port": w.forward.LocalPort,
"pod": podName,
})
if w.healthChecker != nil {
w.healthChecker.MarkStarting(w.forward.ID())
}
@@ -123,7 +174,14 @@ func (w *ForwardWorker) run() {
}
// Log the error
log.Printf("[%s] Port-forward connection failed: %v", w.forward.ID(), err)
logger.Warn("Port-forward connection failed, will retry", map[string]interface{}{
"forward_id": w.forward.ID(),
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
"resource": w.forward.Resource,
"local_port": w.forward.LocalPort,
"error": err.Error(),
})
// Clear last pod so we re-resolve on next attempt
w.lastPod = ""
@@ -145,6 +203,26 @@ func (w *ForwardWorker) run() {
}
}
// heartbeatLoop sends periodic heartbeats to the watchdog to prove the worker is alive
// This runs in a separate goroutine and continues throughout the worker's lifetime
func (w *ForwardWorker) heartbeatLoop() {
// Send heartbeats every 15 seconds (well within typical 60s watchdog timeout)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
// Send immediate heartbeat
w.watchdog.Heartbeat(w.forward.ID())
for {
select {
case <-ticker.C:
w.watchdog.Heartbeat(w.forward.ID())
case <-w.ctx.Done():
return
}
}
}
// establishForward establishes a port-forward connection.
// This blocks until the connection is closed or an error occurs.
func (w *ForwardWorker) establishForward(podName string) error {
@@ -156,11 +234,24 @@ func (w *ForwardWorker) establishForward(podName string) error {
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
defer forwardCancel()
// Start a goroutine to monitor for stop signal
// Store cancel function so TriggerReconnect can use it
w.forwardCancelMu.Lock()
w.forwardCancel = forwardCancel
w.forwardCancelMu.Unlock()
defer func() {
w.forwardCancelMu.Lock()
w.forwardCancel = nil
w.forwardCancelMu.Unlock()
}()
// Start a goroutine to monitor for stop signal and reconnect triggers
go func() {
select {
case <-w.stopChan:
close(stopChan)
case <-w.reconnectChan:
close(stopChan)
case <-forwardCtx.Done():
close(stopChan)
}
@@ -202,11 +293,15 @@ func (w *ForwardWorker) establishForward(podName string) error {
if w.verbose {
log.Printf("[%s] Port-forward connection established", w.forward.ID())
}
// Mark connection as established in health checker
if w.healthChecker != nil {
w.healthChecker.MarkConnected(w.forward.ID())
}
case err := <-errChan:
return fmt.Errorf("failed to establish forward: %w", err)
case <-w.ctx.Done():
return nil
case <-time.After(30 * time.Second):
case <-time.After(portForwardReadyTimeout):
return fmt.Errorf("timeout waiting for port-forward to become ready")
}
+286
View File
@@ -0,0 +1,286 @@
package forward
import (
"testing"
"github.com/nvm/kportal/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogWriter_Write(t *testing.T) {
tests := []struct {
name string
prefix string
input string
expectedInLog string
description string
}{
{
name: "write simple message",
prefix: "[worker] ",
input: "test message",
expectedInLog: "[worker] test message",
description: "Should write message with prefix to log",
},
{
name: "write empty message",
prefix: "[test] ",
input: "",
expectedInLog: "[test] ",
description: "Should handle empty message",
},
{
name: "write multiline message",
prefix: "[fwd] ",
input: "line1\nline2",
expectedInLog: "[fwd] line1\nline2",
description: "Should handle multiline messages",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test logWriter
originalWriter := &logWriter{prefix: tt.prefix}
n, err := originalWriter.Write([]byte(tt.input))
require.NoError(t, err, "Write should not return error")
assert.Equal(t, len(tt.input), n, "Write should return number of bytes written")
})
}
}
func TestForwardWorker_GetForward(t *testing.T) {
tests := []struct {
name string
forward config.Forward
description string
}{
{
name: "get pod forward",
forward: config.Forward{
Resource: "pod/my-app",
LocalPort: 8080,
Port: 80,
Protocol: "tcp",
},
description: "Should return the forward configuration",
},
{
name: "get service forward",
forward: config.Forward{
Resource: "service/postgres",
LocalPort: 5432,
Port: 5432,
Protocol: "tcp",
},
description: "Should return service forward configuration",
},
{
name: "get forward with selector",
forward: config.Forward{
Resource: "pod",
Selector: "app=nginx,env=prod",
LocalPort: 8080,
Port: 80,
Protocol: "tcp",
},
description: "Should return forward with label selector",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: We can't easily test the full worker lifecycle without mocks,
// but we can test the constructor and simple getters
// This test would require proper mocking setup
// For now, we'll test the Forward struct directly
id := tt.forward.ID()
assert.NotEmpty(t, id, "Forward should have an ID")
forwardStr := tt.forward.String()
assert.NotEmpty(t, forwardStr, "Forward should have a string representation")
assert.Contains(t, forwardStr, tt.forward.Resource, "String should contain resource")
})
}
}
func TestForwardWorker_IsRunning(t *testing.T) {
// This is a basic test of the goroutine state tracking
// Full integration tests would require mock dependencies
t.Run("worker state tracking", func(t *testing.T) {
// Test the concept of the done channel
doneChan := make(chan struct{})
// Initially, channel is open (worker would be running)
select {
case <-doneChan:
t.Fatal("doneChan should be open initially")
default:
// Expected: channel is open
}
// Close the channel (simulating worker done)
close(doneChan)
// Now channel should be closed
select {
case <-doneChan:
// Expected: channel is closed
default:
t.Fatal("doneChan should be closed after close")
}
})
}
func TestForwardID(t *testing.T) {
tests := []struct {
name string
forward config.Forward
expectUnique bool
description string
}{
{
name: "unique IDs for different forwards",
forward: config.Forward{
Resource: "pod/app1",
LocalPort: 8080,
Port: 80,
},
expectUnique: true,
description: "Different forwards should have different IDs",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id1 := tt.forward.ID()
// Create a different forward
fwd2 := config.Forward{
Resource: "pod/app2",
LocalPort: 8081,
Port: 80,
}
id2 := fwd2.ID()
if tt.expectUnique {
assert.NotEqual(t, id1, id2, "Different forwards should have different IDs")
}
// Same forward should produce same ID
id3 := tt.forward.ID()
assert.Equal(t, id1, id3, "Same forward should produce same ID")
})
}
}
func TestForwardString(t *testing.T) {
tests := []struct {
name string
forward config.Forward
expectedContains []string
description string
}{
{
name: "pod forward string",
forward: config.Forward{
Resource: "pod/my-app",
LocalPort: 8080,
Port: 80,
},
expectedContains: []string{"pod/my-app", "8080", "80"},
description: "Should contain resource and ports",
},
{
name: "service forward string",
forward: config.Forward{
Resource: "service/postgres",
LocalPort: 5432,
Port: 5432,
},
expectedContains: []string{"service/postgres", "5432"},
description: "Should contain service and port",
},
{
name: "selector forward string",
forward: config.Forward{
Resource: "pod",
Selector: "app=nginx",
LocalPort: 8080,
Port: 80,
},
expectedContains: []string{"app=nginx", "8080", "80"},
description: "Should contain selector and ports",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.forward.String()
assert.NotEmpty(t, result, "String representation should not be empty")
for _, expected := range tt.expectedContains {
assert.Contains(t, result, expected,
"String should contain %s", expected)
}
})
}
}
func TestSleepWithBackoffConcept(t *testing.T) {
// Test the backoff concept without actually running a worker
t.Run("backoff delay increases", func(t *testing.T) {
// This tests the retry backoff behavior conceptually
delays := []int{1, 2, 4, 8, 10, 10, 10}
for i, expected := range delays {
// Simulate backoff calculation
delay := 1
for j := 0; j < i; j++ {
delay *= 2
if delay > 10 {
delay = 10
}
}
assert.Equal(t, expected, delay,
"Backoff at attempt %d should be %d", i, expected)
}
})
}
func TestWorkerVerboseMode(t *testing.T) {
tests := []struct {
name string
verbose bool
description string
}{
{
name: "verbose mode enabled",
verbose: true,
description: "Worker should respect verbose flag",
},
{
name: "verbose mode disabled",
verbose: false,
description: "Worker should respect non-verbose flag",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that verbose flag is a boolean
assert.IsType(t, bool(true), tt.verbose)
// In a real worker, this would control logging
// For now, we just verify the type
})
}
}
+203 -51
View File
@@ -3,11 +3,17 @@ package healthcheck
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
)
const (
startupGracePeriod = 10 * time.Second
dataTransferSize = 1024 // bytes to read in data transfer test
)
// Status represents the health status of a port forward
type Status string
@@ -16,15 +22,26 @@ const (
StatusUnhealthy Status = "Error"
StatusStarting Status = "Starting"
StatusReconnect Status = "Reconnecting"
StatusStale Status = "Stale" // Connection is old or idle
)
// CheckMethod represents the health check method
type CheckMethod string
const (
CheckMethodTCPDial CheckMethod = "tcp-dial" // Simple TCP connection test
CheckMethodDataTransfer CheckMethod = "data-transfer" // Try to read data from connection
)
// PortHealth represents the health status of a single port
type PortHealth struct {
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
ConnectionTime time.Time // When current connection was established
LastActivity time.Time // Last time data was transferred
}
// StatusCallback is called when a port's health status changes
@@ -32,26 +49,52 @@ type StatusCallback func(forwardID string, status Status, errorMsg string)
// Checker performs periodic health checks on local ports
type Checker struct {
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
method CheckMethod
maxConnectionAge time.Duration
maxIdleTime time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewChecker creates a new health checker
// CheckerOptions configures the health checker
type CheckerOptions struct {
Interval time.Duration
Timeout time.Duration
Method CheckMethod
MaxConnectionAge time.Duration
MaxIdleTime time.Duration
}
// NewChecker creates a new health checker with default options
func NewChecker(interval, timeout time.Duration) *Checker {
return NewCheckerWithOptions(CheckerOptions{
Interval: interval,
Timeout: timeout,
Method: CheckMethodDataTransfer,
MaxConnectionAge: 25 * time.Minute,
MaxIdleTime: 10 * time.Minute,
})
}
// NewCheckerWithOptions creates a new health checker with custom options
func NewCheckerWithOptions(opts CheckerOptions) *Checker {
ctx, cancel := context.WithCancel(context.Background())
return &Checker{
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: interval,
timeout: timeout,
ctx: ctx,
cancel: cancel,
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: opts.Interval,
timeout: opts.Timeout,
method: opts.Method,
maxConnectionAge: opts.MaxConnectionAge,
maxIdleTime: opts.MaxIdleTime,
ctx: ctx,
cancel: cancel,
}
}
@@ -60,11 +103,14 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
c.ports[forwardID] = &PortHealth{
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: time.Now(),
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: now,
ConnectionTime: now,
LastActivity: now,
}
c.callbacks[forwardID] = callback
@@ -73,6 +119,28 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
go c.checkLoop(forwardID)
}
// MarkConnected marks a forward as having established a new connection
func (c *Checker) MarkConnected(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
now := time.Now()
health.ConnectionTime = now
health.LastActivity = now
}
}
// RecordActivity records data transfer activity for a forward
func (c *Checker) RecordActivity(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
health.LastActivity = time.Now()
}
}
// Unregister removes a port from monitoring
func (c *Checker) Unregister(forwardID string) {
c.mu.Lock()
@@ -85,39 +153,41 @@ func (c *Checker) Unregister(forwardID string) {
// MarkReconnecting marks a forward as reconnecting (called by worker)
func (c *Checker) MarkReconnecting(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusReconnect
health.LastCheck = time.Now()
// Notify if status changed
c.mu.Unlock()
if oldStatus != StatusReconnect {
c.mu.Unlock()
c.notifyStatusChange(forwardID, StatusReconnect, "")
c.mu.Lock()
}
return
}
c.mu.Unlock()
}
// MarkStarting marks a forward as starting (called by worker)
func (c *Checker) MarkStarting(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusStarting
health.LastCheck = time.Now()
// Notify if status changed
c.mu.Unlock()
if oldStatus != StatusStarting {
c.mu.Unlock()
c.notifyStatusChange(forwardID, StatusStarting, "")
c.mu.Lock()
}
return
}
c.mu.Unlock()
}
// GetStatus returns the current health status of a forward
@@ -191,38 +261,64 @@ func (c *Checker) checkPort(forwardID string) {
port := health.Port
oldStatus := health.Status
registeredAt := health.RegisteredAt
connectionTime := health.ConnectionTime
lastActivity := health.LastActivity
c.mu.RUnlock()
// Attempt to connect to the local port
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
now := time.Now()
newStatus := StatusHealthy
errorMsg := ""
if err != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := time.Since(registeredAt)
if timeSinceStart < 10*time.Second {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = err.Error()
// Check for stale connections based on age or idle time
connectionAge := now.Sub(connectionTime)
idleTime := now.Sub(lastActivity)
// Only enforce max connection age if the connection is ALSO idle
// This prevents interrupting active transfers (e.g., database dumps)
if c.maxConnectionAge > 0 && connectionAge > c.maxConnectionAge && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("connection age %v exceeds max %v (and idle for %v)",
connectionAge.Round(time.Second), c.maxConnectionAge, idleTime.Round(time.Second))
} else if c.maxIdleTime > 0 && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("idle time %v exceeds max %v", idleTime.Round(time.Second), c.maxIdleTime)
} else {
conn.Close()
// Perform connectivity check
var checkErr error
switch c.method {
case CheckMethodDataTransfer:
checkErr = c.checkDataTransfer(port)
case CheckMethodTCPDial:
checkErr = c.checkTCPDial(port)
default:
checkErr = c.checkTCPDial(port)
}
if checkErr != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := now.Sub(registeredAt)
if timeSinceStart < startupGracePeriod {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = checkErr.Error()
}
}
// Update health status
c.mu.Lock()
if health, exists := c.ports[forwardID]; exists {
health.Status = newStatus
health.LastCheck = time.Now()
health.LastCheck = now
health.ErrorMessage = errorMsg
// Successful health check indicates connection is active
// This prevents false positives where healthy connections are marked as idle
if newStatus == StatusHealthy {
health.LastActivity = now
}
}
c.mu.Unlock()
@@ -232,6 +328,62 @@ func (c *Checker) checkPort(forwardID string) {
}
}
// checkTCPDial performs a simple TCP dial test
func (c *Checker) checkTCPDial(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
conn.Close()
return nil
}
// checkDataTransfer attempts to read data from the connection to verify tunnel health
func (c *Checker) checkDataTransfer(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
defer conn.Close()
// Set a short read deadline to detect hung connections
// We don't expect to receive data, but we want to verify the connection isn't hung
conn.SetReadDeadline(time.Now().Add(c.timeout))
// Try to read a small amount of data
// Most servers will either:
// 1. Send a banner (SSH, FTP, etc) - we'll read it successfully
// 2. Wait for client to send first (HTTP, postgres) - we'll timeout (which is OK)
// 3. Hung/stale connection - will timeout with different error
buf := make([]byte, dataTransferSize)
_, err = conn.Read(buf)
// We expect either:
// - No error (banner received)
// - EOF (connection closed by server after connect)
// - Timeout (server waiting for client)
// All of these indicate the tunnel is working
if err == nil || err == io.EOF {
return nil
}
// Timeout is acceptable - server is waiting for us to send data first
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil
}
// Other errors indicate a problem
return fmt.Errorf("data transfer check failed: %w", err)
}
// notifyStatusChange calls the callback for a forward
func (c *Checker) notifyStatusChange(forwardID string, status Status, errorMsg string) {
c.mu.RLock()
+551
View File
@@ -0,0 +1,551 @@
package healthcheck
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// HealthCheckTestSuite contains tests for the health checker
type HealthCheckTestSuite struct {
suite.Suite
checker *Checker
listener net.Listener
port int
}
func TestHealthCheckSuite(t *testing.T) {
suite.Run(t, new(HealthCheckTestSuite))
}
func (s *HealthCheckTestSuite) SetupTest() {
// Create a test listener on a random port
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
s.listener = ln
s.port = ln.Addr().(*net.TCPAddr).Port
// Create checker with fast intervals for testing
s.checker = NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 500 * time.Millisecond,
MaxIdleTime: 300 * time.Millisecond,
})
}
func (s *HealthCheckTestSuite) TearDownTest() {
if s.checker != nil {
s.checker.Stop()
}
if s.listener != nil {
s.listener.Close()
}
}
// TestRegisterAndUnregister tests basic registration and unregistration
func (s *HealthCheckTestSuite) TestRegisterAndUnregister() {
callbackCalled := false
var callbackStatus Status
var mu sync.Mutex
callback := func(forwardID string, status Status, errorMsg string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
callbackStatus = status
}
// Register port
s.checker.Register("test-forward", s.port, callback)
// Wait for health check to run
time.Sleep(200 * time.Millisecond)
// Verify callback was called with healthy status
mu.Lock()
assert.True(s.T(), callbackCalled, "Callback should have been called")
assert.Equal(s.T(), StatusHealthy, callbackStatus)
mu.Unlock()
// Unregister
s.checker.Unregister("test-forward")
// Verify port is no longer monitored
status, exists := s.checker.GetStatus("test-forward")
assert.False(s.T(), exists, "Port should no longer exist after unregister")
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestTCPDialMethod tests the TCP dial health check method
func (s *HealthCheckTestSuite) TestTCPDialMethod() {
tests := []struct {
name string
setupPort bool
expectedStatus Status
description string
}{
{
name: "port available - healthy",
setupPort: true,
expectedStatus: StatusHealthy,
description: "When port is listening, status should be healthy",
},
{
name: "port unavailable - unhealthy",
setupPort: false,
expectedStatus: StatusUnhealthy,
description: "When port is not listening, status should be unhealthy",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
if tt.setupPort {
// Use the existing listener
testPort = s.port
} else {
// Use a port that's not listening
testPort = 54321 // Likely unused port
}
// Create a new checker for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if !tt.setupPort {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(200 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status, tt.description)
if testListener != nil {
testListener.Close()
}
})
}
}
// TestDataTransferMethod tests the data transfer health check method
func (s *HealthCheckTestSuite) TestDataTransferMethod() {
tests := []struct {
name string
serverBehavior string // "banner", "silent", "close", "none"
expectedStatus Status
}{
{
name: "server sends banner - healthy",
serverBehavior: "banner",
expectedStatus: StatusHealthy,
},
{
name: "server waits silently - healthy (timeout OK)",
serverBehavior: "silent",
expectedStatus: StatusHealthy,
},
{
name: "server closes connection - healthy (EOF OK)",
serverBehavior: "close",
expectedStatus: StatusHealthy,
},
{
name: "no server listening - unhealthy",
serverBehavior: "none",
expectedStatus: StatusUnhealthy,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
var err error
if tt.serverBehavior != "none" {
// Start test server
testListener, err = net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
testPort = testListener.Addr().(*net.TCPAddr).Port
// Handle connections based on behavior
go func() {
for {
conn, err := testListener.Accept()
if err != nil {
return
}
switch tt.serverBehavior {
case "banner":
conn.Write([]byte("220 Welcome\r\n"))
time.Sleep(50 * time.Millisecond)
conn.Close()
case "close":
conn.Close()
case "silent":
// Just keep connection open
time.Sleep(200 * time.Millisecond)
conn.Close()
}
}
}()
defer testListener.Close()
} else {
testPort = 54322 // Unused port
}
// Create checker with data transfer method
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodDataTransfer,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if tt.serverBehavior == "none" {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(300 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status)
})
}
}
// TestConnectionAgeDetection tests max connection age detection
func (s *HealthCheckTestSuite) TestConnectionAgeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max connection age
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 150 * time.Millisecond, // Very short for testing
MaxIdleTime: 0, // Disable idle detection
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait for initial healthy status
var gotHealthy, gotStale bool
timeout := time.After(1 * time.Second)
for {
select {
case status := <-statusChanges:
if status == StatusHealthy || status == StatusStarting {
gotHealthy = true
}
if status == StatusStale {
gotStale = true
}
if gotHealthy && gotStale {
return // Test passed
}
case <-timeout:
s.T().Fatalf("Expected StatusStale after max connection age exceeded. gotHealthy=%v, gotStale=%v",
gotHealthy, gotStale)
}
}
}
// TestIdleTimeDetection tests that connections with passing health checks are NOT marked as stale
// This verifies that successful health checks update LastActivity, preventing false idle detection
func (s *HealthCheckTestSuite) TestIdleTimeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max idle time
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable age detection
MaxIdleTime: 150 * time.Millisecond, // Very short for testing
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait long enough that idle time WOULD be exceeded if health checks didn't update LastActivity
time.Sleep(500 * time.Millisecond)
// Verify connection is still healthy, not stale
// This proves that successful health checks are updating LastActivity
status, exists := checker.GetStatus("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), StatusHealthy, status, "Connection with passing health checks should NOT be marked as stale")
// Verify we never received a StatusStale callback
select {
case status := <-statusChanges:
if status == StatusStale {
s.T().Fatal("Connection should NOT be marked as stale when health checks are passing")
}
default:
// No stale status - this is correct
}
}
// TestMarkConnected tests that MarkConnected resets connection time
func (s *HealthCheckTestSuite) TestMarkConnected() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 200 * time.Millisecond,
MaxIdleTime: 0,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Mark as reconnected (resets connection time)
checker.MarkConnected("test-forward")
// Wait for connection age to exceed (relative to first connection time)
time.Sleep(200 * time.Millisecond)
// Check status - should still be healthy because we reset connection time
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Note: Might be StatusStale by now, but the key is that MarkConnected delayed it
// This is a timing-sensitive test, so we just verify the functionality exists
_ = status
}
// TestRecordActivity tests that RecordActivity resets idle time
func (s *HealthCheckTestSuite) TestRecordActivity() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 200 * time.Millisecond,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Periodically record activity to prevent idle detection
ticker := time.NewTicker(80 * time.Millisecond)
defer ticker.Stop()
go func() {
for i := 0; i < 5; i++ {
<-ticker.C
checker.RecordActivity("test-forward")
}
}()
// Wait longer than idle timeout
time.Sleep(500 * time.Millisecond)
// Should still be healthy due to activity
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// May transition to stale eventually, but activity recording should have delayed it
_ = status
}
// TestMarkReconnecting tests the MarkReconnecting functionality
func (s *HealthCheckTestSuite) TestMarkReconnecting() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
s.checker.Register("test-forward", s.port, callback)
// Wait for initial status
time.Sleep(150 * time.Millisecond)
// Mark as reconnecting
s.checker.MarkReconnecting("test-forward")
// Should receive reconnecting status
timeout := time.After(500 * time.Millisecond)
gotReconnect := false
for !gotReconnect {
select {
case status := <-statusChanges:
if status == StatusReconnect {
gotReconnect = true
}
case <-timeout:
s.T().Fatal("Expected StatusReconnect")
}
}
}
// TestStartingGracePeriod tests that errors during grace period show as "Starting"
func (s *HealthCheckTestSuite) TestStartingGracePeriod() {
// Use a port that's not listening
unavailablePort := 54323
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register without callback - we'll check status directly
checker.Register("test-forward", unavailablePort, nil)
// Immediately check status - should be Starting or not yet checked
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Initially should be Starting
assert.Equal(s.T(), StatusStarting, status)
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 200*time.Millisecond)
// Now should be Unhealthy
status, exists = checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestGetAllErrors tests retrieving all error messages
func (s *HealthCheckTestSuite) TestGetAllErrors() {
// Create a new checker with faster intervals for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register multiple forwards
checker.Register("forward1", s.port, nil)
checker.Register("forward2", 54324, nil) // Unavailable port
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 300*time.Millisecond)
errors := checker.GetAllErrors()
// forward2 should have an error
_, hasError := errors["forward2"]
assert.True(s.T(), hasError, "forward2 should have an error")
// forward1 should not have an error
_, hasError = errors["forward1"]
assert.False(s.T(), hasError, "forward1 should not have an error")
}
// TestConcurrentOperations tests thread safety
func (s *HealthCheckTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := fmt.Sprintf("forward-%d", id)
s.checker.Register(forwardID, s.port, nil)
time.Sleep(50 * time.Millisecond)
s.checker.MarkConnected(forwardID)
s.checker.RecordActivity(forwardID)
status, _ := s.checker.GetStatus(forwardID)
_ = status
s.checker.Unregister(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestDefaultOptions tests that NewChecker uses sensible defaults
func TestDefaultOptions(t *testing.T) {
checker := NewChecker(5*time.Second, 2*time.Second)
defer checker.Stop()
assert.Equal(t, 5*time.Second, checker.interval)
assert.Equal(t, 2*time.Second, checker.timeout)
assert.Equal(t, CheckMethodDataTransfer, checker.method)
assert.Equal(t, 25*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 10*time.Minute, checker.maxIdleTime)
}
// TestCustomOptions tests NewCheckerWithOptions
func TestCustomOptions(t *testing.T) {
opts := CheckerOptions{
Interval: 1 * time.Second,
Timeout: 500 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 5 * time.Minute,
MaxIdleTime: 2 * time.Minute,
}
checker := NewCheckerWithOptions(opts)
defer checker.Stop()
assert.Equal(t, 1*time.Second, checker.interval)
assert.Equal(t, 500*time.Millisecond, checker.timeout)
assert.Equal(t, CheckMethodTCPDial, checker.method)
assert.Equal(t, 5*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 2*time.Minute, checker.maxIdleTime)
}
+306
View File
@@ -0,0 +1,306 @@
package k8s
import (
"context"
"fmt"
"net"
"sort"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// Discovery provides cluster introspection capabilities for the UI wizards.
// It queries the Kubernetes API to list contexts, namespaces, pods, and services.
type Discovery struct {
pool *ClientPool
}
// NewDiscovery creates a new Discovery instance using the provided client pool.
func NewDiscovery(pool *ClientPool) *Discovery {
return &Discovery{
pool: pool,
}
}
// PodInfo contains information about a pod relevant for port forwarding.
type PodInfo struct {
Name string
Namespace string
Containers []ContainerInfo
Status string
Created metav1.Time
}
// ContainerInfo contains information about a container within a pod.
type ContainerInfo struct {
Name string
Ports []PortInfo
}
// PortInfo describes a port exposed by a container or service.
type PortInfo struct {
Name string
Port int32
Protocol string
}
// ServiceInfo contains information about a service.
type ServiceInfo struct {
Name string
Namespace string
Ports []PortInfo
Type string
}
// ListContexts returns all available Kubernetes contexts from kubeconfig.
func (d *Discovery) ListContexts() ([]string, error) {
return d.pool.ListContexts()
}
// GetCurrentContext returns the name of the current context from kubeconfig.
func (d *Discovery) GetCurrentContext() (string, error) {
return d.pool.GetCurrentContext()
}
// ListNamespaces returns all namespaces in the given context.
// Returns an error if the context is invalid or unreachable.
func (d *Discovery) ListNamespaces(ctx context.Context, contextName string) ([]string, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
nsList, err := client.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list namespaces: %w", err)
}
namespaces := make([]string, 0, len(nsList.Items))
for _, ns := range nsList.Items {
namespaces = append(namespaces, ns.Name)
}
// Sort alphabetically
sort.Strings(namespaces)
return namespaces, nil
}
// ListPods returns all running pods in the given namespace with their port information.
// Only returns pods in Running or Pending state.
func (d *Discovery) ListPods(ctx context.Context, contextName, namespace string) ([]PodInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
podList, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
pods := make([]PodInfo, 0)
for _, pod := range podList.Items {
// Only include Running or Pending pods
if pod.Status.Phase != corev1.PodRunning && pod.Status.Phase != corev1.PodPending {
continue
}
containers := make([]ContainerInfo, 0, len(pod.Spec.Containers))
for _, container := range pod.Spec.Containers {
ports := make([]PortInfo, 0, len(container.Ports))
for _, port := range container.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.ContainerPort,
Protocol: string(port.Protocol),
})
}
containers = append(containers, ContainerInfo{
Name: container.Name,
Ports: ports,
})
}
pods = append(pods, PodInfo{
Name: pod.Name,
Namespace: pod.Namespace,
Containers: containers,
Status: string(pod.Status.Phase),
Created: pod.CreationTimestamp,
})
}
// Sort by creation time (newest first)
sort.Slice(pods, func(i, j int) bool {
return pods[i].Created.After(pods[j].Created.Time)
})
return pods, nil
}
// ListPodsWithSelector returns pods matching the given label selector.
// Selector format: "key=value,key2=value2"
// Returns an error if the selector is invalid.
func (d *Discovery) ListPodsWithSelector(ctx context.Context, contextName, namespace, selector string) ([]PodInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
// Validate selector format
selector = strings.TrimSpace(selector)
if selector == "" {
return nil, fmt.Errorf("selector cannot be empty")
}
podList, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
})
if err != nil {
return nil, fmt.Errorf("failed to list pods with selector: %w", err)
}
pods := make([]PodInfo, 0)
for _, pod := range podList.Items {
// Only include Running pods for selector-based forwards
if pod.Status.Phase != corev1.PodRunning {
continue
}
containers := make([]ContainerInfo, 0, len(pod.Spec.Containers))
for _, container := range pod.Spec.Containers {
ports := make([]PortInfo, 0, len(container.Ports))
for _, port := range container.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.ContainerPort,
Protocol: string(port.Protocol),
})
}
containers = append(containers, ContainerInfo{
Name: container.Name,
Ports: ports,
})
}
pods = append(pods, PodInfo{
Name: pod.Name,
Namespace: pod.Namespace,
Containers: containers,
Status: string(pod.Status.Phase),
Created: pod.CreationTimestamp,
})
}
// Sort by creation time (newest first)
sort.Slice(pods, func(i, j int) bool {
return pods[i].Created.After(pods[j].Created.Time)
})
return pods, nil
}
// ListServices returns all services in the given namespace.
func (d *Discovery) ListServices(ctx context.Context, contextName, namespace string) ([]ServiceInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
svcList, err := client.CoreV1().Services(namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list services: %w", err)
}
services := make([]ServiceInfo, 0, len(svcList.Items))
for _, svc := range svcList.Items {
ports := make([]PortInfo, 0, len(svc.Spec.Ports))
for _, port := range svc.Spec.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.Port,
Protocol: string(port.Protocol),
})
}
services = append(services, ServiceInfo{
Name: svc.Name,
Namespace: svc.Namespace,
Ports: ports,
Type: string(svc.Spec.Type),
})
}
// Sort alphabetically
sort.Slice(services, func(i, j int) bool {
return services[i].Name < services[j].Name
})
return services, nil
}
// GetUniquePorts extracts unique ports from a list of pods.
// Returns a sorted list of port numbers with their names (if available).
func GetUniquePorts(pods []PodInfo) []PortInfo {
portMap := make(map[int32]string)
for _, pod := range pods {
for _, container := range pod.Containers {
for _, port := range container.Ports {
// Prefer named ports
if _, ok := portMap[port.Port]; !ok || port.Name != "" {
if port.Name != "" {
portMap[port.Port] = port.Name
} else if !ok {
portMap[port.Port] = fmt.Sprintf("port-%d", port.Port)
}
}
}
}
}
// Convert to slice
ports := make([]PortInfo, 0, len(portMap))
for port, name := range portMap {
ports = append(ports, PortInfo{
Name: name,
Port: port,
})
}
// Sort by port number
sort.Slice(ports, func(i, j int) bool {
return ports[i].Port < ports[j].Port
})
return ports
}
// CheckPortAvailability checks if a local port is available.
// Returns: available (bool), processInfo (string), error
func CheckPortAvailability(port int) (bool, string, error) {
if port < 1 || port > 65535 {
return false, "", fmt.Errorf("invalid port: %d", port)
}
// Try to listen on the port
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
// Port is in use
// Try to get process info (best-effort)
processInfo := "unknown process"
// Note: Getting process info requires platform-specific code
// For now, just return a generic message
return false, processInfo, nil
}
// Port is available, close the listener
listener.Close()
return true, "", nil
}
+34 -5
View File
@@ -4,9 +4,11 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -17,18 +19,32 @@ import (
// PortForwarder handles Kubernetes port-forwarding operations.
type PortForwarder struct {
clientPool *ClientPool
resolver *ResourceResolver
clientPool *ClientPool
resolver *ResourceResolver
tcpKeepalive time.Duration // TCP keepalive interval
dialTimeout time.Duration // Connection dial timeout
}
// NewPortForwarder creates a new PortForwarder instance.
// NewPortForwarder creates a new PortForwarder instance with default settings.
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
return &PortForwarder{
clientPool: clientPool,
resolver: resolver,
clientPool: clientPool,
resolver: resolver,
tcpKeepalive: 30 * time.Second, // Default: 30 second keepalive
dialTimeout: 30 * time.Second, // Default: 30 second dial timeout
}
}
// SetTCPKeepalive configures the TCP keepalive interval for new connections.
func (pf *PortForwarder) SetTCPKeepalive(keepalive time.Duration) {
pf.tcpKeepalive = keepalive
}
// SetDialTimeout configures the connection dial timeout.
func (pf *PortForwarder) SetDialTimeout(timeout time.Duration) {
pf.dialTimeout = timeout
}
// ForwardRequest contains the parameters for a port-forward request.
type ForwardRequest struct {
ContextName string // Kubernetes context name
@@ -164,6 +180,19 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
// executePortForward performs the actual port-forward operation.
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
// Configure TCP settings on the underlying connection
// This is set in the rest.Config which will be used by the SPDY transport
if config.Dial == nil {
// Create a custom dialer with configurable timeout and keepalive
// - Timeout: How long to wait for connection to establish
// - KeepAlive: TCP keepalive helps OS detect dead connections at network layer
dialer := &net.Dialer{
Timeout: pf.dialTimeout, // Configurable dial timeout
KeepAlive: pf.tcpKeepalive, // Configurable keepalive interval
}
config.Dial = dialer.DialContext
}
// Create SPDY roundtripper
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
-26
View File
@@ -228,29 +228,3 @@ func (r *ResourceResolver) InvalidateCache(contextName, namespace, resource stri
}
}
}
// GetPodList returns a list of pods matching the given criteria.
// This is useful for debugging and testing.
func (r *ResourceResolver) GetPodList(ctx context.Context, contextName, namespace, selector string) ([]*corev1.Pod, error) {
client, err := r.clientPool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
listOptions := metav1.ListOptions{}
if selector != "" {
listOptions.LabelSelector = selector
}
pods, err := client.CoreV1().Pods(namespace).List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
result := make([]*corev1.Pod, len(pods.Items))
for i := range pods.Items {
result[i] = &pods.Items[i]
}
return result, nil
}
+70
View File
@@ -0,0 +1,70 @@
package logger_test
import (
"bytes"
"fmt"
"testing"
"github.com/nvm/kportal/internal/logger"
)
// This test demonstrates the logger output formats
func TestLoggerDemo(t *testing.T) {
t.Skip("Demo only - run manually with: go test -v -run TestLoggerDemo")
fmt.Println("\n=== TEXT FORMAT (DEFAULT) ===")
textBuf := &bytes.Buffer{}
textLogger := logger.New(logger.LevelInfo, logger.FormatText, textBuf)
textLogger.Info("Port forward started", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"local_port": 8080,
"pod": "app-xyz123",
})
textLogger.Warn("Connection failed, retrying", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "connection refused",
"retry": 3,
})
textLogger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "pod not found",
})
fmt.Print(textBuf.String())
fmt.Println("\n=== JSON FORMAT ===")
jsonBuf := &bytes.Buffer{}
jsonLogger := logger.New(logger.LevelInfo, logger.FormatJSON, jsonBuf)
jsonLogger.Info("Port forward started", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"local_port": 8080,
"pod": "app-xyz123",
})
jsonLogger.Warn("Connection failed, retrying", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "connection refused",
"retry": 3,
})
jsonLogger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "pod not found",
})
fmt.Print(jsonBuf.String())
fmt.Println("\n=== LOG LEVEL FILTERING (Debug level disabled) ===")
filteredBuf := &bytes.Buffer{}
filteredLogger := logger.New(logger.LevelInfo, logger.FormatText, filteredBuf)
filteredLogger.Debug("This will not appear", nil)
filteredLogger.Info("This will appear", nil)
filteredLogger.Warn("This will also appear", nil)
fmt.Print(filteredBuf.String())
}
+96
View File
@@ -0,0 +1,96 @@
package logger
import (
"bytes"
"io"
"strings"
"sync"
)
// KlogWriter is an io.Writer that routes klog output through our structured logger.
// It parses klog messages and routes them to appropriate log levels.
// It is thread-safe for concurrent writes.
type KlogWriter struct {
logger *Logger
buffer *bytes.Buffer
mu sync.Mutex
}
// NewKlogWriter creates a new KlogWriter that routes k8s client-go logs
// through our structured logger.
func NewKlogWriter(logger *Logger) *KlogWriter {
return &KlogWriter{
logger: logger,
buffer: &bytes.Buffer{},
}
}
// Write implements io.Writer.
// It parses klog output and routes it through our structured logger.
// This method is thread-safe.
func (w *KlogWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// Write to buffer first
w.buffer.Write(p)
// Process complete lines
for {
line, err := w.buffer.ReadString('\n')
if err != nil {
// No complete line yet, write back what we read and wait for more
if err == io.EOF && line != "" {
w.buffer.WriteString(line)
}
break
}
// Process the complete line
w.processLine(strings.TrimSpace(line))
}
return len(p), nil
}
// processLine parses a klog line and routes it to the appropriate log level.
func (w *KlogWriter) processLine(line string) {
if line == "" {
return
}
// Parse klog format: "I1124 12:34:56.789012 12345 file.go:123] message"
// First character indicates level: I=Info, W=Warning, E=Error, F=Fatal
if len(line) < 1 {
return
}
level := line[0]
message := line
// Try to extract just the message part after "]"
if idx := strings.Index(line, "] "); idx != -1 {
message = line[idx+2:]
}
// Determine log level and route accordingly
switch level {
case 'I': // Info
w.logger.Debug(message, map[string]interface{}{
"source": "k8s-client",
})
case 'W': // Warning
w.logger.Warn(message, map[string]interface{}{
"source": "k8s-client",
})
case 'E', 'F': // Error or Fatal
w.logger.Error(message, map[string]interface{}{
"source": "k8s-client",
})
default:
// Unknown format, log as debug
w.logger.Debug(message, map[string]interface{}{
"source": "k8s-client",
})
}
}
+280
View File
@@ -0,0 +1,280 @@
package logger
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKlogWriter(t *testing.T) {
tests := []struct {
name string
input string
expectedLevel string
expectedMsg string
loggerLevel Level
loggerFormat Format
shouldLog bool
description string
}{
{
name: "info level log",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Starting port forward\n",
expectedLevel: "DEBUG",
expectedMsg: "Starting port forward",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Info logs from k8s should be routed as DEBUG",
},
{
name: "warning level log",
input: "W1124 12:34:56.789012 12345 portforward.go:456] Connection unstable\n",
expectedLevel: "WARN",
expectedMsg: "Connection unstable",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Warning logs should be routed as WARN",
},
{
name: "error level log",
input: "E1124 12:34:56.789012 12345 portforward.go:789] Connection failed\n",
expectedLevel: "ERROR",
expectedMsg: "Connection failed",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Error logs should be routed as ERROR",
},
{
name: "fatal level log",
input: "F1124 12:34:56.789012 12345 portforward.go:999] Fatal error\n",
expectedLevel: "ERROR",
expectedMsg: "Fatal error",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Fatal logs should be routed as ERROR",
},
{
name: "multiline input",
input: "I1124 12:34:56.789012 12345 portforward.go:123] First message\nI1124 12:34:57.123456 12345 portforward.go:124] Second message\n",
expectedLevel: "DEBUG",
expectedMsg: "First message",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Should handle multiple log lines",
},
{
name: "log filtered by level",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Debug message\n",
expectedLevel: "DEBUG",
expectedMsg: "Debug message",
loggerLevel: LevelInfo, // Logger set to INFO, DEBUG should be filtered
loggerFormat: FormatText,
shouldLog: false,
description: "DEBUG logs should be filtered when logger level is INFO",
},
{
name: "unknown log format",
input: "X1124 12:34:56.789012 12345 portforward.go:123] Unknown format\n",
expectedLevel: "DEBUG",
expectedMsg: "Unknown format",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Unknown format should default to DEBUG",
},
{
name: "empty line",
input: "\n",
expectedLevel: "",
expectedMsg: "",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: false,
description: "Empty lines should be ignored",
},
{
name: "partial line no newline",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Partial",
expectedLevel: "",
expectedMsg: "",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: false,
description: "Partial lines without newline should be buffered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create output buffer
var buf bytes.Buffer
// Create logger with specified level and format
logger := New(tt.loggerLevel, tt.loggerFormat, &buf)
// Create klog writer
klogWriter := NewKlogWriter(logger)
// Write input
n, err := klogWriter.Write([]byte(tt.input))
require.NoError(t, err)
assert.Equal(t, len(tt.input), n)
// Check output
output := buf.String()
if !tt.shouldLog {
assert.Empty(t, output, "Expected no log output")
return
}
if tt.loggerFormat == FormatText {
// Text format: [LEVEL] message
assert.Contains(t, output, fmt.Sprintf("[%s]", tt.expectedLevel))
assert.Contains(t, output, tt.expectedMsg)
assert.Contains(t, output, "k8s-client") // Should include source field
} else {
// JSON format
var entry logEntry
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) > 0 {
err := json.Unmarshal([]byte(lines[0]), &entry)
require.NoError(t, err)
assert.Equal(t, tt.expectedLevel, entry.Level)
assert.Equal(t, tt.expectedMsg, entry.Message)
assert.Equal(t, "k8s-client", entry.Fields["source"])
}
}
})
}
}
func TestKlogWriterBuffering(t *testing.T) {
tests := []struct {
name string
writes []string
expectCount int
description string
}{
{
name: "single complete line",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] Complete line\n",
},
expectCount: 1,
description: "Single complete line should produce one log entry",
},
{
name: "partial then complete",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] Partial ",
"line\n",
},
expectCount: 1,
description: "Partial writes should be buffered and combined",
},
{
name: "multiple complete lines in chunks",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] First\n",
"I1124 12:34:57.123456 12345 portforward.go:124] Second\n",
"I1124 12:34:58.456789 12345 portforward.go:125] Third\n",
},
expectCount: 3,
description: "Multiple complete lines should produce multiple log entries",
},
{
name: "mixed partial and complete",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] First\nI1124 12:34:57.123456 12345 port",
"forward.go:124] Second\n",
},
expectCount: 2,
description: "Mixed partial and complete lines should be handled correctly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, FormatText, &buf)
klogWriter := NewKlogWriter(logger)
// Write all chunks
for _, write := range tt.writes {
_, err := klogWriter.Write([]byte(write))
require.NoError(t, err)
}
// Count log entries (each line starts with [LEVEL])
output := buf.String()
count := strings.Count(output, "[DEBUG]") +
strings.Count(output, "[INFO]") +
strings.Count(output, "[WARN]") +
strings.Count(output, "[ERROR]")
assert.Equal(t, tt.expectCount, count, "Expected %d log entries, got %d", tt.expectCount, count)
})
}
}
func TestKlogWriterJSONFormat(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, FormatJSON, &buf)
klogWriter := NewKlogWriter(logger)
// Write a k8s log line
input := "I1124 12:34:56.789012 12345 portforward.go:123] Starting port forward\n"
_, err := klogWriter.Write([]byte(input))
require.NoError(t, err)
// Parse JSON output
var entry logEntry
err = json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
// Verify JSON structure
assert.Equal(t, "DEBUG", entry.Level)
assert.Equal(t, "Starting port forward", entry.Message)
assert.NotEmpty(t, entry.Time)
assert.Equal(t, "k8s-client", entry.Fields["source"])
}
func TestKlogWriterConcurrency(t *testing.T) {
// Test that concurrent writes don't cause data races
var buf bytes.Buffer
logger := New(LevelDebug, FormatText, &buf)
klogWriter := NewKlogWriter(logger)
done := make(chan bool)
numGoroutines := 10
numWrites := 100
for i := 0; i < numGoroutines; i++ {
go func(id int) {
for j := 0; j < numWrites; j++ {
msg := fmt.Sprintf("I1124 12:34:56.789012 12345 test.go:123] Message from goroutine %d iteration %d\n", id, j)
klogWriter.Write([]byte(msg))
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < numGoroutines; i++ {
<-done
}
// Just verify we didn't panic (data race detector would catch issues)
assert.NotEmpty(t, buf.String())
}
+105
View File
@@ -0,0 +1,105 @@
package logger
import (
"github.com/go-logr/logr"
)
// LogrAdapter implements the logr.LogSink interface to route klog v2 logs
// through our structured logger. This captures ALL klog output including
// error logs, structured logs, and named logger output.
type LogrAdapter struct {
logger *Logger
name string
level int
}
// NewLogrAdapter creates a new logr.LogSink that routes all klog v2 logs
// through our structured logger.
func NewLogrAdapter(logger *Logger) logr.LogSink {
return &LogrAdapter{
logger: logger,
name: "",
level: 0,
}
}
// Init initializes the logger with runtime info (not used in our implementation).
func (l *LogrAdapter) Init(info logr.RuntimeInfo) {
// No-op: we don't need runtime info
}
// Enabled tests whether this LogSink is enabled at the specified V-level.
// We route all logs through our logger's level filtering.
func (l *LogrAdapter) Enabled(level int) bool {
// Map logr V-levels to our levels:
// V(0) = Info level (always enabled if logger level <= Info)
// V(1+) = Debug level (enabled if logger level <= Debug)
if level == 0 {
return l.logger.level <= LevelInfo
}
return l.logger.level <= LevelDebug
}
// Info logs a non-error message with the given key/value pairs.
func (l *LogrAdapter) Info(level int, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
// Map logr V-levels to our levels:
// V(0) = Info, V(1+) = Debug
if level == 0 {
l.logger.Info(msg, fields)
} else {
l.logger.Debug(msg, fields)
}
}
// Error logs an error message with the given key/value pairs.
func (l *LogrAdapter) Error(err error, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
if err != nil {
fields["error"] = err.Error()
}
l.logger.Error(msg, fields)
}
// WithValues returns a new LogSink with additional key/value pairs.
func (l *LogrAdapter) WithValues(keysAndValues ...interface{}) logr.LogSink {
// For simplicity, we don't implement value accumulation
// Each log call receives all its keysAndValues directly
return l
}
// WithName returns a new LogSink with the specified name appended.
func (l *LogrAdapter) WithName(name string) logr.LogSink {
newLogger := *l
if l.name == "" {
newLogger.name = name
} else {
newLogger.name = l.name + "." + name
}
return &newLogger
}
// kvToMap converts a slice of alternating keys and values to a map.
func (l *LogrAdapter) kvToMap(keysAndValues []interface{}) map[string]interface{} {
fields := make(map[string]interface{})
fields["source"] = "k8s-client"
for i := 0; i < len(keysAndValues); i += 2 {
if i+1 < len(keysAndValues) {
key, ok := keysAndValues[i].(string)
if ok {
fields[key] = keysAndValues[i+1]
}
}
}
return fields
}
+367
View File
@@ -0,0 +1,367 @@
package logger
import (
"bytes"
"encoding/json"
"errors"
"testing"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogrAdapter_Info(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "info log v0 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
message: "Connection established",
keysAndValues: []interface{}{"pod", "my-app-123", "port", 8080},
expectOutput: true,
expectContains: []string{"[INFO]", "Connection established", "pod", "my-app-123"},
},
{
name: "info log v0 with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Port forward ready",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Port forward ready"},
},
{
name: "info log v0 silenced with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
message: "This should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "debug log v1 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
message: "Detailed connection info",
keysAndValues: []interface{}{"details", "some-value"},
expectOutput: true,
expectContains: []string{"[DEBUG]", "Detailed connection info", "details"},
},
{
name: "debug log v1 silenced with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
message: "This debug should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "info with odd number of kvs (incomplete pair)",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Message with incomplete kv",
keysAndValues: []interface{}{"key1", "value1", "key2"}, // key2 has no value
expectOutput: true,
expectContains: []string{"[INFO]", "Message with incomplete kv", "key1", "value1"},
},
{
name: "info with source field added automatically",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Test source field",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Test source field", "source:k8s-client"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.V(tt.logrLevel).Info(tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_Error(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
err error
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "error with error object",
loggerLevel: LevelError,
err: errors.New("connection failed"),
message: "Port forward failed",
keysAndValues: []interface{}{"pod", "my-app-123"},
expectOutput: true,
expectContains: []string{"[ERROR]", "Port forward failed", "connection failed", "pod", "my-app-123"},
},
{
name: "error without error object",
loggerLevel: LevelError,
err: nil,
message: "Generic error message",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[ERROR]", "Generic error message"},
},
{
name: "error silenced with level above error",
loggerLevel: LevelError + 1,
err: errors.New("should not appear"),
message: "This error should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "error with multiple kvs",
loggerLevel: LevelError,
err: errors.New("sandbox not found"),
message: "Unhandled Error",
keysAndValues: []interface{}{"pod", "test-pod", "uid", "abc123", "port", 8080},
expectOutput: true,
expectContains: []string{"[ERROR]", "Unhandled Error", "sandbox not found", "pod", "test-pod", "uid", "abc123"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.Error(tt.err, tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_WithName(t *testing.T) {
tests := []struct {
name string
loggerNames []string
message string
expectContains string
}{
{
name: "single logger name",
loggerNames: []string{"portforward"},
message: "Test message",
expectContains: "logger:portforward",
},
{
name: "nested logger names",
loggerNames: []string{"controller", "worker", "healthcheck"},
message: "Nested message",
expectContains: "logger:controller.worker.healthcheck",
},
{
name: "no logger name",
loggerNames: []string{},
message: "No name message",
expectContains: "source:k8s-client", // Should still have source but no logger field
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Apply WithName calls
for _, name := range tt.loggerNames {
logrLogger = logrLogger.WithName(name)
}
logrLogger.Info(tt.message)
output := buf.String()
assert.Contains(t, output, tt.expectContains)
})
}
}
func TestLogrAdapter_Enabled(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
expectEnabled bool
}{
{
name: "v0 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 enabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 disabled with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
expectEnabled: false,
},
{
name: "v1 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
expectEnabled: true,
},
{
name: "v1 disabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
expectEnabled: false,
},
{
name: "v2 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 2,
expectEnabled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := New(tt.loggerLevel, FormatText, &bytes.Buffer{})
sink := NewLogrAdapter(logger)
enabled := sink.Enabled(tt.logrLevel)
assert.Equal(t, tt.expectEnabled, enabled)
})
}
}
func TestLogrAdapter_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatJSON, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("test-component")
logrLogger.Info("Test JSON message", "key1", "value1", "key2", 123)
// Parse JSON output
var entry logEntry
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "INFO", entry.Level)
assert.Equal(t, "Test JSON message", entry.Message)
assert.Equal(t, "k8s-client", entry.Fields["source"])
assert.Equal(t, "test-component", entry.Fields["logger"])
assert.Equal(t, "value1", entry.Fields["key1"])
assert.Equal(t, float64(123), entry.Fields["key2"]) // JSON numbers decode as float64
}
func TestLogrAdapter_ConcurrentWrites(t *testing.T) {
// Note: bytes.Buffer is not thread-safe for writes, so this test verifies
// that our LogrAdapter doesn't panic under concurrent load, but we don't
// verify exact output (since logger uses fmt.Fprintf which is also not thread-safe)
buf := &bytes.Buffer{}
logger := New(LevelDebug, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Spawn multiple goroutines writing concurrently
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
logrLogger.Info("Concurrent message", "goroutine", id, "iteration", j)
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
output := buf.String()
// Verify we got substantial output (not checking exact count due to buffer race)
// The main goal is to ensure no panics occur during concurrent writes
assert.NotEmpty(t, output, "Should have some log output")
assert.Contains(t, output, "Concurrent message")
}
func TestLogrAdapter_RealWorldKlogError(t *testing.T) {
// Simulate the exact error message from the screenshot
buf := &bytes.Buffer{}
logger := New(LevelError, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("UnhandledError")
err := errors.New("an error occurred forwarding 8401 -> 8401: error forwarding port 8401 to pod 4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308, uid : failed to find sandbox '4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308' in store: not found")
logrLogger.Error(err, "Unhandled Error")
output := buf.String()
assert.Contains(t, output, "[ERROR]")
assert.Contains(t, output, "Unhandled Error")
assert.Contains(t, output, "failed to find sandbox")
assert.Contains(t, output, "logger:UnhandledError")
}
func TestLogrAdapter_SilenceMode(t *testing.T) {
// Test that logs are completely silenced when logger level is above error
buf := &bytes.Buffer{}
logger := New(LevelError+1, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Try all log levels
logrLogger.V(0).Info("Info message should not appear")
logrLogger.V(1).Info("Debug message should not appear")
logrLogger.Error(errors.New("error object"), "Error message should not appear")
output := buf.String()
assert.Empty(t, output, "All logs should be silenced")
}
+164
View File
@@ -0,0 +1,164 @@
package logger
import (
"encoding/json"
"fmt"
"io"
"os"
"sync"
"time"
)
type Level int
const (
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
)
type Format int
const (
FormatText Format = iota
FormatJSON
)
type Logger struct {
level Level
format Format
output io.Writer
mu sync.Mutex // Protects concurrent writes to output
}
type logEntry struct {
Time string `json:"time"`
Level string `json:"level"`
Message string `json:"message"`
Fields map[string]interface{} `json:"fields,omitempty"`
}
func New(level Level, format Format, output io.Writer) *Logger {
if output == nil {
output = os.Stderr
}
return &Logger{
level: level,
format: format,
output: output,
}
}
func (l *Logger) log(level Level, msg string, fields map[string]interface{}) {
if level < l.level {
return
}
levelStr := levelToString(level)
l.mu.Lock()
defer l.mu.Unlock()
if l.format == FormatJSON {
entry := logEntry{
Time: time.Now().Format(time.RFC3339),
Level: levelStr,
Message: msg,
Fields: fields,
}
data, _ := json.Marshal(entry)
fmt.Fprintln(l.output, string(data))
} else {
// Text format
if len(fields) > 0 {
fmt.Fprintf(l.output, "[%s] %s %v\n", levelStr, msg, fields)
} else {
fmt.Fprintf(l.output, "[%s] %s\n", levelStr, msg)
}
}
}
func (l *Logger) Debug(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelDebug, msg, f)
}
func (l *Logger) Info(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelInfo, msg, f)
}
func (l *Logger) Warn(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelWarn, msg, f)
}
func (l *Logger) Error(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelError, msg, f)
}
func levelToString(level Level) string {
switch level {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
// Global logger for backward compatibility
var globalLogger *Logger
func Init(level Level, format Format, output ...io.Writer) {
var out io.Writer
if len(output) > 0 && output[0] != nil {
out = output[0]
} else {
out = os.Stderr
}
globalLogger = New(level, format, out)
}
func Debug(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Debug(msg, fields...)
}
}
func Info(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Info(msg, fields...)
}
}
func Warn(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Warn(msg, fields...)
}
}
func Error(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Error(msg, fields...)
}
}
+521
View File
@@ -0,0 +1,521 @@
package logger
import (
"bytes"
"encoding/json"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoggerTextFormat(t *testing.T) {
tests := []struct {
name string
level Level
logLevel Level
message string
fields map[string]interface{}
expectOutput bool
expectContains []string
}{
{
name: "info logged at info level",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: nil,
expectOutput: true,
expectContains: []string{"[INFO]", "test message"},
},
{
name: "debug filtered at info level",
level: LevelInfo,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: false,
expectContains: []string{},
},
{
name: "error logged at info level",
level: LevelInfo,
logLevel: LevelError,
message: "error message",
fields: nil,
expectOutput: true,
expectContains: []string{"[ERROR]", "error message"},
},
{
name: "info with fields",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: map[string]interface{}{
"key1": "value1",
"key2": 123,
},
expectOutput: true,
expectContains: []string{"[INFO]", "test message", "key1", "value1"},
},
{
name: "warn logged at warn level",
level: LevelWarn,
logLevel: LevelWarn,
message: "warning message",
fields: nil,
expectOutput: true,
expectContains: []string{"[WARN]", "warning message"},
},
{
name: "info filtered at warn level",
level: LevelWarn,
logLevel: LevelInfo,
message: "info message",
fields: nil,
expectOutput: false,
expectContains: []string{},
},
{
name: "debug logged at debug level",
level: LevelDebug,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: true,
expectContains: []string{"[DEBUG]", "debug message"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.level, FormatText, buf)
// Log at the specified level
switch tt.logLevel {
case LevelDebug:
if tt.fields != nil {
logger.Debug(tt.message, tt.fields)
} else {
logger.Debug(tt.message)
}
case LevelInfo:
if tt.fields != nil {
logger.Info(tt.message, tt.fields)
} else {
logger.Info(tt.message)
}
case LevelWarn:
if tt.fields != nil {
logger.Warn(tt.message, tt.fields)
} else {
logger.Warn(tt.message)
}
case LevelError:
if tt.fields != nil {
logger.Error(tt.message, tt.fields)
} else {
logger.Error(tt.message)
}
}
output := buf.String()
if tt.expectOutput {
assert.NotEmpty(t, output, "Expected log output but got none")
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Expected output to contain: %s", expected)
}
} else {
assert.Empty(t, output, "Expected no log output but got: %s", output)
}
})
}
}
func TestLoggerJSONFormat(t *testing.T) {
tests := []struct {
name string
level Level
logLevel Level
message string
fields map[string]interface{}
expectOutput bool
expectLevel string
}{
{
name: "info logged at info level",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: nil,
expectOutput: true,
expectLevel: "INFO",
},
{
name: "debug filtered at info level",
level: LevelInfo,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: false,
expectLevel: "",
},
{
name: "error logged at debug level",
level: LevelDebug,
logLevel: LevelError,
message: "error message",
fields: nil,
expectOutput: true,
expectLevel: "ERROR",
},
{
name: "info with fields",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: map[string]interface{}{
"context": "production",
"port": 8080,
"retry": 3,
},
expectOutput: true,
expectLevel: "INFO",
},
{
name: "warn at warn level",
level: LevelWarn,
logLevel: LevelWarn,
message: "warning message",
fields: nil,
expectOutput: true,
expectLevel: "WARN",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.level, FormatJSON, buf)
// Log at the specified level
switch tt.logLevel {
case LevelDebug:
if tt.fields != nil {
logger.Debug(tt.message, tt.fields)
} else {
logger.Debug(tt.message)
}
case LevelInfo:
if tt.fields != nil {
logger.Info(tt.message, tt.fields)
} else {
logger.Info(tt.message)
}
case LevelWarn:
if tt.fields != nil {
logger.Warn(tt.message, tt.fields)
} else {
logger.Warn(tt.message)
}
case LevelError:
if tt.fields != nil {
logger.Error(tt.message, tt.fields)
} else {
logger.Error(tt.message)
}
}
output := buf.String()
if tt.expectOutput {
assert.NotEmpty(t, output, "Expected log output but got none")
// Parse JSON
var entry logEntry
err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry)
require.NoError(t, err, "Failed to parse JSON output: %s", output)
// Validate fields
assert.Equal(t, tt.expectLevel, entry.Level)
assert.Equal(t, tt.message, entry.Message)
assert.NotEmpty(t, entry.Time, "Time field should not be empty")
// Validate custom fields if provided
if tt.fields != nil {
require.NotNil(t, entry.Fields, "Expected fields in JSON output")
for key, expectedValue := range tt.fields {
actualValue, exists := entry.Fields[key]
assert.True(t, exists, "Expected field %s not found in output", key)
// JSON unmarshaling converts numbers to float64
if floatVal, ok := expectedValue.(int); ok {
assert.Equal(t, float64(floatVal), actualValue)
} else {
assert.Equal(t, expectedValue, actualValue)
}
}
}
} else {
assert.Empty(t, output, "Expected no log output but got: %s", output)
}
})
}
}
func TestGlobalLogger(t *testing.T) {
tests := []struct {
name string
initLevel Level
initFormat Format
logFunc func(string, ...map[string]interface{})
message string
expectContains string
}{
{
name: "global info logger text",
initLevel: LevelInfo,
initFormat: FormatText,
logFunc: Info,
message: "global info message",
expectContains: "[INFO]",
},
{
name: "global error logger text",
initLevel: LevelInfo,
initFormat: FormatText,
logFunc: Error,
message: "global error message",
expectContains: "[ERROR]",
},
{
name: "global warn logger json",
initLevel: LevelWarn,
initFormat: FormatJSON,
logFunc: Warn,
message: "global warn message",
expectContains: `"level":"WARN"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Capture stderr by replacing globalLogger's output
buf := &bytes.Buffer{}
Init(tt.initLevel, tt.initFormat)
globalLogger.output = buf
// Call the global log function
tt.logFunc(tt.message)
output := buf.String()
assert.Contains(t, output, tt.expectContains)
assert.Contains(t, output, tt.message)
})
}
}
func TestLogLevelsFiltering(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logAtLevels []Level
expectOutputs []bool
}{
{
name: "debug level logs everything",
loggerLevel: LevelDebug,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{true, true, true, true},
},
{
name: "info level filters debug",
loggerLevel: LevelInfo,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, true, true, true},
},
{
name: "warn level filters debug and info",
loggerLevel: LevelWarn,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, false, true, true},
},
{
name: "error level only logs errors",
loggerLevel: LevelError,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, false, false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
for i, logLevel := range tt.logAtLevels {
buf.Reset()
switch logLevel {
case LevelDebug:
logger.Debug("test")
case LevelInfo:
logger.Info("test")
case LevelWarn:
logger.Warn("test")
case LevelError:
logger.Error("test")
}
hasOutput := buf.Len() > 0
assert.Equal(t, tt.expectOutputs[i], hasOutput,
"Level %v at logger level %v: expected output=%v, got=%v",
logLevel, tt.loggerLevel, tt.expectOutputs[i], hasOutput)
}
})
}
}
func TestLoggerNilOutput(t *testing.T) {
// Test that logger defaults to os.Stderr when output is nil
logger := New(LevelInfo, FormatText, nil)
assert.NotNil(t, logger.output, "Logger output should not be nil")
}
func TestLevelToString(t *testing.T) {
tests := []struct {
level Level
expected string
}{
{LevelDebug, "DEBUG"},
{LevelInfo, "INFO"},
{LevelWarn, "WARN"},
{LevelError, "ERROR"},
{Level(999), "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := levelToString(tt.level)
assert.Equal(t, tt.expected, result)
})
}
}
func TestJSONFieldTypes(t *testing.T) {
tests := []struct {
name string
fields map[string]interface{}
}{
{
name: "string fields",
fields: map[string]interface{}{
"key1": "value1",
"key2": "value2",
},
},
{
name: "numeric fields",
fields: map[string]interface{}{
"port": 8080,
"timeout": 30,
"retry": 3,
},
},
{
name: "boolean fields",
fields: map[string]interface{}{
"enabled": true,
"running": false,
},
},
{
name: "mixed types",
fields: map[string]interface{}{
"context": "production",
"port": 8080,
"enabled": true,
"namespace": "default",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatJSON, buf)
logger.Info("test message", tt.fields)
var entry logEntry
err := json.Unmarshal([]byte(strings.TrimSpace(buf.String())), &entry)
require.NoError(t, err)
assert.Equal(t, len(tt.fields), len(entry.Fields),
"Field count mismatch")
for key := range tt.fields {
_, exists := entry.Fields[key]
assert.True(t, exists, "Field %s not found in JSON output", key)
}
})
}
}
func TestInitWithCustomOutput(t *testing.T) {
tests := []struct {
name string
output io.Writer
expectDiscard bool
description string
}{
{
name: "init with custom buffer",
output: &bytes.Buffer{},
expectDiscard: false,
description: "Should use provided buffer",
},
{
name: "init with io.Discard",
output: io.Discard,
expectDiscard: true,
description: "Should use io.Discard to silence output",
},
{
name: "init without output defaults to stderr",
output: nil,
expectDiscard: false,
description: "Should default to stderr when no output provided",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.output != nil {
Init(LevelInfo, FormatText, tt.output)
} else {
Init(LevelInfo, FormatText)
}
// Verify global logger was initialized
assert.NotNil(t, globalLogger, "Global logger should be initialized")
if tt.output != nil && !tt.expectDiscard {
// For buffer, verify output works
if buf, ok := tt.output.(*bytes.Buffer); ok {
Info("test message")
output := buf.String()
assert.Contains(t, output, "test message")
assert.Contains(t, output, "[INFO]")
}
} else if tt.expectDiscard {
// For io.Discard, verify no output appears (we can't really test this directly,
// but we can verify the logger was set with the right output)
assert.Equal(t, io.Discard, globalLogger.output)
}
})
}
}
-5
View File
@@ -67,8 +67,3 @@ func (b *Backoff) calculateJitter(delay time.Duration) time.Duration {
jitter := (b.rng.Float64()*2 - 1) * maxJitter
return time.Duration(jitter)
}
// Sleep waits for the next backoff duration.
func (b *Backoff) Sleep() {
time.Sleep(b.Next())
}
+5 -3
View File
@@ -158,10 +158,12 @@ func TestBackoff_ExponentialProgression(t *testing.T) {
// We allow for jitter by checking a range
for i := 1; i < len(delays)-1; i++ {
// Each delay should be roughly double the previous (accounting for jitter)
// With 10% jitter on each value, worst case: (2.0 * 1.1) / 0.9 = 2.44
// We use 1.7x to 2.5x as a reasonable range with 10% jitter on each
// With 10% jitter on each value:
// Lower bound: (2.0 * 0.9) / 1.1 ≈ 1.636
// Upper bound: (2.0 * 1.1) / 0.9 ≈ 2.444
// We use 1.6x to 2.5x as a reasonable range to account for jitter variance
ratio := float64(delays[i]) / float64(delays[i-1])
assert.GreaterOrEqual(t, ratio, 1.7, "exponential growth should be ~2x")
assert.GreaterOrEqual(t, ratio, 1.6, "exponential growth should be ~2x")
assert.LessOrEqual(t, ratio, 2.5, "exponential growth should be ~2x")
}
}
+269 -42
View File
@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
// ForwardUpdateMsg is sent when a forward status changes
@@ -44,11 +45,29 @@ type BubbleTeaUI struct {
toggleCallback func(id string, enable bool)
version string
errors map[string]string // Track error messages by forward ID
// Modal wizard state
viewMode ViewMode
addWizard *AddWizardState
removeWizard *RemoveWizardState
// Delete confirmation state
deleteConfirming bool
deleteConfirmID string
deleteConfirmAlias string
deleteConfirmCursor int // 0 = Yes, 1 = No
// Dependencies for wizards
discovery *k8s.Discovery
mutator *config.Mutator
configPath string
}
// bubbletea model
type model struct {
ui *BubbleTeaUI
ui *BubbleTeaUI
termWidth int
termHeight int
}
// NewBubbleTeaUI creates a new bubbletea-based UI
@@ -61,11 +80,22 @@ func NewBubbleTeaUI(toggleCallback func(id string, enable bool), version string)
toggleCallback: toggleCallback,
version: version,
errors: make(map[string]string),
viewMode: ViewModeMain,
}
return ui
}
// SetWizardDependencies sets the dependencies needed for the add/remove wizards
func (ui *BubbleTeaUI) SetWizardDependencies(discovery *k8s.Discovery, mutator *config.Mutator, configPath string) {
ui.mu.Lock()
defer ui.mu.Unlock()
ui.discovery = discovery
ui.mutator = mutator
ui.configPath = configPath
}
// Start starts the bubbletea application
func (ui *BubbleTeaUI) Start() error {
m := model{ui: ui}
@@ -187,33 +217,55 @@ func (m model) Init() tea.Cmd {
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.ui.mu.RLock()
viewMode := m.ui.viewMode
m.ui.mu.RUnlock()
switch msg := msg.(type) {
case tea.WindowSizeMsg:
// Update terminal dimensions on resize
m.termWidth = msg.Width
m.termHeight = msg.Height
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q":
return m, tea.Quit
case "up", "k":
m.ui.moveSelection(-1)
case "down", "j":
m.ui.moveSelection(1)
case " ", "enter":
m.ui.toggleSelected()
// Route based on current view mode
switch viewMode {
case ViewModeMain:
return m.handleMainViewKeys(msg)
case ViewModeAddWizard:
return m.handleAddWizardKeys(msg)
case ViewModeRemoveWizard:
return m.handleRemoveWizardKeys(msg)
}
case ForwardAddMsg:
// Already handled in AddForward, just trigger re-render
// Forward management messages (always update main view data)
case ForwardAddMsg, ForwardUpdateMsg, ForwardErrorMsg, ForwardRemoveMsg:
return m, nil
case ForwardUpdateMsg:
// Already handled in UpdateStatus, just trigger re-render
return m, nil
case ForwardErrorMsg:
// Already handled in SetError, just trigger re-render
return m, nil
case ForwardRemoveMsg:
// Already handled in Remove, just trigger re-render
// Wizard-specific messages
case ContextsLoadedMsg:
return m.handleContextsLoaded(msg)
case NamespacesLoadedMsg:
return m.handleNamespacesLoaded(msg)
case PodsLoadedMsg:
return m.handlePodsLoaded(msg)
case ServicesLoadedMsg:
return m.handleServicesLoaded(msg)
case SelectorValidatedMsg:
return m.handleSelectorValidated(msg)
case PortCheckedMsg:
return m.handlePortChecked(msg)
case ForwardSavedMsg:
return m.handleForwardSaved(msg)
case ForwardsRemovedMsg:
return m.handleForwardsRemoved(msg)
case WizardCompleteMsg:
m.ui.mu.Lock()
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
m.ui.removeWizard = nil
m.ui.mu.Unlock()
return m, nil
}
@@ -221,11 +273,57 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
func (m model) View() string {
m.ui.mu.RLock()
viewMode := m.ui.viewMode
deleteConfirming := m.ui.deleteConfirming
m.ui.mu.RUnlock()
// Always render main view as base
mainView := m.renderMainView()
// Use actual terminal dimensions for proper centering
termWidth := m.termWidth
termHeight := m.termHeight
// Fallback to reasonable defaults if dimensions not yet received
if termWidth == 0 {
termWidth = 120
}
if termHeight == 0 {
termHeight = 40
}
// Overlay delete confirmation if active
if deleteConfirming {
modal := m.renderDeleteConfirmation()
return overlayContent(mainView, modal, termWidth, termHeight)
}
// Overlay wizard if active
switch viewMode {
case ViewModeAddWizard:
modal := m.renderAddWizard()
return overlayContent(mainView, modal, termWidth, termHeight)
case ViewModeRemoveWizard:
modal := m.renderRemoveWizard()
return overlayContent(mainView, modal, termWidth, termHeight)
default:
return mainView
}
}
func (m model) renderMainView() string {
m.ui.mu.RLock()
defer m.ui.mu.RUnlock()
var b strings.Builder
// Get terminal dimensions for proper sizing
termHeight := m.termHeight
if termHeight == 0 {
termHeight = 40 // Fallback
}
// Styles
titleStyle := lipgloss.NewStyle().
Bold(true).
@@ -280,7 +378,7 @@ func (m model) View() string {
}
isSelected := (idx == m.ui.selectedIndex)
isDisabled := m.ui.disabledMap[id]
isDisabled := m.ui.disabledMap[id] || fwd.Status == "Disabled"
// Selection indicator
indicator := " "
@@ -350,21 +448,7 @@ func (m model) View() string {
}
}
// Footer
b.WriteString("\n")
footerStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
keyStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("220"))
footer := fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: Quit │ Total: %d",
keyStyle.Render("↑↓"),
keyStyle.Render("jk"),
keyStyle.Render("Space"),
keyStyle.Render("q"),
len(m.ui.forwardOrder))
b.WriteString(footerStyle.Render(footer))
// Display errors if any
// Display errors if any (before footer)
if len(m.ui.errors) > 0 {
b.WriteString("\n\n")
errorHeaderStyle := lipgloss.NewStyle().
@@ -374,20 +458,104 @@ func (m model) View() string {
b.WriteString(errorHeaderStyle.Render("Errors:"))
b.WriteString("\n")
errorLineStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Width(118). // Slightly less than table width (120) for padding
MaxWidth(118)
for id, errMsg := range m.ui.errors {
// Find the forward to display its alias
if fwd, ok := m.ui.forwards[id]; ok {
errorLineStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("196"))
line := fmt.Sprintf(" • %s: %s", fwd.Alias, errMsg)
b.WriteString(errorLineStyle.Render(line))
b.WriteString("\n")
// Format: " • alias: error message"
prefix := fmt.Sprintf(" • %s: ", fwd.Alias)
// Wrap the error message if it's too long
// Max line length is 118, subtract prefix length
maxErrLen := 118 - len(prefix)
wrappedMsg := wrapText(errMsg, maxErrLen)
// Render first line with prefix
lines := strings.Split(wrappedMsg, "\n")
if len(lines) > 0 {
b.WriteString(errorLineStyle.Render(prefix + lines[0]))
b.WriteString("\n")
// Render subsequent lines with indentation
indent := strings.Repeat(" ", len(prefix))
for i := 1; i < len(lines); i++ {
b.WriteString(errorLineStyle.Render(indent + lines[i]))
b.WriteString("\n")
}
}
}
}
}
// Calculate current content height
currentContent := b.String()
currentLines := strings.Count(currentContent, "\n") + 1
// Footer styles
footerStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
keyStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("220"))
footer := fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: New %s: Edit %s: Delete %s: Quit │ Total: %d",
keyStyle.Render("↑↓"),
keyStyle.Render("jk"),
keyStyle.Render("Space"),
keyStyle.Render("n"),
keyStyle.Render("e"),
keyStyle.Render("d"),
keyStyle.Render("q"),
len(m.ui.forwardOrder))
// Fill space to push footer to bottom (reserve 2 lines: 1 for spacing, 1 for footer)
footerHeight := 2
remainingLines := termHeight - currentLines - footerHeight
if remainingLines > 0 {
b.WriteString(strings.Repeat("\n", remainingLines))
}
// Add footer at bottom
b.WriteString("\n")
b.WriteString(footerStyle.Render(footer))
return b.String()
}
// wrapText wraps text to the specified width, breaking at word boundaries
func wrapText(text string, width int) string {
if len(text) <= width {
return text
}
var result strings.Builder
var line strings.Builder
words := strings.Fields(text)
for i, word := range words {
// If adding this word would exceed width, start new line
if line.Len()+len(word)+1 > width && line.Len() > 0 {
result.WriteString(line.String())
result.WriteString("\n")
line.Reset()
}
// Add space before word (except first word on line)
if line.Len() > 0 {
line.WriteString(" ")
}
line.WriteString(word)
// Last word - flush the line
if i == len(words)-1 {
result.WriteString(line.String())
}
}
return result.String()
}
// moveSelection moves the selection up or down
func (ui *BubbleTeaUI) moveSelection(delta int) {
ui.mu.Lock()
@@ -406,6 +574,65 @@ func (ui *BubbleTeaUI) moveSelection(delta int) {
}
}
// renderDeleteConfirmation renders the delete confirmation dialog
func (m model) renderDeleteConfirmation() string {
m.ui.mu.RLock()
defer m.ui.mu.RUnlock()
var b strings.Builder
// Use wizard color palette for consistency
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(warningColor). // Yellow for warning (delete action)
Padding(0, 1)
buttonSelectedStyle := lipgloss.NewStyle().
Background(primaryColor). // Pink/Magenta background
Foreground(lipgloss.Color("230")). // Light yellow text
Bold(true).
Padding(0, 1)
buttonUnselectedStyle := lipgloss.NewStyle().
Foreground(mutedColor). // Gray
Padding(0, 1)
deleteInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252")). // Light gray for info text
Italic(true)
// Title
b.WriteString(titleStyle.Render("⚠ Delete Port Forward"))
b.WriteString("\n\n")
// Message
b.WriteString("Are you sure you want to delete:\n\n")
b.WriteString(deleteInfoStyle.Render(" " + m.ui.deleteConfirmAlias))
b.WriteString("\n\n")
// Buttons
if m.ui.deleteConfirmCursor == 0 {
b.WriteString(buttonSelectedStyle.Render(" Yes "))
b.WriteString(" ")
b.WriteString(buttonUnselectedStyle.Render(" No "))
} else {
b.WriteString(buttonUnselectedStyle.Render(" Yes "))
b.WriteString(" ")
b.WriteString(buttonSelectedStyle.Render(" No "))
}
b.WriteString("\n\n")
b.WriteString(helpStyle.Render("←/→: Navigate Enter: Confirm Esc: Cancel"))
// Wrap in a box using wizard style
boxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor). // Purple border like other wizards
Padding(1, 2)
return boxStyle.Render(b.String())
}
// toggleSelected toggles the selected forward on/off
func (ui *BubbleTeaUI) toggleSelected() {
ui.mu.Lock()
-181
View File
@@ -1,181 +0,0 @@
package ui
import (
"fmt"
"os"
"sync"
"golang.org/x/term"
)
// InteractiveController handles keyboard input and selection state
type InteractiveController struct {
mu sync.RWMutex
selectedIndex int
forwardIDs []string // Ordered list of forward IDs
disabledMap map[string]bool // Tracks which forwards are disabled
toggleCallback func(id string, enable bool)
enabled bool
oldTermState *term.State
}
// NewInteractiveController creates a new interactive controller
func NewInteractiveController(toggleCallback func(id string, enable bool)) *InteractiveController {
return &InteractiveController{
selectedIndex: 0,
forwardIDs: make([]string, 0),
disabledMap: make(map[string]bool),
toggleCallback: toggleCallback,
enabled: false,
}
}
// Enable puts the terminal in raw mode for keyboard input
func (ic *InteractiveController) Enable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.enabled {
return nil
}
// Save current terminal state
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return fmt.Errorf("failed to enable raw mode: %w", err)
}
ic.oldTermState = oldState
ic.enabled = true
return nil
}
// Disable restores the terminal to normal mode
func (ic *InteractiveController) Disable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if !ic.enabled {
return nil
}
if ic.oldTermState != nil {
if err := term.Restore(int(os.Stdin.Fd()), ic.oldTermState); err != nil {
return fmt.Errorf("failed to restore terminal: %w", err)
}
}
ic.enabled = false
return nil
}
// UpdateForwardsList updates the list of forwards for navigation
func (ic *InteractiveController) UpdateForwardsList(ids []string) {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.forwardIDs = ids
// Ensure selected index is valid
if ic.selectedIndex >= len(ic.forwardIDs) {
ic.selectedIndex = len(ic.forwardIDs) - 1
}
if ic.selectedIndex < 0 && len(ic.forwardIDs) > 0 {
ic.selectedIndex = 0
}
}
// MoveUp moves selection up
func (ic *InteractiveController) MoveUp() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex > 0 {
ic.selectedIndex--
}
}
// MoveDown moves selection down
func (ic *InteractiveController) MoveDown() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex < len(ic.forwardIDs)-1 {
ic.selectedIndex++
}
}
// ToggleSelected toggles the enable/disable state of the selected forward
func (ic *InteractiveController) ToggleSelected() {
ic.mu.Lock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
ic.mu.Unlock()
return
}
selectedID := ic.forwardIDs[ic.selectedIndex]
currentlyDisabled := ic.disabledMap[selectedID]
newState := !currentlyDisabled
ic.disabledMap[selectedID] = newState
ic.mu.Unlock()
// Call the toggle callback
if ic.toggleCallback != nil {
ic.toggleCallback(selectedID, !newState) // enable is inverse of disabled
}
}
// GetSelectedIndex returns the current selection index
func (ic *InteractiveController) GetSelectedIndex() int {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.selectedIndex
}
// IsDisabled returns whether a forward is disabled
func (ic *InteractiveController) IsDisabled(id string) bool {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.disabledMap[id]
}
// GetSelectedID returns the ID of the currently selected forward
func (ic *InteractiveController) GetSelectedID() string {
ic.mu.RLock()
defer ic.mu.RUnlock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
return ""
}
return ic.forwardIDs[ic.selectedIndex]
}
// HandleKey processes keyboard input and returns true if should continue
func (ic *InteractiveController) HandleKey(b []byte) bool {
if len(b) == 0 {
return true
}
// Handle single byte keys
if len(b) == 1 {
switch b[0] {
case 'q', 'Q', 3: // q, Q, or Ctrl+C
return false
case ' ', '\r': // Space or Enter to toggle
ic.ToggleSelected()
return true
}
}
// Handle escape sequences (arrow keys)
if len(b) == 3 && b[0] == 27 && b[1] == 91 {
switch b[2] {
case 65: // Up arrow
ic.MoveUp()
case 66: // Down arrow
ic.MoveDown()
}
}
return true
}
+7 -48
View File
@@ -23,10 +23,9 @@ type ForwardStatus struct {
// TableUI manages the terminal table display
type TableUI struct {
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
interactive *InteractiveController
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
}
// NewTableUI creates a new table UI manager
@@ -37,13 +36,6 @@ func NewTableUI(verbose bool) *TableUI {
}
}
// SetInteractiveController sets the interactive controller
func (t *TableUI) SetInteractiveController(ic *InteractiveController) {
t.mu.Lock()
defer t.mu.Unlock()
t.interactive = ic
}
// AddForward registers a new forward for display
func (t *TableUI) AddForward(id string, fwd *config.Forward) {
t.mu.Lock()
@@ -126,27 +118,10 @@ func (t *TableUI) Render() {
}
}
// Update interactive controller with current forward IDs (in display order)
if t.interactive != nil {
ids := make([]string, len(entries))
for i, entry := range entries {
ids[i] = entry.id
}
t.interactive.UpdateForwardsList(ids)
}
// Print each forward
for i, entry := range entries {
for _, entry := range entries {
fwd := entry.fwd
// Check if this row is selected
isSelected := false
isDisabled := false
if t.interactive != nil {
isSelected = (i == t.interactive.GetSelectedIndex())
isDisabled = t.interactive.IsDisabled(entry.id)
}
// Truncate long names
alias := truncate(fwd.Alias, 25)
resource := truncate(fwd.Resource, 25)
@@ -154,8 +129,8 @@ func (t *TableUI) Render() {
// Color code status with indicator
statusStr := formatStatusWithIndicator(fwd.Status)
// Build the row content
rowContent := fmt.Sprintf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s",
// Print the row
fmt.Printf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s\n",
fwd.Context,
fwd.Namespace,
alias,
@@ -164,26 +139,10 @@ func (t *TableUI) Render() {
fwd.RemotePort,
fwd.LocalPort,
statusStr)
// Apply selection highlighting or disabled styling
if isSelected {
// Replace leading spaces with arrow, then apply reverse video to entire line
rowContent = "\033[7m> " + rowContent[2:] + "\033[0m"
} else if isDisabled {
// Apply dimmed styling to entire line
rowContent = "\033[2m" + rowContent + "\033[0m"
}
fmt.Println(rowContent)
}
fmt.Println(strings.Repeat("=", 130))
helpText := "Total forwards: %d | ↑↓: Navigate | Space: Toggle | q: Quit"
if !t.verbose {
fmt.Printf(helpText+"\n", len(t.forwards))
} else {
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
}
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
// In verbose mode, add a newline to separate from logs
if t.verbose {
+222
View File
@@ -0,0 +1,222 @@
package ui
import (
"context"
"fmt"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
const (
k8sAPITimeout = 10 * time.Second
)
// Messages sent from async commands back to the update loop
// ContextsLoadedMsg is sent when contexts have been loaded
type ContextsLoadedMsg struct {
contexts []string
err error
}
// NamespacesLoadedMsg is sent when namespaces have been loaded
type NamespacesLoadedMsg struct {
namespaces []string
err error
}
// PodsLoadedMsg is sent when pods have been loaded
type PodsLoadedMsg struct {
pods []k8s.PodInfo
err error
}
// ServicesLoadedMsg is sent when services have been loaded
type ServicesLoadedMsg struct {
services []k8s.ServiceInfo
err error
}
// SelectorValidatedMsg is sent when a selector has been validated
type SelectorValidatedMsg struct {
valid bool
pods []k8s.PodInfo
err error
}
// PortCheckedMsg is sent when a port's availability has been checked
type PortCheckedMsg struct {
port int
available bool
message string
}
// ForwardSavedMsg is sent when a forward has been saved to config
type ForwardSavedMsg struct {
success bool
err error
}
// ForwardsRemovedMsg is sent when forwards have been removed from config
type ForwardsRemovedMsg struct {
success bool
count int
err error
}
// WizardCompleteMsg signals that the wizard has completed
type WizardCompleteMsg struct{}
// Command functions (return tea.Cmd)
// loadContextsCmd loads available Kubernetes contexts
func loadContextsCmd(discovery *k8s.Discovery) tea.Cmd {
return func() tea.Msg {
contexts, err := discovery.ListContexts()
if err != nil {
return ContextsLoadedMsg{err: err}
}
return ContextsLoadedMsg{contexts: contexts}
}
}
// loadNamespacesCmd loads namespaces for the given context
func loadNamespacesCmd(discovery *k8s.Discovery, contextName string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
namespaces, err := discovery.ListNamespaces(ctx, contextName)
if err != nil {
return NamespacesLoadedMsg{err: err}
}
return NamespacesLoadedMsg{namespaces: namespaces}
}
}
// loadPodsCmd loads pods for the given context and namespace
func loadPodsCmd(discovery *k8s.Discovery, contextName, namespace string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
pods, err := discovery.ListPods(ctx, contextName, namespace)
if err != nil {
return PodsLoadedMsg{err: err}
}
return PodsLoadedMsg{pods: pods}
}
}
// loadServicesCmd loads services for the given context and namespace
func loadServicesCmd(discovery *k8s.Discovery, contextName, namespace string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
services, err := discovery.ListServices(ctx, contextName, namespace)
if err != nil {
return ServicesLoadedMsg{err: err}
}
return ServicesLoadedMsg{services: services}
}
}
// validateSelectorCmd validates a label selector and returns matching pods
func validateSelectorCmd(discovery *k8s.Discovery, contextName, namespace, selector string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
pods, err := discovery.ListPodsWithSelector(ctx, contextName, namespace, selector)
if err != nil {
return SelectorValidatedMsg{valid: false, err: err}
}
return SelectorValidatedMsg{
valid: len(pods) > 0,
pods: pods,
}
}
}
// checkPortCmd checks if a local port is available
func checkPortCmd(port int) tea.Cmd {
return func() tea.Msg {
available, processInfo, err := k8s.CheckPortAvailability(port)
msg := ""
if err != nil {
msg = fmt.Sprintf("✗ Error: %v", err)
} else if available {
msg = fmt.Sprintf("✓ Port %d available", port)
} else {
msg = fmt.Sprintf("✗ Port %d in use by %s", port, processInfo)
}
return PortCheckedMsg{
port: port,
available: available,
message: msg,
}
}
}
// saveForwardCmd saves a new forward to the configuration file
func saveForwardCmd(mutator *config.Mutator, contextName, namespace string, fwd config.Forward) tea.Cmd {
return func() tea.Msg {
err := mutator.AddForward(contextName, namespace, fwd)
return ForwardSavedMsg{
success: err == nil,
err: err,
}
}
}
// updateForwardCmd atomically updates an existing forward (used in edit mode)
func updateForwardCmd(mutator *config.Mutator, oldID, contextName, namespace string, fwd config.Forward) tea.Cmd {
return func() tea.Msg {
err := mutator.UpdateForward(oldID, contextName, namespace, fwd)
return ForwardSavedMsg{
success: err == nil,
err: err,
}
}
}
// removeForwardsCmd removes selected forwards from the configuration file
func removeForwardsCmd(mutator *config.Mutator, forwards []RemovableForward) tea.Cmd {
return func() tea.Msg {
// Create a map of IDs to remove
idsToRemove := make(map[string]bool)
for _, fwd := range forwards {
idsToRemove[fwd.ID] = true
}
// Remove forwards matching the IDs
err := mutator.RemoveForwards(func(ctx, ns string, fwd config.Forward) bool {
return idsToRemove[fwd.ID()]
})
return ForwardsRemovedMsg{
success: err == nil,
count: len(forwards),
err: err,
}
}
}
// removeForwardByIDCmd removes a single forward by its ID
func removeForwardByIDCmd(mutator *config.Mutator, id string) tea.Cmd {
return func() tea.Msg {
err := mutator.RemoveForwardByID(id)
return ForwardsRemovedMsg{
success: err == nil,
count: 1,
err: err,
}
}
}
+811
View File
@@ -0,0 +1,811 @@
package ui
import (
"fmt"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
// isFilterableStep returns true if the step supports search/filter
func isFilterableStep(step AddWizardStep) bool {
switch step {
case StepSelectContext, StepSelectNamespace:
return true
case StepEnterResource:
// Only service selection is filterable (pod prefix and selector are text input)
return true // We'll check resource type in the handler
default:
return false
}
}
// handleMainViewKeys handles keyboard input in the main view
func (m model) handleMainViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// If delete confirmation is showing, handle it separately
if m.ui.deleteConfirming {
return m.handleDeleteConfirmation(msg)
}
switch msg.String() {
case "ctrl+c", "q":
return m, tea.Quit
case "up", "k":
m.ui.moveSelection(-1)
case "down", "j":
m.ui.moveSelection(1)
case " ", "enter":
m.ui.toggleSelected()
case "n": // Enter add wizard
m.ui.mu.Lock()
if m.ui.discovery == nil || m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
m.ui.viewMode = ViewModeAddWizard
m.ui.addWizard = newAddWizardState()
m.ui.addWizard.loading = true
m.ui.mu.Unlock()
// Load contexts
return m, loadContextsCmd(m.ui.discovery)
case "e": // Edit selected forward
m.ui.mu.Lock()
if len(m.ui.forwardOrder) == 0 {
// No forwards to edit
m.ui.mu.Unlock()
return m, nil
}
if m.ui.discovery == nil || m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
// Get the currently selected forward
currentSelectedIndex := m.ui.selectedIndex
if currentSelectedIndex < 0 || currentSelectedIndex >= len(m.ui.forwardOrder) {
m.ui.mu.Unlock()
return m, nil
}
selectedID := m.ui.forwardOrder[currentSelectedIndex]
selectedForward, ok := m.ui.forwards[selectedID]
if !ok {
m.ui.mu.Unlock()
return m, nil
}
// Create an add wizard pre-filled with the current forward's values
m.ui.viewMode = ViewModeAddWizard
m.ui.addWizard = newAddWizardState()
// Pre-fill the wizard with current values
m.ui.addWizard.selectedContext = selectedForward.Context
m.ui.addWizard.selectedNamespace = selectedForward.Namespace
m.ui.addWizard.resourceValue = selectedForward.Resource
m.ui.addWizard.remotePort = selectedForward.RemotePort
m.ui.addWizard.localPort = selectedForward.LocalPort
m.ui.addWizard.alias = selectedForward.Alias
// Determine resource type from the resource string
if strings.HasPrefix(selectedForward.Type, "service") {
m.ui.addWizard.selectedResourceType = ResourceTypeService
} else {
m.ui.addWizard.selectedResourceType = ResourceTypePodPrefix
}
// Mark as edit mode and store original ID
m.ui.addWizard.isEditing = true
m.ui.addWizard.originalID = selectedID
// Start at the remote port step (skip context/namespace/resource selection)
m.ui.addWizard.step = StepEnterRemotePort
// Load resources to detect ports
m.ui.addWizard.loading = true
m.ui.mu.Unlock()
// Load pods or services to detect available ports
if m.ui.addWizard.selectedResourceType == ResourceTypeService {
return m, loadServicesCmd(m.ui.discovery, selectedForward.Context, selectedForward.Namespace)
}
return m, loadPodsCmd(m.ui.discovery, selectedForward.Context, selectedForward.Namespace)
case "d": // Delete currently selected forward - show confirmation
m.ui.mu.Lock()
if len(m.ui.forwardOrder) == 0 {
// No forwards to delete
m.ui.mu.Unlock()
return m, nil
}
if m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
// Get the currently selected forward
currentSelectedIndex := m.ui.selectedIndex
if currentSelectedIndex < 0 || currentSelectedIndex >= len(m.ui.forwardOrder) {
m.ui.mu.Unlock()
return m, nil
}
selectedID := m.ui.forwardOrder[currentSelectedIndex]
selectedForward, ok := m.ui.forwards[selectedID]
if !ok {
m.ui.mu.Unlock()
return m, nil
}
// Show confirmation dialog
m.ui.deleteConfirming = true
m.ui.deleteConfirmID = selectedID
m.ui.deleteConfirmAlias = selectedForward.Alias
m.ui.deleteConfirmCursor = 0 // Default to "No" for safety
m.ui.mu.Unlock()
return m, nil
}
return m, nil
}
// handleDeleteConfirmation handles keyboard input for delete confirmation dialog
func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
switch msg.String() {
case "ctrl+c", "esc":
// Cancel deletion
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.mu.Unlock()
// Force a repaint by returning the model
return m, tea.ClearScreen
case "left", "h", "right", "l":
// Toggle between Yes/No
m.ui.deleteConfirmCursor = 1 - m.ui.deleteConfirmCursor
m.ui.mu.Unlock()
return m, nil
case "enter", "y":
// Confirm deletion (either Enter on Yes or pressing 'y')
if m.ui.deleteConfirmCursor == 0 || msg.String() == "y" {
id := m.ui.deleteConfirmID
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.mu.Unlock()
return m, removeForwardByIDCmd(m.ui.mutator, id)
}
// Enter on No = cancel
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.mu.Unlock()
return m, tea.ClearScreen
case "n":
// Quick 'n' for no
m.ui.deleteConfirming = false
m.ui.deleteConfirmID = ""
m.ui.deleteConfirmAlias = ""
m.ui.deleteConfirmCursor = 0 // Reset cursor
m.ui.mu.Unlock()
return m, tea.ClearScreen
}
m.ui.mu.Unlock()
return m, nil
}
// handleAddWizardKeys handles keyboard input in the add wizard
func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
wizard := m.ui.addWizard
if wizard == nil {
return m, nil
}
switch msg.String() {
case "ctrl+c":
// Hard cancel
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
case "esc":
// If there's an active search filter, clear it instead of going back
if wizard.searchFilter != "" && isFilterableStep(wizard.step) {
wizard.clearSearchFilter()
return m, nil
}
// In edit mode, Esc always cancels (don't navigate back through skipped steps)
if wizard.isEditing {
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
// In add mode, go back or cancel
if wizard.step == StepSelectContext {
// On first step, cancel entirely
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
} else {
// Go back one step
wizard.step--
wizard.cursor = 0
wizard.clearTextInput()
wizard.clearSearchFilter()
wizard.error = nil
// Reset input mode based on the step we're going back to
switch wizard.step {
case StepSelectContext, StepSelectNamespace, StepSelectResourceType:
wizard.inputMode = InputModeList
case StepEnterResource:
if wizard.selectedResourceType == ResourceTypeService {
wizard.inputMode = InputModeList
} else {
wizard.inputMode = InputModeText
}
case StepEnterRemotePort, StepEnterLocalPort:
wizard.inputMode = InputModeText
case StepConfirmation:
wizard.inputMode = InputModeList
}
}
return m, nil
case "up", "k":
// In confirmation step, toggle between alias and buttons
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusButtons {
wizard.confirmationFocus = FocusAlias
}
} else {
wizard.moveCursor(-1)
}
case "down", "j":
// In confirmation step, toggle between alias and buttons
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
} else {
wizard.moveCursor(1) // Navigate between buttons
}
} else {
wizard.moveCursor(1)
}
case "tab":
// Tab moves between alias field and buttons in confirmation
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
} else {
wizard.confirmationFocus = FocusAlias
}
}
case "enter":
return m.handleAddWizardEnter()
case "backspace":
// Allow backspace in text input mode OR when focused on alias in confirmation OR when filtering
canBackspace := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step) && len(wizard.searchFilter) > 0)
if canBackspace {
if isFilterableStep(wizard.step) && wizard.inputMode == InputModeList && len(wizard.searchFilter) > 0 {
// Backspace in search filter
wizard.searchFilter = wizard.searchFilter[:len(wizard.searchFilter)-1]
wizard.cursor = 0
wizard.scrollOffset = 0
} else if len(wizard.textInput) > 0 {
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
}
}
default:
// Handle text input
canTypeText := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step))
if canTypeText && len(msg.String()) == 1 {
// If in list mode on filterable step, add to search filter instead of textInput
if wizard.inputMode == InputModeList && isFilterableStep(wizard.step) {
char := rune(msg.String()[0])
// Only allow printable characters
if char >= 32 && char < 127 {
wizard.searchFilter += string(char)
wizard.cursor = 0
wizard.scrollOffset = 0
}
} else {
wizard.handleTextInput(rune(msg.String()[0]))
// Trigger validation for selector
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
if len(wizard.textInput) > 0 {
wizard.loading = true
wizard.error = nil
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
}
}
}
}
}
return m, nil
}
// handleAddWizardEnter handles Enter key in the add wizard
func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
wizard := m.ui.addWizard
switch wizard.step {
case StepSelectContext:
filteredContexts := wizard.getFilteredContexts()
if wizard.cursor >= 0 && wizard.cursor < len(filteredContexts) {
wizard.selectedContext = filteredContexts[wizard.cursor]
wizard.step = StepSelectNamespace
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.loading = true
return m, loadNamespacesCmd(m.ui.discovery, wizard.selectedContext)
}
case StepSelectNamespace:
filteredNamespaces := wizard.getFilteredNamespaces()
if wizard.cursor >= 0 && wizard.cursor < len(filteredNamespaces) {
wizard.selectedNamespace = filteredNamespaces[wizard.cursor]
wizard.step = StepSelectResourceType
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.inputMode = InputModeList
}
case StepSelectResourceType:
if wizard.cursor >= 0 && wizard.cursor < 3 {
wizard.selectedResourceType = ResourceType(wizard.cursor)
wizard.step = StepEnterResource
wizard.cursor = 0
if wizard.selectedResourceType == ResourceTypeService {
wizard.inputMode = InputModeList
wizard.loading = true
return m, loadServicesCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace)
} else {
wizard.inputMode = InputModeText
wizard.loading = true
return m, loadPodsCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace)
}
}
case StepEnterResource:
switch wizard.selectedResourceType {
case ResourceTypePodPrefix:
if wizard.textInput != "" {
wizard.resourceValue = wizard.textInput
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
// Detect ports from matching pods
wizard.detectedPorts = k8s.GetUniquePorts(wizard.pods)
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
case ResourceTypePodSelector:
if wizard.textInput != "" && len(wizard.matchingPods) > 0 {
wizard.resourceValue = "pod"
wizard.selector = wizard.textInput
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
// Detect ports from matching pods
wizard.detectedPorts = k8s.GetUniquePorts(wizard.matchingPods)
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
case ResourceTypeService:
filteredServices := wizard.getFilteredServices()
if wizard.cursor >= 0 && wizard.cursor < len(filteredServices) {
wizard.resourceValue = filteredServices[wizard.cursor].Name
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
wizard.clearSearchFilter()
// Get ports from selected service
wizard.detectedPorts = filteredServices[wizard.cursor].Ports
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
}
case StepEnterRemotePort:
if wizard.inputMode == InputModeList && len(wizard.detectedPorts) > 0 {
// List mode - user selected from detected ports
if wizard.cursor == len(wizard.detectedPorts) {
// Selected "Manual entry" option
wizard.inputMode = InputModeText
wizard.clearTextInput()
} else if wizard.cursor >= 0 && wizard.cursor < len(wizard.detectedPorts) {
// Selected a detected port
wizard.remotePort = int(wizard.detectedPorts[wizard.cursor].Port)
wizard.step = StepEnterLocalPort
wizard.clearTextInput()
wizard.inputMode = InputModeText
wizard.error = nil
}
} else {
// Text mode - manual entry
port, err := strconv.Atoi(wizard.textInput)
if err != nil || port < 1 || port > 65535 {
wizard.error = fmt.Errorf("invalid port number")
} else {
wizard.remotePort = port
wizard.step = StepEnterLocalPort
wizard.clearTextInput()
wizard.error = nil
}
}
case StepEnterLocalPort:
port, err := strconv.Atoi(wizard.textInput)
if err != nil || port < 1 || port > 65535 {
wizard.error = fmt.Errorf("invalid port number")
} else {
wizard.localPort = port
wizard.step = StepConfirmation
wizard.clearTextInput()
wizard.cursor = 0
wizard.inputMode = InputModeList
wizard.error = nil
wizard.loading = true
return m, checkPortCmd(port)
}
case StepConfirmation:
// If focused on alias field, move to buttons
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
return m, nil
}
// Handle button selection
if wizard.cursor == 0 {
// Confirmed - save the forward
wizard.alias = wizard.textInput
// Build the forward config
fwd := config.Forward{
Protocol: "tcp",
Port: wizard.remotePort,
LocalPort: wizard.localPort,
Alias: wizard.alias,
}
if wizard.selectedResourceType == ResourceTypePodPrefix {
fwd.Resource = "pod/" + wizard.resourceValue
} else if wizard.selectedResourceType == ResourceTypePodSelector {
fwd.Resource = wizard.resourceValue
fwd.Selector = wizard.selector
} else if wizard.selectedResourceType == ResourceTypeService {
fwd.Resource = "service/" + wizard.resourceValue
}
wizard.loading = true
// If editing, use atomic update operation
if wizard.isEditing {
return m, updateForwardCmd(m.ui.mutator, wizard.originalID, wizard.selectedContext, wizard.selectedNamespace, fwd)
}
return m, saveForwardCmd(m.ui.mutator, wizard.selectedContext, wizard.selectedNamespace, fwd)
} else {
// Cancelled
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
}
case StepSuccess:
if wizard.cursor == 0 {
// Add another
m.ui.addWizard = newAddWizardState()
m.ui.addWizard.loading = true
return m, loadContextsCmd(m.ui.discovery)
} else {
// Return to main view
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
}
}
return m, nil
}
// handleRemoveWizardKeys handles keyboard input in the remove wizard
func (m model) handleRemoveWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
wizard := m.ui.removeWizard
if wizard == nil {
return m, nil
}
switch msg.String() {
case "ctrl+c":
// Hard cancel - always exit
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
return m, tea.ClearScreen
case "esc":
if wizard.confirming {
// In confirmation mode, Esc confirms the removal (same as pressing Yes)
selectedForwards := wizard.getSelectedForwards()
return m, removeForwardsCmd(m.ui.mutator, selectedForwards)
} else {
// Not confirming yet - cancel entirely
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
}
return m, tea.ClearScreen
case "up", "k":
wizard.moveCursor(-1)
case "down", "j":
wizard.moveCursor(1)
case " ":
if !wizard.confirming {
wizard.toggleSelection()
}
case "a":
wizard.selectAll()
case "n":
wizard.selectNone()
case "enter":
if !wizard.confirming {
if wizard.getSelectedCount() == 0 {
// Nothing selected
return m, nil
}
// Show confirmation
wizard.confirming = true
wizard.confirmCursor = 0
} else {
// Confirmed
if wizard.confirmCursor == 0 {
// Yes, remove
selectedForwards := wizard.getSelectedForwards()
return m, removeForwardsCmd(m.ui.mutator, selectedForwards)
} else {
// No, cancel
wizard.confirming = false
}
}
}
return m, nil
}
// Message handlers
func (m model) handleContextsLoaded(msg ContextsLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
// Get current context and move it to the top
currentCtx, err := m.ui.discovery.GetCurrentContext()
if err == nil && currentCtx != "" {
// Reorder contexts with current first
reordered := []string{currentCtx}
for _, ctx := range msg.contexts {
if ctx != currentCtx {
reordered = append(reordered, ctx)
}
}
m.ui.addWizard.contexts = reordered
} else {
m.ui.addWizard.contexts = msg.contexts
}
}
}
return m, nil
}
func (m model) handleNamespacesLoaded(msg NamespacesLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.namespaces = msg.namespaces
}
}
return m, nil
}
func (m model) handlePodsLoaded(msg PodsLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.pods = msg.pods
// If we're at the remote port step (edit mode), detect ports now
if m.ui.addWizard.step == StepEnterRemotePort {
m.ui.addWizard.detectedPorts = k8s.GetUniquePorts(msg.pods)
if len(m.ui.addWizard.detectedPorts) > 0 {
m.ui.addWizard.inputMode = InputModeList
m.ui.addWizard.cursor = 0
} else {
m.ui.addWizard.inputMode = InputModeText
m.ui.addWizard.textInput = fmt.Sprintf("%d", m.ui.addWizard.remotePort)
}
}
}
}
return m, nil
}
func (m model) handleServicesLoaded(msg ServicesLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.services = msg.services
// If we're at the remote port step (edit mode), detect ports now
if m.ui.addWizard.step == StepEnterRemotePort {
// Find the service by name
for _, svc := range msg.services {
if svc.Name == m.ui.addWizard.resourceValue {
m.ui.addWizard.detectedPorts = svc.Ports
if len(m.ui.addWizard.detectedPorts) > 0 {
m.ui.addWizard.inputMode = InputModeList
m.ui.addWizard.cursor = 0
} else {
m.ui.addWizard.inputMode = InputModeText
m.ui.addWizard.textInput = fmt.Sprintf("%d", m.ui.addWizard.remotePort)
}
break
}
}
}
}
}
return m, nil
}
func (m model) handleSelectorValidated(msg SelectorValidatedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.valid {
m.ui.addWizard.matchingPods = msg.pods
} else {
m.ui.addWizard.matchingPods = nil
}
}
return m, nil
}
func (m model) handlePortChecked(msg PortCheckedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.portAvailable = msg.available
m.ui.addWizard.portCheckMsg = msg.message
}
return m, nil
}
func (m model) handleForwardSaved(msg ForwardSavedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
if msg.success {
// Move to success step
m.ui.addWizard.step = StepSuccess
m.ui.addWizard.cursor = 0
m.ui.addWizard.inputMode = InputModeList
} else {
m.ui.addWizard.error = msg.err
}
}
return m, nil
}
func (m model) handleForwardsRemoved(msg ForwardsRemovedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
// Delete now happens directly without wizard
// Just ensure we're back in main view
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
// If there was an error, it will be logged but we don't show it in UI for now
// The config watcher will either reload (success) or keep old config (failure)
return m, nil
}
+365
View File
@@ -0,0 +1,365 @@
package ui
import (
"strings"
"github.com/nvm/kportal/internal/k8s"
)
// filterStrings filters a slice of strings by a search filter (case-insensitive substring match)
func filterStrings(items []string, filter string) []string {
if filter == "" {
return items
}
filtered := []string{}
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item), filterLower) {
filtered = append(filtered, item)
}
}
return filtered
}
// matchesFilter checks if a string matches the filter (case-insensitive substring match)
func matchesFilter(item, filter string) bool {
if filter == "" {
return true
}
return strings.Contains(strings.ToLower(item), strings.ToLower(filter))
}
// ViewMode represents the current view state of the UI
type ViewMode int
const (
ViewModeMain ViewMode = iota
ViewModeAddWizard
ViewModeRemoveWizard
)
// InputMode represents whether the wizard is in list selection or text input mode
type InputMode int
const (
InputModeList InputMode = iota
InputModeText
)
// AddWizardStep represents the current step in the add wizard flow
type AddWizardStep int
const (
StepSelectContext AddWizardStep = iota
StepSelectNamespace
StepSelectResourceType
StepEnterResource
StepEnterRemotePort
StepEnterLocalPort
StepConfirmation
StepSuccess
)
// ConfirmationFocus represents what the user is focused on in confirmation step
type ConfirmationFocus int
const (
FocusAlias ConfirmationFocus = iota
FocusButtons
)
// ResourceType represents the type of Kubernetes resource to forward to
type ResourceType int
const (
ResourceTypePodPrefix ResourceType = iota
ResourceTypePodSelector
ResourceTypeService
)
// String returns a human-readable name for the resource type
func (r ResourceType) String() string {
switch r {
case ResourceTypePodPrefix:
return "Pod (by name prefix)"
case ResourceTypePodSelector:
return "Pod (by label selector)"
case ResourceTypeService:
return "Service"
default:
return "Unknown"
}
}
// Description returns a description of the resource type
func (r ResourceType) Description() string {
switch r {
case ResourceTypePodPrefix:
return "Recommended for specific pod instances"
case ResourceTypePodSelector:
return "Flexible, survives pod restarts automatically"
case ResourceTypeService:
return "Most stable, load-balanced"
default:
return ""
}
}
// AddWizardState maintains the state for the add port forward wizard
type AddWizardState struct {
step AddWizardStep
inputMode InputMode
cursor int
scrollOffset int // For scrolling long lists
textInput string
searchFilter string // For filtering lists (contexts, namespaces, services)
loading bool
error error
// Selections made by user
selectedContext string
selectedNamespace string
selectedResourceType ResourceType
resourceValue string // pod prefix or service name
selector string // for pod selector type
remotePort int
localPort int
alias string
// Available options (loaded asynchronously from k8s)
contexts []string
namespaces []string
pods []k8s.PodInfo
services []k8s.ServiceInfo
// Validation state
portAvailable bool
portCheckMsg string
matchingPods []k8s.PodInfo
// Edit mode
isEditing bool
originalID string // ID of the forward being edited
// Detected ports from resources
detectedPorts []k8s.PortInfo
// Confirmation focus (alias field vs buttons)
confirmationFocus ConfirmationFocus
}
// newAddWizardState creates a new add wizard state initialized to the first step
func newAddWizardState() *AddWizardState {
return &AddWizardState{
step: StepSelectContext,
inputMode: InputModeList,
cursor: 0,
contexts: []string{},
}
}
// moveCursor moves the cursor up or down in list selection mode
func (w *AddWizardState) moveCursor(delta int) {
if w.inputMode != InputModeList {
return
}
var maxItems int
switch w.step {
case StepSelectContext:
maxItems = len(w.getFilteredContexts())
case StepSelectNamespace:
maxItems = len(w.getFilteredNamespaces())
case StepSelectResourceType:
maxItems = 3 // Three resource types
case StepEnterResource:
if w.selectedResourceType == ResourceTypeService {
maxItems = len(w.getFilteredServices())
}
case StepEnterRemotePort:
if len(w.detectedPorts) > 0 {
maxItems = len(w.detectedPorts) + 1 // +1 for "Manual entry" option
}
}
w.cursor += delta
if w.cursor < 0 {
w.cursor = 0
}
if w.cursor >= maxItems && maxItems > 0 {
w.cursor = maxItems - 1
}
// Adjust scroll offset to keep cursor visible
// Viewport shows max 20 items at a time
const viewportHeight = 20
// If cursor moved below visible area, scroll down
if w.cursor >= w.scrollOffset+viewportHeight {
w.scrollOffset = w.cursor - viewportHeight + 1
}
// If cursor moved above visible area, scroll up
if w.cursor < w.scrollOffset {
w.scrollOffset = w.cursor
}
// Ensure scroll offset is valid
if w.scrollOffset < 0 {
w.scrollOffset = 0
}
}
// handleTextInput handles a single character input in text mode
func (w *AddWizardState) handleTextInput(char rune) {
// Note: Caller already checks if text input is allowed (inputMode or confirmation step)
// so we don't need to check inputMode here
// Handle backspace
if char == 127 || char == 8 {
if len(w.textInput) > 0 {
w.textInput = w.textInput[:len(w.textInput)-1]
}
return
}
// Only allow printable characters
if char >= 32 && char < 127 {
w.textInput += string(char)
}
}
// clearTextInput clears the text input field
func (w *AddWizardState) clearTextInput() {
w.textInput = ""
}
// RemoveWizardState maintains the state for the remove port forward wizard
type RemoveWizardState struct {
forwards []RemovableForward
cursor int
selected map[int]bool
confirming bool
confirmCursor int // 0 = Yes, 1 = No
}
// RemovableForward represents a forward that can be removed
type RemovableForward struct {
ID string
Context string
Namespace string
Alias string
Resource string
Selector string
Port int
LocalPort int
}
// moveCursor moves the cursor up or down
func (w *RemoveWizardState) moveCursor(delta int) {
if w.confirming {
// Move between Yes/No in confirmation
w.confirmCursor += delta
if w.confirmCursor < 0 {
w.confirmCursor = 0
}
if w.confirmCursor > 1 {
w.confirmCursor = 1
}
} else {
// Move between forwards
w.cursor += delta
if w.cursor < 0 {
w.cursor = 0
}
if w.cursor >= len(w.forwards) {
w.cursor = len(w.forwards) - 1
}
}
}
// toggleSelection toggles the selection of the current forward
func (w *RemoveWizardState) toggleSelection() {
if w.confirming {
return
}
w.selected[w.cursor] = !w.selected[w.cursor]
}
// selectAll selects all forwards for removal
func (w *RemoveWizardState) selectAll() {
if w.confirming {
return
}
for i := range w.forwards {
w.selected[i] = true
}
}
// selectNone deselects all forwards
func (w *RemoveWizardState) selectNone() {
if w.confirming {
return
}
w.selected = make(map[int]bool)
}
// getSelectedCount returns the number of selected forwards
func (w *RemoveWizardState) getSelectedCount() int {
count := 0
for _, selected := range w.selected {
if selected {
count++
}
}
return count
}
// getSelectedForwards returns a list of selected forwards
func (w *RemoveWizardState) getSelectedForwards() []RemovableForward {
selected := make([]RemovableForward, 0)
for i, fwd := range w.forwards {
if w.selected[i] {
selected = append(selected, fwd)
}
}
return selected
}
// getFilteredContexts returns contexts filtered by search string
func (w *AddWizardState) getFilteredContexts() []string {
if w.searchFilter == "" {
return w.contexts
}
return filterStrings(w.contexts, w.searchFilter)
}
// getFilteredNamespaces returns namespaces filtered by search string
func (w *AddWizardState) getFilteredNamespaces() []string {
if w.searchFilter == "" {
return w.namespaces
}
return filterStrings(w.namespaces, w.searchFilter)
}
// getFilteredServices returns services filtered by search string
func (w *AddWizardState) getFilteredServices() []k8s.ServiceInfo {
if w.searchFilter == "" {
return w.services
}
filtered := []k8s.ServiceInfo{}
for _, svc := range w.services {
if matchesFilter(svc.Name, w.searchFilter) {
filtered = append(filtered, svc)
}
}
return filtered
}
// clearSearchFilter clears the search filter and resets cursor/scroll
func (w *AddWizardState) clearSearchFilter() {
w.searchFilter = ""
w.cursor = 0
w.scrollOffset = 0
}
+350
View File
@@ -0,0 +1,350 @@
package ui
import (
"testing"
"github.com/nvm/kportal/internal/k8s"
"github.com/stretchr/testify/assert"
)
func TestFilterStrings(t *testing.T) {
tests := []struct {
name string
items []string
filter string
expected []string
}{
{
name: "empty filter returns all items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "",
expected: []string{"namespace-1", "namespace-2", "namespace-3"},
},
{
name: "filter matches multiple items",
items: []string{"prod-api", "prod-db", "staging-api", "dev-api"},
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter matches single item",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "2",
expected: []string{"namespace-2"},
},
{
name: "filter matches no items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "xyz",
expected: []string{},
},
{
name: "case insensitive matching",
items: []string{"Production", "Staging", "Development"},
filter: "prod",
expected: []string{"Production"},
},
{
name: "partial string matching",
items: []string{"my-app-frontend", "my-app-backend", "other-service"},
filter: "app",
expected: []string{"my-app-frontend", "my-app-backend"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterStrings(tt.items, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMatchesFilter(t *testing.T) {
tests := []struct {
name string
item string
filter string
expected bool
}{
{
name: "empty filter matches everything",
item: "namespace-1",
filter: "",
expected: true,
},
{
name: "exact match",
item: "namespace-1",
filter: "namespace-1",
expected: true,
},
{
name: "partial match",
item: "production-api",
filter: "prod",
expected: true,
},
{
name: "no match",
item: "namespace-1",
filter: "xyz",
expected: false,
},
{
name: "case insensitive match",
item: "Production",
filter: "prod",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchesFilter(tt.item, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredContexts(t *testing.T) {
wizard := &AddWizardState{
contexts: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-cluster"},
},
{
name: "filter by 'cluster'",
filter: "cluster",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'staging'",
filter: "staging",
expected: []string{"staging-cluster"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredContexts()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredNamespaces(t *testing.T) {
wizard := &AddWizardState{
namespaces: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter by 'kube'",
filter: "kube",
expected: []string{"kube-system", "kube-public"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"prod-api", "staging-api"},
},
{
name: "filter by 'ing' (partial match)",
filter: "ing",
expected: []string{"staging-api", "staging-db", "monitoring", "logging"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredNamespaces()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredServices(t *testing.T) {
wizard := &AddWizardState{
services: []k8s.ServiceInfo{
{Name: "api-gateway"},
{Name: "api-backend"},
{Name: "database"},
{Name: "redis-cache"},
{Name: "postgres-db"},
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"api-gateway", "api-backend", "database", "redis-cache", "postgres-db"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"api-gateway", "api-backend"},
},
{
name: "filter by 'db'",
filter: "db",
expected: []string{"postgres-db"},
},
{
name: "filter by 'base'",
filter: "base",
expected: []string{"database"},
},
{
name: "filter by 'redis'",
filter: "redis",
expected: []string{"redis-cache"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredServices()
resultNames := make([]string, len(result))
for i, svc := range result {
resultNames[i] = svc.Name
}
assert.Equal(t, tt.expected, resultNames)
})
}
}
func TestClearSearchFilter(t *testing.T) {
wizard := &AddWizardState{
searchFilter: "test",
cursor: 5,
scrollOffset: 10,
}
wizard.clearSearchFilter()
assert.Equal(t, "", wizard.searchFilter, "searchFilter should be cleared")
assert.Equal(t, 0, wizard.cursor, "cursor should be reset to 0")
assert.Equal(t, 0, wizard.scrollOffset, "scrollOffset should be reset to 0")
}
func TestMoveCursorWithFilteredLists(t *testing.T) {
tests := []struct {
name string
step AddWizardStep
contexts []string
namespaces []string
searchFilter string
initialCursor int
delta int
expectedCursor int
}{
{
name: "move down in filtered contexts",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 0,
delta: 1,
expectedCursor: 1,
},
{
name: "cannot move beyond filtered list",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 1,
delta: 1,
expectedCursor: 1, // Should stay at 1 (last item in filtered list)
},
{
name: "move up in filtered list",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3", "other"},
searchFilter: "ns",
initialCursor: 2,
delta: -1,
expectedCursor: 1,
},
{
name: "cannot move above 0",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3"},
searchFilter: "ns",
initialCursor: 0,
delta: -1,
expectedCursor: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard := &AddWizardState{
step: tt.step,
inputMode: InputModeList,
cursor: tt.initialCursor,
contexts: tt.contexts,
namespaces: tt.namespaces,
searchFilter: tt.searchFilter,
}
wizard.moveCursor(tt.delta)
assert.Equal(t, tt.expectedCursor, wizard.cursor)
})
}
}
+211
View File
@@ -0,0 +1,211 @@
package ui
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
// Color palette for wizards
var (
primaryColor = lipgloss.Color("205") // Pink/Magenta
successColor = lipgloss.Color("42") // Green
errorColor = lipgloss.Color("196") // Red
warningColor = lipgloss.Color("220") // Yellow
mutedColor = lipgloss.Color("241") // Gray
accentColor = lipgloss.Color("63") // Purple
highlightColor = lipgloss.Color("117") // Light blue
)
// Text styles
var (
wizardHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(primaryColor).
MarginBottom(0)
wizardStepStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true)
breadcrumbStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
selectedStyle = lipgloss.NewStyle().
Foreground(primaryColor).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
warningStyle = lipgloss.NewStyle().
Foreground(warningColor).
Bold(true)
mutedStyle = lipgloss.NewStyle().
Foreground(mutedColor)
helpStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true)
spinnerStyle = lipgloss.NewStyle().
Foreground(accentColor).
Bold(true)
)
// Input styles
var (
inputStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
validInputStyle = lipgloss.NewStyle().
Foreground(successColor)
)
// Checkbox styles
var (
checkedBoxStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
uncheckedBoxStyle = lipgloss.NewStyle().
Foreground(mutedColor)
)
// Container styles
var (
wizardBoxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor).
Padding(1, 2).
Width(60)
)
// Helper functions for rendering
// renderProgress returns a step indicator like "Step 2/7"
func renderProgress(current, total int) string {
return wizardStepStyle.Render(fmt.Sprintf("Step %d/%d", current, total))
}
// renderHeader returns a formatted header with title and progress
func renderHeader(title, progress string) string {
header := wizardHeaderStyle.Render(title)
if progress != "" {
header += " " + progress
}
return header + "\n\n"
}
// renderBreadcrumb returns a formatted breadcrumb path
func renderBreadcrumb(parts ...string) string {
return breadcrumbStyle.Render(strings.Join(parts, " / "))
}
// renderList renders a list of items with cursor selection and viewport scrolling
func renderList(items []string, cursor int, prefix string, scrollOffset int) string {
var b strings.Builder
const viewportHeight = 20
totalItems := len(items)
// Show scroll up indicator if there are items above the viewport
if scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := scrollOffset
end := scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render visible items
for i := start; i < end; i++ {
cursorPrefix := prefix
if i == cursor {
cursorPrefix = "▸ "
b.WriteString(selectedStyle.Render(cursorPrefix + items[i]))
} else {
b.WriteString(cursorPrefix + items[i])
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
return b.String()
}
// renderTextInput renders a text input field with a cursor
func renderTextInput(label, value string, valid bool) string {
var b strings.Builder
b.WriteString(label)
inputText := value + "█"
if valid {
b.WriteString(validInputStyle.Render(inputText))
} else {
b.WriteString(inputStyle.Render(inputText))
}
return b.String()
}
// overlayContent overlays modal content centered on the base view
func overlayContent(base, modal string, termWidth, termHeight int) string {
baseLines := strings.Split(base, "\n")
modalLines := strings.Split(modal, "\n")
// Ensure base has enough lines
for len(baseLines) < termHeight {
baseLines = append(baseLines, "")
}
modalHeight := len(modalLines)
modalWidth := 0
for _, line := range modalLines {
w := lipgloss.Width(line)
if w > modalWidth {
modalWidth = w
}
}
// Calculate center position
startRow := (termHeight - modalHeight) / 2
if startRow < 0 {
startRow = 0
}
// Create result with modal overlaid
result := make([]string, len(baseLines))
copy(result, baseLines)
for i, modalLine := range modalLines {
row := startRow + i
if row >= 0 && row < len(result) {
// Center the modal line
padding := (termWidth - lipgloss.Width(modalLine)) / 2
if padding < 0 {
padding = 0
}
result[row] = strings.Repeat(" ", padding) + modalLine
}
}
return strings.Join(result, "\n")
}
+653
View File
@@ -0,0 +1,653 @@
package ui
import (
"fmt"
"strings"
)
// renderAddWizard renders the appropriate step of the add wizard
func (m model) renderAddWizard() string {
if m.ui.addWizard == nil {
return ""
}
wizard := m.ui.addWizard
var content string
switch wizard.step {
case StepSelectContext:
content = m.renderSelectContext()
case StepSelectNamespace:
content = m.renderSelectNamespace()
case StepSelectResourceType:
content = m.renderSelectResourceType()
case StepEnterResource:
content = m.renderEnterResource()
case StepEnterRemotePort:
content = m.renderEnterRemotePort()
case StepEnterLocalPort:
content = m.renderEnterLocalPort()
case StepConfirmation:
content = m.renderConfirmation()
case StepSuccess:
content = m.renderSuccess()
default:
content = "Unknown step"
}
return wizardBoxStyle.Render(content)
}
func (m model) renderSelectContext() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(1, 7)))
b.WriteString("Select Kubernetes Context:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading contexts..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Error: %v", wizard.error)))
} else if len(wizard.contexts) == 0 {
b.WriteString(mutedStyle.Render("No contexts found in kubeconfig"))
} else {
filteredContexts := wizard.getFilteredContexts()
if len(filteredContexts) == 0 {
b.WriteString(mutedStyle.Render("No matching contexts"))
} else {
const viewportHeight = 20
totalItems := len(filteredContexts)
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render visible contexts with (current) marker on first one (only if not filtered)
for i := start; i < end; i++ {
prefix := " "
text := filteredContexts[i]
// Only show (current) marker if no filter and this is the first item in original list
if wizard.searchFilter == "" && i == 0 {
text += mutedStyle.Render(" (current)")
}
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + text))
} else {
b.WriteString(prefix + text)
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
}
}
b.WriteString("\n")
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Cancel", len(wizard.getFilteredContexts()), len(wizard.contexts))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderSelectNamespace() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(2, 7)))
b.WriteString(fmt.Sprintf("Context: %s\n\n", breadcrumbStyle.Render(wizard.selectedContext)))
b.WriteString("Select Namespace:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading namespaces..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Error: %v\n", wizard.error)))
b.WriteString(mutedStyle.Render("\nCluster may be unreachable. Check context."))
} else if len(wizard.namespaces) == 0 {
b.WriteString(mutedStyle.Render("No namespaces found"))
} else {
filteredNamespaces := wizard.getFilteredNamespaces()
if len(filteredNamespaces) == 0 {
b.WriteString(mutedStyle.Render("No matching namespaces"))
} else {
b.WriteString(renderList(filteredNamespaces, wizard.cursor, " ", wizard.scrollOffset))
}
}
b.WriteString("\n")
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredNamespaces()), len(wizard.namespaces))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderSelectResourceType() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(3, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n\n")
b.WriteString("Select Resource Type:\n\n")
resourceTypes := []ResourceType{
ResourceTypePodPrefix,
ResourceTypePodSelector,
ResourceTypeService,
}
for i, rt := range resourceTypes {
prefix := " "
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + rt.String()))
b.WriteString("\n")
b.WriteString(mutedStyle.Render(" " + rt.Description()))
} else {
b.WriteString(prefix + rt.String())
}
b.WriteString("\n")
if i < len(resourceTypes)-1 {
b.WriteString("\n")
}
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
return b.String()
}
func (m model) renderEnterResource() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(4, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n\n")
switch wizard.selectedResourceType {
case ResourceTypePodPrefix:
b.WriteString("Enter pod name prefix:\n\n")
// Show running pods for reference
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading pods..."))
} else if len(wizard.pods) > 0 {
b.WriteString(mutedStyle.Render("Running pods:\n"))
showCount := 0
for _, pod := range wizard.pods {
if strings.HasPrefix(pod.Name, wizard.textInput) || wizard.textInput == "" {
if showCount < 5 { // Limit to 5 pods
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", pod.Name)))
showCount++
}
}
}
if showCount == 0 && wizard.textInput != "" {
b.WriteString(mutedStyle.Render(" (no matching pods)\n"))
} else if len(wizard.pods) > showCount {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" ... and %d more\n", len(wizard.pods)-showCount)))
}
b.WriteString("\n")
}
// Text input
b.WriteString(renderTextInput("Prefix: ", wizard.textInput, true))
b.WriteString("\n\n")
// Show match count
if wizard.textInput != "" {
matchCount := 0
for _, pod := range wizard.pods {
if strings.HasPrefix(pod.Name, wizard.textInput) {
matchCount++
}
}
if matchCount > 0 {
b.WriteString(successStyle.Render(fmt.Sprintf("✓ Matches %d pod(s)", matchCount)))
} else {
b.WriteString(warningStyle.Render("⚠ No matching pods (you can still proceed)"))
}
}
case ResourceTypePodSelector:
b.WriteString("Enter label selector:\n")
b.WriteString(mutedStyle.Render("Format: key=value,key2=value2\n\n"))
b.WriteString(renderTextInput("Selector: ", wizard.textInput, true))
b.WriteString("\n\n")
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Validating selector..."))
} else if len(wizard.matchingPods) > 0 {
b.WriteString(successStyle.Render(fmt.Sprintf("✓ Found %d matching pod(s):\n", len(wizard.matchingPods))))
showCount := 0
for _, pod := range wizard.matchingPods {
if showCount < 3 {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", pod.Name)))
showCount++
}
}
if len(wizard.matchingPods) > 3 {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" ... and %d more\n", len(wizard.matchingPods)-3)))
}
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Invalid selector: %v", wizard.error)))
}
case ResourceTypeService:
b.WriteString("Select service:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading services..."))
} else if len(wizard.services) == 0 {
b.WriteString(mutedStyle.Render("No services found"))
} else {
filteredServices := wizard.getFilteredServices()
if len(filteredServices) == 0 {
b.WriteString(mutedStyle.Render("No matching services"))
} else {
serviceNames := make([]string, len(filteredServices))
for i, svc := range filteredServices {
serviceNames[i] = svc.Name
}
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
}
}
}
b.WriteString("\n")
// Show appropriate help text based on resource type and filter state
if wizard.selectedResourceType == ResourceTypeService {
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredServices()), len(wizard.services))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
} else {
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderEnterRemotePort() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(5, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n")
// Show resource selection
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("%s [%s]", wizard.resourceValue, wizard.selector)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf("Resource: %s\n\n", resourceInfo)))
// If we have detected ports and in list mode, show them as a list
if len(wizard.detectedPorts) > 0 && wizard.inputMode == InputModeList {
b.WriteString("Select remote port:\n\n")
const viewportHeight = 20
totalItems := len(wizard.detectedPorts) + 1 // +1 for manual entry option
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render detected ports within viewport
for i := start; i < end && i < len(wizard.detectedPorts); i++ {
port := wizard.detectedPorts[i]
portDesc := fmt.Sprintf("%d", port.Port)
if port.Name != "" {
portDesc += fmt.Sprintf(" (%s)", port.Name)
}
prefix := " "
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + portDesc))
} else {
b.WriteString(prefix + portDesc)
}
b.WriteString("\n")
}
// Add "Manual entry" option if within viewport
manualIdx := len(wizard.detectedPorts)
if manualIdx >= start && manualIdx < end {
manualOption := "Manual entry (type port number)"
prefix := " "
if wizard.cursor == manualIdx {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + manualOption))
} else {
b.WriteString(mutedStyle.Render(prefix + manualOption))
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
} else {
// Text input mode (no detected ports or user chose manual entry)
if len(wizard.detectedPorts) > 0 {
b.WriteString(mutedStyle.Render("Detected ports:\n"))
for _, port := range wizard.detectedPorts {
portDesc := fmt.Sprintf("%d", port.Port)
if port.Name != "" {
portDesc += fmt.Sprintf(" (%s)", port.Name)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", portDesc)))
}
b.WriteString("\n")
}
b.WriteString(renderTextInput("Remote port: ", wizard.textInput, wizard.error == nil))
b.WriteString("\n\n")
if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ %v", wizard.error)))
} else if wizard.textInput != "" {
b.WriteString(mutedStyle.Render("Press Enter to continue"))
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderEnterLocalPort() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(6, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n")
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("%s [%s]", wizard.resourceValue, wizard.selector)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf("Resource: %s\n", resourceInfo)))
b.WriteString(mutedStyle.Render(fmt.Sprintf("Remote port: %d\n\n", wizard.remotePort)))
b.WriteString(renderTextInput("Local port: ", wizard.textInput, wizard.error == nil))
b.WriteString("\n\n")
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Checking availability..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ %v", wizard.error)))
} else if wizard.portCheckMsg != "" {
if wizard.portAvailable {
b.WriteString(successStyle.Render(wizard.portCheckMsg))
} else {
b.WriteString(errorStyle.Render(wizard.portCheckMsg))
}
} else if wizard.textInput != "" && wizard.localPort > 0 {
b.WriteString(mutedStyle.Render("Press Enter to check availability"))
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
return b.String()
}
func (m model) renderConfirmation() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(7, 7)))
b.WriteString("\n")
b.WriteString("Review Configuration:\n\n")
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("pod (selector: %s)", wizard.selector)
} else if wizard.selectedResourceType == ResourceTypePodPrefix {
resourceInfo = fmt.Sprintf("pod/%s", wizard.resourceValue)
} else if wizard.selectedResourceType == ResourceTypeService {
resourceInfo = fmt.Sprintf("service/%s", wizard.resourceValue)
}
b.WriteString(fmt.Sprintf(" Context: %s\n", wizard.selectedContext))
b.WriteString(fmt.Sprintf(" Namespace: %s\n", wizard.selectedNamespace))
b.WriteString(fmt.Sprintf(" Resource: %s\n", resourceInfo))
b.WriteString(fmt.Sprintf(" Remote Port: %d\n", wizard.remotePort))
b.WriteString(fmt.Sprintf(" Local Port: %d\n", wizard.localPort))
b.WriteString(" Protocol: tcp\n")
b.WriteString("\n")
// Show alias field with focus indicator
if wizard.confirmationFocus == FocusAlias {
b.WriteString(selectedStyle.Render("▸ Optional alias (friendly name):") + "\n")
b.WriteString(" Alias: " + validInputStyle.Render(wizard.textInput+"█") + "\n")
} else {
b.WriteString(mutedStyle.Render(" Optional alias (friendly name):") + "\n")
b.WriteString(mutedStyle.Render(" Alias: "+wizard.textInput) + "\n")
}
b.WriteString("\n")
// Show buttons with focus indicator
if wizard.confirmationFocus == FocusButtons {
if wizard.cursor == 0 {
b.WriteString(selectedStyle.Render("▸ Add to .kportal.yaml") + "\n")
b.WriteString(" Cancel\n")
} else {
b.WriteString(" Add to .kportal.yaml\n")
b.WriteString(selectedStyle.Render("▸ Cancel") + "\n")
}
} else {
b.WriteString(mutedStyle.Render(" Add to .kportal.yaml") + "\n")
b.WriteString(mutedStyle.Render(" Cancel") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓/Tab: Navigate Enter: Confirm Esc: Back"))
return b.String()
}
func (m model) renderSuccess() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(successStyle.Render("Success! ✓"))
b.WriteString("\n\n")
if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("Error: %v", wizard.error)))
} else {
b.WriteString("Added to .kportal.yaml\n\n")
forwardDesc := fmt.Sprintf("localhost:%d → %s:%d",
wizard.localPort,
wizard.resourceValue,
wizard.remotePort)
if wizard.alias != "" {
forwardDesc = fmt.Sprintf("%s (%s)", wizard.alias, forwardDesc)
}
b.WriteString(successStyle.Render(forwardDesc))
b.WriteString("\n\n")
b.WriteString(mutedStyle.Render("The port forward will be active shortly."))
}
b.WriteString("\n\n")
b.WriteString("Would you like to:\n")
if wizard.cursor == 0 {
b.WriteString(selectedStyle.Render("▸ Add another port forward") + "\n")
b.WriteString(" Return to main view\n")
} else {
b.WriteString(" Add another port forward\n")
b.WriteString(selectedStyle.Render("▸ Return to main view") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select"))
return b.String()
}
// renderRemoveWizard renders the remove wizard
func (m model) renderRemoveWizard() string {
if m.ui.removeWizard == nil {
return ""
}
wizard := m.ui.removeWizard
var content string
if wizard.confirming {
content = m.renderRemoveConfirmation()
} else {
content = m.renderRemoveSelection()
}
return wizardBoxStyle.Render(content)
}
func (m model) renderRemoveSelection() string {
wizard := m.ui.removeWizard
var b strings.Builder
b.WriteString(renderHeader("Remove Port Forwards", ""))
b.WriteString("\n")
b.WriteString("Select forwards to remove (Space to toggle):\n\n")
for i, fwd := range wizard.forwards {
isSelected := i == wizard.cursor
isChecked := wizard.selected[i]
line1 := fmt.Sprintf("%s:%d→%d", fwd.Alias, fwd.Port, fwd.LocalPort)
line2 := fmt.Sprintf(" %s/%s/%s", fwd.Context, fwd.Namespace, fwd.Resource)
checkbox := "[ ] "
if isChecked {
checkbox = "[✓] "
}
fullLine := checkbox + line1
if isSelected {
b.WriteString(selectedStyle.Render(fullLine))
} else {
if isChecked {
b.WriteString(checkedBoxStyle.Render(checkbox) + line1)
} else {
b.WriteString(uncheckedBoxStyle.Render(checkbox) + line1)
}
}
b.WriteString("\n")
b.WriteString(mutedStyle.Render(line2))
b.WriteString("\n\n")
}
selectedCount := wizard.getSelectedCount()
b.WriteString(fmt.Sprintf("%d of %d selected\n\n", selectedCount, len(wizard.forwards)))
b.WriteString(helpStyle.Render("Space: Toggle a: All n: None Enter: Remove Esc: Cancel"))
return b.String()
}
func (m model) renderRemoveConfirmation() string {
wizard := m.ui.removeWizard
var b strings.Builder
b.WriteString(renderHeader("Confirm Removal", ""))
b.WriteString("\n")
selectedCount := wizard.getSelectedCount()
b.WriteString(fmt.Sprintf("Remove %d port forward(s)?\n\n", selectedCount))
selectedForwards := wizard.getSelectedForwards()
for _, fwd := range selectedForwards {
b.WriteString(errorStyle.Render(fmt.Sprintf(" • %s:%d→%d\n", fwd.Alias, fwd.Port, fwd.LocalPort)))
b.WriteString(mutedStyle.Render(fmt.Sprintf(" %s/%s/%s\n", fwd.Context, fwd.Namespace, fwd.Resource)))
}
b.WriteString("\n")
b.WriteString(warningStyle.Render("This action cannot be undone."))
b.WriteString("\n\n")
// Yes/No buttons
if wizard.confirmCursor == 0 {
b.WriteString(selectedStyle.Render("▸ Yes, remove them") + "\n")
b.WriteString(" Cancel\n")
} else {
b.WriteString(" Yes, remove them\n")
b.WriteString(selectedStyle.Render("▸ Cancel") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Confirm Esc: Cancel"))
return b.String()
}