Compare commits

...

12 Commits

Author SHA1 Message Date
lukaszraczylo 1167847fd4 bugfixes nov2025 pt4 (#7)
* Add mDNS resolution.
* Update the website and documentation
2025-11-25 11:14:33 +00:00
lukaszraczylo 3a7cc6f502 bugfixes nov2025 pt3 (#6)
* Minor improvements.
* DRY the codebase.
* Add version checker / updater.
2025-11-25 01:28:23 +00:00
lukaszraczylo 49acba5679 Bugfixes nov2025 pt2 (#5)
* UI bugfixes.
* Fix open port check during new fwd setup wizard
2025-11-25 00:09:32 +00:00
lukaszraczylo 39fe4286b4 Fix the watchdog being too aggressive. 2025-11-24 13:19:44 +00:00
lukaszraczylo 2fdc5912e7 healtcheck improvements (#4)
* Advanced healtchecks.
* Add watchdog for stale connections handling.
2025-11-24 13:00:19 +00:00
lukaszraczylo 7df161aee0 bugfixes nov2025 (#3)
* Fix enter misbehaving.
* Cleanup after previous tui implementation.
* Fix race condition and improve logging
* Add filtering of the namespaces by text input in the wizard UI
2025-11-24 11:09:23 +00:00
lukaszraczylo f41c316b2b Add configuration wizard. (#2)
* Add configuration wizard.
2025-11-24 02:28:08 +00:00
lukaszraczylo 0f86f8e230 Update the tap link in README and github page 2025-11-23 18:35:49 +00:00
lukaszraczylo 0b07c99342 Remove semver misleading config. 2025-11-23 18:35:49 +00:00
lukaszraczylo 77b3f18a07 Base release - 0.2.x 2025-11-23 18:35:49 +00:00
lukaszraczylo 2e6db9ae2f Add kportal screenshot to documentation
- Add screenshot to README.md
- Add screenshot to GitHub Pages site with styling
- Include screenshot image in docs/ directory
2025-11-23 18:35:48 +00:00
lukaszraczylo 2ca8a2df69 Fix build and deployment issues
- Fix .gitignore to only ignore binary at root (/kportal)
- Add cmd/kportal/main.go to repository (was incorrectly ignored)
- Resolve merge conflict in static.yml workflow
- Ensure GitHub Pages workflow only triggers on docs/ changes
2025-11-23 18:35:37 +00:00
55 changed files with 10107 additions and 1705 deletions
-3
View File
@@ -5,11 +5,8 @@ on:
# Runs on pushes targeting the default branch
push:
branches: ["main"]
<<<<<<< HEAD
=======
paths:
- 'docs/**'
>>>>>>> b4f4c38 (Add github page.)
# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:
+1 -1
View File
@@ -1,5 +1,5 @@
CLAUDE.md
kportal
/kportal
DEPLOYMENT_SUMMARY.md
HOMEBREW_COMPLIANCE.md
RELEASE_SETUP.md
+1 -1
View File
@@ -19,7 +19,7 @@ builds:
- arm64
ldflags:
- -s -w
- -X main.version={{.Version}}
- -X main.appVersion={{.Version}}
archives:
- id: kportal
+21
View File
@@ -1,6 +1,27 @@
# Example kportal configuration
# Copy this file to your project and customize as needed
# Optional: Health check configuration
# These settings control how kportal monitors connection health and detects stale connections
healthCheck:
interval: "3s" # How often to check connection health (default: 3s)
timeout: "2s" # Timeout for health check operations (default: 2s)
method: "data-transfer" # Health check method: "tcp-dial" or "data-transfer" (default: data-transfer)
# - tcp-dial: Simple TCP connection test (fast, less reliable)
# - data-transfer: Attempts to read data (slower, more reliable)
maxConnectionAge: "25m" # Maximum connection age before proactive reconnect (default: 25m)
# Helps avoid Kubernetes API server timeouts (typically 30m)
maxIdleTime: "10m" # Maximum idle time before marking as stale (default: 10m)
# Connections with no data transfer are marked stale
# Optional: Reliability configuration
# These settings improve connection stability for long-running transfers
reliability:
tcpKeepalive: "30s" # TCP keepalive interval for OS-level connection monitoring (default: 30s)
dialTimeout: "30s" # Connection dial timeout (default: 30s)
retryOnStale: true # Automatically reconnect when stale connections detected (default: true)
watchdogPeriod: "30s" # Goroutine watchdog check interval to detect hung workers (default: 30s)
contexts:
# Production context
- name: production
+1 -1
View File
@@ -37,7 +37,7 @@ GOFMT=$(GOCMD) fmt
# Build flags
BUILD_FLAGS=-buildvcs=false
LDFLAGS=-ldflags="-s -w -X main.version=$(VERSION)"
LDFLAGS=-ldflags="-s -w -X main.appVersion=$(VERSION)"
all: fmt vet staticcheck test build
+169 -514
View File
@@ -1,42 +1,43 @@
# kportal
<p align="center">
<img src="docs/kportal-logo-dark.svg" alt="kportal logo" width="400">
</p>
[![Release](https://img.shields.io/github/v/release/lukaszraczylo/kportal)](https://github.com/lukaszraczylo/kportal/releases)
[![License](https://img.shields.io/github/license/lukaszraczylo/kportal)](LICENSE)
[![Go Report Card](https://goreportcard.com/badge/github.com/lukaszraczylo/kportal)](https://goreportcard.com/report/github.com/lukaszraczylo/kportal)
<p align="center">
<a href="https://github.com/lukaszraczylo/kportal/releases"><img src="https://img.shields.io/github/v/release/lukaszraczylo/kportal" alt="Release"></a>
<a href="LICENSE"><img src="https://img.shields.io/github/license/lukaszraczylo/kportal" alt="License"></a>
<a href="https://goreportcard.com/report/github.com/lukaszraczylo/kportal"><img src="https://goreportcard.com/badge/github.com/lukaszraczylo/kportal" alt="Go Report Card"></a>
</p>
**Modern Kubernetes port-forward manager with interactive terminal UI**
<p align="center">
<strong>Kubernetes port-forward manager with interactive terminal UI</strong>
</p>
kportal simplifies managing multiple Kubernetes port-forwards with an elegant, interactive terminal interface. Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea), it provides real-time status updates, automatic reconnection, and hot-reload configuration support.
kportal manages multiple Kubernetes port-forwards with an interactive terminal interface. It provides real-time status updates, automatic reconnection, hot-reload configuration, and mDNS hostname publishing.
![kportal Demo](docs/images/demo.png)
![kportal Screenshot](docs/kportal-screenshot.png)
## ✨ Features
- 🎯 **Interactive TUI** - Beautiful terminal interface with keyboard navigation (↑↓/jk, Space to toggle, q to quit)
- 🔄 **Auto-Reconnect** - Automatic retry with exponential backoff on connection failures (max 10s)
- **Hot-Reload** - Update configuration without restarting - changes applied automatically
- 🏥 **Health Checks** - Real-time port forward status monitoring with 5-second intervals
- 🎨 **Multi-Context** - Support for multiple Kubernetes contexts and namespaces
- 📦 **Batch Management** - Manage all port-forwards from a single configuration file
- 🔌 **Toggle Forwards** - Enable/disable individual port-forwards on the fly with Space key
- 🚀 **Grace Period** - Smart 10-second grace period to avoid false "Error" status on startup
- 📊 **Status Display** - Clear visual indicators: Active (●), Starting (○), Reconnecting (◐), Error (✗)
- 🔍 **Error Reporting** - Detailed error messages displayed below the table
- 🔄 **Pod Restart Handling** - Detects and reconnects to pods when they restart
- 🏷️ **Label Selector Support** - Dynamically target pods using label selectors
- 📋 **Prefix Matching** - Automatically find and reconnect to pods with name prefixes
- 🚫 **Port Conflict Detection** - Validates port availability before starting with detailed PID info
- 🎭 **Alias Support** - Cleaner, more readable display names for your forwards
- **Interactive TUI** - Terminal interface with keyboard navigation
- **Live management** - Add, edit, and delete port-forwards without restarting
- **Auto-reconnect** - Exponential backoff retry on connection failures
- **Hot-reload** - Configuration changes applied automatically
- **Health monitoring** - Multiple check methods with stale connection detection
- **Multi-context** - Support for multiple Kubernetes contexts and namespaces
- **Pod restart handling** - Automatic reconnection when pods restart
- **Label selectors** - Dynamic pod targeting using label selectors
- **Port conflict detection** - Validates port availability with PID information
- **mDNS hostnames** - Access forwards via `.local` hostnames
## 📦 Installation
### Homebrew (macOS/Linux)
```bash
brew install lukaszraczylo/tap/kportal
brew install lukaszraczylo/brew-taps/kportal
```
### Quick Install Script
### Quick Install
```bash
curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/kportal/main/install.sh | bash
@@ -44,24 +45,19 @@ curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/kportal/main/install.
### Manual Download
Download the latest binary for your platform from the [releases page](https://github.com/lukaszraczylo/kportal/releases):
- **macOS**: `kportal-{version}-darwin-{amd64|arm64}.tar.gz`
- **Linux**: `kportal-{version}-linux-{amd64|arm64}.tar.gz`
- **Windows**: `kportal-{version}-windows-{amd64|arm64}.zip`
Download binaries from the [releases page](https://github.com/lukaszraczylo/kportal/releases).
### Build from Source
```bash
git clone https://github.com/lukaszraczylo/kportal.git
cd kportal
make build
make install
make build && make install
```
## 🚀 Quick Start
1. **Create a configuration file** (`.kportal.yaml`):
Create `.kportal.yaml`:
```yaml
contexts:
@@ -75,29 +71,32 @@ contexts:
localPort: 5432
alias: prod-db
- name: frontend
forwards:
- resource: service/redis
protocol: tcp
port: 6379
localPort: 6380
alias: prod-redis
localPort: 6379
```
2. **Run kportal**:
Run:
```bash
kportal
```
3. **Navigate the interface**:
- `↑↓` or `j/k` - Navigate through forwards
- `Space` or `Enter` - Toggle forward on/off
- `q` - Quit application
### Keyboard Controls
| Key | Action |
|-----|--------|
| `↑↓` / `j/k` | Navigate |
| `Space` / `Enter` | Toggle forward |
| `a` | Add forward |
| `e` | Edit forward |
| `d` | Delete forward |
| `q` | Quit |
## 📖 Configuration
### Simple Configuration
### Basic Structure
```yaml
contexts:
@@ -105,551 +104,207 @@ contexts:
namespaces:
- name: <namespace-name>
forwards:
- resource: <resource-type>/<resource-name>
- resource: <type>/<name>
protocol: tcp
port: <remote-port>
localPort: <local-port>
alias: <friendly-name> # Optional
alias: <display-name> # optional
selector: <label-selector> # optional
```
### Advanced Configuration
### Forward Options
```yaml
contexts:
# Production cluster
- name: prod-us-west
namespaces:
- name: databases
forwards:
# Direct pod connection with prefix matching
- resource: pod/postgres-primary
protocol: tcp
port: 5432
localPort: 5432
alias: prod-postgres
# Service connection
- resource: service/redis-master
protocol: tcp
port: 6379
localPort: 6379
alias: prod-redis
# Pod with label selector
- resource: pod
selector: app=mongodb
protocol: tcp
port: 27017
localPort: 27017
alias: mongo
- name: applications
forwards:
- resource: deployment/api-server
protocol: tcp
port: 8080
localPort: 8080
alias: api
# Development cluster
- name: dev-local
namespaces:
- name: default
forwards:
- resource: service/grafana
protocol: tcp
port: 3000
localPort: 3000
alias: grafana-dashboard
```
### Configuration Options
| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `resource` | string | Yes | Kubernetes resource with type prefix (e.g., `service/name`, `pod/name`) |
| `protocol` | string | Yes | Connection protocol (typically `tcp`) |
| `port` | int | Yes | Remote port on the Kubernetes resource |
| `localPort` | int | Yes | Local port to forward to |
| `alias` | string | No | Friendly name for display (defaults to resource name) |
| `selector` | string | No | Label selector for dynamic pod selection (e.g., `app=nginx,env=prod`) |
| Field | Required | Description |
|-------|----------|-------------|
| `resource` | Yes | Resource type and name (e.g., `service/postgres`, `pod/my-app`) |
| `protocol` | Yes | Protocol (`tcp`) |
| `port` | Yes | Remote port |
| `localPort` | Yes | Local port |
| `alias` | No | Display name and mDNS hostname |
| `selector` | No | Label selector for pod resolution |
### Resource Formats
- **Pod by name**: `pod/pod-name` or just `pod-name`
- **Pod by prefix**: `pod/my-app` (matches `my-app-xyz789`, `my-app-abc123`, etc.)
- **Pod by selector**: Set `resource: pod` and use `selector: app=nginx`
- **Service**: `service/service-name` or `svc/service-name`
- **Deployment**: `deployment/deployment-name` or `deploy/deployment-name`
| Format | Description |
|--------|-------------|
| `service/name` | Service forwarding |
| `pod/name` | Direct pod by name |
| `pod/prefix` | Pod by prefix (matches `prefix-*`) |
| `pod` + `selector` | Pod by label selector |
| `deployment/name` | Deployment |
## 🎮 Usage
### Health Check Configuration
### Interactive Mode (Default)
```yaml
healthCheck:
interval: "3s" # Check frequency
timeout: "2s" # Check timeout
method: "data-transfer" # tcp-dial or data-transfer
maxConnectionAge: "25m" # Reconnect before k8s timeout
maxIdleTime: "10m" # Detect idle connections
reliability:
tcpKeepalive: "30s"
dialTimeout: "30s"
retryOnStale: true
```
Health check methods:
- `tcp-dial` - Fast TCP connection test
- `data-transfer` - Verifies tunnel functionality by attempting data read
Connection age reconnection only triggers when the connection is also idle, preventing interruption of active transfers like database dumps.
### mDNS Hostnames
Enable mDNS to access forwards via `.local` hostnames:
```yaml
mdns:
enabled: true
contexts:
- name: production
namespaces:
- name: default
forwards:
- resource: service/postgres
port: 5432
localPort: 5432
alias: prod-db # Accessible via prod-db.local:5432
```
- Explicit `alias` becomes `<alias>.local`
- Without alias, hostname is generated from resource name (`service/redis``redis.local`)
- Works on macOS (Bonjour) and Linux (avahi-daemon)
Verify registration:
```bash
dns-sd -B _kportal._tcp local # macOS
avahi-browse -t _kportal._tcp # Linux
```
## Usage
### Interactive Mode
```bash
kportal
```
Starts the interactive TUI where you can:
- View all configured port-forwards in a table
- See real-time status updates (Active, Starting, Reconnecting, Error)
- Toggle forwards on/off with Space key
- View detailed error messages at the bottom of the screen
### Verbose Mode
```bash
kportal -v
```
Runs in verbose mode with:
- Detailed logging to stdout
- Periodic status table updates every 2 seconds
- Full error traces
- No interactive controls (for automation/debugging)
### Validate Configuration
```bash
kportal --check
```
Validates your configuration file without starting any forwards:
- Checks YAML syntax
- Validates all required fields
- Detects duplicate local ports
- Shows validation errors with line numbers
### Custom Configuration File
### Custom Config File
```bash
kportal -c /path/to/config.yaml
```
### Version Information
## Status Indicators
| Indicator | Description |
|-----------|-------------|
| `● Active` | Connection healthy |
| `○ Starting` | Initial connection (10s grace period) |
| `◐ Reconnecting` | Reconnecting after failure |
| `✗ Error` | Connection failed |
| `○ Disabled` | Manually disabled |
## Advanced Features
### Hot-Reload
Configuration changes are applied automatically. Manual reload:
```bash
kportal --version
# Output: kportal version 0.1.5
kill -HUP $(pgrep kportal)
```
## 🔄 kftray Migration
### Port Conflict Detection
Migrate from kftray JSON configuration:
kportal validates port availability at startup and during hot-reload, showing which process is using conflicting ports.
### Retry Strategy
Exponential backoff: 1s → 2s → 4s → 8s → 10s (max). Retries continue indefinitely until connection succeeds.
## Migration from kftray
```bash
kportal --convert configs.json --convert-output .kportal.yaml
```
**Example conversion:**
## Signal Handling
kftray JSON:
```json
[
{
"service": "postgres",
"namespace": "default",
"local_port": 5432,
"remote_port": 5432,
"context": "production",
"workload_type": "service",
"protocol": "tcp",
"alias": "prod-db"
}
]
```
Converts to kportal YAML:
```yaml
contexts:
- name: production
namespaces:
- name: default
forwards:
- resource: service/postgres
protocol: tcp
port: 5432
localPort: 5432
alias: prod-db
```
## 🎨 Status Indicators
| Indicator | Status | Description |
|-----------|--------|-------------|
| `● Active` | 🟢 Green | Port-forward is active and healthy |
| `○ Starting` | 🟡 Yellow | Initial connection in progress (10s grace period) |
| `◐ Reconnecting` | 🟡 Yellow | Attempting to reconnect after failure |
| `✗ Error` | 🔴 Red | Connection failed - see error details below table |
| `○ Disabled` | ⚪ Gray | Port-forward manually disabled via Space key |
## 🛠️ Advanced Features
### Hot-Reload
kportal automatically watches for configuration file changes and reloads:
```bash
# Edit your config while kportal is running
vim .kportal.yaml
# Changes are applied automatically within seconds:
# - New forwards are started
# - Removed forwards are stopped
# - Existing forwards continue running unchanged
```
Supports manual reload via `SIGHUP`:
```bash
kill -HUP $(pgrep kportal)
```
### Health Checks
Built-in health monitoring system:
- **Check interval**: Every 5 seconds
- **Timeout**: 2 seconds per check
- **Grace period**: 10 seconds for new connections
- **Automatic updates**: Real-time status changes in UI
- **Error tracking**: Detailed error messages for failed connections
### Error Display
When connections fail, errors are displayed below the table:
```
Errors:
• prod-postgres: dial tcp 127.0.0.1:5432: connect: connection refused
• prod-redis: i/o timeout after 2.0s
```
Errors automatically clear when:
- Connection becomes healthy
- Forward is disabled
- Forward is removed
### Port Conflict Detection
kportal checks for port conflicts at multiple stages:
**At startup:**
```
Port conflicts detected:
Port 8080:
• Requested by: api-server (context: prod, namespace: default)
• Currently used by: PID 1234 (chrome)
```
**During hot-reload:**
- Only validates new ports being added
- Skips currently managed ports
- Rejects configuration if conflicts found
### Pod Restart Handling
When a pod restarts:
1. Port-forward connection breaks
2. kportal immediately re-resolves the resource:
- For prefix matches: Finds newest pod with matching prefix
- For selectors: Re-queries pods with matching labels
3. Reconnects to new pod
4. Logs the switch: `Switched to new pod: old-pod-abc → new-pod-xyz`
### Retry Strategy
Exponential backoff with maximum interval:
- **Intervals**: 1s → 2s → 4s → 8s → 10s (max)
- **Infinite retries**: Continues until connection succeeds
- **Independent**: Each forward has its own retry logic
- **Grace period**: First 10 seconds show "Starting" instead of "Error"
## 🔧 Development
### Prerequisites
- Go 1.23 or higher
- Access to a Kubernetes cluster
- kubectl configured with contexts
### Building
```bash
# Build binary
make build
# Run tests
make test
# Run all checks (fmt, vet, staticcheck, test)
make all
# Check current version
make version
# Install locally
make install
# Install system-wide
sudo make install-system
# Clean build artifacts
make clean
```
### Project Structure
```
kportal/
├── cmd/kportal/ # Main application entry point
├── internal/
│ ├── config/ # Configuration parsing and validation
│ ├── forward/ # Port-forward manager and workers
│ │ ├── manager.go # Orchestrates all forwards
│ │ ├── worker.go # Individual forward worker
│ │ └── port_checker.go # Port conflict detection
│ ├── healthcheck/ # Health monitoring system
│ │ └── checker.go # Port health checking
│ ├── k8s/ # Kubernetes client wrapper
│ │ ├── client.go # K8s client management
│ │ ├── port_forward.go # Port-forward implementation
│ │ └── resolver.go # Resource resolution
│ ├── retry/ # Retry logic with backoff
│ │ └── backoff.go # Exponential backoff
│ ├── ui/ # Terminal UI implementations
│ │ ├── bubbletea_ui.go # Interactive TUI (Bubble Tea)
│ │ └── table_ui.go # Simple table for verbose mode
│ └── converter/ # kftray JSON converter
├── Formula/ # Homebrew formula
├── .github/workflows/ # CI/CD pipelines
│ └── release.yml # Release automation
├── install.sh # Installation script
├── semver.yaml # Semantic version config
├── Makefile # Build automation
└── README.md # This file
```
## 📝 Examples
### Database Access
```yaml
contexts:
production:
namespaces:
databases:
- resource: postgres-primary
port: 5432
local_port: 5432
alias: prod-db
```
Connect with:
```bash
kportal # Start in another terminal
psql -h localhost -p 5432 -U postgres
```
### Multiple Services
```yaml
contexts:
dev:
namespaces:
default:
- resource: api
port: 8080
local_port: 8080
- resource: frontend
port: 3000
local_port: 3000
- resource: redis
port: 6379
local_port: 6379
```
Access:
- API: `http://localhost:8080`
- Frontend: `http://localhost:3000`
- Redis: `redis-cli -p 6379`
### Cross-Context Setup
```yaml
contexts:
prod-us:
namespaces:
backend:
- resource: api
port: 8080
local_port: 8080
alias: prod-us-api
prod-eu:
namespaces:
backend:
- resource: api
port: 8080
local_port: 8081 # Different local port
alias: prod-eu-api
```
Compare APIs across regions simultaneously.
### Debug Multiple Pod Versions
```yaml
contexts:
production:
namespaces:
default:
# Version 1
- resource: pod
selector: app=myapp,version=v1
port: 8080
local_port: 8080
alias: app-v1
# Version 2
- resource: pod
selector: app=myapp,version=v2
port: 8080
local_port: 8081
alias: app-v2
# Debug port for v2
- resource: pod
selector: app=myapp,version=v2
port: 6060
local_port: 6060
alias: app-v2-pprof
```
- `Ctrl+C` / `SIGTERM` - Graceful shutdown
- `SIGHUP` - Reload configuration
## 🐛 Troubleshooting
### Port Already in Use
**Problem**: `Port 8080: already in use by PID 1234 (chrome)`
**Solutions**:
```bash
# Find the process
lsof -i :8080
# Kill the process
kill 1234
# Or use a different local port in config
local_port: 8081
lsof -i :<port>
kill <pid>
```
### Connection Refused
**Problem**: `dial tcp 127.0.0.1:8080: connect: connection refused`
**Common causes**:
1. **Pod not ready yet** - Wait for status to change from "Starting" → "Active" (10s grace period)
2. **Wrong port number** - Verify the pod/service actually exposes that port
3. **Service not exposed** - Check with `kubectl get svc` and `kubectl describe svc <name>`
**Debug**:
```bash
# Check pod status
kubectl get pods -n <namespace>
# Check if port is exposed
kubectl describe pod <pod-name> -n <namespace>
# Check service endpoints
kubectl get endpoints <service-name> -n <namespace>
```
1. Verify pod is running: `kubectl get pods -n <namespace>`
2. Verify port is correct: `kubectl describe pod <pod>`
3. Check service endpoints: `kubectl get endpoints <service>`
### Context Not Found
**Problem**: `context "prod" not found in kubeconfig`
**Solution**:
```bash
# List available contexts
kubectl config get-contexts
# Verify context name matches
kubectl config current-context
# Update your config to use the correct context name
```
### Health Check Errors During Startup
## 🔧 Development
**Problem**: Seeing "Error" status immediately after starting
### Prerequisites
**This is normal!** kportal has a 10-second grace period. If the connection is still failing after 10 seconds, check:
- Pod is running: `kubectl get pods`
- Port is correct in config
- Network connectivity to cluster
- Go 1.23+
- Kubernetes cluster access
- kubectl configured
### Logs Covering UI
### Build
**Problem**: Kubernetes client logs appearing over the interactive UI
**This is fixed in v0.1.5+**. The interactive mode now completely suppresses all logs including:
- Standard Go `log` package
- Kubernetes `klog` output
- Any stderr/stdout leakage
If you still see logs, please file an issue!
## 🤝 Contributing
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
1. Fork the repository
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
3. Make your changes and add tests
4. Run checks: `make all`
5. Commit your changes (follow [semantic commit messages](#semantic-versioning))
6. Push to the branch (`git push origin feature/amazing-feature`)
7. Open a Pull Request
### Semantic Versioning
This project uses [semver-gen](https://github.com/lukaszraczylo/semver-generator) for automatic semantic version generation based on git commit messages.
**Version Keywords:**
- **Patch** (0.0.X): `fix`, `bugfix`, `hotfix`, `patch`, `docs`, `test`, `refactor`
- **Minor** (0.X.0): `feat`, `feature`, `add`, `enhance`, `update`, `improve`
- **Major** (X.0.0): `breaking`, `major`, `BREAKING CHANGE`
Example commits:
```bash
git commit -m "feat: add health check grace period" # Bumps minor version
git commit -m "fix: resolve port conflict detection" # Bumps patch version
git commit -m "breaking: change config file format" # Bumps major version
make build # Build binary
make test # Run tests
make all # fmt, vet, staticcheck, test
make install # Install locally
```
## 📄 License
## Contributing
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
## 🙏 Acknowledgments
## License
- Built with [Bubble Tea](https://github.com/charmbracelet/bubbletea) by Charm - An awesome framework for building terminal UIs
- Styled with [Lipgloss](https://github.com/charmbracelet/lipgloss) - Terminal styling library
- Inspired by [kftray](https://github.com/hcavarsan/kftray) - Original GUI port-forward manager
- Uses [client-go](https://github.com/kubernetes/client-go) for Kubernetes integration
- Version management by [semver-gen](https://github.com/lukaszraczylo/semver-generator)
MIT License - see [LICENSE](LICENSE).
## 📚 Documentation
## Acknowledgments
- [Bubble Tea](https://github.com/charmbracelet/bubbletea) - Terminal UI framework
- [Lipgloss](https://github.com/charmbracelet/lipgloss) - Terminal styling
- [client-go](https://github.com/kubernetes/client-go) - Kubernetes client
- [kftray](https://github.com/hcavarsan/kftray) - Inspiration
## Links
- [Website](https://lukaszraczylo.github.io/kportal)
- [Issue Tracker](https://github.com/lukaszraczylo/kportal/issues)
- [Issues](https://github.com/lukaszraczylo/kportal/issues)
- [Releases](https://github.com/lukaszraczylo/kportal/releases)
- [Changelog](CHANGELOG.md)
## Signal Handling
- `Ctrl+C` / `SIGTERM`: Graceful shutdown (closes all forwards)
- `SIGHUP`: Reload configuration file
---
Made with ❤️ by [Lukasz Raczylo](https://github.com/lukaszraczylo)
+57 -251
View File
@@ -1,320 +1,126 @@
# Release Infrastructure Setup Summary
# Release Infrastructure
This document summarizes all the release infrastructure that has been set up for kportal.
Documentation for kportal's release automation and distribution.
## Completed Setup
### 1. GitHub Actions CI/CD Pipeline
## 🔄 CI/CD Pipeline
**File**: `.github/workflows/release.yml`
**Features**:
- Multi-platform binary builds (Linux, macOS, Windows - amd64 & arm64)
- Automatic release creation on version tags
- Binary archiving (tar.gz for Unix, zip for Windows)
- SHA256 checksum generation
- Automated Homebrew formula updates
- Release notes generation
The pipeline builds multi-platform binaries, creates GitHub releases, and updates Homebrew on version tags.
### Trigger a Release
**How to trigger**:
```bash
# Commit with semantic versioning keywords
git commit -m "feat: add new feature"
# Tag the release
git tag -a v0.2.0 -m "Release v0.2.0"
# Push tags
git push origin v0.2.0
```
The pipeline will automatically:
The pipeline will:
1. Build binaries for all platforms
2. Create GitHub release with binaries
2. Create GitHub release with binaries and checksums
3. Update Homebrew tap formula
4. Generate release notes
### 2. Installation Methods
## 📦 Installation Methods
#### A. Homebrew Formula
### Homebrew
**File**: `Formula/kportal.rb`
**Installation command**:
```bash
brew install lukaszraczylo/tap/kportal
```
**Note**: Formula is automatically updated by CI/CD pipeline. You'll need to create a separate tap repository:
1. Create repo: `https://github.com/lukaszraczylo/brew-taps`
2. Add Formula/kportal.rb to that repo
3. Set `HOMEBREW_TAP_TOKEN` secret in GitHub repository settings
Formula is automatically updated by CI/CD. Requires:
- Tap repository: `https://github.com/lukaszraczylo/brew-taps`
- Secret: `HOMEBREW_TAP_TOKEN` with `repo` scope
#### B. Quick Install Script
### Install Script
**File**: `install.sh`
**Features**:
- Auto-detects OS and architecture
- Downloads appropriate binary
- Extracts and installs to /usr/local/bin
- Verifies installation
- Colorful output with emoji indicators
**Installation command**:
```bash
curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/kportal/main/install.sh | bash
```
#### C. Manual Download
Auto-detects OS/architecture and installs to `/usr/local/bin`.
Users can download binaries directly from GitHub releases:
```
https://github.com/lukaszraczylo/kportal/releases
```
### Manual Download
### 3. Documentation
Download from [releases page](https://github.com/lukaszraczylo/kportal/releases).
#### A. Comprehensive README.md
## Platform Support
**File**: `README.md`
| OS | Architecture | Format |
|----|--------------|--------|
| Linux | amd64, arm64 | tar.gz |
| macOS | amd64, arm64 | tar.gz |
| Windows | amd64, arm64 | zip |
**Contents**:
- Feature showcase with emojis
- Multiple installation methods
- Quick start guide
- Configuration examples
- Usage instructions
- Advanced features documentation
- Troubleshooting guide
- Contributing guidelines
## 🚀 Release Process
#### B. GitHub Pages Website
**File**: `docs/index.html`
**Features**:
- Modern, responsive design with TailwindCSS
- Hero section with clear CTA
- Feature showcase cards
- Installation guide
- Configuration examples with syntax highlighting
- Documentation links
- Mobile-friendly
**URL** (once enabled): `https://lukaszraczylo.github.io/kportal`
**To enable**:
1. Go to GitHub repository settings
2. Pages section
3. Source: Deploy from a branch
4. Branch: main
5. Folder: /docs
### 4. Supporting Files
#### CHANGELOG.md
**File**: `CHANGELOG.md`
Tracks all changes following Keep a Changelog format. Update this file with each release.
#### CONTRIBUTING.md
**File**: `CONTRIBUTING.md`
Guidelines for:
- Bug reporting
- Feature requests
- Pull request process
- Commit message format
- Development setup
- Testing guidelines
## 🚀 Release Workflow
### Standard Release Process
1. **Develop features**
1. **Make changes and test**
```bash
git checkout -b feature/my-feature
# Make changes
make test
make all
make test && make all
```
2. **Commit with semantic messages**
```bash
git commit -m "feat: add amazing feature"
git commit -m "fix: resolve bug in health check"
```
2. **Update CHANGELOG.md**
3. **Update CHANGELOG.md**
```markdown
## [0.2.0] - 2025-11-24
### Added
- Amazing new feature
### Fixed
- Bug in health check
```
4. **Tag the release**
3. **Tag and push**
```bash
git tag -a v0.2.0 -m "Release v0.2.0"
git push origin main
git push origin v0.2.0
```
5. **CI/CD automatically**:
- Builds all binaries
- Creates GitHub release
- Updates Homebrew formula
- Attaches binaries and checksums
## Version Bumping
### Version Bumping (Semantic Versioning)
Version determined by commit message keywords:
Version is automatically determined by semver-gen from commit messages:
| Bump | Keywords |
|------|----------|
| Patch (0.0.X) | `fix`, `bugfix`, `docs`, `test`, `refactor` |
| Minor (0.X.0) | `feat`, `feature`, `add`, `enhance`, `update` |
| Major (X.0.0) | `breaking`, `major`, `BREAKING CHANGE` |
- **Patch** (0.0.X): `fix`, `bugfix`, `hotfix`, `patch`, `docs`, `test`, `refactor`
- **Minor** (0.X.0): `feat`, `feature`, `add`, `enhance`, `update`, `improve`
- **Major** (X.0.0): `breaking`, `major`, `BREAKING CHANGE`
## Required Secrets
## 📦 Platform Support
| Secret | Purpose |
|--------|---------|
| `GITHUB_TOKEN` | Provided by GitHub Actions |
| `HOMEBREW_TAP_TOKEN` | Personal access token with `repo` scope |
### Supported Platforms
| OS | Architecture | Archive Format |
|---------|-------------|----------------|
| Linux | amd64 | tar.gz |
| Linux | arm64 | tar.gz |
| macOS | amd64 | tar.gz |
| macOS | arm64 | tar.gz |
| Windows | amd64 | zip |
| Windows | arm64 | zip |
## 🔒 Required GitHub Secrets
For full automation, set these secrets in your GitHub repository:
1. **GITHUB_TOKEN** - Automatically provided by GitHub Actions
2. **HOMEBREW_TAP_TOKEN** - Personal access token for updating Homebrew tap
- Create at: https://github.com/settings/tokens
- Permissions needed: `repo` scope
- Add to repository secrets
## 📝 Next Steps
## ⚙️ Initial Setup
### 1. Enable GitHub Pages
- Repository Settings → Pages → Source: main branch, /docs folder
### 2. Create Homebrew Tap Repository
Repository Settings → Pages → Source: main branch, /docs folder
### 2. Create Homebrew Tap
```bash
# Create new repository
gh repo create lukaszraczylo/brew-taps --public
# Clone and set up
git clone https://github.com/lukaszraczylo/brew-taps
cd brew-taps
cp ../kportal/Formula/kportal.rb ./Formula/
git add Formula/kportal.rb
git commit -m "Initial formula for kportal"
git push origin main
mkdir Formula
# Formula will be auto-updated by CI
```
### 3. Add GitHub Token to Secrets
- Repository Settings → Secrets and variables → Actions
- New repository secret
### 3. Add Token Secret
Repository Settings → Secrets → Actions → New secret:
- Name: `HOMEBREW_TAP_TOKEN`
- Value: Your personal access token
### 4. Create First Release
```bash
cd kportal
git add .
git commit -m "feat: initial release setup"
git push origin main
git tag -a v0.1.5 -m "Release v0.1.5"
git push origin v0.1.5
```
### 5. Test Installation Methods
After first release, test:
```bash
# Homebrew (once tap is set up)
brew install lukaszraczylo/tap/kportal
# Quick install script
curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/kportal/main/install.sh | bash
# Manual download
# Visit releases page and download binary
```
## 🎨 Customization
### Update Website Colors
Edit `docs/index.html`:
```javascript
tailwind.config = {
theme: {
extend: {
colors: {
primary: '#3b82f6', // Blue
secondary: '#8b5cf6', // Purple
dark: '#0f172a', // Dark slate
}
}
}
}
```
### Update Release Notes Template
Edit `.github/workflows/release.yml` in the "Generate release notes" step.
## 📊 Monitoring
After releases, monitor:
- GitHub Actions workflow runs
- GitHub Releases page
- Homebrew tap repository commits
- Download statistics on releases page
- Value: Personal access token with `repo` scope
## 🐛 Troubleshooting
### Release workflow fails
- Check GitHub Actions logs
- Verify all required secrets are set
- Ensure tag follows v\d+.\d+.\d+ format
- Verify secrets are configured
- Ensure tag follows `v\d+.\d+.\d+` format
### Homebrew formula not updating
- Verify HOMEBREW_TAP_TOKEN is valid
### Homebrew not updating
- Verify `HOMEBREW_TAP_TOKEN` is valid
- Check tap repository permissions
- Review release workflow logs
### Install script fails
- Test locally with different OS/arch combinations
- Check release binary naming matches script expectations
- Verify binaries are attached to release
## ✅ Checklist for First Release
- [ ] All code committed and pushed
- [ ] GitHub Pages enabled
- [ ] Homebrew tap repository created
- [ ] HOMEBREW_TAP_TOKEN secret set
- [ ] CHANGELOG.md updated
- [ ] Version tag created and pushed
- [ ] Release workflow completed successfully
- [ ] Binaries attached to release
- [ ] Homebrew formula updated
- [ ] Install script tested
- [ ] Documentation website live
- [ ] README.md installation links work
---
**Documentation last updated**: 2025-11-23
**Setup completed for**: kportal v0.1.5
- Verify release binaries are attached
- Check binary naming matches script expectations
+104
View File
@@ -0,0 +1,104 @@
# Interactive Wizards
kportal includes wizards for adding and removing port forwards from the running UI.
## ⌨️ Quick Reference
| Key | Action |
|-----|--------|
| `a` | Add new forward |
| `d` | Delete forwards |
## Add Forward Wizard
Press `a` from the main view to start the wizard.
### Steps
1. **Context** - Select Kubernetes context
2. **Namespace** - Select namespace
3. **Resource Type** - Choose pod (prefix), pod (selector), or service
4. **Resource** - Enter prefix, selector, or select service
5. **Remote Port** - Enter port on the resource
6. **Local Port** - Enter local port (validates availability)
7. **Confirm** - Review and optionally add an alias
### Navigation
| Key | Action |
|-----|--------|
| `↑↓` / `j/k` | Navigate options |
| `Enter` | Confirm and proceed |
| `Esc` | Go back / Cancel |
| `Ctrl+C` | Cancel immediately |
## 🗑️ Delete Forward Wizard
Press `d` from the main view.
### Navigation
| Key | Action |
|-----|--------|
| `↑↓` / `j/k` | Navigate |
| `Space` | Toggle selection |
| `a` | Select all |
| `n` | Deselect all |
| `Enter` | Confirm deletion |
| `Esc` | Cancel |
## 🎯 Resource Selection
### Pod by Prefix
Enter app name prefix to match pods:
- `nginx` matches `nginx-deployment-abc123`
- `postgres` matches `postgres-statefulset-0`
### Pod by Selector
Use Kubernetes label syntax:
- `app=nginx`
- `app=nginx,env=prod`
Matching pods are shown in real-time.
### Service
Select from discovered services in the namespace.
## 🔄 Auto Hot-Reload
Changes are applied automatically:
1. Wizard writes to `.kportal.yaml` atomically
2. File watcher detects change (~100ms)
3. Manager reloads and starts forward
4. UI updates
## Error Handling
The wizards handle:
- Cluster unreachable - allows manual entry
- Port conflicts - shows which process is using the port
- Invalid selectors - real-time validation
- Duplicate ports - prevents conflicts
## 🐛 Troubleshooting
### Wizard not appearing
Verify cluster connectivity:
```bash
kubectl cluster-info
```
### Port validation delayed
Port checks run asynchronously. Wait briefly after typing.
### Changes not visible
Check:
1. `.kportal.yaml` was written correctly
2. No validation errors in file
3. kportal process is running
+405
View File
@@ -0,0 +1,405 @@
package main
import (
"context"
"flag"
"fmt"
"io"
"log"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/go-logr/logr"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/converter"
"github.com/nvm/kportal/internal/forward"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/mdns"
"github.com/nvm/kportal/internal/ui"
"github.com/nvm/kportal/internal/version"
"k8s.io/klog/v2"
)
const (
defaultConfigFile = ".kportal.yaml"
initialForwardSettleTime = 100 * time.Millisecond
tableUpdateInterval = 2 * time.Second
// GitHub repository info for update checks
githubOwner = "lukaszraczylo"
githubRepo = "kportal"
)
var (
configFile = flag.String("c", defaultConfigFile, "Path to configuration file")
verbose = flag.Bool("v", false, "Enable verbose logging")
logFormat = flag.String("log-format", "text", "Log format: text or json")
check = flag.Bool("check", false, "Validate configuration and exit")
showVersion = flag.Bool("version", false, "Show version and exit")
checkUpdate = flag.Bool("update", false, "Check for updates and exit")
convertInput = flag.String("convert", "", "Convert kftray JSON config to kportal YAML (provide input file path)")
convertOutput = flag.String("convert-output", ".kportal.yaml", "Output file for converted configuration")
appVersion = "0.1.0" // Set via ldflags during build
)
func main() {
flag.Parse()
if *showVersion {
fmt.Printf("kportal version %s\n", appVersion)
os.Exit(0)
}
if *checkUpdate {
checkForUpdates()
os.Exit(0)
}
// Validate config path security
if *configFile != "" {
absConfigPath, err := filepath.Abs(*configFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid config path: %v\n", err)
os.Exit(1)
}
absConfigPath = filepath.Clean(absConfigPath)
// Block system directories
systemDirs := []string{"/etc", "/sys", "/proc", "/dev"}
for _, sysDir := range systemDirs {
if strings.HasPrefix(absConfigPath, sysDir) {
fmt.Fprintf(os.Stderr, "Error: Config file cannot be in system directory: %s\n", sysDir)
os.Exit(1)
}
}
*configFile = absConfigPath
}
// Initialize structured logger
var logLevel logger.Level
var logFmt logger.Format
var logOutput io.Writer
if *verbose {
logLevel = logger.LevelDebug
logOutput = os.Stderr
} else {
logLevel = logger.LevelInfo
logOutput = io.Discard // Silence logger in non-verbose mode to prevent UI corruption
}
switch *logFormat {
case "json":
logFmt = logger.FormatJSON
default:
logFmt = logger.FormatText
}
logger.Init(logLevel, logFmt, logOutput)
// Configure klog (used by kubernetes client-go) to route through our logger
// This prevents k8s logs from interfering with the UI
//
// klog v2 uses multiple output mechanisms:
// 1. SetOutput() - for basic text output
// 2. SetLogger() - for structured/error logs (logr interface)
//
// We must configure BOTH to capture all logs including error messages
// that would otherwise bypass SetOutput() and write directly to stderr.
klog.LogToStderr(false) // Disable direct stderr writes
if *verbose {
// In verbose mode, route all klog through our structured logger at DEBUG level
klogLogger := logger.New(logger.LevelDebug, logFmt, os.Stderr)
// Configure text output routing
klogWriter := logger.NewKlogWriter(klogLogger)
klog.SetOutput(klogWriter)
// Configure structured/error log routing via logr interface
// This captures "Unhandled Error" and other structured logs that bypass SetOutput
logrSink := logger.NewLogrAdapter(klogLogger)
klog.SetLogger(logr.New(logrSink))
} else {
// In non-verbose mode, completely silence ALL klog output
klog.SetOutput(io.Discard)
// Also silence structured/error logs via a discard logger
silentLogger := logger.New(logger.LevelError+1, logFmt, io.Discard) // Level above ERROR = silence all
logrSink := logger.NewLogrAdapter(silentLogger)
klog.SetLogger(logr.New(logrSink))
}
// Handle conversion mode
if *convertInput != "" {
if err := converter.ConvertKFTrayToKPortal(*convertInput, *convertOutput); err != nil {
fmt.Fprintf(os.Stderr, "Error converting configuration: %v\n", err)
os.Exit(1)
}
// Print summary
contextMap, totalForwards, err := converter.GetConversionSummary(*convertInput)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: Could not generate summary: %v\n", err)
} else {
fmt.Printf("Successfully converted %d forwards from %s to %s\n", totalForwards, *convertInput, *convertOutput)
fmt.Printf("Generated configuration with:\n")
for ctx, namespaces := range contextMap {
fmt.Printf(" - Context '%s':\n", ctx)
for ns, count := range namespaces {
fmt.Printf(" - Namespace '%s': %d forwards\n", ns, count)
}
}
}
os.Exit(0)
}
if !*verbose {
// In interactive mode, disable ALL logging to avoid interfering with bubbletea UI
log.SetOutput(io.Discard)
log.SetPrefix("")
log.SetFlags(0)
} else {
// Verbose mode - enable standard log formatting
log.SetFlags(log.LstdFlags | log.Lshortfile)
}
// Load configuration
cfg, err := config.LoadConfig(*configFile)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading config: %v\n", err)
os.Exit(1)
}
// Validate configuration
validator := config.NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
fmt.Fprint(os.Stderr, config.FormatValidationErrors(errs))
os.Exit(1)
}
if *check {
fmt.Println("Configuration is valid")
os.Exit(0)
}
// Only log startup messages in verbose mode
if *verbose {
log.Printf("kportal v%s", appVersion)
log.Printf("Loading configuration from: %s", *configFile)
}
// Create Kubernetes client pool and discovery for wizards
pool, err := k8s.NewClientPool()
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to create k8s client pool: %v\n", err)
fmt.Fprintf(os.Stderr, "Add/remove wizards will not be available\n")
}
discovery := k8s.NewDiscovery(pool)
mutator := config.NewMutator(*configFile)
// Create forward manager
manager, err := forward.NewManager(*verbose)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating forward manager: %v\n", err)
os.Exit(1)
}
// Create mDNS publisher if enabled in config
mdnsPublisher := mdns.NewPublisher(cfg.IsMDNSEnabled())
manager.SetMDNSPublisher(mdnsPublisher)
if cfg.IsMDNSEnabled() && *verbose {
log.Printf("mDNS hostname publishing enabled - aliases will be accessible via <alias>.local")
}
// Create UI (bubbletea for interactive, simple table for verbose)
var bubbleTeaUI *ui.BubbleTeaUI
var tableUI *ui.TableUI
if !*verbose {
// Interactive mode with bubbletea
bubbleTeaUI = ui.NewBubbleTeaUI(func(id string, enable bool) {
if enable {
manager.EnableForward(id)
} else {
manager.DisableForward(id)
}
}, appVersion)
// Set wizard dependencies
// Note: mutator is always available (for delete/edit), discovery requires valid kubeconfig (for add)
bubbleTeaUI.SetWizardDependencies(discovery, mutator, *configFile)
// Check for updates in background (non-blocking)
go func() {
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if update := checker.CheckForUpdate(ctx); update != nil {
bubbleTeaUI.SetUpdateAvailable(update.LatestVersion, update.ReleaseURL)
}
}()
manager.SetStatusUI(bubbleTeaUI)
} else {
// Verbose mode with simple table
tableUI = ui.NewTableUI(*verbose)
manager.SetStatusUI(tableUI)
// Check for updates and print to log
go func() {
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if update := checker.CheckForUpdate(ctx); update != nil {
log.Printf("Update available: v%s (current: v%s) - %s",
update.LatestVersion, update.CurrentVersion, update.ReleaseURL)
}
}()
}
// Start forwards
if err := manager.Start(cfg); err != nil {
fmt.Fprintf(os.Stderr, "Error starting forwards: %v\n", err)
os.Exit(1)
}
if *verbose {
// Verbose mode - use simple table with periodic updates
tableUI.RenderInitial()
// Setup signal handling
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
// Start table update loop
go func() {
ticker := time.NewTicker(tableUpdateInterval)
defer ticker.Stop()
for range ticker.C {
tableUI.Render()
}
}()
// Setup config watcher for hot-reload
watcher, err := config.NewWatcher(*configFile, func(newCfg *config.Config) error {
return manager.Reload(newCfg)
}, *verbose)
if err != nil {
log.Printf("Warning: Failed to setup config watcher: %v", err)
log.Printf("Hot-reload will not be available")
} else {
watcher.Start()
defer watcher.Stop()
}
log.Printf("Press Ctrl+C to stop")
// Wait for signals
for {
sig := <-sigChan
switch sig {
case syscall.SIGHUP:
log.Printf("Received SIGHUP, reloading configuration...")
newCfg, err := config.LoadConfig(*configFile)
if err != nil {
log.Printf("Failed to reload config: %v", err)
continue
}
if errs := validator.ValidateConfig(newCfg); len(errs) > 0 {
log.Printf("Config validation failed:")
log.Print(config.FormatValidationErrors(errs))
continue
}
if err := manager.Reload(newCfg); err != nil {
log.Printf("Failed to reload: %v", err)
}
case os.Interrupt, syscall.SIGTERM:
log.Printf("Received shutdown signal, stopping...")
// Graceful shutdown with timeout - force exit if it takes too long
shutdownDone := make(chan struct{})
go func() {
manager.Stop()
close(shutdownDone)
}()
select {
case <-shutdownDone:
log.Printf("Graceful shutdown complete")
case <-time.After(5 * time.Second):
log.Printf("Shutdown timed out, forcing exit...")
case sig := <-sigChan:
// Second signal received - force exit immediately
log.Printf("Received second signal (%v), forcing exit...", sig)
}
os.Exit(0)
}
}
} else {
// Interactive mode with bubbletea
// Setup config watcher in background
watcher, err := config.NewWatcher(*configFile, func(newCfg *config.Config) error {
return manager.Reload(newCfg)
}, *verbose)
if err == nil {
watcher.Start()
defer watcher.Stop()
}
// Setup signal handler for clean shutdown
go func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
<-sigChan
bubbleTeaUI.Stop()
manager.Stop()
os.Exit(0)
}()
// Give a moment for initial forwards to be added
time.Sleep(initialForwardSettleTime)
// Start the bubbletea app (blocks until quit)
if err := bubbleTeaUI.Start(); err != nil {
fmt.Fprintf(os.Stderr, "Failed to start UI: %v\n", err)
manager.Stop()
os.Exit(1)
}
// Clean shutdown
manager.Stop()
}
}
// checkForUpdates checks for available updates and prints the result
func checkForUpdates() {
fmt.Printf("kportal version %s\n", appVersion)
fmt.Println("Checking for updates...")
checker := version.NewChecker(githubOwner, githubRepo, appVersion)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
update := checker.CheckForUpdate(ctx)
if update == nil {
fmt.Println("You are running the latest version.")
return
}
fmt.Printf("\nUpdate available: v%s\n", update.LatestVersion)
fmt.Printf("Download: %s\n", update.ReleaseURL)
fmt.Println("\nTo update, download the latest release from the URL above")
fmt.Println("or use your package manager (e.g., 'brew upgrade kportal').")
}
+696 -346
View File
File diff suppressed because it is too large Load Diff
+132
View File
@@ -0,0 +1,132 @@
<svg width="310" height="150" viewBox="0 0 310 150" xmlns="http://www.w3.org/2000/svg" id="darkLogo">
<defs>
<!-- Simple turbulence for portal edges -->
<filter id="portalTurbulence" x="-50%" y="-50%" width="200%" height="200%">
<feTurbulence type="fractalNoise" baseFrequency="0.02 0.03" numOctaves="2" result="turbulence" seed="5">
<animate attributeName="seed" values="5;10;5" dur="8s" repeatCount="indefinite"/>
</feTurbulence>
<feDisplacementMap in2="turbulence" in="SourceGraphic" scale="2" xChannelSelector="R" yChannelSelector="G"/>
</filter>
<!-- Blue glow -->
<filter id="blueGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="5" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Orange glow -->
<filter id="orangeGlow" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="5" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Text glow -->
<filter id="textGlow" x="-20%" y="-20%" width="140%" height="140%">
<feGaussianBlur stdDeviation="1" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Gradients -->
<radialGradient id="bluePortal" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#000814;stop-opacity:0.9"/>
<stop offset="20%" style="stop-color:#001845;stop-opacity:0.8"/>
<stop offset="50%" style="stop-color:#0077B6;stop-opacity:0.95"/>
<stop offset="80%" style="stop-color:#00B4D8;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#90E0EF;stop-opacity:1"/>
</radialGradient>
<radialGradient id="orangePortal" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#1A0E00;stop-opacity:0.9"/>
<stop offset="20%" style="stop-color:#3D2314;stop-opacity:0.8"/>
<stop offset="50%" style="stop-color:#F77F00;stop-opacity:0.95"/>
<stop offset="80%" style="stop-color:#FCBF49;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FFD6A5;stop-opacity:1"/>
</radialGradient>
</defs>
<!-- Blue Portal (LEFT) -->
<g id="bluePortalGroup">
<!-- Outer rings -->
<ellipse cx="50" cy="75" rx="35" ry="50" fill="none" stroke="#90E0EF" stroke-width="0.5" opacity="0.2"/>
<ellipse cx="50" cy="75" rx="30" ry="44" fill="none" stroke="#00B4D8" stroke-width="1" opacity="0.3"/>
<!-- Main portal -->
<ellipse cx="50" cy="75" rx="26" ry="40" fill="url(#bluePortal)" filter="url(#blueGlow)" opacity="0.95"/>
<!-- Inner energy rings -->
<ellipse cx="50" cy="75" rx="20" ry="32" fill="none" stroke="#00B4D8" stroke-width="2" opacity="0.7">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="50" cy="75" rx="14" ry="24" fill="none" stroke="#90E0EF" stroke-width="1.5" opacity="0.5">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="50" cy="75" rx="7" ry="12" fill="#000814" opacity="0.95"/>
</g>
<!-- Text: "kportal" -->
<!-- Orange K -->
<text x="76" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="#FCBF49" filter="url(#textGlow)">
k
<animate attributeName="x" values="76;79;76" dur="4s" repeatCount="indefinite"/>
</text>
<!-- White "porta" -->
<text x="105" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="white" filter="url(#textGlow)">
porta
</text>
<!-- Blue L -->
<text x="220" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="300" fill="#00B4D8" filter="url(#textGlow)">
l
<animate attributeName="x" values="220;223;220" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Orange Portal (RIGHT) at x=260 -->
<g id="orangePortalGroup">
<!-- Outer rings -->
<ellipse cx="260" cy="75" rx="35" ry="50" fill="none" stroke="#FFD6A5" stroke-width="0.5" opacity="0.2"/>
<ellipse cx="260" cy="75" rx="30" ry="44" fill="none" stroke="#FCBF49" stroke-width="1" opacity="0.3"/>
<!-- Main portal -->
<ellipse cx="260" cy="75" rx="26" ry="40" fill="url(#orangePortal)" filter="url(#orangeGlow)" opacity="0.95"/>
<!-- Inner energy rings -->
<ellipse cx="260" cy="75" rx="20" ry="32" fill="none" stroke="#FCBF49" stroke-width="2" opacity="0.7">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="260" cy="75" rx="14" ry="24" fill="none" stroke="#FFD6A5" stroke-width="1.5" opacity="0.5">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="260" cy="75" rx="7" ry="12" fill="#1A0E00" opacity="0.95"/>
</g>
<!-- Energy connection between portals -->
<path d="M 76 75 Q 180 70 222 75" stroke="url(#energyGradient)" stroke-width="0.5" fill="none" opacity="0.3">
<animate attributeName="opacity" values="0.1;0.3;0.1" dur="4s" repeatCount="indefinite"/>
</path>
<defs>
<linearGradient id="energyGradient" x1="0%" y1="0%" x2="100%" y2="0%">
<stop offset="0%" style="stop-color:#00B4D8;stop-opacity:1"/>
<stop offset="50%" style="stop-color:#667eea;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FCBF49;stop-opacity:1"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 7.7 KiB

+128
View File
@@ -0,0 +1,128 @@
<svg width="310" height="150" viewBox="0 0 310 150" xmlns="http://www.w3.org/2000/svg" id="lightLogo">
<defs>
<!-- Simple turbulence for portal edges -->
<filter id="portalTurbulenceLight" x="-50%" y="-50%" width="200%" height="200%">
<feTurbulence type="fractalNoise" baseFrequency="0.02 0.03" numOctaves="2" result="turbulence" seed="5">
<animate attributeName="seed" values="5;10;5" dur="8s" repeatCount="indefinite"/>
</feTurbulence>
<feDisplacementMap in2="turbulence" in="SourceGraphic" scale="2" xChannelSelector="R" yChannelSelector="G"/>
</filter>
<!-- Blue glow for light background -->
<filter id="blueGlowLight" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="4" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Orange glow for light background -->
<filter id="orangeGlowLight" x="-50%" y="-50%" width="200%" height="200%">
<feGaussianBlur stdDeviation="4" result="coloredBlur"/>
<feMerge>
<feMergeNode in="coloredBlur"/>
<feMergeNode in="SourceGraphic"/>
</feMerge>
</filter>
<!-- Text shadow for light background -->
<filter id="textShadowLight" x="-20%" y="-20%" width="140%" height="140%">
<feDropShadow dx="0" dy="1" stdDeviation="0.5" flood-opacity="0.15"/>
</filter>
<!-- Enhanced gradients for light background -->
<radialGradient id="bluePortalLight" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#001529;stop-opacity:1"/>
<stop offset="20%" style="stop-color:#002766;stop-opacity:0.95"/>
<stop offset="50%" style="stop-color:#0066CC;stop-opacity:0.98"/>
<stop offset="80%" style="stop-color:#0099FF;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#66CCFF;stop-opacity:1"/>
</radialGradient>
<radialGradient id="orangePortalLight" cx="50%" cy="50%">
<stop offset="0%" style="stop-color:#2E1A00;stop-opacity:1"/>
<stop offset="20%" style="stop-color:#5C3317;stop-opacity:0.95"/>
<stop offset="50%" style="stop-color:#E66100;stop-opacity:0.98"/>
<stop offset="80%" style="stop-color:#FF9933;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#FFBB66;stop-opacity:1"/>
</radialGradient>
</defs>
<!-- Blue Portal (LEFT) -->
<g id="bluePortalGroupLight">
<!-- Outer rings -->
<ellipse cx="50" cy="75" rx="35" ry="50" fill="none" stroke="#0099FF" stroke-width="0.8" opacity="0.3"/>
<ellipse cx="50" cy="75" rx="30" ry="44" fill="none" stroke="#0066CC" stroke-width="1.2" opacity="0.4"/>
<!-- Main portal -->
<ellipse cx="50" cy="75" rx="26" ry="40" fill="url(#bluePortalLight)" filter="url(#blueGlowLight)" opacity="1"/>
<!-- Inner energy rings -->
<ellipse cx="50" cy="75" rx="20" ry="32" fill="none" stroke="#0099FF" stroke-width="2" opacity="0.8">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="50" cy="75" rx="14" ry="24" fill="none" stroke="#66CCFF" stroke-width="1.5" opacity="0.6">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="50" cy="75" rx="7" ry="12" fill="#001529" opacity="1"/>
</g>
<!-- Text: "kportal" with dark colors for light background -->
<!-- Orange K -->
<text x="76" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#E66100" filter="url(#textShadowLight)">
k
<animate attributeName="x" values="76;79;76" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Dark "porta" for light background -->
<text x="105" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#2C3E50" filter="url(#textShadowLight)">
porta
</text>
<!-- Blue L -->
<text x="220" y="90" font-family="'Helvetica Neue', Arial, sans-serif" font-size="52" font-weight="400" fill="#0066CC" filter="url(#textShadowLight)">
l
<animate attributeName="x" values="220;223;220" dur="4s" repeatCount="indefinite"/>
</text>
<!-- Orange Portal (RIGHT) at x=260 -->
<g id="orangePortalGroupLight">
<!-- Outer rings -->
<ellipse cx="260" cy="75" rx="35" ry="50" fill="none" stroke="#FFBB66" stroke-width="0.8" opacity="0.3"/>
<ellipse cx="260" cy="75" rx="30" ry="44" fill="none" stroke="#FF9933" stroke-width="1.2" opacity="0.4"/>
<!-- Main portal -->
<ellipse cx="260" cy="75" rx="26" ry="40" fill="url(#orangePortalLight)" filter="url(#orangeGlowLight)" opacity="1"/>
<!-- Inner energy rings -->
<ellipse cx="260" cy="75" rx="20" ry="32" fill="none" stroke="#FF9933" stroke-width="2" opacity="0.8">
<animate attributeName="rx" values="20;18;20" dur="3s" repeatCount="indefinite"/>
<animate attributeName="ry" values="32;30;32" dur="3s" repeatCount="indefinite"/>
</ellipse>
<ellipse cx="260" cy="75" rx="14" ry="24" fill="none" stroke="#FFBB66" stroke-width="1.5" opacity="0.6">
<animate attributeName="rx" values="14;16;14" dur="2.5s" repeatCount="indefinite"/>
<animate attributeName="ry" values="24;26;24" dur="2.5s" repeatCount="indefinite"/>
</ellipse>
<!-- Portal core -->
<ellipse cx="260" cy="75" rx="7" ry="12" fill="#2E1A00" opacity="1"/>
</g>
<!-- Energy connection between portals -->
<path d="M 76 75 Q 180 70 222 75" stroke="url(#energyGradientLight)" stroke-width="0.7" fill="none" opacity="0.4">
<animate attributeName="opacity" values="0.2;0.4;0.2" dur="4s" repeatCount="indefinite"/>
</path>
<defs>
<linearGradient id="energyGradientLight" x1="0%" y1="0%" x2="100%" y2="0%">
<stop offset="0%" style="stop-color:#0066CC;stop-opacity:1"/>
<stop offset="50%" style="stop-color:#8B7CC6;stop-opacity:1"/>
<stop offset="100%" style="stop-color:#E66100;stop-opacity:1"/>
</linearGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 184 KiB

+5
View File
@@ -17,6 +17,7 @@ require (
require (
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/cenkalti/backoff v2.2.1+incompatible // indirect
github.com/charmbracelet/colorprofile v0.3.3 // indirect
github.com/charmbracelet/x/ansi v0.11.1 // indirect
github.com/charmbracelet/x/cellbuf v0.0.14 // indirect
@@ -25,6 +26,7 @@ require (
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emicklei/go-restful/v3 v3.13.0 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fxamacker/cbor/v2 v2.9.0 // indirect
@@ -47,11 +49,13 @@ require (
github.com/google/gnostic-models v0.7.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/grandcat/zeroconf v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/miekg/dns v1.1.27 // indirect
github.com/moby/spdystream v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
@@ -67,6 +71,7 @@ require (
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/oauth2 v0.33.0 // indirect
golang.org/x/sys v0.38.0 // indirect
+16
View File
@@ -2,6 +2,8 @@ github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPd
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.3.3 h1:DjJzJtLP6/NZ8p7Cgjno0CKGr7wwRJGxWUwh2IyhfAI=
@@ -23,6 +25,8 @@ github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsV
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bFY/oTyCes=
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
@@ -82,6 +86,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo=
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA=
github.com/grandcat/zeroconf v1.0.0 h1:uHhahLBKqwWBV6WZUDAT71044vwOTL+McW0mBJvo6kE=
github.com/grandcat/zeroconf v1.0.0/go.mod h1:lTKmG1zh86XyCoUeIHSA4FJMBwCJiQmGfcP2PdzytEs=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -98,6 +104,8 @@ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2J
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM=
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU=
github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -120,6 +128,7 @@ github.com/onsi/ginkgo/v2 v2.21.0 h1:7rg/4f3rB88pb5obDgNZrNHrQ4e6WpjonchcpuBRnZM
github.com/onsi/ginkgo/v2 v2.21.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo=
github.com/onsi/gomega v1.35.1 h1:Cwbd75ZBPxFSuZ6T+rN/WCb/gOc6YgFBXLlZLhC7Ds4=
github.com/onsi/gomega v1.35.1/go.mod h1:PvZbdDc8J6XJEpDK4HCuRBm8a6Fzp9/DmhC9C7yFlog=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
@@ -147,12 +156,17 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
@@ -164,6 +178,7 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -179,6 +194,7 @@ golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
+171 -37
View File
@@ -1,15 +1,149 @@
package config
import (
"bytes"
"fmt"
"os"
"strings"
"time"
"gopkg.in/yaml.v3"
)
const (
// maxConfigSize is the maximum allowed configuration file size (10MB)
maxConfigSize = 10 * 1024 * 1024
// Default health check settings
DefaultHealthCheckInterval = 3 * time.Second // How often to check connection health
DefaultHealthCheckTimeout = 2 * time.Second // Timeout for health check probes
DefaultHealthCheckMethod = "data-transfer" // More reliable than tcp-dial
DefaultMaxConnectionAge = 25 * time.Minute // Reconnect before k8s 30min timeout
DefaultMaxIdleTime = 10 * time.Minute // Reconnect if no activity
// Default reliability settings
DefaultTCPKeepalive = 30 * time.Second // OS-level TCP keepalive interval
DefaultDialTimeout = 30 * time.Second // Connection establishment timeout
DefaultWatchdogPeriod = 30 * time.Second // Goroutine health check interval
)
// Config represents the root configuration structure from .kportal.yaml
type Config struct {
Contexts []Context `yaml:"contexts"`
Contexts []Context `yaml:"contexts"`
HealthCheck *HealthCheckSpec `yaml:"healthCheck,omitempty"`
Reliability *ReliabilitySpec `yaml:"reliability,omitempty"`
MDNS *MDNSSpec `yaml:"mdns,omitempty"`
}
// MDNSSpec configures mDNS (multicast DNS) hostname publishing
// When enabled, forwards with aliases can be accessed via <alias>.local hostnames
type MDNSSpec struct {
Enabled bool `yaml:"enabled"` // Enable mDNS hostname publishing
}
// HealthCheckSpec configures health check behavior
type HealthCheckSpec struct {
Interval string `yaml:"interval,omitempty"` // e.g., "3s", "5s"
Timeout string `yaml:"timeout,omitempty"` // e.g., "2s"
Method string `yaml:"method,omitempty"` // "tcp-dial" | "data-transfer"
MaxConnectionAge string `yaml:"maxConnectionAge,omitempty"` // e.g., "25m" - reconnect before k8s timeout
MaxIdleTime string `yaml:"maxIdleTime,omitempty"` // e.g., "10m" - reconnect if no activity
}
// ReliabilitySpec configures connection reliability features
type ReliabilitySpec struct {
TCPKeepalive string `yaml:"tcpKeepalive,omitempty"` // e.g., "30s" - OS-level keepalive
DialTimeout string `yaml:"dialTimeout,omitempty"` // e.g., "30s" - connection dial timeout
RetryOnStale bool `yaml:"retryOnStale,omitempty"` // Auto-reconnect on stale detection
WatchdogPeriod string `yaml:"watchdogPeriod,omitempty"` // e.g., "30s" - goroutine watchdog interval
}
// parseDurationOrDefault parses a duration string and returns the default if empty or invalid.
func parseDurationOrDefault(value string, defaultDur time.Duration) time.Duration {
if value == "" {
return defaultDur
}
if d, err := time.ParseDuration(value); err == nil {
return d
}
return defaultDur
}
// GetHealthCheckIntervalOrDefault returns the health check interval or default value
func (c *Config) GetHealthCheckIntervalOrDefault() time.Duration {
if c.HealthCheck == nil {
return DefaultHealthCheckInterval
}
return parseDurationOrDefault(c.HealthCheck.Interval, DefaultHealthCheckInterval)
}
// GetHealthCheckTimeoutOrDefault returns the health check timeout or default value
func (c *Config) GetHealthCheckTimeoutOrDefault() time.Duration {
if c.HealthCheck == nil {
return DefaultHealthCheckTimeout
}
return parseDurationOrDefault(c.HealthCheck.Timeout, DefaultHealthCheckTimeout)
}
// GetHealthCheckMethod returns the health check method or default
func (c *Config) GetHealthCheckMethod() string {
if c.HealthCheck != nil && c.HealthCheck.Method != "" {
return c.HealthCheck.Method
}
return DefaultHealthCheckMethod
}
// GetMaxConnectionAge returns the max connection age or default
func (c *Config) GetMaxConnectionAge() time.Duration {
if c.HealthCheck == nil {
return DefaultMaxConnectionAge
}
return parseDurationOrDefault(c.HealthCheck.MaxConnectionAge, DefaultMaxConnectionAge)
}
// GetMaxIdleTime returns the max idle time or default
func (c *Config) GetMaxIdleTime() time.Duration {
if c.HealthCheck == nil {
return DefaultMaxIdleTime
}
return parseDurationOrDefault(c.HealthCheck.MaxIdleTime, DefaultMaxIdleTime)
}
// GetTCPKeepalive returns the TCP keepalive duration or default
func (c *Config) GetTCPKeepalive() time.Duration {
if c.Reliability == nil {
return DefaultTCPKeepalive
}
return parseDurationOrDefault(c.Reliability.TCPKeepalive, DefaultTCPKeepalive)
}
// GetRetryOnStale returns whether to retry on stale connections
func (c *Config) GetRetryOnStale() bool {
if c.Reliability != nil {
return c.Reliability.RetryOnStale
}
return true // Default: enabled
}
// GetWatchdogPeriod returns the goroutine watchdog check period or default
func (c *Config) GetWatchdogPeriod() time.Duration {
if c.Reliability == nil {
return DefaultWatchdogPeriod
}
return parseDurationOrDefault(c.Reliability.WatchdogPeriod, DefaultWatchdogPeriod)
}
// GetDialTimeout returns the connection dial timeout or default
func (c *Config) GetDialTimeout() time.Duration {
if c.Reliability == nil {
return DefaultDialTimeout
}
return parseDurationOrDefault(c.Reliability.DialTimeout, DefaultDialTimeout)
}
// IsMDNSEnabled returns whether mDNS hostname publishing is enabled
func (c *Config) IsMDNSEnabled() bool {
return c.MDNS != nil && c.MDNS.Enabled
}
// Context represents a Kubernetes context with its namespaces
@@ -78,8 +212,37 @@ func (f *Forward) GetNamespace() string {
return f.namespaceName
}
// GetMDNSAlias returns the alias to use for mDNS hostname registration.
// If an explicit alias is set, it returns that.
// Otherwise, it generates one from the resource name (e.g., "service/logto" -> "logto").
func (f *Forward) GetMDNSAlias() string {
if f.Alias != "" {
return f.Alias
}
// Generate alias from resource name
// Format is "type/name" (e.g., "service/logto", "pod/my-app")
parts := strings.SplitN(f.Resource, "/", 2)
if len(parts) == 2 && parts[1] != "" {
return parts[1]
}
// Fallback: can't generate a valid alias (e.g., "pod" with selector)
return ""
}
// LoadConfig loads and parses the configuration file from the given path.
func LoadConfig(path string) (*Config, error) {
// Validate file size before reading
fileInfo, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("failed to stat config file: %w", err)
}
if fileInfo.Size() > maxConfigSize {
return nil, fmt.Errorf("config file too large: %d bytes (max %d)", fileInfo.Size(), maxConfigSize)
}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
@@ -89,9 +252,15 @@ func LoadConfig(path string) (*Config, error) {
}
// ParseConfig parses YAML configuration data into a Config struct.
// It uses strict parsing that rejects unknown keys to catch typos.
func ParseConfig(data []byte) (*Config, error) {
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
// Use decoder with KnownFields to reject unknown keys (catches typos)
decoder := yaml.NewDecoder(bytes.NewReader(data))
decoder.KnownFields(true)
if err := decoder.Decode(&cfg); err != nil {
return nil, fmt.Errorf("failed to parse YAML: %w", err)
}
@@ -122,38 +291,3 @@ func (c *Config) GetAllForwards() []Forward {
return forwards
}
// GetForwardsByContext returns all forwards for a specific context.
func (c *Config) GetForwardsByContext(contextName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
forwards = append(forwards, ns.Forwards...)
}
break
}
}
return forwards
}
// GetForwardsByNamespace returns all forwards for a specific context and namespace.
func (c *Config) GetForwardsByNamespace(contextName, namespaceName string) []Forward {
var forwards []Forward
for _, ctx := range c.Contexts {
if ctx.Name == contextName {
for _, ns := range ctx.Namespaces {
if ns.Name == namespaceName {
forwards = append(forwards, ns.Forwards...)
break
}
}
break
}
}
return forwards
}
+1 -67
View File
@@ -97,7 +97,7 @@ func TestLoadConfig_FileNotFound(t *testing.T) {
cfg, err := LoadConfig("/non/existent/path/.kportal.yaml")
assert.Error(t, err, "LoadConfig should fail with non-existent file")
assert.Nil(t, cfg, "config should be nil on error")
assert.Contains(t, err.Error(), "failed to read config file", "error should mention read failure")
assert.Contains(t, err.Error(), "failed to stat config file", "error should mention stat failure")
}
func TestForward_ID(t *testing.T) {
@@ -298,72 +298,6 @@ func TestConfig_GetAllForwards(t *testing.T) {
assert.Len(t, forwards, 4, "should return all forwards from all contexts and namespaces")
}
func TestConfig_GetForwardsByContext(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: cluster2
namespaces:
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByContext("cluster1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1")
forwards2 := cfg.GetForwardsByContext("cluster2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster2")
forwards3 := cfg.GetForwardsByContext("non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent context")
}
func TestConfig_GetForwardsByNamespace(t *testing.T) {
yamlData := []byte(`contexts:
- name: cluster1
namespaces:
- name: ns1
forwards:
- resource: pod/app1
port: 8080
localPort: 8080
- resource: pod/app2
port: 8081
localPort: 8081
- name: ns2
forwards:
- resource: pod/app3
port: 9090
localPort: 9090
`)
cfg, err := ParseConfig(yamlData)
assert.NoError(t, err)
forwards := cfg.GetForwardsByNamespace("cluster1", "ns1")
assert.Len(t, forwards, 2, "should return forwards only from cluster1/ns1")
forwards2 := cfg.GetForwardsByNamespace("cluster1", "ns2")
assert.Len(t, forwards2, 1, "should return forwards only from cluster1/ns2")
forwards3 := cfg.GetForwardsByNamespace("cluster1", "non-existent")
assert.Len(t, forwards3, 0, "should return empty slice for non-existent namespace")
}
func TestForward_SetContext(t *testing.T) {
fwd := Forward{
Resource: "pod/my-app",
+273
View File
@@ -0,0 +1,273 @@
package config
import (
"fmt"
"os"
"path/filepath"
"sync"
"gopkg.in/yaml.v3"
)
// Mutator provides safe, atomic mutations to the kportal configuration file.
// All operations use atomic file writes (write to temp, then rename) to prevent
// corruption and ensure the file watcher picks up changes.
type Mutator struct {
configPath string
mu sync.Mutex // Ensure only one mutation at a time
}
// NewMutator creates a new configuration mutator for the given config file path.
func NewMutator(configPath string) *Mutator {
return &Mutator{
configPath: configPath,
}
}
// findOrCreateContext finds an existing context or creates a new one
func (m *Mutator) findOrCreateContext(cfg *Config, contextName string) *Context {
for i := range cfg.Contexts {
if cfg.Contexts[i].Name == contextName {
return &cfg.Contexts[i]
}
}
// Create new context
cfg.Contexts = append(cfg.Contexts, Context{
Name: contextName,
Namespaces: []Namespace{},
})
return &cfg.Contexts[len(cfg.Contexts)-1]
}
// findOrCreateNamespace finds an existing namespace or creates a new one
func (m *Mutator) findOrCreateNamespace(ctx *Context, namespaceName string) *Namespace {
for i := range ctx.Namespaces {
if ctx.Namespaces[i].Name == namespaceName {
return &ctx.Namespaces[i]
}
}
// Create new namespace
ctx.Namespaces = append(ctx.Namespaces, Namespace{
Name: namespaceName,
Forwards: []Forward{},
})
return &ctx.Namespaces[len(ctx.Namespaces)-1]
}
// AddForward adds a new port forward to the configuration.
// If the context or namespace doesn't exist, they will be created.
// The new configuration is validated before writing.
// Returns an error if the port is already in use or validation fails.
func (m *Mutator) AddForward(contextName, namespaceName string, fwd Forward) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
// If file doesn't exist, create empty config
if os.IsNotExist(err) {
cfg = &Config{Contexts: []Context{}}
} else {
return fmt.Errorf("failed to load config: %w", err)
}
}
// Find or create context and namespace
targetContext := m.findOrCreateContext(cfg, contextName)
targetNamespace := m.findOrCreateNamespace(targetContext, namespaceName)
// Set context/namespace on the forward for validation
fwd.SetContext(contextName, namespaceName)
// Check for duplicate local port
allForwards := cfg.GetAllForwards()
for _, existing := range allForwards {
if existing.LocalPort == fwd.LocalPort {
return fmt.Errorf("port %d is already in use by %s", fwd.LocalPort, existing.String())
}
}
// Add the forward
targetNamespace.Forwards = append(targetNamespace.Forwards, fwd)
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// RemoveForwards removes forwards matching the predicate function.
// The predicate receives the context, namespace, and forward, and should return true
// to remove that forward.
// Empty namespaces and contexts are preserved (not automatically removed).
func (m *Mutator) RemoveForwards(predicate func(ctx, ns string, fwd Forward) bool) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// Iterate and filter
for i := range cfg.Contexts {
ctx := &cfg.Contexts[i]
filteredNamespaces := []Namespace{}
for j := range ctx.Namespaces {
ns := &ctx.Namespaces[j]
// Filter forwards
filtered := []Forward{}
for _, fwd := range ns.Forwards {
// CRITICAL: Set context/namespace so fwd.ID() generates correct ID
fwd.SetContext(ctx.Name, ns.Name)
if !predicate(ctx.Name, ns.Name, fwd) {
// Keep this forward
filtered = append(filtered, fwd)
}
}
ns.Forwards = filtered
// Only keep namespaces that have at least one forward
if len(ns.Forwards) > 0 {
filteredNamespaces = append(filteredNamespaces, *ns)
}
}
ctx.Namespaces = filteredNamespaces
}
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// RemoveForwardByID removes a specific forward by its ID.
func (m *Mutator) RemoveForwardByID(id string) error {
return m.RemoveForwards(func(ctx, ns string, fwd Forward) bool {
return fwd.ID() == id
})
}
// UpdateForward atomically replaces an existing forward with a new one.
// This is used for editing - it removes the old forward and adds the new one in a single transaction.
// If the old forward doesn't exist, returns an error.
// If the new forward validation fails, the operation is rolled back (old forward remains).
func (m *Mutator) UpdateForward(oldID, newContextName, newNamespaceName string, newFwd Forward) error {
m.mu.Lock()
defer m.mu.Unlock()
// Load current config
cfg, err := LoadConfig(m.configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// First, verify the old forward exists and remove it
oldForwardFound := false
for i := range cfg.Contexts {
ctx := &cfg.Contexts[i]
for j := range ctx.Namespaces {
ns := &ctx.Namespaces[j]
// Filter forwards, removing the old one
filtered := []Forward{}
for _, fwd := range ns.Forwards {
// CRITICAL: Set context/namespace so fwd.ID() generates correct ID
fwd.SetContext(ctx.Name, ns.Name)
if fwd.ID() == oldID {
oldForwardFound = true
// Skip this forward (remove it)
continue
}
// Keep this forward
filtered = append(filtered, fwd)
}
ns.Forwards = filtered
}
}
if !oldForwardFound {
return fmt.Errorf("forward with ID %s not found", oldID)
}
// Now add the new forward
// Find or create context and namespace
targetContext := m.findOrCreateContext(cfg, newContextName)
targetNamespace := m.findOrCreateNamespace(targetContext, newNamespaceName)
// Set context/namespace on the forward for validation
newFwd.SetContext(newContextName, newNamespaceName)
// Check for duplicate local port (excluding the one we just removed)
allForwards := cfg.GetAllForwards()
for _, existing := range allForwards {
if existing.LocalPort == newFwd.LocalPort && existing.ID() != oldID {
return fmt.Errorf("port %d is already in use by %s", newFwd.LocalPort, existing.String())
}
}
// Add the new forward
targetNamespace.Forwards = append(targetNamespace.Forwards, newFwd)
// Validate the new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(cfg); len(errs) > 0 {
return fmt.Errorf("validation failed: %s", FormatValidationErrors(errs))
}
// Write atomically
return m.writeAtomic(cfg)
}
// writeAtomic writes the configuration atomically to prevent corruption.
// Steps:
// 1. Marshal config to YAML
// 2. Write to temporary file (.kportal.yaml.tmp)
// 3. Atomic rename to actual config file
//
// This ensures the file watcher picks up a complete, valid file.
func (m *Mutator) writeAtomic(cfg *Config) error {
// Marshal to YAML
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
// Create temporary file in same directory as config
dir := filepath.Dir(m.configPath)
tmpFile := filepath.Join(dir, ".kportal.yaml.tmp")
// Write to temp file
if err := os.WriteFile(tmpFile, data, 0600); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename
if err := os.Rename(tmpFile, m.configPath); err != nil {
// Clean up temp file on failure
os.Remove(tmpFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
+99 -7
View File
@@ -6,10 +6,15 @@ import (
)
const (
minPort = 1
maxPort = 65535
MinPort = 1
MaxPort = 65535
)
// IsValidPort returns true if the port number is within the valid range (1-65535).
func IsValidPort(port int) bool {
return port >= MinPort && port <= MaxPort
}
// ValidationError represents a configuration validation error with context.
type ValidationError struct {
Field string // The field that failed validation
@@ -56,6 +61,11 @@ func (v *Validator) ValidateConfig(cfg *Config) []ValidationError {
// Check for duplicate local ports
errs = append(errs, v.validateDuplicatePorts(cfg)...)
// Validate mDNS configuration
if cfg.IsMDNSEnabled() {
errs = append(errs, v.validateMDNS(cfg)...)
}
return errs
}
@@ -84,7 +94,7 @@ func (v *Validator) validateStructure(cfg *Config) []ValidationError {
Field: fmt.Sprintf("contexts[%d].namespaces", i),
Message: fmt.Sprintf("Context '%s' must have at least one namespace", ctx.Name),
})
continue
// Don't continue - still validate other aspects of the context if any
}
for j, ns := range ctx.Namespaces {
@@ -130,17 +140,17 @@ func (v *Validator) validateForward(fwd *Forward) []ValidationError {
}
// Validate ports
if fwd.Port < minPort || fwd.Port > maxPort {
if fwd.Port < MinPort || fwd.Port > MaxPort {
errs = append(errs, ValidationError{
Field: "port",
Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), minPort, maxPort),
Message: fmt.Sprintf("Invalid port %d for forward %s (must be between %d and %d)", fwd.Port, fwd.ID(), MinPort, MaxPort),
})
}
if fwd.LocalPort < minPort || fwd.LocalPort > maxPort {
if fwd.LocalPort < MinPort || fwd.LocalPort > MaxPort {
errs = append(errs, ValidationError{
Field: "localPort",
Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), minPort, maxPort),
Message: fmt.Sprintf("Invalid localPort %d for forward %s (must be between %d and %d)", fwd.LocalPort, fwd.ID(), MinPort, MaxPort),
})
}
@@ -265,3 +275,85 @@ func FormatValidationErrors(errs []ValidationError) string {
return sb.String()
}
// validateMDNS validates mDNS configuration when enabled.
// It checks that aliases used for mDNS hostnames are valid and unique.
// This includes both explicit aliases and auto-generated ones from resource names.
func (v *Validator) validateMDNS(cfg *Config) []ValidationError {
var errs []ValidationError
aliasMap := make(map[string][]string) // alias -> list of forward IDs using it
for _, ctx := range cfg.Contexts {
for _, ns := range ctx.Namespaces {
for _, fwd := range ns.Forwards {
// Get the mDNS alias (explicit or generated from resource name)
mdnsAlias := fwd.GetMDNSAlias()
if mdnsAlias == "" {
// No alias available (e.g., "pod" with selector only)
continue
}
// Validate alias is a valid hostname (RFC 1123)
if !isValidHostname(mdnsAlias) {
errs = append(errs, ValidationError{
Field: "alias",
Message: fmt.Sprintf("Forward %s has invalid mDNS hostname '%s' (must be a valid RFC 1123 hostname)", fwd.ID(), mdnsAlias),
})
}
aliasMap[mdnsAlias] = append(aliasMap[mdnsAlias], fwd.ID())
}
}
}
// Check for duplicate aliases (would cause mDNS conflicts)
for alias, forwards := range aliasMap {
if len(forwards) > 1 {
errs = append(errs, ValidationError{
Field: "alias",
Message: fmt.Sprintf("Duplicate mDNS hostname '%s' used by multiple forwards (would cause conflict)", alias),
Context: map[string]string{
"alias": alias,
"forwards": strings.Join(forwards, ", "),
},
})
}
}
return errs
}
// isValidHostname checks if a string is a valid RFC 1123 hostname.
// Hostnames must start with alphanumeric, contain only alphanumeric and hyphens,
// and be 1-63 characters long.
func isValidHostname(name string) bool {
if len(name) == 0 || len(name) > 63 {
return false
}
// Must start with alphanumeric
if !isAlphanumeric(name[0]) {
return false
}
// Must end with alphanumeric
if !isAlphanumeric(name[len(name)-1]) {
return false
}
// Check all characters
for i := 0; i < len(name); i++ {
c := name[i]
if !isAlphanumeric(c) && c != '-' {
return false
}
}
return true
}
// isAlphanumeric returns true if the character is a letter or digit.
func isAlphanumeric(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')
}
+271
View File
@@ -701,3 +701,274 @@ func TestValidator_ValidateStructure(t *testing.T) {
})
}
}
func TestValidator_ValidateMDNS(t *testing.T) {
validator := NewValidator()
tests := []struct {
name string
config *Config
expectErrors bool
errorContains []string
}{
{
name: "mDNS disabled - no validation",
config: &Config{
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app", Port: 8080, LocalPort: 8080, Alias: "invalid_alias", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: false,
},
{
name: "mDNS enabled - valid aliases",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app1", Port: 8080, LocalPort: 8080, Alias: "my-app", contextName: "dev", namespaceName: "default"},
{Resource: "pod/app2", Port: 8081, LocalPort: 8081, Alias: "my-service", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: false,
},
{
name: "mDNS enabled - no alias (allowed)",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app", Port: 8080, LocalPort: 8080, contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: false,
},
{
name: "mDNS enabled - invalid alias with underscore",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app", Port: 8080, LocalPort: 8080, Alias: "my_app", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: true,
errorContains: []string{"invalid mDNS hostname", "RFC 1123"},
},
{
name: "mDNS enabled - alias starts with hyphen",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app", Port: 8080, LocalPort: 8080, Alias: "-myapp", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: true,
errorContains: []string{"invalid mDNS hostname"},
},
{
name: "mDNS enabled - alias ends with hyphen",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app", Port: 8080, LocalPort: 8080, Alias: "myapp-", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: true,
errorContains: []string{"invalid mDNS hostname"},
},
{
name: "mDNS enabled - duplicate aliases",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "dev",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app1", Port: 8080, LocalPort: 8080, Alias: "myapp", contextName: "dev", namespaceName: "default"},
{Resource: "pod/app2", Port: 8081, LocalPort: 8081, Alias: "myapp", contextName: "dev", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: true,
errorContains: []string{"Duplicate mDNS hostname", "conflict"},
},
{
name: "mDNS enabled - duplicate aliases across contexts",
config: &Config{
MDNS: &MDNSSpec{Enabled: true},
Contexts: []Context{
{
Name: "cluster1",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app1", Port: 8080, LocalPort: 8080, Alias: "shared-name", contextName: "cluster1", namespaceName: "default"},
},
},
},
},
{
Name: "cluster2",
Namespaces: []Namespace{
{
Name: "default",
Forwards: []Forward{
{Resource: "pod/app2", Port: 8081, LocalPort: 8081, Alias: "shared-name", contextName: "cluster2", namespaceName: "default"},
},
},
},
},
},
},
expectErrors: true,
errorContains: []string{"Duplicate mDNS hostname", "shared-name"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := validator.ValidateConfig(tt.config)
if tt.expectErrors {
assert.NotEmpty(t, errs, "expected validation errors")
// Check that expected error messages are present
for _, expectedMsg := range tt.errorContains {
found := false
for _, err := range errs {
if strings.Contains(err.Message, expectedMsg) {
found = true
break
}
}
assert.True(t, found, "expected error message '%s' not found in errors: %v", expectedMsg, errs)
}
} else {
assert.Empty(t, errs, "expected no validation errors, got: %v", errs)
}
})
}
}
func TestIsValidHostname(t *testing.T) {
tests := []struct {
name string
hostname string
valid bool
}{
{"valid simple", "myservice", true},
{"valid with hyphen", "my-service", true},
{"valid with numbers", "service123", true},
{"valid mixed", "my-service-123", true},
{"valid uppercase", "MyService", true},
{"valid single char", "a", true},
{"valid single digit", "1", true},
{"valid max length (63)", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", true},
{"invalid empty", "", false},
{"invalid too long (64)", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", false},
{"invalid starts with hyphen", "-myservice", false},
{"invalid ends with hyphen", "myservice-", false},
{"invalid underscore", "my_service", false},
{"invalid dot", "my.service", false},
{"invalid space", "my service", false},
{"invalid special char", "my@service", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isValidHostname(tt.hostname)
assert.Equal(t, tt.valid, result, "isValidHostname(%q) = %v, want %v", tt.hostname, result, tt.valid)
})
}
}
func TestIsAlphanumeric(t *testing.T) {
tests := []struct {
char byte
valid bool
}{
{'a', true},
{'z', true},
{'A', true},
{'Z', true},
{'0', true},
{'9', true},
{'-', false},
{'_', false},
{'.', false},
{' ', false},
{'@', false},
}
for _, tt := range tests {
t.Run(string(tt.char), func(t *testing.T) {
result := isAlphanumeric(tt.char)
assert.Equal(t, tt.valid, result, "isAlphanumeric(%q) = %v, want %v", tt.char, result, tt.valid)
})
}
}
+20 -10
View File
@@ -6,6 +6,7 @@ import (
"path/filepath"
"github.com/fsnotify/fsnotify"
"github.com/nvm/kportal/internal/logger"
)
// ReloadCallback is called when the configuration file changes.
@@ -113,28 +114,37 @@ func (w *Watcher) handleReload() {
// Load new configuration
newCfg, err := LoadConfig(w.configPath)
if err != nil {
log.Printf("Failed to load configuration: %v", err)
log.Printf("Keeping previous configuration active")
logger.Error("Failed to load configuration during hot-reload", map[string]interface{}{
"config_path": w.configPath,
"error": err.Error(),
})
logger.Info("Keeping previous configuration active", nil)
return
}
// Validate new configuration
validator := NewValidator()
if errs := validator.ValidateConfig(newCfg); len(errs) > 0 {
log.Printf("Configuration validation failed:")
log.Print(FormatValidationErrors(errs))
log.Printf("Keeping previous configuration active")
logger.Error("Configuration validation failed during hot-reload", map[string]interface{}{
"config_path": w.configPath,
"validation_errors": len(errs),
})
logger.Info("Keeping previous configuration active", nil)
return
}
// Call reload callback
if err := w.callback(newCfg); err != nil {
log.Printf("Failed to apply new configuration: %v", err)
log.Printf("Keeping previous configuration active")
logger.Error("Failed to apply new configuration", map[string]interface{}{
"config_path": w.configPath,
"error": err.Error(),
})
logger.Info("Keeping previous configuration active", nil)
return
}
if w.verbose {
log.Printf("Configuration reloaded successfully")
}
logger.Info("Configuration reloaded successfully", map[string]interface{}{
"config_path": w.configPath,
"forwards_count": len(newCfg.GetAllForwards()),
})
}
+171 -13
View File
@@ -9,6 +9,8 @@ import (
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/healthcheck"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/mdns"
)
// StatusUpdater is an interface for updating forward status
@@ -28,23 +30,32 @@ type Manager struct {
portForwarder *k8s.PortForwarder
portChecker *PortChecker
healthChecker *healthcheck.Checker
watchdog *Watchdog
mdnsPublisher *mdns.Publisher
verbose bool
currentConfig *config.Config
statusUI StatusUpdater
}
// NewManager creates a new forward Manager.
func NewManager(verbose bool) *Manager {
// The health checker will be created with default settings and can be
// reconfigured via SetConfig().
func NewManager(verbose bool) (*Manager, error) {
clientPool, err := k8s.NewClientPool()
if err != nil {
log.Fatalf("Failed to create client pool: %v", err)
return nil, fmt.Errorf("failed to create client pool: %w", err)
}
resolver := k8s.NewResourceResolver(clientPool)
portForwarder := k8s.NewPortForwarder(clientPool, resolver)
// Create health checker: check every 5 seconds with 2 second timeout
healthChecker := healthcheck.NewChecker(5*time.Second, 2*time.Second)
// Create health checker with defaults: check every 3 seconds with 2 second timeout
// Will be reconfigured when config is loaded
healthChecker := healthcheck.NewChecker(3*time.Second, 2*time.Second)
// Create watchdog with default settings: check every 30 seconds, 60 second hang threshold
// Will be reconfigured when config is loaded
watchdog := NewWatchdog(30*time.Second, 60*time.Second)
return &Manager{
workers: make(map[string]*ForwardWorker),
@@ -53,8 +64,54 @@ func NewManager(verbose bool) *Manager {
portForwarder: portForwarder,
portChecker: NewPortChecker(),
healthChecker: healthChecker,
watchdog: watchdog,
verbose: verbose,
}, nil
}
// configureHealthChecker creates a new health checker with settings from config
func (m *Manager) configureHealthChecker(cfg *config.Config) {
// Stop existing health checker
if m.healthChecker != nil {
m.healthChecker.Stop()
}
// Parse check method
methodStr := cfg.GetHealthCheckMethod()
var method healthcheck.CheckMethod
switch methodStr {
case "tcp-dial":
method = healthcheck.CheckMethodTCPDial
case "data-transfer":
method = healthcheck.CheckMethodDataTransfer
default:
method = healthcheck.CheckMethodDataTransfer
}
// Create new health checker with config settings
m.healthChecker = healthcheck.NewCheckerWithOptions(healthcheck.CheckerOptions{
Interval: cfg.GetHealthCheckIntervalOrDefault(),
Timeout: cfg.GetHealthCheckTimeoutOrDefault(),
Method: method,
MaxConnectionAge: cfg.GetMaxConnectionAge(),
MaxIdleTime: cfg.GetMaxIdleTime(),
})
// Configure TCP settings on port forwarder
tcpKeepalive := cfg.GetTCPKeepalive()
dialTimeout := cfg.GetDialTimeout()
m.portForwarder.SetTCPKeepalive(tcpKeepalive)
m.portForwarder.SetDialTimeout(dialTimeout)
logger.Info("Health checker and reliability configured", map[string]interface{}{
"interval": cfg.GetHealthCheckIntervalOrDefault().String(),
"timeout": cfg.GetHealthCheckTimeoutOrDefault().String(),
"method": methodStr,
"max_connection_age": cfg.GetMaxConnectionAge().String(),
"max_idle_time": cfg.GetMaxIdleTime().String(),
"tcp_keepalive": tcpKeepalive.String(),
"dial_timeout": dialTimeout.String(),
})
}
// SetStatusUI sets the status updater for the manager
@@ -62,6 +119,11 @@ func (m *Manager) SetStatusUI(ui StatusUpdater) {
m.statusUI = ui
}
// SetMDNSPublisher sets the mDNS publisher for the manager
func (m *Manager) SetMDNSPublisher(publisher *mdns.Publisher) {
m.mdnsPublisher = publisher
}
// Start initializes and starts all port-forwards from the configuration.
func (m *Manager) Start(cfg *config.Config) error {
if cfg == nil {
@@ -70,6 +132,20 @@ func (m *Manager) Start(cfg *config.Config) error {
m.currentConfig = cfg
// Configure health checker with settings from config
m.configureHealthChecker(cfg)
// Start watchdog
watchdogPeriod := cfg.GetWatchdogPeriod()
m.watchdog.checkInterval = watchdogPeriod
m.watchdog.hangThreshold = watchdogPeriod * 2 // Hang threshold is 2x check interval
m.watchdog.Start()
logger.Info("Watchdog started", map[string]interface{}{
"check_interval": watchdogPeriod.String(),
"hang_threshold": (watchdogPeriod * 2).String(),
})
// Get all forwards from config
forwards := cfg.GetAllForwards()
@@ -93,7 +169,14 @@ func (m *Manager) Start(cfg *config.Config) error {
for _, fwd := range forwards {
if err := m.startWorker(fwd); err != nil {
log.Printf("Failed to start worker for %s: %v", fwd.ID(), err)
logger.Error("Failed to start worker", map[string]interface{}{
"forward_id": fwd.ID(),
"context": fwd.GetContext(),
"namespace": fwd.GetNamespace(),
"resource": fwd.Resource,
"local_port": fwd.LocalPort,
"error": err.Error(),
})
// Continue with other workers
}
}
@@ -106,8 +189,14 @@ func (m *Manager) Start(cfg *config.Config) error {
func (m *Manager) Stop() {
log.Printf("Stopping all port-forwards...")
// Stop health checker first
// Stop health checker and watchdog first
m.healthChecker.Stop()
m.watchdog.Stop()
// Stop mDNS publisher
if m.mdnsPublisher != nil {
m.mdnsPublisher.Stop()
}
m.workersMu.Lock()
workers := make([]*ForwardWorker, 0, len(m.workers))
@@ -146,7 +235,9 @@ func (m *Manager) Reload(newCfg *config.Config) error {
return fmt.Errorf("new configuration is nil")
}
log.Printf("Reloading configuration...")
logger.Info("Reloading configuration", map[string]interface{}{
"new_forwards_count": len(newCfg.GetAllForwards()),
})
// Get all forwards from new config
newForwards := newCfg.GetAllForwards()
@@ -258,31 +349,86 @@ func (m *Manager) startWorker(fwd config.Forward) error {
m.statusUI.AddForward(fwd.ID(), &fwd)
}
// Register with watchdog
m.watchdog.RegisterWorker(fwd.ID(), func(forwardID string) {
logger.Warn("Watchdog triggered reconnection for hung worker", map[string]interface{}{
"forward_id": forwardID,
})
// Find and trigger reconnect on hung worker
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("watchdog detected hung worker")
}
})
// Register with health checker
m.healthChecker.Register(fwd.ID(), fwd.LocalPort, func(forwardID string, status healthcheck.Status, errorMsg string) {
if m.statusUI != nil {
m.statusUI.UpdateStatus(forwardID, string(status))
// Send error separately if there is one
if status == healthcheck.StatusUnhealthy && errorMsg != "" {
if (status == healthcheck.StatusUnhealthy || status == healthcheck.StatusStale) && errorMsg != "" {
if ui, ok := m.statusUI.(interface{ SetError(id, msg string) }); ok {
ui.SetError(forwardID, errorMsg)
}
}
}
// Handle stale connections: trigger reconnection if retryOnStale is enabled
if status == healthcheck.StatusStale && m.currentConfig.GetRetryOnStale() {
logger.Info("Stale connection detected, triggering reconnection", map[string]interface{}{
"forward_id": forwardID,
"reason": errorMsg,
})
// Find and notify the worker to reconnect
m.workersMu.RLock()
worker, exists := m.workers[forwardID]
m.workersMu.RUnlock()
if exists {
worker.TriggerReconnect("stale connection")
}
}
})
// Create and start worker
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker)
worker := NewForwardWorker(fwd, m.portForwarder, m.verbose, m.statusUI, m.healthChecker, m.watchdog)
worker.Start()
// Store worker
m.workers[fwd.ID()] = worker
// Register mDNS hostname if enabled
// Uses explicit alias if set, otherwise generates from resource name
if m.mdnsPublisher != nil {
mdnsAlias := fwd.GetMDNSAlias()
if mdnsAlias != "" {
if err := m.mdnsPublisher.Register(fwd.ID(), mdnsAlias, fwd.LocalPort); err != nil {
logger.Warn("Failed to register mDNS hostname", map[string]interface{}{
"forward_id": fwd.ID(),
"alias": mdnsAlias,
"error": err.Error(),
})
// Don't fail the forward start - mDNS is optional
}
}
}
return nil
}
// stopWorker stops and removes a forward worker.
func (m *Manager) stopWorker(id string) error {
return m.stopWorkerInternal(id, true)
}
// stopWorkerInternal stops a worker with option to remove from UI or just update status
func (m *Manager) stopWorkerInternal(id string, removeFromUI bool) error {
m.workersMu.Lock()
worker, exists := m.workers[id]
if !exists {
@@ -292,11 +438,23 @@ func (m *Manager) stopWorker(id string) error {
delete(m.workers, id)
m.workersMu.Unlock()
// Unregister from health checker
// Unregister from health checker and watchdog
m.healthChecker.Unregister(id)
m.watchdog.UnregisterWorker(id)
// Note: We DON'T call Remove() here anymore - keep it in the UI
// The UI will show it as disabled instead
// Unregister mDNS hostname
if m.mdnsPublisher != nil {
m.mdnsPublisher.Unregister(id)
}
// Notify UI - either remove or update to disabled status
if m.statusUI != nil {
if removeFromUI {
m.statusUI.Remove(id)
} else {
m.statusUI.UpdateStatus(id, "Disabled")
}
}
// Stop the worker
worker.Stop()
@@ -346,7 +504,7 @@ func (m *Manager) getResourceForPort(forwards []config.Forward, port int) string
// DisableForward temporarily stops a forward by ID
func (m *Manager) DisableForward(id string) error {
if err := m.stopWorker(id); err != nil {
if err := m.stopWorkerInternal(id, false); err != nil {
return err
}
log.Printf("Disabled: %s", id)
+157 -46
View File
@@ -6,8 +6,96 @@ import (
"os/exec"
"runtime"
"strings"
"github.com/nvm/kportal/internal/logger"
)
const (
// maxPIDLength is the maximum length of a valid PID string (9 digits covers PIDs up to 999,999,999)
maxPIDLength = 9
// minNetstatFields is the minimum number of fields expected in netstat output
minNetstatFields = 5
)
// isValidPID validates that a PID string contains only digits
func isValidPID(pid string) bool {
if len(pid) == 0 || len(pid) > maxPIDLength {
return false
}
for _, c := range pid {
if c < '0' || c > '9' {
return false
}
}
return true
}
// processInfo holds information about a process using a port
type processInfo struct {
pid string
name string
isValid bool
}
// formatProcessInfo formats process information for display
func formatProcessInfo(info processInfo) string {
if !info.isValid {
return "unknown"
}
if info.name != "" {
return fmt.Sprintf("%s (PID %s)", info.name, info.pid)
}
return fmt.Sprintf("PID %s", info.pid)
}
// formatProcessList formats a list of processes into a human-readable string.
// Returns "unknown" if the list is empty.
func formatProcessList(processes []processInfo) string {
if len(processes) == 0 {
return "unknown"
}
if len(processes) == 1 {
return formatProcessInfo(processes[0])
}
// Multiple processes - format as comma-separated list
parts := make([]string, len(processes))
for i, p := range processes {
parts[i] = formatProcessInfo(p)
}
return strings.Join(parts, ", ")
}
// getProcessNameByPID retrieves the process name for a given PID on Unix systems
func getProcessNameByPID(pid string) string {
cmd := exec.Command("ps", "-p", pid, "-o", "comm=")
output, err := cmd.Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(output))
}
// getProcessNameByPIDWindows retrieves the process name for a given PID on Windows
func getProcessNameByPIDWindows(pid string) string {
cmd := exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
output, err := cmd.Output()
if err != nil {
return ""
}
// Parse CSV output: "process.exe","1234","Console","1","12,345 K"
csvLine := strings.TrimSpace(string(output))
if csvLine == "" {
return ""
}
parts := strings.Split(csvLine, ",")
if len(parts) > 0 {
return strings.Trim(parts[0], "\"")
}
return ""
}
// PortConflict represents a local port that is already in use.
type PortConflict struct {
Port int // The conflicting port number
@@ -89,23 +177,55 @@ func (pc *PortChecker) getProcessUsingPortUnix(port int) string {
return "unknown"
}
// Get the first PID if multiple are returned
// Handle multiple PIDs (multiple processes on same port)
pids := strings.Split(pidStr, "\n")
pid := pids[0]
var validProcesses []processInfo
// Get process name using ps
cmd = exec.Command("ps", "-p", pid, "-o", "comm=")
output, err = cmd.Output()
if err != nil {
return fmt.Sprintf("PID %s", pid)
for _, pid := range pids {
pid = strings.TrimSpace(pid)
if pid == "" {
continue
}
if !isValidPID(pid) {
logger.Debug("Invalid PID format from lsof output", map[string]interface{}{
"port": port,
"raw_pid": pid,
})
continue
}
procName := getProcessNameByPID(pid)
validProcesses = append(validProcesses, processInfo{
pid: pid,
name: procName,
isValid: true,
})
}
procName := strings.TrimSpace(string(output))
if procName == "" {
return fmt.Sprintf("PID %s", pid)
return formatProcessList(validProcesses)
}
// isListeningState checks if a netstat line indicates a listening state.
// This handles both English and potentially other locales by checking for common patterns.
func isListeningState(line string, fields []string) bool {
upperLine := strings.ToUpper(line)
// Check for common listening state indicators across locales
// English: LISTENING, German: ABHÖREN, French: ÉCOUTE, etc.
// The most reliable check is the state field position (4th field, 0-indexed = 3)
// and that it's a TCP connection with 0.0.0.0:0 or *:* as foreign address
if len(fields) >= minNetstatFields {
state := strings.ToUpper(fields[3])
// Common listening state values across Windows locales
if state == "LISTENING" || state == "ABHÖREN" || state == "ÉCOUTE" ||
state == "ESCUCHANDO" || state == "ASCOLTO" || state == "NASŁUCHIWANIE" {
return true
}
}
return fmt.Sprintf("%s (PID %s)", procName, pid)
// Fallback: check if line contains LISTENING (most common case)
return strings.Contains(upperLine, "LISTENING")
}
// getProcessUsingPortWindows uses netstat to find the process using a port on Windows.
@@ -121,6 +241,8 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
lines := strings.Split(string(output), "\n")
portStr := fmt.Sprintf(":%d", port)
var validProcesses []processInfo
for _, line := range lines {
if !strings.Contains(line, portStr) {
continue
@@ -129,40 +251,42 @@ func (pc *PortChecker) getProcessUsingPortWindows(port int) string {
// Parse the line to extract PID
// Format: TCP 0.0.0.0:8080 0.0.0.0:0 LISTENING 1234
fields := strings.Fields(line)
if len(fields) < 5 {
if len(fields) < minNetstatFields {
continue
}
// Check if this is a LISTENING state
if !strings.Contains(strings.ToUpper(line), "LISTENING") {
// Check if this is a LISTENING state (locale-aware)
if !isListeningState(line, fields) {
continue
}
// Verify the local address field actually contains our port
// (avoid matching port in foreign address)
localAddr := fields[1]
if !strings.HasSuffix(localAddr, portStr) {
continue
}
pid := fields[len(fields)-1]
// Get process name using tasklist
cmd = exec.Command("tasklist", "/FI", fmt.Sprintf("PID eq %s", pid), "/FO", "CSV", "/NH")
output, err = cmd.Output()
if err != nil {
return fmt.Sprintf("PID %s", pid)
if !isValidPID(pid) {
logger.Debug("Invalid PID format from netstat output", map[string]interface{}{
"port": port,
"raw_pid": pid,
"line": line,
})
continue
}
// Parse CSV output: "process.exe","1234","Console","1","12,345 K"
csvLine := strings.TrimSpace(string(output))
if csvLine == "" {
return fmt.Sprintf("PID %s", pid)
}
parts := strings.Split(csvLine, ",")
if len(parts) > 0 {
procName := strings.Trim(parts[0], "\"")
return fmt.Sprintf("%s (PID %s)", procName, pid)
}
return fmt.Sprintf("PID %s", pid)
procName := getProcessNameByPIDWindows(pid)
validProcesses = append(validProcesses, processInfo{
pid: pid,
name: procName,
isValid: true,
})
}
return "unknown"
return formatProcessList(validProcesses)
}
// FormatConflicts formats port conflicts into a human-readable error message.
@@ -188,16 +312,3 @@ func FormatConflicts(conflicts []PortConflict) string {
return sb.String()
}
// GetPortsFromForwards extracts all local ports from a list of forward configurations.
func GetPortsFromForwards(forwards []interface{}) []int {
ports := make([]int, 0, len(forwards))
for _, fwd := range forwards {
// This function expects a generic interface to work with different forward types
// The actual implementation should use the Forward struct from config package
if f, ok := fwd.(interface{ GetLocalPort() int }); ok {
ports = append(ports, f.GetLocalPort())
}
}
return ports
}
+174
View File
@@ -0,0 +1,174 @@
package forward
import (
"context"
"sync"
"time"
"github.com/nvm/kportal/internal/logger"
)
// Watchdog monitors worker goroutines to detect hung workers
type Watchdog struct {
mu sync.RWMutex
workers map[string]*workerState // key: forward ID
checkInterval time.Duration
hangThreshold time.Duration // How long without heartbeat before considered hung
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// workerState tracks the health of a single worker
type workerState struct {
forwardID string
lastHeartbeat time.Time
heartbeatCount uint64
isHung bool
onHungCallback func(forwardID string)
}
// NewWatchdog creates a new goroutine watchdog
func NewWatchdog(checkInterval, hangThreshold time.Duration) *Watchdog {
ctx, cancel := context.WithCancel(context.Background())
return &Watchdog{
workers: make(map[string]*workerState),
checkInterval: checkInterval,
hangThreshold: hangThreshold,
ctx: ctx,
cancel: cancel,
}
}
// Start begins the watchdog monitoring loop
func (w *Watchdog) Start() {
w.wg.Add(1)
go w.monitorLoop()
}
// Stop stops the watchdog
func (w *Watchdog) Stop() {
w.cancel()
w.wg.Wait()
}
// RegisterWorker adds a worker to monitor
func (w *Watchdog) RegisterWorker(forwardID string, onHungCallback func(string)) {
w.mu.Lock()
defer w.mu.Unlock()
w.workers[forwardID] = &workerState{
forwardID: forwardID,
lastHeartbeat: time.Now(),
heartbeatCount: 0,
isHung: false,
onHungCallback: onHungCallback,
}
logger.Debug("Watchdog registered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// UnregisterWorker removes a worker from monitoring
func (w *Watchdog) UnregisterWorker(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
delete(w.workers, forwardID)
logger.Debug("Watchdog unregistered worker", map[string]interface{}{
"forward_id": forwardID,
})
}
// Heartbeat records that a worker is alive and processing
// Workers should call this periodically (e.g., in their main loop)
func (w *Watchdog) Heartbeat(forwardID string) {
w.mu.Lock()
defer w.mu.Unlock()
if state, exists := w.workers[forwardID]; exists {
state.lastHeartbeat = time.Now()
state.heartbeatCount++
state.isHung = false
}
}
// GetWorkerState returns the current state of a worker (for testing)
func (w *Watchdog) GetWorkerState(forwardID string) (lastHeartbeat time.Time, count uint64, exists bool) {
w.mu.RLock()
defer w.mu.RUnlock()
if state, ok := w.workers[forwardID]; ok {
return state.lastHeartbeat, state.heartbeatCount, true
}
return time.Time{}, 0, false
}
// monitorLoop periodically checks all workers
func (w *Watchdog) monitorLoop() {
defer w.wg.Done()
ticker := time.NewTicker(w.checkInterval)
defer ticker.Stop()
for {
select {
case <-w.ctx.Done():
return
case <-ticker.C:
w.checkWorkers()
}
}
}
// hungWorkerInfo stores information about a hung worker for deferred callback execution
type hungWorkerInfo struct {
forwardID string
callback func(string)
}
// checkWorkers checks all registered workers for hung state
func (w *Watchdog) checkWorkers() {
// Collect hung workers while holding the lock
var hungWorkers []hungWorkerInfo
w.mu.Lock()
now := time.Now()
for forwardID, state := range w.workers {
timeSinceHeartbeat := now.Sub(state.lastHeartbeat)
// Check if worker is hung
if timeSinceHeartbeat > w.hangThreshold {
if !state.isHung {
// First time detecting hung state
state.isHung = true
logger.Warn("Watchdog detected hung worker", map[string]interface{}{
"forward_id": forwardID,
"time_since_heartbeat": timeSinceHeartbeat.String(),
"hang_threshold": w.hangThreshold.String(),
"heartbeat_count": state.heartbeatCount,
})
// Collect callback for deferred execution outside the lock
if state.onHungCallback != nil {
hungWorkers = append(hungWorkers, hungWorkerInfo{
forwardID: forwardID,
callback: state.onHungCallback,
})
}
}
}
}
w.mu.Unlock()
// Execute callbacks outside the lock to prevent deadlocks and ensure
// consistent state during callback execution. Callbacks are idempotent
// (they trigger reconnection via channels), so concurrent state changes
// between detection and callback execution are safe.
for _, hw := range hungWorkers {
hw.callback(hw.forwardID)
}
}
+310
View File
@@ -0,0 +1,310 @@
package forward
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// WatchdogTestSuite contains tests for the watchdog
type WatchdogTestSuite struct {
suite.Suite
watchdog *Watchdog
}
func TestWatchdogSuite(t *testing.T) {
suite.Run(t, new(WatchdogTestSuite))
}
func (s *WatchdogTestSuite) SetupTest() {
// Create watchdog with fast intervals for testing
s.watchdog = NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
s.watchdog.Start()
}
func (s *WatchdogTestSuite) TearDownTest() {
if s.watchdog != nil {
s.watchdog.Stop()
}
}
// TestRegisterUnregister tests basic registration and unregistration
func (s *WatchdogTestSuite) TestRegisterUnregister() {
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
// Register worker
s.watchdog.RegisterWorker("test-forward", callback)
// Verify worker is registered
_, _, exists := s.watchdog.GetWorkerState("test-forward")
assert.True(s.T(), exists, "Worker should be registered")
// Unregister worker
s.watchdog.UnregisterWorker("test-forward")
// Verify worker is unregistered
_, _, exists = s.watchdog.GetWorkerState("test-forward")
assert.False(s.T(), exists, "Worker should be unregistered")
assert.False(s.T(), callbackCalled, "Callback should not have been called")
}
// TestHeartbeat tests that heartbeats update worker state
func (s *WatchdogTestSuite) TestHeartbeat() {
s.watchdog.RegisterWorker("test-forward", nil)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat1, count1, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(1), count1)
// Wait a bit
time.Sleep(50 * time.Millisecond)
// Send another heartbeat
s.watchdog.Heartbeat("test-forward")
lastHeartbeat2, count2, exists := s.watchdog.GetWorkerState("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), uint64(2), count2)
assert.True(s.T(), lastHeartbeat2.After(lastHeartbeat1), "Second heartbeat should be after first")
}
// TestHungWorkerDetection tests that hung workers are detected
func (s *WatchdogTestSuite) TestHungWorkerDetection() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for worker to be considered hung (300ms threshold + 100ms check interval)
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
// TestHealthyWorkerNotDetectedAsHung tests that workers sending heartbeats are not considered hung
func (s *WatchdogTestSuite) TestHealthyWorkerNotDetectedAsHung() {
callbackCalled := false
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send periodic heartbeats (faster than hang threshold)
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
done := make(chan bool)
go func() {
for i := 0; i < 10; i++ {
<-ticker.C
s.watchdog.Heartbeat("test-forward")
}
done <- true
}()
// Wait for all heartbeats to complete
<-done
// Check that callback was not called
mu.Lock()
assert.False(s.T(), callbackCalled, "Callback should not be called for healthy worker")
mu.Unlock()
}
// TestMultipleWorkers tests monitoring multiple workers simultaneously
func (s *WatchdogTestSuite) TestMultipleWorkers() {
callbacks := make(map[string]int)
var mu sync.Mutex
makeCallback := func(id string) func(string) {
return func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbacks[id]++
}
}
// Register multiple workers
s.watchdog.RegisterWorker("worker-1", makeCallback("worker-1"))
s.watchdog.RegisterWorker("worker-2", makeCallback("worker-2"))
s.watchdog.RegisterWorker("worker-3", makeCallback("worker-3"))
// worker-1: Keep sending heartbeats (healthy)
ticker1 := time.NewTicker(50 * time.Millisecond)
defer ticker1.Stop()
go func() {
for i := 0; i < 10; i++ {
<-ticker1.C
s.watchdog.Heartbeat("worker-1")
}
}()
// worker-2: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-2")
// worker-3: Send initial heartbeat then stop (will become hung)
s.watchdog.Heartbeat("worker-3")
// Wait for hung workers to be detected
time.Sleep(600 * time.Millisecond)
// Check results
mu.Lock()
defer mu.Unlock()
assert.Equal(s.T(), 0, callbacks["worker-1"], "worker-1 should not trigger callback (healthy)")
assert.Greater(s.T(), callbacks["worker-2"], 0, "worker-2 should trigger callback (hung)")
assert.Greater(s.T(), callbacks["worker-3"], 0, "worker-3 should trigger callback (hung)")
}
// TestCallbackOnlyOnFirstDetection tests that callback is only called once when hung is first detected
func (s *WatchdogTestSuite) TestCallbackOnlyOnFirstDetection() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for multiple check cycles
time.Sleep(1 * time.Second)
// Check that callback was only called once
mu.Lock()
assert.Equal(s.T(), 1, callbackCount, "Callback should only be called once")
mu.Unlock()
}
// TestHeartbeatResetsHungState tests that sending heartbeat after hung detection resets state
func (s *WatchdogTestSuite) TestHeartbeatResetsHungState() {
callbackCount := 0
var mu sync.Mutex
callback := func(forwardID string) {
mu.Lock()
defer mu.Unlock()
callbackCount++
}
s.watchdog.RegisterWorker("test-forward", callback)
// Send initial heartbeat
s.watchdog.Heartbeat("test-forward")
// Wait for hung detection
time.Sleep(500 * time.Millisecond)
mu.Lock()
firstCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 1, firstCount, "First hung detection should trigger callback")
// Send heartbeat to reset hung state
s.watchdog.Heartbeat("test-forward")
// Wait for worker to become hung again
time.Sleep(500 * time.Millisecond)
mu.Lock()
secondCount := callbackCount
mu.Unlock()
assert.Equal(s.T(), 2, secondCount, "Second hung detection should trigger callback again")
}
// TestConcurrentOperations tests thread safety
func (s *WatchdogTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numWorkers := 10
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := string(rune('a' + id))
s.watchdog.RegisterWorker(forwardID, nil)
for j := 0; j < 10; j++ {
s.watchdog.Heartbeat(forwardID)
time.Sleep(10 * time.Millisecond)
}
s.watchdog.UnregisterWorker(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestStopWatchdog tests that stopping watchdog cleans up properly
func TestStopWatchdog(t *testing.T) {
watchdog := NewWatchdog(100*time.Millisecond, 300*time.Millisecond)
watchdog.Start()
callbackCalled := false
callback := func(forwardID string) {
callbackCalled = true
}
watchdog.RegisterWorker("test-forward", callback)
watchdog.Heartbeat("test-forward")
// Stop watchdog before hang detection
time.Sleep(100 * time.Millisecond)
watchdog.Stop()
// Wait to ensure no more callbacks after stop
time.Sleep(500 * time.Millisecond)
assert.False(t, callbackCalled, "Callback should not be called after watchdog is stopped")
}
// TestWatchdogWithZeroHeartbeats tests detecting hung worker that never sends heartbeats
func (s *WatchdogTestSuite) TestWatchdogWithZeroHeartbeats() {
callbackCalled := make(chan string, 1)
callback := func(forwardID string) {
callbackCalled <- forwardID
}
// Register worker but never send heartbeat
s.watchdog.RegisterWorker("test-forward", callback)
// Wait for hung detection
timeout := time.After(1 * time.Second)
select {
case forwardID := <-callbackCalled:
assert.Equal(s.T(), "test-forward", forwardID)
case <-timeout:
s.T().Fatal("Timeout waiting for hung worker callback")
}
}
+124 -20
View File
@@ -5,31 +5,41 @@ import (
"fmt"
"io"
"log"
"sync"
"time"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/healthcheck"
"github.com/nvm/kportal/internal/k8s"
"github.com/nvm/kportal/internal/logger"
"github.com/nvm/kportal/internal/retry"
)
const (
portForwardReadyTimeout = 30 * time.Second
)
// ForwardWorker manages a single port-forward connection with automatic retry.
type ForwardWorker struct {
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
startTime time.Time // Track when the worker started
forward config.Forward
portForwarder *k8s.PortForwarder
ctx context.Context
cancel context.CancelFunc
stopChan chan struct{}
doneChan chan struct{}
reconnectChan chan string // Channel to trigger reconnection
verbose bool
lastPod string // Track the last pod we connected to
statusUI StatusUpdater
healthChecker *healthcheck.Checker
watchdog *Watchdog
startTime time.Time // Track when the worker started
forwardCancel context.CancelFunc // Cancel function for current forward attempt
forwardCancelMu sync.Mutex // Protects forwardCancel
}
// NewForwardWorker creates a new ForwardWorker for a single forward configuration.
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker) *ForwardWorker {
func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verbose bool, statusUI StatusUpdater, healthChecker *healthcheck.Checker, watchdog *Watchdog) *ForwardWorker {
ctx, cancel := context.WithCancel(context.Background())
return &ForwardWorker{
@@ -39,13 +49,32 @@ func NewForwardWorker(fwd config.Forward, portForwarder *k8s.PortForwarder, verb
cancel: cancel,
stopChan: make(chan struct{}),
doneChan: make(chan struct{}),
reconnectChan: make(chan string, 1), // Buffered to avoid blocking
verbose: verbose,
statusUI: statusUI,
healthChecker: healthChecker,
watchdog: watchdog,
startTime: time.Now(),
}
}
// TriggerReconnect triggers a reconnection (e.g., due to stale connection)
func (w *ForwardWorker) TriggerReconnect(reason string) {
// Cancel current forward if running
w.forwardCancelMu.Lock()
if w.forwardCancel != nil {
w.forwardCancel()
}
w.forwardCancelMu.Unlock()
// Send reconnect signal (non-blocking)
select {
case w.reconnectChan <- reason:
default:
// Channel already has pending reconnect
}
}
// Start begins the port-forward worker in a goroutine.
// The worker will continuously retry on failures with exponential backoff.
func (w *ForwardWorker) Start() {
@@ -56,13 +85,28 @@ func (w *ForwardWorker) Start() {
func (w *ForwardWorker) Stop() {
w.cancel()
close(w.stopChan)
<-w.doneChan // Wait for worker to finish
// Wait for worker to finish with timeout to prevent blocking forever
select {
case <-w.doneChan:
// Worker finished gracefully
case <-time.After(3 * time.Second):
// Worker didn't finish in time, but we've cancelled its context
// so it will clean up eventually
log.Printf("[%s] Worker stop timed out, continuing...", w.forward.ID())
}
}
// run is the main worker loop that handles retries.
func (w *ForwardWorker) run() {
defer close(w.doneChan)
// Start heartbeat goroutine to continuously send heartbeats to watchdog
// This prevents false "hung worker" detection when connections are long-lived
if w.watchdog != nil {
go w.heartbeatLoop()
}
backoff := retry.NewBackoff()
for {
@@ -86,7 +130,13 @@ func (w *ForwardWorker) run() {
)
if err != nil {
log.Printf("[%s] Failed to resolve resource: %v", w.forward.ID(), err)
logger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": w.forward.ID(),
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
"resource": w.forward.Resource,
"error": err.Error(),
})
w.sleepWithBackoff(backoff)
continue
}
@@ -96,10 +146,20 @@ func (w *ForwardWorker) run() {
if w.healthChecker != nil {
w.healthChecker.MarkReconnecting(w.forward.ID())
}
log.Printf("[%s] Switched to new pod: %s → %s", w.forward.ID(), w.lastPod, podName)
logger.Info("Pod restart detected, switching to new pod", map[string]interface{}{
"forward_id": w.forward.ID(),
"old_pod": w.lastPod,
"new_pod": podName,
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
})
} else if w.lastPod == "" {
log.Printf("[%s] Forwarding %s → localhost:%d",
w.forward.ID(), w.forward.String(), w.forward.LocalPort)
logger.Info("Starting port forward", map[string]interface{}{
"forward_id": w.forward.ID(),
"target": w.forward.String(),
"local_port": w.forward.LocalPort,
"pod": podName,
})
if w.healthChecker != nil {
w.healthChecker.MarkStarting(w.forward.ID())
}
@@ -123,7 +183,14 @@ func (w *ForwardWorker) run() {
}
// Log the error
log.Printf("[%s] Port-forward connection failed: %v", w.forward.ID(), err)
logger.Warn("Port-forward connection failed, will retry", map[string]interface{}{
"forward_id": w.forward.ID(),
"context": w.forward.GetContext(),
"namespace": w.forward.GetNamespace(),
"resource": w.forward.Resource,
"local_port": w.forward.LocalPort,
"error": err.Error(),
})
// Clear last pod so we re-resolve on next attempt
w.lastPod = ""
@@ -145,6 +212,26 @@ func (w *ForwardWorker) run() {
}
}
// heartbeatLoop sends periodic heartbeats to the watchdog to prove the worker is alive
// This runs in a separate goroutine and continues throughout the worker's lifetime
func (w *ForwardWorker) heartbeatLoop() {
// Send heartbeats every 15 seconds (well within typical 60s watchdog timeout)
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
// Send immediate heartbeat
w.watchdog.Heartbeat(w.forward.ID())
for {
select {
case <-ticker.C:
w.watchdog.Heartbeat(w.forward.ID())
case <-w.ctx.Done():
return
}
}
}
// establishForward establishes a port-forward connection.
// This blocks until the connection is closed or an error occurs.
func (w *ForwardWorker) establishForward(podName string) error {
@@ -156,11 +243,24 @@ func (w *ForwardWorker) establishForward(podName string) error {
forwardCtx, forwardCancel := context.WithCancel(w.ctx)
defer forwardCancel()
// Start a goroutine to monitor for stop signal
// Store cancel function so TriggerReconnect can use it
w.forwardCancelMu.Lock()
w.forwardCancel = forwardCancel
w.forwardCancelMu.Unlock()
defer func() {
w.forwardCancelMu.Lock()
w.forwardCancel = nil
w.forwardCancelMu.Unlock()
}()
// Start a goroutine to monitor for stop signal and reconnect triggers
go func() {
select {
case <-w.stopChan:
close(stopChan)
case <-w.reconnectChan:
close(stopChan)
case <-forwardCtx.Done():
close(stopChan)
}
@@ -202,11 +302,15 @@ func (w *ForwardWorker) establishForward(podName string) error {
if w.verbose {
log.Printf("[%s] Port-forward connection established", w.forward.ID())
}
// Mark connection as established in health checker
if w.healthChecker != nil {
w.healthChecker.MarkConnected(w.forward.ID())
}
case err := <-errChan:
return fmt.Errorf("failed to establish forward: %w", err)
case <-w.ctx.Done():
return nil
case <-time.After(30 * time.Second):
case <-time.After(portForwardReadyTimeout):
return fmt.Errorf("timeout waiting for port-forward to become ready")
}
+286
View File
@@ -0,0 +1,286 @@
package forward
import (
"testing"
"github.com/nvm/kportal/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogWriter_Write(t *testing.T) {
tests := []struct {
name string
prefix string
input string
expectedInLog string
description string
}{
{
name: "write simple message",
prefix: "[worker] ",
input: "test message",
expectedInLog: "[worker] test message",
description: "Should write message with prefix to log",
},
{
name: "write empty message",
prefix: "[test] ",
input: "",
expectedInLog: "[test] ",
description: "Should handle empty message",
},
{
name: "write multiline message",
prefix: "[fwd] ",
input: "line1\nline2",
expectedInLog: "[fwd] line1\nline2",
description: "Should handle multiline messages",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test logWriter
originalWriter := &logWriter{prefix: tt.prefix}
n, err := originalWriter.Write([]byte(tt.input))
require.NoError(t, err, "Write should not return error")
assert.Equal(t, len(tt.input), n, "Write should return number of bytes written")
})
}
}
func TestForwardWorker_GetForward(t *testing.T) {
tests := []struct {
name string
forward config.Forward
description string
}{
{
name: "get pod forward",
forward: config.Forward{
Resource: "pod/my-app",
LocalPort: 8080,
Port: 80,
Protocol: "tcp",
},
description: "Should return the forward configuration",
},
{
name: "get service forward",
forward: config.Forward{
Resource: "service/postgres",
LocalPort: 5432,
Port: 5432,
Protocol: "tcp",
},
description: "Should return service forward configuration",
},
{
name: "get forward with selector",
forward: config.Forward{
Resource: "pod",
Selector: "app=nginx,env=prod",
LocalPort: 8080,
Port: 80,
Protocol: "tcp",
},
description: "Should return forward with label selector",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: We can't easily test the full worker lifecycle without mocks,
// but we can test the constructor and simple getters
// This test would require proper mocking setup
// For now, we'll test the Forward struct directly
id := tt.forward.ID()
assert.NotEmpty(t, id, "Forward should have an ID")
forwardStr := tt.forward.String()
assert.NotEmpty(t, forwardStr, "Forward should have a string representation")
assert.Contains(t, forwardStr, tt.forward.Resource, "String should contain resource")
})
}
}
func TestForwardWorker_IsRunning(t *testing.T) {
// This is a basic test of the goroutine state tracking
// Full integration tests would require mock dependencies
t.Run("worker state tracking", func(t *testing.T) {
// Test the concept of the done channel
doneChan := make(chan struct{})
// Initially, channel is open (worker would be running)
select {
case <-doneChan:
t.Fatal("doneChan should be open initially")
default:
// Expected: channel is open
}
// Close the channel (simulating worker done)
close(doneChan)
// Now channel should be closed
select {
case <-doneChan:
// Expected: channel is closed
default:
t.Fatal("doneChan should be closed after close")
}
})
}
func TestForwardID(t *testing.T) {
tests := []struct {
name string
forward config.Forward
expectUnique bool
description string
}{
{
name: "unique IDs for different forwards",
forward: config.Forward{
Resource: "pod/app1",
LocalPort: 8080,
Port: 80,
},
expectUnique: true,
description: "Different forwards should have different IDs",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id1 := tt.forward.ID()
// Create a different forward
fwd2 := config.Forward{
Resource: "pod/app2",
LocalPort: 8081,
Port: 80,
}
id2 := fwd2.ID()
if tt.expectUnique {
assert.NotEqual(t, id1, id2, "Different forwards should have different IDs")
}
// Same forward should produce same ID
id3 := tt.forward.ID()
assert.Equal(t, id1, id3, "Same forward should produce same ID")
})
}
}
func TestForwardString(t *testing.T) {
tests := []struct {
name string
forward config.Forward
expectedContains []string
description string
}{
{
name: "pod forward string",
forward: config.Forward{
Resource: "pod/my-app",
LocalPort: 8080,
Port: 80,
},
expectedContains: []string{"pod/my-app", "8080", "80"},
description: "Should contain resource and ports",
},
{
name: "service forward string",
forward: config.Forward{
Resource: "service/postgres",
LocalPort: 5432,
Port: 5432,
},
expectedContains: []string{"service/postgres", "5432"},
description: "Should contain service and port",
},
{
name: "selector forward string",
forward: config.Forward{
Resource: "pod",
Selector: "app=nginx",
LocalPort: 8080,
Port: 80,
},
expectedContains: []string{"app=nginx", "8080", "80"},
description: "Should contain selector and ports",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.forward.String()
assert.NotEmpty(t, result, "String representation should not be empty")
for _, expected := range tt.expectedContains {
assert.Contains(t, result, expected,
"String should contain %s", expected)
}
})
}
}
func TestSleepWithBackoffConcept(t *testing.T) {
// Test the backoff concept without actually running a worker
t.Run("backoff delay increases", func(t *testing.T) {
// This tests the retry backoff behavior conceptually
delays := []int{1, 2, 4, 8, 10, 10, 10}
for i, expected := range delays {
// Simulate backoff calculation
delay := 1
for j := 0; j < i; j++ {
delay *= 2
if delay > 10 {
delay = 10
}
}
assert.Equal(t, expected, delay,
"Backoff at attempt %d should be %d", i, expected)
}
})
}
func TestWorkerVerboseMode(t *testing.T) {
tests := []struct {
name string
verbose bool
description string
}{
{
name: "verbose mode enabled",
verbose: true,
description: "Worker should respect verbose flag",
},
{
name: "verbose mode disabled",
verbose: false,
description: "Worker should respect non-verbose flag",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that verbose flag is a boolean
assert.IsType(t, bool(true), tt.verbose)
// In a real worker, this would control logging
// For now, we just verify the type
})
}
}
+228 -73
View File
@@ -3,9 +3,17 @@ package healthcheck
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/nvm/kportal/internal/config"
)
const (
startupGracePeriod = 10 * time.Second
dataTransferSize = 1024 // bytes to read in data transfer test
)
// Status represents the health status of a port forward
@@ -16,15 +24,26 @@ const (
StatusUnhealthy Status = "Error"
StatusStarting Status = "Starting"
StatusReconnect Status = "Reconnecting"
StatusStale Status = "Stale" // Connection is old or idle
)
// CheckMethod represents the health check method
type CheckMethod string
const (
CheckMethodTCPDial CheckMethod = "tcp-dial" // Simple TCP connection test
CheckMethodDataTransfer CheckMethod = "data-transfer" // Try to read data from connection
)
// PortHealth represents the health status of a single port
type PortHealth struct {
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
Port int
LastCheck time.Time
Status Status
ErrorMessage string
RegisteredAt time.Time // When this port was registered
ConnectionTime time.Time // When current connection was established
LastActivity time.Time // Last time data was transferred
}
// StatusCallback is called when a port's health status changes
@@ -32,26 +51,52 @@ type StatusCallback func(forwardID string, status Status, errorMsg string)
// Checker performs periodic health checks on local ports
type Checker struct {
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
ports map[string]*PortHealth // key: forward ID
callbacks map[string]StatusCallback
interval time.Duration
timeout time.Duration
method CheckMethod
maxConnectionAge time.Duration
maxIdleTime time.Duration
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewChecker creates a new health checker
// CheckerOptions configures the health checker
type CheckerOptions struct {
Interval time.Duration
Timeout time.Duration
Method CheckMethod
MaxConnectionAge time.Duration
MaxIdleTime time.Duration
}
// NewChecker creates a new health checker with default options
func NewChecker(interval, timeout time.Duration) *Checker {
return NewCheckerWithOptions(CheckerOptions{
Interval: interval,
Timeout: timeout,
Method: CheckMethodDataTransfer,
MaxConnectionAge: config.DefaultMaxConnectionAge,
MaxIdleTime: config.DefaultMaxIdleTime,
})
}
// NewCheckerWithOptions creates a new health checker with custom options
func NewCheckerWithOptions(opts CheckerOptions) *Checker {
ctx, cancel := context.WithCancel(context.Background())
return &Checker{
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: interval,
timeout: timeout,
ctx: ctx,
cancel: cancel,
ports: make(map[string]*PortHealth),
callbacks: make(map[string]StatusCallback),
interval: opts.Interval,
timeout: opts.Timeout,
method: opts.Method,
maxConnectionAge: opts.MaxConnectionAge,
maxIdleTime: opts.MaxIdleTime,
ctx: ctx,
cancel: cancel,
}
}
@@ -60,11 +105,14 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
c.ports[forwardID] = &PortHealth{
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: time.Now(),
Port: port,
LastCheck: time.Time{},
Status: StatusStarting,
RegisteredAt: now,
ConnectionTime: now,
LastActivity: now,
}
c.callbacks[forwardID] = callback
@@ -73,6 +121,28 @@ func (c *Checker) Register(forwardID string, port int, callback StatusCallback)
go c.checkLoop(forwardID)
}
// MarkConnected marks a forward as having established a new connection
func (c *Checker) MarkConnected(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
now := time.Now()
health.ConnectionTime = now
health.LastActivity = now
}
}
// RecordActivity records data transfer activity for a forward
func (c *Checker) RecordActivity(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
health.LastActivity = time.Now()
}
}
// Unregister removes a port from monitoring
func (c *Checker) Unregister(forwardID string) {
c.mu.Lock()
@@ -82,42 +152,34 @@ func (c *Checker) Unregister(forwardID string) {
delete(c.callbacks, forwardID)
}
// markStatus is a helper to set a forward's status and notify on change.
func (c *Checker) markStatus(forwardID string, newStatus Status) {
c.mu.Lock()
health, exists := c.ports[forwardID]
if !exists {
c.mu.Unlock()
return
}
oldStatus := health.Status
health.Status = newStatus
health.LastCheck = time.Now()
c.mu.Unlock()
if oldStatus != newStatus {
c.notifyStatusChange(forwardID, newStatus, "")
}
}
// MarkReconnecting marks a forward as reconnecting (called by worker)
func (c *Checker) MarkReconnecting(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusReconnect
health.LastCheck = time.Now()
// Notify if status changed
if oldStatus != StatusReconnect {
c.mu.Unlock()
c.notifyStatusChange(forwardID, StatusReconnect, "")
c.mu.Lock()
}
}
c.markStatus(forwardID, StatusReconnect)
}
// MarkStarting marks a forward as starting (called by worker)
func (c *Checker) MarkStarting(forwardID string) {
c.mu.Lock()
defer c.mu.Unlock()
if health, exists := c.ports[forwardID]; exists {
oldStatus := health.Status
health.Status = StatusStarting
health.LastCheck = time.Now()
// Notify if status changed
if oldStatus != StatusStarting {
c.mu.Unlock()
c.notifyStatusChange(forwardID, StatusStarting, "")
c.mu.Lock()
}
}
c.markStatus(forwardID, StatusStarting)
}
// GetStatus returns the current health status of a forward
@@ -131,6 +193,17 @@ func (c *Checker) GetStatus(forwardID string) (Status, bool) {
return StatusUnhealthy, false
}
// GetLastCheckTime returns the last health check time for a forward
func (c *Checker) GetLastCheckTime(forwardID string) (time.Time, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if health, exists := c.ports[forwardID]; exists {
return health.LastCheck, true
}
return time.Time{}, false
}
// GetAllErrors returns all forwards with errors and their error messages
func (c *Checker) GetAllErrors() map[string]string {
c.mu.RLock()
@@ -191,38 +264,64 @@ func (c *Checker) checkPort(forwardID string) {
port := health.Port
oldStatus := health.Status
registeredAt := health.RegisteredAt
connectionTime := health.ConnectionTime
lastActivity := health.LastActivity
c.mu.RUnlock()
// Attempt to connect to the local port
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
now := time.Now()
newStatus := StatusHealthy
errorMsg := ""
if err != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := time.Since(registeredAt)
if timeSinceStart < 10*time.Second {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = err.Error()
// Check for stale connections based on age or idle time
connectionAge := now.Sub(connectionTime)
idleTime := now.Sub(lastActivity)
// Only enforce max connection age if the connection is ALSO idle
// This prevents interrupting active transfers (e.g., database dumps)
if c.maxConnectionAge > 0 && connectionAge > c.maxConnectionAge && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("connection age %v exceeds max %v (and idle for %v)",
connectionAge.Round(time.Second), c.maxConnectionAge, idleTime.Round(time.Second))
} else if c.maxIdleTime > 0 && idleTime > c.maxIdleTime {
newStatus = StatusStale
errorMsg = fmt.Sprintf("idle time %v exceeds max %v", idleTime.Round(time.Second), c.maxIdleTime)
} else {
conn.Close()
// Perform connectivity check
var checkErr error
switch c.method {
case CheckMethodDataTransfer:
checkErr = c.checkDataTransfer(port)
case CheckMethodTCPDial:
checkErr = c.checkTCPDial(port)
default:
checkErr = c.checkTCPDial(port)
}
if checkErr != nil {
// Grace period: if forward is less than 10 seconds old, keep it as "Starting"
// This avoids scary "Error" messages during initial connection attempts
timeSinceStart := now.Sub(registeredAt)
if timeSinceStart < startupGracePeriod {
newStatus = StatusStarting
} else {
newStatus = StatusUnhealthy
}
errorMsg = checkErr.Error()
}
}
// Update health status
c.mu.Lock()
if health, exists := c.ports[forwardID]; exists {
health.Status = newStatus
health.LastCheck = time.Now()
health.LastCheck = now
health.ErrorMessage = errorMsg
// Successful health check indicates connection is active
// This prevents false positives where healthy connections are marked as idle
if newStatus == StatusHealthy {
health.LastActivity = now
}
}
c.mu.Unlock()
@@ -232,6 +331,62 @@ func (c *Checker) checkPort(forwardID string) {
}
}
// checkTCPDial performs a simple TCP dial test
func (c *Checker) checkTCPDial(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
conn.Close()
return nil
}
// checkDataTransfer attempts to read data from the connection to verify tunnel health
func (c *Checker) checkDataTransfer(port int) error {
ctx, cancel := context.WithTimeout(c.ctx, c.timeout)
defer cancel()
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
return err
}
defer conn.Close()
// Set a short read deadline to detect hung connections
// We don't expect to receive data, but we want to verify the connection isn't hung
conn.SetReadDeadline(time.Now().Add(c.timeout))
// Try to read a small amount of data
// Most servers will either:
// 1. Send a banner (SSH, FTP, etc) - we'll read it successfully
// 2. Wait for client to send first (HTTP, postgres) - we'll timeout (which is OK)
// 3. Hung/stale connection - will timeout with different error
buf := make([]byte, dataTransferSize)
_, err = conn.Read(buf)
// We expect either:
// - No error (banner received)
// - EOF (connection closed by server after connect)
// - Timeout (server waiting for client)
// All of these indicate the tunnel is working
if err == nil || err == io.EOF {
return nil
}
// Timeout is acceptable - server is waiting for us to send data first
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
return nil
}
// Other errors indicate a problem
return fmt.Errorf("data transfer check failed: %w", err)
}
// notifyStatusChange calls the callback for a forward
func (c *Checker) notifyStatusChange(forwardID string, status Status, errorMsg string) {
c.mu.RLock()
+551
View File
@@ -0,0 +1,551 @@
package healthcheck
import (
"fmt"
"net"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// HealthCheckTestSuite contains tests for the health checker
type HealthCheckTestSuite struct {
suite.Suite
checker *Checker
listener net.Listener
port int
}
func TestHealthCheckSuite(t *testing.T) {
suite.Run(t, new(HealthCheckTestSuite))
}
func (s *HealthCheckTestSuite) SetupTest() {
// Create a test listener on a random port
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
s.listener = ln
s.port = ln.Addr().(*net.TCPAddr).Port
// Create checker with fast intervals for testing
s.checker = NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 500 * time.Millisecond,
MaxIdleTime: 300 * time.Millisecond,
})
}
func (s *HealthCheckTestSuite) TearDownTest() {
if s.checker != nil {
s.checker.Stop()
}
if s.listener != nil {
s.listener.Close()
}
}
// TestRegisterAndUnregister tests basic registration and unregistration
func (s *HealthCheckTestSuite) TestRegisterAndUnregister() {
callbackCalled := false
var callbackStatus Status
var mu sync.Mutex
callback := func(forwardID string, status Status, errorMsg string) {
mu.Lock()
defer mu.Unlock()
callbackCalled = true
callbackStatus = status
}
// Register port
s.checker.Register("test-forward", s.port, callback)
// Wait for health check to run
time.Sleep(200 * time.Millisecond)
// Verify callback was called with healthy status
mu.Lock()
assert.True(s.T(), callbackCalled, "Callback should have been called")
assert.Equal(s.T(), StatusHealthy, callbackStatus)
mu.Unlock()
// Unregister
s.checker.Unregister("test-forward")
// Verify port is no longer monitored
status, exists := s.checker.GetStatus("test-forward")
assert.False(s.T(), exists, "Port should no longer exist after unregister")
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestTCPDialMethod tests the TCP dial health check method
func (s *HealthCheckTestSuite) TestTCPDialMethod() {
tests := []struct {
name string
setupPort bool
expectedStatus Status
description string
}{
{
name: "port available - healthy",
setupPort: true,
expectedStatus: StatusHealthy,
description: "When port is listening, status should be healthy",
},
{
name: "port unavailable - unhealthy",
setupPort: false,
expectedStatus: StatusUnhealthy,
description: "When port is not listening, status should be unhealthy",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
if tt.setupPort {
// Use the existing listener
testPort = s.port
} else {
// Use a port that's not listening
testPort = 54321 // Likely unused port
}
// Create a new checker for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if !tt.setupPort {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(200 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status, tt.description)
if testListener != nil {
testListener.Close()
}
})
}
}
// TestDataTransferMethod tests the data transfer health check method
func (s *HealthCheckTestSuite) TestDataTransferMethod() {
tests := []struct {
name string
serverBehavior string // "banner", "silent", "close", "none"
expectedStatus Status
}{
{
name: "server sends banner - healthy",
serverBehavior: "banner",
expectedStatus: StatusHealthy,
},
{
name: "server waits silently - healthy (timeout OK)",
serverBehavior: "silent",
expectedStatus: StatusHealthy,
},
{
name: "server closes connection - healthy (EOF OK)",
serverBehavior: "close",
expectedStatus: StatusHealthy,
},
{
name: "no server listening - unhealthy",
serverBehavior: "none",
expectedStatus: StatusUnhealthy,
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
var testPort int
var testListener net.Listener
var err error
if tt.serverBehavior != "none" {
// Start test server
testListener, err = net.Listen("tcp", "127.0.0.1:0")
require.NoError(s.T(), err)
testPort = testListener.Addr().(*net.TCPAddr).Port
// Handle connections based on behavior
go func() {
for {
conn, err := testListener.Accept()
if err != nil {
return
}
switch tt.serverBehavior {
case "banner":
conn.Write([]byte("220 Welcome\r\n"))
time.Sleep(50 * time.Millisecond)
conn.Close()
case "close":
conn.Close()
case "silent":
// Just keep connection open
time.Sleep(200 * time.Millisecond)
conn.Close()
}
}
}()
defer testListener.Close()
} else {
testPort = 54322 // Unused port
}
// Create checker with data transfer method
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodDataTransfer,
MaxConnectionAge: 0, // Disable for this test
MaxIdleTime: 0, // Disable for this test
})
defer checker.Stop()
checker.Register("test-forward", testPort, nil)
// Wait for health checks to complete
if tt.serverBehavior == "none" {
// For unhealthy case, wait for grace period
time.Sleep(startupGracePeriod + 200*time.Millisecond)
} else {
time.Sleep(300 * time.Millisecond)
}
// Check status directly
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), tt.expectedStatus, status)
})
}
}
// TestConnectionAgeDetection tests max connection age detection
func (s *HealthCheckTestSuite) TestConnectionAgeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max connection age
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 150 * time.Millisecond, // Very short for testing
MaxIdleTime: 0, // Disable idle detection
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait for initial healthy status
var gotHealthy, gotStale bool
timeout := time.After(1 * time.Second)
for {
select {
case status := <-statusChanges:
if status == StatusHealthy || status == StatusStarting {
gotHealthy = true
}
if status == StatusStale {
gotStale = true
}
if gotHealthy && gotStale {
return // Test passed
}
case <-timeout:
s.T().Fatalf("Expected StatusStale after max connection age exceeded. gotHealthy=%v, gotStale=%v",
gotHealthy, gotStale)
}
}
}
// TestIdleTimeDetection tests that connections with passing health checks are NOT marked as stale
// This verifies that successful health checks update LastActivity, preventing false idle detection
func (s *HealthCheckTestSuite) TestIdleTimeDetection() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
// Create checker with very short max idle time
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0, // Disable age detection
MaxIdleTime: 150 * time.Millisecond, // Very short for testing
})
defer checker.Stop()
checker.Register("test-forward", s.port, callback)
// Wait long enough that idle time WOULD be exceeded if health checks didn't update LastActivity
time.Sleep(500 * time.Millisecond)
// Verify connection is still healthy, not stale
// This proves that successful health checks are updating LastActivity
status, exists := checker.GetStatus("test-forward")
require.True(s.T(), exists)
assert.Equal(s.T(), StatusHealthy, status, "Connection with passing health checks should NOT be marked as stale")
// Verify we never received a StatusStale callback
select {
case status := <-statusChanges:
if status == StatusStale {
s.T().Fatal("Connection should NOT be marked as stale when health checks are passing")
}
default:
// No stale status - this is correct
}
}
// TestMarkConnected tests that MarkConnected resets connection time
func (s *HealthCheckTestSuite) TestMarkConnected() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 200 * time.Millisecond,
MaxIdleTime: 0,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Wait a bit
time.Sleep(100 * time.Millisecond)
// Mark as reconnected (resets connection time)
checker.MarkConnected("test-forward")
// Wait for connection age to exceed (relative to first connection time)
time.Sleep(200 * time.Millisecond)
// Check status - should still be healthy because we reset connection time
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Note: Might be StatusStale by now, but the key is that MarkConnected delayed it
// This is a timing-sensitive test, so we just verify the functionality exists
_ = status
}
// TestRecordActivity tests that RecordActivity resets idle time
func (s *HealthCheckTestSuite) TestRecordActivity() {
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 200 * time.Millisecond,
})
defer checker.Stop()
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
checker.Register("test-forward", s.port, callback)
// Periodically record activity to prevent idle detection
ticker := time.NewTicker(80 * time.Millisecond)
defer ticker.Stop()
go func() {
for i := 0; i < 5; i++ {
<-ticker.C
checker.RecordActivity("test-forward")
}
}()
// Wait longer than idle timeout
time.Sleep(500 * time.Millisecond)
// Should still be healthy due to activity
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// May transition to stale eventually, but activity recording should have delayed it
_ = status
}
// TestMarkReconnecting tests the MarkReconnecting functionality
func (s *HealthCheckTestSuite) TestMarkReconnecting() {
statusChanges := make(chan Status, 10)
callback := func(forwardID string, status Status, errorMsg string) {
statusChanges <- status
}
s.checker.Register("test-forward", s.port, callback)
// Wait for initial status
time.Sleep(150 * time.Millisecond)
// Mark as reconnecting
s.checker.MarkReconnecting("test-forward")
// Should receive reconnecting status
timeout := time.After(500 * time.Millisecond)
gotReconnect := false
for !gotReconnect {
select {
case status := <-statusChanges:
if status == StatusReconnect {
gotReconnect = true
}
case <-timeout:
s.T().Fatal("Expected StatusReconnect")
}
}
}
// TestStartingGracePeriod tests that errors during grace period show as "Starting"
func (s *HealthCheckTestSuite) TestStartingGracePeriod() {
// Use a port that's not listening
unavailablePort := 54323
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 50 * time.Millisecond,
Timeout: 25 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register without callback - we'll check status directly
checker.Register("test-forward", unavailablePort, nil)
// Immediately check status - should be Starting or not yet checked
status, exists := checker.GetStatus("test-forward")
assert.True(s.T(), exists)
// Initially should be Starting
assert.Equal(s.T(), StatusStarting, status)
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 200*time.Millisecond)
// Now should be Unhealthy
status, exists = checker.GetStatus("test-forward")
assert.True(s.T(), exists)
assert.Equal(s.T(), StatusUnhealthy, status)
}
// TestGetAllErrors tests retrieving all error messages
func (s *HealthCheckTestSuite) TestGetAllErrors() {
// Create a new checker with faster intervals for this test
checker := NewCheckerWithOptions(CheckerOptions{
Interval: 100 * time.Millisecond,
Timeout: 50 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 0,
MaxIdleTime: 0,
})
defer checker.Stop()
// Register multiple forwards
checker.Register("forward1", s.port, nil)
checker.Register("forward2", 54324, nil) // Unavailable port
// Wait for grace period to expire
time.Sleep(startupGracePeriod + 300*time.Millisecond)
errors := checker.GetAllErrors()
// forward2 should have an error
_, hasError := errors["forward2"]
assert.True(s.T(), hasError, "forward2 should have an error")
// forward1 should not have an error
_, hasError = errors["forward1"]
assert.False(s.T(), hasError, "forward1 should not have an error")
}
// TestConcurrentOperations tests thread safety
func (s *HealthCheckTestSuite) TestConcurrentOperations() {
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
forwardID := fmt.Sprintf("forward-%d", id)
s.checker.Register(forwardID, s.port, nil)
time.Sleep(50 * time.Millisecond)
s.checker.MarkConnected(forwardID)
s.checker.RecordActivity(forwardID)
status, _ := s.checker.GetStatus(forwardID)
_ = status
s.checker.Unregister(forwardID)
}(i)
}
wg.Wait()
// If we get here without deadlocks or panics, test passes
}
// TestDefaultOptions tests that NewChecker uses sensible defaults
func TestDefaultOptions(t *testing.T) {
checker := NewChecker(5*time.Second, 2*time.Second)
defer checker.Stop()
assert.Equal(t, 5*time.Second, checker.interval)
assert.Equal(t, 2*time.Second, checker.timeout)
assert.Equal(t, CheckMethodDataTransfer, checker.method)
assert.Equal(t, 25*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 10*time.Minute, checker.maxIdleTime)
}
// TestCustomOptions tests NewCheckerWithOptions
func TestCustomOptions(t *testing.T) {
opts := CheckerOptions{
Interval: 1 * time.Second,
Timeout: 500 * time.Millisecond,
Method: CheckMethodTCPDial,
MaxConnectionAge: 5 * time.Minute,
MaxIdleTime: 2 * time.Minute,
}
checker := NewCheckerWithOptions(opts)
defer checker.Stop()
assert.Equal(t, 1*time.Second, checker.interval)
assert.Equal(t, 500*time.Millisecond, checker.timeout)
assert.Equal(t, CheckMethodTCPDial, checker.method)
assert.Equal(t, 5*time.Minute, checker.maxConnectionAge)
assert.Equal(t, 2*time.Minute, checker.maxIdleTime)
}
+302
View File
@@ -0,0 +1,302 @@
package k8s
import (
"context"
"fmt"
"net"
"sort"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// Discovery provides cluster introspection capabilities for the UI wizards.
// It queries the Kubernetes API to list contexts, namespaces, pods, and services.
type Discovery struct {
pool *ClientPool
}
// NewDiscovery creates a new Discovery instance using the provided client pool.
func NewDiscovery(pool *ClientPool) *Discovery {
return &Discovery{
pool: pool,
}
}
// PodInfo contains information about a pod relevant for port forwarding.
type PodInfo struct {
Name string
Namespace string
Containers []ContainerInfo
Status string
Created metav1.Time
}
// ContainerInfo contains information about a container within a pod.
type ContainerInfo struct {
Name string
Ports []PortInfo
}
// PortInfo describes a port exposed by a container or service.
type PortInfo struct {
Name string
Port int32
Protocol string
}
// ServiceInfo contains information about a service.
type ServiceInfo struct {
Name string
Namespace string
Ports []PortInfo
Type string
}
// ListContexts returns all available Kubernetes contexts from kubeconfig.
func (d *Discovery) ListContexts() ([]string, error) {
return d.pool.ListContexts()
}
// GetCurrentContext returns the name of the current context from kubeconfig.
func (d *Discovery) GetCurrentContext() (string, error) {
return d.pool.GetCurrentContext()
}
// ListNamespaces returns all namespaces in the given context.
// Returns an error if the context is invalid or unreachable.
func (d *Discovery) ListNamespaces(ctx context.Context, contextName string) ([]string, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
nsList, err := client.CoreV1().Namespaces().List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list namespaces: %w", err)
}
namespaces := make([]string, 0, len(nsList.Items))
for _, ns := range nsList.Items {
namespaces = append(namespaces, ns.Name)
}
// Sort alphabetically
sort.Strings(namespaces)
return namespaces, nil
}
// ListPods returns all running pods in the given namespace with their port information.
// Only returns pods in Running or Pending state.
func (d *Discovery) ListPods(ctx context.Context, contextName, namespace string) ([]PodInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
podList, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
pods := make([]PodInfo, 0)
for _, pod := range podList.Items {
// Only include Running or Pending pods
if pod.Status.Phase != corev1.PodRunning && pod.Status.Phase != corev1.PodPending {
continue
}
containers := make([]ContainerInfo, 0, len(pod.Spec.Containers))
for _, container := range pod.Spec.Containers {
ports := make([]PortInfo, 0, len(container.Ports))
for _, port := range container.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.ContainerPort,
Protocol: string(port.Protocol),
})
}
containers = append(containers, ContainerInfo{
Name: container.Name,
Ports: ports,
})
}
pods = append(pods, PodInfo{
Name: pod.Name,
Namespace: pod.Namespace,
Containers: containers,
Status: string(pod.Status.Phase),
Created: pod.CreationTimestamp,
})
}
// Sort by creation time (newest first)
sort.Slice(pods, func(i, j int) bool {
return pods[i].Created.After(pods[j].Created.Time)
})
return pods, nil
}
// ListPodsWithSelector returns pods matching the given label selector.
// Selector format: "key=value,key2=value2"
// Returns an error if the selector is invalid.
func (d *Discovery) ListPodsWithSelector(ctx context.Context, contextName, namespace, selector string) ([]PodInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
// Validate selector format
selector = strings.TrimSpace(selector)
if selector == "" {
return nil, fmt.Errorf("selector cannot be empty")
}
podList, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
})
if err != nil {
return nil, fmt.Errorf("failed to list pods with selector: %w", err)
}
pods := make([]PodInfo, 0)
for _, pod := range podList.Items {
// Only include Running pods for selector-based forwards
if pod.Status.Phase != corev1.PodRunning {
continue
}
containers := make([]ContainerInfo, 0, len(pod.Spec.Containers))
for _, container := range pod.Spec.Containers {
ports := make([]PortInfo, 0, len(container.Ports))
for _, port := range container.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.ContainerPort,
Protocol: string(port.Protocol),
})
}
containers = append(containers, ContainerInfo{
Name: container.Name,
Ports: ports,
})
}
pods = append(pods, PodInfo{
Name: pod.Name,
Namespace: pod.Namespace,
Containers: containers,
Status: string(pod.Status.Phase),
Created: pod.CreationTimestamp,
})
}
// Sort by creation time (newest first)
sort.Slice(pods, func(i, j int) bool {
return pods[i].Created.After(pods[j].Created.Time)
})
return pods, nil
}
// ListServices returns all services in the given namespace.
func (d *Discovery) ListServices(ctx context.Context, contextName, namespace string) ([]ServiceInfo, error) {
client, err := d.pool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
svcList, err := client.CoreV1().Services(namespace).List(ctx, metav1.ListOptions{})
if err != nil {
return nil, fmt.Errorf("failed to list services: %w", err)
}
services := make([]ServiceInfo, 0, len(svcList.Items))
for _, svc := range svcList.Items {
ports := make([]PortInfo, 0, len(svc.Spec.Ports))
for _, port := range svc.Spec.Ports {
ports = append(ports, PortInfo{
Name: port.Name,
Port: port.Port,
Protocol: string(port.Protocol),
})
}
services = append(services, ServiceInfo{
Name: svc.Name,
Namespace: svc.Namespace,
Ports: ports,
Type: string(svc.Spec.Type),
})
}
// Sort alphabetically
sort.Slice(services, func(i, j int) bool {
return services[i].Name < services[j].Name
})
return services, nil
}
// GetUniquePorts extracts unique ports from a list of pods.
// Returns a sorted list of port numbers with their names (if available).
func GetUniquePorts(pods []PodInfo) []PortInfo {
portMap := make(map[int32]string)
for _, pod := range pods {
for _, container := range pod.Containers {
for _, port := range container.Ports {
// Prefer named ports
if _, ok := portMap[port.Port]; !ok || port.Name != "" {
if port.Name != "" {
portMap[port.Port] = port.Name
} else if !ok {
portMap[port.Port] = fmt.Sprintf("port-%d", port.Port)
}
}
}
}
}
// Convert to slice
ports := make([]PortInfo, 0, len(portMap))
for port, name := range portMap {
ports = append(ports, PortInfo{
Name: name,
Port: port,
})
}
// Sort by port number
sort.Slice(ports, func(i, j int) bool {
return ports[i].Port < ports[j].Port
})
return ports
}
// CheckPortAvailability checks if a local port is available.
// Returns: available (bool), processInfo (string), error
func CheckPortAvailability(port int) (bool, string, error) {
if port < 1 || port > 65535 {
return false, "", fmt.Errorf("invalid port: %d", port)
}
// Try to listen on the port
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
// Port is in use - return error details
return false, err.Error(), nil
}
// Port is available, close the listener
listener.Close()
return true, "", nil
}
+42 -5
View File
@@ -4,9 +4,13 @@ import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/nvm/kportal/internal/config"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -17,18 +21,32 @@ import (
// PortForwarder handles Kubernetes port-forwarding operations.
type PortForwarder struct {
clientPool *ClientPool
resolver *ResourceResolver
clientPool *ClientPool
resolver *ResourceResolver
tcpKeepalive time.Duration // TCP keepalive interval
dialTimeout time.Duration // Connection dial timeout
}
// NewPortForwarder creates a new PortForwarder instance.
// NewPortForwarder creates a new PortForwarder instance with default settings.
func NewPortForwarder(clientPool *ClientPool, resolver *ResourceResolver) *PortForwarder {
return &PortForwarder{
clientPool: clientPool,
resolver: resolver,
clientPool: clientPool,
resolver: resolver,
tcpKeepalive: config.DefaultTCPKeepalive,
dialTimeout: config.DefaultDialTimeout,
}
}
// SetTCPKeepalive configures the TCP keepalive interval for new connections.
func (pf *PortForwarder) SetTCPKeepalive(keepalive time.Duration) {
pf.tcpKeepalive = keepalive
}
// SetDialTimeout configures the connection dial timeout.
func (pf *PortForwarder) SetDialTimeout(timeout time.Duration) {
pf.dialTimeout = timeout
}
// ForwardRequest contains the parameters for a port-forward request.
type ForwardRequest struct {
ContextName string // Kubernetes context name
@@ -124,6 +142,9 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
}
// Get pods backing the service using label selector
if len(service.Spec.Selector) == 0 {
return fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", serviceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(req.Namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
@@ -164,6 +185,19 @@ func (pf *PortForwarder) forwardToService(ctx context.Context, req *ForwardReque
// executePortForward performs the actual port-forward operation.
func (pf *PortForwarder) executePortForward(config *rest.Config, url *url.URL, req *ForwardRequest) error {
// Configure TCP settings on the underlying connection
// This is set in the rest.Config which will be used by the SPDY transport
if config.Dial == nil {
// Create a custom dialer with configurable timeout and keepalive
// - Timeout: How long to wait for connection to establish
// - KeepAlive: TCP keepalive helps OS detect dead connections at network layer
dialer := &net.Dialer{
Timeout: pf.dialTimeout, // Configurable dial timeout
KeepAlive: pf.tcpKeepalive, // Configurable keepalive interval
}
config.Dial = dialer.DialContext
}
// Create SPDY roundtripper
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
@@ -228,6 +262,9 @@ func (pf *PortForwarder) GetPodForResource(ctx context.Context, contextName, nam
return "", fmt.Errorf("failed to get service: %w", err)
}
if len(service.Spec.Selector) == 0 {
return "", fmt.Errorf("service %s has no selector (headless service without selector cannot be port-forwarded)", resourceName)
}
selector := metav1.FormatLabelSelector(&metav1.LabelSelector{MatchLabels: service.Spec.Selector})
pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{
LabelSelector: selector,
-26
View File
@@ -228,29 +228,3 @@ func (r *ResourceResolver) InvalidateCache(contextName, namespace, resource stri
}
}
}
// GetPodList returns a list of pods matching the given criteria.
// This is useful for debugging and testing.
func (r *ResourceResolver) GetPodList(ctx context.Context, contextName, namespace, selector string) ([]*corev1.Pod, error) {
client, err := r.clientPool.GetClient(contextName)
if err != nil {
return nil, fmt.Errorf("failed to get client: %w", err)
}
listOptions := metav1.ListOptions{}
if selector != "" {
listOptions.LabelSelector = selector
}
pods, err := client.CoreV1().Pods(namespace).List(ctx, listOptions)
if err != nil {
return nil, fmt.Errorf("failed to list pods: %w", err)
}
result := make([]*corev1.Pod, len(pods.Items))
for i := range pods.Items {
result[i] = &pods.Items[i]
}
return result, nil
}
+70
View File
@@ -0,0 +1,70 @@
package logger_test
import (
"bytes"
"fmt"
"testing"
"github.com/nvm/kportal/internal/logger"
)
// This test demonstrates the logger output formats
func TestLoggerDemo(t *testing.T) {
t.Skip("Demo only - run manually with: go test -v -run TestLoggerDemo")
fmt.Println("\n=== TEXT FORMAT (DEFAULT) ===")
textBuf := &bytes.Buffer{}
textLogger := logger.New(logger.LevelInfo, logger.FormatText, textBuf)
textLogger.Info("Port forward started", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"local_port": 8080,
"pod": "app-xyz123",
})
textLogger.Warn("Connection failed, retrying", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "connection refused",
"retry": 3,
})
textLogger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "pod not found",
})
fmt.Print(textBuf.String())
fmt.Println("\n=== JSON FORMAT ===")
jsonBuf := &bytes.Buffer{}
jsonLogger := logger.New(logger.LevelInfo, logger.FormatJSON, jsonBuf)
jsonLogger.Info("Port forward started", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"local_port": 8080,
"pod": "app-xyz123",
})
jsonLogger.Warn("Connection failed, retrying", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "connection refused",
"retry": 3,
})
jsonLogger.Error("Failed to resolve resource", map[string]interface{}{
"forward_id": "prod/default/pod/app:8080",
"error": "pod not found",
})
fmt.Print(jsonBuf.String())
fmt.Println("\n=== LOG LEVEL FILTERING (Debug level disabled) ===")
filteredBuf := &bytes.Buffer{}
filteredLogger := logger.New(logger.LevelInfo, logger.FormatText, filteredBuf)
filteredLogger.Debug("This will not appear", nil)
filteredLogger.Info("This will appear", nil)
filteredLogger.Warn("This will also appear", nil)
fmt.Print(filteredBuf.String())
}
+96
View File
@@ -0,0 +1,96 @@
package logger
import (
"bytes"
"io"
"strings"
"sync"
)
// KlogWriter is an io.Writer that routes klog output through our structured logger.
// It parses klog messages and routes them to appropriate log levels.
// It is thread-safe for concurrent writes.
type KlogWriter struct {
logger *Logger
buffer *bytes.Buffer
mu sync.Mutex
}
// NewKlogWriter creates a new KlogWriter that routes k8s client-go logs
// through our structured logger.
func NewKlogWriter(logger *Logger) *KlogWriter {
return &KlogWriter{
logger: logger,
buffer: &bytes.Buffer{},
}
}
// Write implements io.Writer.
// It parses klog output and routes it through our structured logger.
// This method is thread-safe.
func (w *KlogWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// Write to buffer first
w.buffer.Write(p)
// Process complete lines
for {
line, err := w.buffer.ReadString('\n')
if err != nil {
// No complete line yet, write back what we read and wait for more
if err == io.EOF && line != "" {
w.buffer.WriteString(line)
}
break
}
// Process the complete line
w.processLine(strings.TrimSpace(line))
}
return len(p), nil
}
// processLine parses a klog line and routes it to the appropriate log level.
func (w *KlogWriter) processLine(line string) {
if line == "" {
return
}
// Parse klog format: "I1124 12:34:56.789012 12345 file.go:123] message"
// First character indicates level: I=Info, W=Warning, E=Error, F=Fatal
if len(line) < 1 {
return
}
level := line[0]
message := line
// Try to extract just the message part after "]"
if idx := strings.Index(line, "] "); idx != -1 {
message = line[idx+2:]
}
// Determine log level and route accordingly
switch level {
case 'I': // Info
w.logger.Debug(message, map[string]interface{}{
"source": "k8s-client",
})
case 'W': // Warning
w.logger.Warn(message, map[string]interface{}{
"source": "k8s-client",
})
case 'E', 'F': // Error or Fatal
w.logger.Error(message, map[string]interface{}{
"source": "k8s-client",
})
default:
// Unknown format, log as debug
w.logger.Debug(message, map[string]interface{}{
"source": "k8s-client",
})
}
}
+280
View File
@@ -0,0 +1,280 @@
package logger
import (
"bytes"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestKlogWriter(t *testing.T) {
tests := []struct {
name string
input string
expectedLevel string
expectedMsg string
loggerLevel Level
loggerFormat Format
shouldLog bool
description string
}{
{
name: "info level log",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Starting port forward\n",
expectedLevel: "DEBUG",
expectedMsg: "Starting port forward",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Info logs from k8s should be routed as DEBUG",
},
{
name: "warning level log",
input: "W1124 12:34:56.789012 12345 portforward.go:456] Connection unstable\n",
expectedLevel: "WARN",
expectedMsg: "Connection unstable",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Warning logs should be routed as WARN",
},
{
name: "error level log",
input: "E1124 12:34:56.789012 12345 portforward.go:789] Connection failed\n",
expectedLevel: "ERROR",
expectedMsg: "Connection failed",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Error logs should be routed as ERROR",
},
{
name: "fatal level log",
input: "F1124 12:34:56.789012 12345 portforward.go:999] Fatal error\n",
expectedLevel: "ERROR",
expectedMsg: "Fatal error",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Fatal logs should be routed as ERROR",
},
{
name: "multiline input",
input: "I1124 12:34:56.789012 12345 portforward.go:123] First message\nI1124 12:34:57.123456 12345 portforward.go:124] Second message\n",
expectedLevel: "DEBUG",
expectedMsg: "First message",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Should handle multiple log lines",
},
{
name: "log filtered by level",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Debug message\n",
expectedLevel: "DEBUG",
expectedMsg: "Debug message",
loggerLevel: LevelInfo, // Logger set to INFO, DEBUG should be filtered
loggerFormat: FormatText,
shouldLog: false,
description: "DEBUG logs should be filtered when logger level is INFO",
},
{
name: "unknown log format",
input: "X1124 12:34:56.789012 12345 portforward.go:123] Unknown format\n",
expectedLevel: "DEBUG",
expectedMsg: "Unknown format",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: true,
description: "Unknown format should default to DEBUG",
},
{
name: "empty line",
input: "\n",
expectedLevel: "",
expectedMsg: "",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: false,
description: "Empty lines should be ignored",
},
{
name: "partial line no newline",
input: "I1124 12:34:56.789012 12345 portforward.go:123] Partial",
expectedLevel: "",
expectedMsg: "",
loggerLevel: LevelDebug,
loggerFormat: FormatText,
shouldLog: false,
description: "Partial lines without newline should be buffered",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create output buffer
var buf bytes.Buffer
// Create logger with specified level and format
logger := New(tt.loggerLevel, tt.loggerFormat, &buf)
// Create klog writer
klogWriter := NewKlogWriter(logger)
// Write input
n, err := klogWriter.Write([]byte(tt.input))
require.NoError(t, err)
assert.Equal(t, len(tt.input), n)
// Check output
output := buf.String()
if !tt.shouldLog {
assert.Empty(t, output, "Expected no log output")
return
}
if tt.loggerFormat == FormatText {
// Text format: [LEVEL] message
assert.Contains(t, output, fmt.Sprintf("[%s]", tt.expectedLevel))
assert.Contains(t, output, tt.expectedMsg)
assert.Contains(t, output, "k8s-client") // Should include source field
} else {
// JSON format
var entry logEntry
lines := strings.Split(strings.TrimSpace(output), "\n")
if len(lines) > 0 {
err := json.Unmarshal([]byte(lines[0]), &entry)
require.NoError(t, err)
assert.Equal(t, tt.expectedLevel, entry.Level)
assert.Equal(t, tt.expectedMsg, entry.Message)
assert.Equal(t, "k8s-client", entry.Fields["source"])
}
}
})
}
}
func TestKlogWriterBuffering(t *testing.T) {
tests := []struct {
name string
writes []string
expectCount int
description string
}{
{
name: "single complete line",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] Complete line\n",
},
expectCount: 1,
description: "Single complete line should produce one log entry",
},
{
name: "partial then complete",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] Partial ",
"line\n",
},
expectCount: 1,
description: "Partial writes should be buffered and combined",
},
{
name: "multiple complete lines in chunks",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] First\n",
"I1124 12:34:57.123456 12345 portforward.go:124] Second\n",
"I1124 12:34:58.456789 12345 portforward.go:125] Third\n",
},
expectCount: 3,
description: "Multiple complete lines should produce multiple log entries",
},
{
name: "mixed partial and complete",
writes: []string{
"I1124 12:34:56.789012 12345 portforward.go:123] First\nI1124 12:34:57.123456 12345 port",
"forward.go:124] Second\n",
},
expectCount: 2,
description: "Mixed partial and complete lines should be handled correctly",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, FormatText, &buf)
klogWriter := NewKlogWriter(logger)
// Write all chunks
for _, write := range tt.writes {
_, err := klogWriter.Write([]byte(write))
require.NoError(t, err)
}
// Count log entries (each line starts with [LEVEL])
output := buf.String()
count := strings.Count(output, "[DEBUG]") +
strings.Count(output, "[INFO]") +
strings.Count(output, "[WARN]") +
strings.Count(output, "[ERROR]")
assert.Equal(t, tt.expectCount, count, "Expected %d log entries, got %d", tt.expectCount, count)
})
}
}
func TestKlogWriterJSONFormat(t *testing.T) {
var buf bytes.Buffer
logger := New(LevelDebug, FormatJSON, &buf)
klogWriter := NewKlogWriter(logger)
// Write a k8s log line
input := "I1124 12:34:56.789012 12345 portforward.go:123] Starting port forward\n"
_, err := klogWriter.Write([]byte(input))
require.NoError(t, err)
// Parse JSON output
var entry logEntry
err = json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
// Verify JSON structure
assert.Equal(t, "DEBUG", entry.Level)
assert.Equal(t, "Starting port forward", entry.Message)
assert.NotEmpty(t, entry.Time)
assert.Equal(t, "k8s-client", entry.Fields["source"])
}
func TestKlogWriterConcurrency(t *testing.T) {
// Test that concurrent writes don't cause data races
var buf bytes.Buffer
logger := New(LevelDebug, FormatText, &buf)
klogWriter := NewKlogWriter(logger)
done := make(chan bool)
numGoroutines := 10
numWrites := 100
for i := 0; i < numGoroutines; i++ {
go func(id int) {
for j := 0; j < numWrites; j++ {
msg := fmt.Sprintf("I1124 12:34:56.789012 12345 test.go:123] Message from goroutine %d iteration %d\n", id, j)
klogWriter.Write([]byte(msg))
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < numGoroutines; i++ {
<-done
}
// Just verify we didn't panic (data race detector would catch issues)
assert.NotEmpty(t, buf.String())
}
+105
View File
@@ -0,0 +1,105 @@
package logger
import (
"github.com/go-logr/logr"
)
// LogrAdapter implements the logr.LogSink interface to route klog v2 logs
// through our structured logger. This captures ALL klog output including
// error logs, structured logs, and named logger output.
type LogrAdapter struct {
logger *Logger
name string
level int
}
// NewLogrAdapter creates a new logr.LogSink that routes all klog v2 logs
// through our structured logger.
func NewLogrAdapter(logger *Logger) logr.LogSink {
return &LogrAdapter{
logger: logger,
name: "",
level: 0,
}
}
// Init initializes the logger with runtime info (not used in our implementation).
func (l *LogrAdapter) Init(info logr.RuntimeInfo) {
// No-op: we don't need runtime info
}
// Enabled tests whether this LogSink is enabled at the specified V-level.
// We route all logs through our logger's level filtering.
func (l *LogrAdapter) Enabled(level int) bool {
// Map logr V-levels to our levels:
// V(0) = Info level (always enabled if logger level <= Info)
// V(1+) = Debug level (enabled if logger level <= Debug)
if level == 0 {
return l.logger.level <= LevelInfo
}
return l.logger.level <= LevelDebug
}
// Info logs a non-error message with the given key/value pairs.
func (l *LogrAdapter) Info(level int, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
// Map logr V-levels to our levels:
// V(0) = Info, V(1+) = Debug
if level == 0 {
l.logger.Info(msg, fields)
} else {
l.logger.Debug(msg, fields)
}
}
// Error logs an error message with the given key/value pairs.
func (l *LogrAdapter) Error(err error, msg string, keysAndValues ...interface{}) {
fields := l.kvToMap(keysAndValues)
if l.name != "" {
fields["logger"] = l.name
}
if err != nil {
fields["error"] = err.Error()
}
l.logger.Error(msg, fields)
}
// WithValues returns a new LogSink with additional key/value pairs.
func (l *LogrAdapter) WithValues(keysAndValues ...interface{}) logr.LogSink {
// For simplicity, we don't implement value accumulation
// Each log call receives all its keysAndValues directly
return l
}
// WithName returns a new LogSink with the specified name appended.
func (l *LogrAdapter) WithName(name string) logr.LogSink {
newLogger := *l
if l.name == "" {
newLogger.name = name
} else {
newLogger.name = l.name + "." + name
}
return &newLogger
}
// kvToMap converts a slice of alternating keys and values to a map.
func (l *LogrAdapter) kvToMap(keysAndValues []interface{}) map[string]interface{} {
fields := make(map[string]interface{})
fields["source"] = "k8s-client"
for i := 0; i < len(keysAndValues); i += 2 {
if i+1 < len(keysAndValues) {
key, ok := keysAndValues[i].(string)
if ok {
fields[key] = keysAndValues[i+1]
}
}
}
return fields
}
+367
View File
@@ -0,0 +1,367 @@
package logger
import (
"bytes"
"encoding/json"
"errors"
"testing"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogrAdapter_Info(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "info log v0 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
message: "Connection established",
keysAndValues: []interface{}{"pod", "my-app-123", "port", 8080},
expectOutput: true,
expectContains: []string{"[INFO]", "Connection established", "pod", "my-app-123"},
},
{
name: "info log v0 with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Port forward ready",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Port forward ready"},
},
{
name: "info log v0 silenced with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
message: "This should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "debug log v1 with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
message: "Detailed connection info",
keysAndValues: []interface{}{"details", "some-value"},
expectOutput: true,
expectContains: []string{"[DEBUG]", "Detailed connection info", "details"},
},
{
name: "debug log v1 silenced with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
message: "This debug should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "info with odd number of kvs (incomplete pair)",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Message with incomplete kv",
keysAndValues: []interface{}{"key1", "value1", "key2"}, // key2 has no value
expectOutput: true,
expectContains: []string{"[INFO]", "Message with incomplete kv", "key1", "value1"},
},
{
name: "info with source field added automatically",
loggerLevel: LevelInfo,
logrLevel: 0,
message: "Test source field",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[INFO]", "Test source field", "source:k8s-client"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.V(tt.logrLevel).Info(tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_Error(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
err error
message string
keysAndValues []interface{}
expectOutput bool
expectContains []string
}{
{
name: "error with error object",
loggerLevel: LevelError,
err: errors.New("connection failed"),
message: "Port forward failed",
keysAndValues: []interface{}{"pod", "my-app-123"},
expectOutput: true,
expectContains: []string{"[ERROR]", "Port forward failed", "connection failed", "pod", "my-app-123"},
},
{
name: "error without error object",
loggerLevel: LevelError,
err: nil,
message: "Generic error message",
keysAndValues: []interface{}{},
expectOutput: true,
expectContains: []string{"[ERROR]", "Generic error message"},
},
{
name: "error silenced with level above error",
loggerLevel: LevelError + 1,
err: errors.New("should not appear"),
message: "This error should not appear",
keysAndValues: []interface{}{},
expectOutput: false,
expectContains: []string{},
},
{
name: "error with multiple kvs",
loggerLevel: LevelError,
err: errors.New("sandbox not found"),
message: "Unhandled Error",
keysAndValues: []interface{}{"pod", "test-pod", "uid", "abc123", "port", 8080},
expectOutput: true,
expectContains: []string{"[ERROR]", "Unhandled Error", "sandbox not found", "pod", "test-pod", "uid", "abc123"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
logrLogger.Error(tt.err, tt.message, tt.keysAndValues...)
output := buf.String()
if tt.expectOutput {
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Output should contain: %s", expected)
}
} else {
assert.Empty(t, output, "No output expected for this log level")
}
})
}
}
func TestLogrAdapter_WithName(t *testing.T) {
tests := []struct {
name string
loggerNames []string
message string
expectContains string
}{
{
name: "single logger name",
loggerNames: []string{"portforward"},
message: "Test message",
expectContains: "logger:portforward",
},
{
name: "nested logger names",
loggerNames: []string{"controller", "worker", "healthcheck"},
message: "Nested message",
expectContains: "logger:controller.worker.healthcheck",
},
{
name: "no logger name",
loggerNames: []string{},
message: "No name message",
expectContains: "source:k8s-client", // Should still have source but no logger field
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Apply WithName calls
for _, name := range tt.loggerNames {
logrLogger = logrLogger.WithName(name)
}
logrLogger.Info(tt.message)
output := buf.String()
assert.Contains(t, output, tt.expectContains)
})
}
}
func TestLogrAdapter_Enabled(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logrLevel int
expectEnabled bool
}{
{
name: "v0 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 enabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 0,
expectEnabled: true,
},
{
name: "v0 disabled with warn logger",
loggerLevel: LevelWarn,
logrLevel: 0,
expectEnabled: false,
},
{
name: "v1 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 1,
expectEnabled: true,
},
{
name: "v1 disabled with info logger",
loggerLevel: LevelInfo,
logrLevel: 1,
expectEnabled: false,
},
{
name: "v2 enabled with debug logger",
loggerLevel: LevelDebug,
logrLevel: 2,
expectEnabled: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logger := New(tt.loggerLevel, FormatText, &bytes.Buffer{})
sink := NewLogrAdapter(logger)
enabled := sink.Enabled(tt.logrLevel)
assert.Equal(t, tt.expectEnabled, enabled)
})
}
}
func TestLogrAdapter_JSONFormat(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatJSON, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("test-component")
logrLogger.Info("Test JSON message", "key1", "value1", "key2", 123)
// Parse JSON output
var entry logEntry
err := json.Unmarshal(buf.Bytes(), &entry)
require.NoError(t, err)
assert.Equal(t, "INFO", entry.Level)
assert.Equal(t, "Test JSON message", entry.Message)
assert.Equal(t, "k8s-client", entry.Fields["source"])
assert.Equal(t, "test-component", entry.Fields["logger"])
assert.Equal(t, "value1", entry.Fields["key1"])
assert.Equal(t, float64(123), entry.Fields["key2"]) // JSON numbers decode as float64
}
func TestLogrAdapter_ConcurrentWrites(t *testing.T) {
// Note: bytes.Buffer is not thread-safe for writes, so this test verifies
// that our LogrAdapter doesn't panic under concurrent load, but we don't
// verify exact output (since logger uses fmt.Fprintf which is also not thread-safe)
buf := &bytes.Buffer{}
logger := New(LevelDebug, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Spawn multiple goroutines writing concurrently
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(id int) {
for j := 0; j < 100; j++ {
logrLogger.Info("Concurrent message", "goroutine", id, "iteration", j)
}
done <- true
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
output := buf.String()
// Verify we got substantial output (not checking exact count due to buffer race)
// The main goal is to ensure no panics occur during concurrent writes
assert.NotEmpty(t, output, "Should have some log output")
assert.Contains(t, output, "Concurrent message")
}
func TestLogrAdapter_RealWorldKlogError(t *testing.T) {
// Simulate the exact error message from the screenshot
buf := &bytes.Buffer{}
logger := New(LevelError, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink).WithName("UnhandledError")
err := errors.New("an error occurred forwarding 8401 -> 8401: error forwarding port 8401 to pod 4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308, uid : failed to find sandbox '4e1e861c28e3b25a88b082e79788169b5d8a7a117904b7bb8c7cd59285cf1d308' in store: not found")
logrLogger.Error(err, "Unhandled Error")
output := buf.String()
assert.Contains(t, output, "[ERROR]")
assert.Contains(t, output, "Unhandled Error")
assert.Contains(t, output, "failed to find sandbox")
assert.Contains(t, output, "logger:UnhandledError")
}
func TestLogrAdapter_SilenceMode(t *testing.T) {
// Test that logs are completely silenced when logger level is above error
buf := &bytes.Buffer{}
logger := New(LevelError+1, FormatText, buf)
sink := NewLogrAdapter(logger)
logrLogger := logr.New(sink)
// Try all log levels
logrLogger.V(0).Info("Info message should not appear")
logrLogger.V(1).Info("Debug message should not appear")
logrLogger.Error(errors.New("error object"), "Error message should not appear")
output := buf.String()
assert.Empty(t, output, "All logs should be silenced")
}
+164
View File
@@ -0,0 +1,164 @@
package logger
import (
"encoding/json"
"fmt"
"io"
"os"
"sync"
"time"
)
type Level int
const (
LevelDebug Level = iota
LevelInfo
LevelWarn
LevelError
)
type Format int
const (
FormatText Format = iota
FormatJSON
)
type Logger struct {
level Level
format Format
output io.Writer
mu sync.Mutex // Protects concurrent writes to output
}
type logEntry struct {
Time string `json:"time"`
Level string `json:"level"`
Message string `json:"message"`
Fields map[string]interface{} `json:"fields,omitempty"`
}
func New(level Level, format Format, output io.Writer) *Logger {
if output == nil {
output = os.Stderr
}
return &Logger{
level: level,
format: format,
output: output,
}
}
func (l *Logger) log(level Level, msg string, fields map[string]interface{}) {
if level < l.level {
return
}
levelStr := levelToString(level)
l.mu.Lock()
defer l.mu.Unlock()
if l.format == FormatJSON {
entry := logEntry{
Time: time.Now().Format(time.RFC3339),
Level: levelStr,
Message: msg,
Fields: fields,
}
data, _ := json.Marshal(entry)
fmt.Fprintln(l.output, string(data))
} else {
// Text format
if len(fields) > 0 {
fmt.Fprintf(l.output, "[%s] %s %v\n", levelStr, msg, fields)
} else {
fmt.Fprintf(l.output, "[%s] %s\n", levelStr, msg)
}
}
}
func (l *Logger) Debug(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelDebug, msg, f)
}
func (l *Logger) Info(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelInfo, msg, f)
}
func (l *Logger) Warn(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelWarn, msg, f)
}
func (l *Logger) Error(msg string, fields ...map[string]interface{}) {
f := make(map[string]interface{})
if len(fields) > 0 {
f = fields[0]
}
l.log(LevelError, msg, f)
}
func levelToString(level Level) string {
switch level {
case LevelDebug:
return "DEBUG"
case LevelInfo:
return "INFO"
case LevelWarn:
return "WARN"
case LevelError:
return "ERROR"
default:
return "UNKNOWN"
}
}
// Global logger for backward compatibility
var globalLogger *Logger
func Init(level Level, format Format, output ...io.Writer) {
var out io.Writer
if len(output) > 0 && output[0] != nil {
out = output[0]
} else {
out = os.Stderr
}
globalLogger = New(level, format, out)
}
func Debug(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Debug(msg, fields...)
}
}
func Info(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Info(msg, fields...)
}
}
func Warn(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Warn(msg, fields...)
}
}
func Error(msg string, fields ...map[string]interface{}) {
if globalLogger != nil {
globalLogger.Error(msg, fields...)
}
}
+521
View File
@@ -0,0 +1,521 @@
package logger
import (
"bytes"
"encoding/json"
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLoggerTextFormat(t *testing.T) {
tests := []struct {
name string
level Level
logLevel Level
message string
fields map[string]interface{}
expectOutput bool
expectContains []string
}{
{
name: "info logged at info level",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: nil,
expectOutput: true,
expectContains: []string{"[INFO]", "test message"},
},
{
name: "debug filtered at info level",
level: LevelInfo,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: false,
expectContains: []string{},
},
{
name: "error logged at info level",
level: LevelInfo,
logLevel: LevelError,
message: "error message",
fields: nil,
expectOutput: true,
expectContains: []string{"[ERROR]", "error message"},
},
{
name: "info with fields",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: map[string]interface{}{
"key1": "value1",
"key2": 123,
},
expectOutput: true,
expectContains: []string{"[INFO]", "test message", "key1", "value1"},
},
{
name: "warn logged at warn level",
level: LevelWarn,
logLevel: LevelWarn,
message: "warning message",
fields: nil,
expectOutput: true,
expectContains: []string{"[WARN]", "warning message"},
},
{
name: "info filtered at warn level",
level: LevelWarn,
logLevel: LevelInfo,
message: "info message",
fields: nil,
expectOutput: false,
expectContains: []string{},
},
{
name: "debug logged at debug level",
level: LevelDebug,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: true,
expectContains: []string{"[DEBUG]", "debug message"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.level, FormatText, buf)
// Log at the specified level
switch tt.logLevel {
case LevelDebug:
if tt.fields != nil {
logger.Debug(tt.message, tt.fields)
} else {
logger.Debug(tt.message)
}
case LevelInfo:
if tt.fields != nil {
logger.Info(tt.message, tt.fields)
} else {
logger.Info(tt.message)
}
case LevelWarn:
if tt.fields != nil {
logger.Warn(tt.message, tt.fields)
} else {
logger.Warn(tt.message)
}
case LevelError:
if tt.fields != nil {
logger.Error(tt.message, tt.fields)
} else {
logger.Error(tt.message)
}
}
output := buf.String()
if tt.expectOutput {
assert.NotEmpty(t, output, "Expected log output but got none")
for _, expected := range tt.expectContains {
assert.Contains(t, output, expected, "Expected output to contain: %s", expected)
}
} else {
assert.Empty(t, output, "Expected no log output but got: %s", output)
}
})
}
}
func TestLoggerJSONFormat(t *testing.T) {
tests := []struct {
name string
level Level
logLevel Level
message string
fields map[string]interface{}
expectOutput bool
expectLevel string
}{
{
name: "info logged at info level",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: nil,
expectOutput: true,
expectLevel: "INFO",
},
{
name: "debug filtered at info level",
level: LevelInfo,
logLevel: LevelDebug,
message: "debug message",
fields: nil,
expectOutput: false,
expectLevel: "",
},
{
name: "error logged at debug level",
level: LevelDebug,
logLevel: LevelError,
message: "error message",
fields: nil,
expectOutput: true,
expectLevel: "ERROR",
},
{
name: "info with fields",
level: LevelInfo,
logLevel: LevelInfo,
message: "test message",
fields: map[string]interface{}{
"context": "production",
"port": 8080,
"retry": 3,
},
expectOutput: true,
expectLevel: "INFO",
},
{
name: "warn at warn level",
level: LevelWarn,
logLevel: LevelWarn,
message: "warning message",
fields: nil,
expectOutput: true,
expectLevel: "WARN",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.level, FormatJSON, buf)
// Log at the specified level
switch tt.logLevel {
case LevelDebug:
if tt.fields != nil {
logger.Debug(tt.message, tt.fields)
} else {
logger.Debug(tt.message)
}
case LevelInfo:
if tt.fields != nil {
logger.Info(tt.message, tt.fields)
} else {
logger.Info(tt.message)
}
case LevelWarn:
if tt.fields != nil {
logger.Warn(tt.message, tt.fields)
} else {
logger.Warn(tt.message)
}
case LevelError:
if tt.fields != nil {
logger.Error(tt.message, tt.fields)
} else {
logger.Error(tt.message)
}
}
output := buf.String()
if tt.expectOutput {
assert.NotEmpty(t, output, "Expected log output but got none")
// Parse JSON
var entry logEntry
err := json.Unmarshal([]byte(strings.TrimSpace(output)), &entry)
require.NoError(t, err, "Failed to parse JSON output: %s", output)
// Validate fields
assert.Equal(t, tt.expectLevel, entry.Level)
assert.Equal(t, tt.message, entry.Message)
assert.NotEmpty(t, entry.Time, "Time field should not be empty")
// Validate custom fields if provided
if tt.fields != nil {
require.NotNil(t, entry.Fields, "Expected fields in JSON output")
for key, expectedValue := range tt.fields {
actualValue, exists := entry.Fields[key]
assert.True(t, exists, "Expected field %s not found in output", key)
// JSON unmarshaling converts numbers to float64
if floatVal, ok := expectedValue.(int); ok {
assert.Equal(t, float64(floatVal), actualValue)
} else {
assert.Equal(t, expectedValue, actualValue)
}
}
}
} else {
assert.Empty(t, output, "Expected no log output but got: %s", output)
}
})
}
}
func TestGlobalLogger(t *testing.T) {
tests := []struct {
name string
initLevel Level
initFormat Format
logFunc func(string, ...map[string]interface{})
message string
expectContains string
}{
{
name: "global info logger text",
initLevel: LevelInfo,
initFormat: FormatText,
logFunc: Info,
message: "global info message",
expectContains: "[INFO]",
},
{
name: "global error logger text",
initLevel: LevelInfo,
initFormat: FormatText,
logFunc: Error,
message: "global error message",
expectContains: "[ERROR]",
},
{
name: "global warn logger json",
initLevel: LevelWarn,
initFormat: FormatJSON,
logFunc: Warn,
message: "global warn message",
expectContains: `"level":"WARN"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Capture stderr by replacing globalLogger's output
buf := &bytes.Buffer{}
Init(tt.initLevel, tt.initFormat)
globalLogger.output = buf
// Call the global log function
tt.logFunc(tt.message)
output := buf.String()
assert.Contains(t, output, tt.expectContains)
assert.Contains(t, output, tt.message)
})
}
}
func TestLogLevelsFiltering(t *testing.T) {
tests := []struct {
name string
loggerLevel Level
logAtLevels []Level
expectOutputs []bool
}{
{
name: "debug level logs everything",
loggerLevel: LevelDebug,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{true, true, true, true},
},
{
name: "info level filters debug",
loggerLevel: LevelInfo,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, true, true, true},
},
{
name: "warn level filters debug and info",
loggerLevel: LevelWarn,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, false, true, true},
},
{
name: "error level only logs errors",
loggerLevel: LevelError,
logAtLevels: []Level{LevelDebug, LevelInfo, LevelWarn, LevelError},
expectOutputs: []bool{false, false, false, true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(tt.loggerLevel, FormatText, buf)
for i, logLevel := range tt.logAtLevels {
buf.Reset()
switch logLevel {
case LevelDebug:
logger.Debug("test")
case LevelInfo:
logger.Info("test")
case LevelWarn:
logger.Warn("test")
case LevelError:
logger.Error("test")
}
hasOutput := buf.Len() > 0
assert.Equal(t, tt.expectOutputs[i], hasOutput,
"Level %v at logger level %v: expected output=%v, got=%v",
logLevel, tt.loggerLevel, tt.expectOutputs[i], hasOutput)
}
})
}
}
func TestLoggerNilOutput(t *testing.T) {
// Test that logger defaults to os.Stderr when output is nil
logger := New(LevelInfo, FormatText, nil)
assert.NotNil(t, logger.output, "Logger output should not be nil")
}
func TestLevelToString(t *testing.T) {
tests := []struct {
level Level
expected string
}{
{LevelDebug, "DEBUG"},
{LevelInfo, "INFO"},
{LevelWarn, "WARN"},
{LevelError, "ERROR"},
{Level(999), "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.expected, func(t *testing.T) {
result := levelToString(tt.level)
assert.Equal(t, tt.expected, result)
})
}
}
func TestJSONFieldTypes(t *testing.T) {
tests := []struct {
name string
fields map[string]interface{}
}{
{
name: "string fields",
fields: map[string]interface{}{
"key1": "value1",
"key2": "value2",
},
},
{
name: "numeric fields",
fields: map[string]interface{}{
"port": 8080,
"timeout": 30,
"retry": 3,
},
},
{
name: "boolean fields",
fields: map[string]interface{}{
"enabled": true,
"running": false,
},
},
{
name: "mixed types",
fields: map[string]interface{}{
"context": "production",
"port": 8080,
"enabled": true,
"namespace": "default",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := &bytes.Buffer{}
logger := New(LevelInfo, FormatJSON, buf)
logger.Info("test message", tt.fields)
var entry logEntry
err := json.Unmarshal([]byte(strings.TrimSpace(buf.String())), &entry)
require.NoError(t, err)
assert.Equal(t, len(tt.fields), len(entry.Fields),
"Field count mismatch")
for key := range tt.fields {
_, exists := entry.Fields[key]
assert.True(t, exists, "Field %s not found in JSON output", key)
}
})
}
}
func TestInitWithCustomOutput(t *testing.T) {
tests := []struct {
name string
output io.Writer
expectDiscard bool
description string
}{
{
name: "init with custom buffer",
output: &bytes.Buffer{},
expectDiscard: false,
description: "Should use provided buffer",
},
{
name: "init with io.Discard",
output: io.Discard,
expectDiscard: true,
description: "Should use io.Discard to silence output",
},
{
name: "init without output defaults to stderr",
output: nil,
expectDiscard: false,
description: "Should default to stderr when no output provided",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.output != nil {
Init(LevelInfo, FormatText, tt.output)
} else {
Init(LevelInfo, FormatText)
}
// Verify global logger was initialized
assert.NotNil(t, globalLogger, "Global logger should be initialized")
if tt.output != nil && !tt.expectDiscard {
// For buffer, verify output works
if buf, ok := tt.output.(*bytes.Buffer); ok {
Info("test message")
output := buf.String()
assert.Contains(t, output, "test message")
assert.Contains(t, output, "[INFO]")
}
} else if tt.expectDiscard {
// For io.Discard, verify no output appears (we can't really test this directly,
// but we can verify the logger was set with the right output)
assert.Equal(t, io.Discard, globalLogger.output)
}
})
}
}
+220
View File
@@ -0,0 +1,220 @@
package mdns
import (
"fmt"
"net"
"sync"
"time"
"github.com/grandcat/zeroconf"
"github.com/nvm/kportal/internal/logger"
)
const (
// shutdownTimeout is the maximum time to wait for mDNS server shutdown
shutdownTimeout = 2 * time.Second
// mdnsDomain is the standard mDNS domain (RFC 6762)
// This is always ".local" for multicast DNS - it's not configurable
// and is different from your network's DNS search domain
mdnsDomain = "local"
)
// Publisher manages mDNS hostname registrations for port forwards.
// It allows forwards with aliases to be accessible via <alias>.local hostnames.
type Publisher struct {
mu sync.RWMutex
servers map[string]*zeroconf.Server // forwardID -> server
aliases map[string]string // forwardID -> alias (for logging)
enabled bool
localIPs []string
}
// NewPublisher creates a new mDNS Publisher.
// If enabled is false, all registration calls will be no-ops.
func NewPublisher(enabled bool) *Publisher {
p := &Publisher{
servers: make(map[string]*zeroconf.Server),
aliases: make(map[string]string),
enabled: enabled,
localIPs: getLocalIPs(),
}
if enabled {
logger.Info("mDNS publisher initialized", map[string]interface{}{
"domain": mdnsDomain,
"local_ips": p.localIPs,
})
}
return p
}
// Register publishes an mDNS hostname for a forward.
// The hostname will be <alias>.local and will resolve to 127.0.0.1.
// If the forward has no alias or mDNS is disabled, this is a no-op.
func (p *Publisher) Register(forwardID, alias string, localPort int) error {
if !p.enabled || alias == "" {
return nil
}
p.mu.Lock()
defer p.mu.Unlock()
// Check if already registered
if _, exists := p.servers[forwardID]; exists {
logger.Debug("mDNS hostname already registered", map[string]interface{}{
"forward_id": forwardID,
"alias": alias,
})
return nil
}
// Register the mDNS service
// We use a generic service type and rely on the hostname registration
server, err := zeroconf.RegisterProxy(
alias, // Instance name (shown in service discovery)
"_kportal._tcp", // Service type (custom for kportal)
"local.", // Domain
localPort, // Port
alias, // Hostname (will be <alias>.local)
[]string{"127.0.0.1"}, // IPs to resolve to
[]string{fmt.Sprintf("forward=%s", forwardID)}, // TXT records
nil, // interfaces (nil = all)
)
if err != nil {
return fmt.Errorf("failed to register mDNS for %s: %w", alias, err)
}
p.servers[forwardID] = server
p.aliases[forwardID] = alias
logger.Info("mDNS hostname registered", map[string]interface{}{
"forward_id": forwardID,
"hostname": GetHostname(alias),
"port": localPort,
})
return nil
}
// Unregister removes the mDNS hostname for a forward.
func (p *Publisher) Unregister(forwardID string) {
if !p.enabled {
return
}
p.mu.Lock()
defer p.mu.Unlock()
server, exists := p.servers[forwardID]
if !exists {
return
}
alias := p.aliases[forwardID]
shutdownWithTimeout(server, forwardID)
delete(p.servers, forwardID)
delete(p.aliases, forwardID)
logger.Info("mDNS hostname unregistered", map[string]interface{}{
"forward_id": forwardID,
"hostname": GetHostname(alias),
})
}
// Stop shuts down all mDNS registrations.
func (p *Publisher) Stop() {
if !p.enabled {
return
}
p.mu.Lock()
defer p.mu.Unlock()
// Shutdown all servers concurrently with timeout
var wg sync.WaitGroup
for forwardID, server := range p.servers {
wg.Add(1)
go func(id string, srv *zeroconf.Server) {
defer wg.Done()
shutdownWithTimeout(srv, id)
}(forwardID, server)
}
// Wait for all shutdowns to complete (or timeout)
wg.Wait()
p.servers = make(map[string]*zeroconf.Server)
p.aliases = make(map[string]string)
logger.Info("mDNS publisher stopped", nil)
}
// shutdownWithTimeout attempts to shutdown a zeroconf server with a timeout.
// If shutdown hangs, it logs a warning and returns anyway.
func shutdownWithTimeout(server *zeroconf.Server, forwardID string) {
done := make(chan struct{})
go func() {
server.Shutdown()
close(done)
}()
select {
case <-done:
// Shutdown completed successfully
case <-time.After(shutdownTimeout):
logger.Warn("mDNS shutdown timed out, continuing anyway", map[string]interface{}{
"forward_id": forwardID,
"timeout": shutdownTimeout.String(),
})
}
}
// IsEnabled returns whether mDNS publishing is enabled.
func (p *Publisher) IsEnabled() bool {
return p.enabled
}
// GetDomain returns the mDNS domain being used (always "local" per RFC 6762).
func (p *Publisher) GetDomain() string {
return mdnsDomain
}
// GetHostname returns the full mDNS hostname for an alias.
// Example: GetHostname("myapp") returns "myapp.local"
func GetHostname(alias string) string {
return alias + "." + mdnsDomain
}
// GetRegisteredCount returns the number of currently registered hostnames.
func (p *Publisher) GetRegisteredCount() int {
p.mu.RLock()
defer p.mu.RUnlock()
return len(p.servers)
}
// getLocalIPs returns the local IP addresses for logging purposes.
func getLocalIPs() []string {
var ips []string
addrs, err := net.InterfaceAddrs()
if err != nil {
return []string{"127.0.0.1"}
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
ips = append(ips, ipnet.IP.String())
}
}
}
if len(ips) == 0 {
return []string{"127.0.0.1"}
}
return ips
}
+154
View File
@@ -0,0 +1,154 @@
package mdns
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Note: Tests that actually register mDNS services require network I/O
// and can be slow or hang in CI environments. We test the logic paths
// without actually calling zeroconf for most tests.
func TestNewPublisher_Disabled(t *testing.T) {
p := NewPublisher(false)
assert.False(t, p.IsEnabled())
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestNewPublisher_Enabled(t *testing.T) {
p := NewPublisher(true)
assert.True(t, p.IsEnabled())
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestRegister_WhenDisabled_NoOp(t *testing.T) {
p := NewPublisher(false)
err := p.Register("forward-1", "test-alias", 8080)
assert.NoError(t, err)
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestRegister_EmptyAlias_NoOp(t *testing.T) {
p := NewPublisher(true)
err := p.Register("forward-1", "", 8080)
assert.NoError(t, err)
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestUnregister_WhenDisabled_NoOp(t *testing.T) {
p := NewPublisher(false)
// Should not panic
p.Unregister("forward-1")
}
func TestUnregister_NotRegistered_NoOp(t *testing.T) {
p := NewPublisher(true)
// Should not panic
p.Unregister("non-existent")
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestStop_WhenDisabled_NoOp(t *testing.T) {
p := NewPublisher(false)
// Should not panic
p.Stop()
}
func TestStop_WhenNoRegistrations(t *testing.T) {
p := NewPublisher(true)
// Should not panic
p.Stop()
assert.Equal(t, 0, p.GetRegisteredCount())
}
func TestGetLocalIPs(t *testing.T) {
ips := getLocalIPs()
// Should return at least one IP
assert.NotEmpty(t, ips, "getLocalIPs should return at least one IP")
// All IPs should be non-empty strings
for _, ip := range ips {
assert.NotEmpty(t, ip, "IP address should not be empty")
}
}
// Integration tests - only run when explicitly requested
// These tests actually register mDNS services and require network access
func TestRegister_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping mDNS integration test in short mode")
}
p := NewPublisher(true)
defer p.Stop()
err := p.Register("forward-1", "test-service", 8080)
assert.NoError(t, err)
assert.Equal(t, 1, p.GetRegisteredCount())
}
func TestRegister_Duplicate_Idempotent_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping mDNS integration test in short mode")
}
p := NewPublisher(true)
defer p.Stop()
// First registration
err := p.Register("forward-1", "test-service", 8080)
assert.NoError(t, err)
assert.Equal(t, 1, p.GetRegisteredCount())
// Second registration with same ID should be idempotent
err = p.Register("forward-1", "test-service", 8080)
assert.NoError(t, err)
assert.Equal(t, 1, p.GetRegisteredCount())
}
func TestRegister_MultipleForwards_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping mDNS integration test in short mode")
}
p := NewPublisher(true)
defer p.Stop()
err1 := p.Register("forward-1", "service-a", 8080)
err2 := p.Register("forward-2", "service-b", 8081)
err3 := p.Register("forward-3", "service-c", 8082)
assert.NoError(t, err1)
assert.NoError(t, err2)
assert.NoError(t, err3)
assert.Equal(t, 3, p.GetRegisteredCount())
}
func TestUnregister_Success_Integration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping mDNS integration test in short mode")
}
p := NewPublisher(true)
defer p.Stop()
p.Register("forward-1", "test-service", 8080)
assert.Equal(t, 1, p.GetRegisteredCount())
p.Unregister("forward-1")
assert.Equal(t, 0, p.GetRegisteredCount())
}
-5
View File
@@ -67,8 +67,3 @@ func (b *Backoff) calculateJitter(delay time.Duration) time.Duration {
jitter := (b.rng.Float64()*2 - 1) * maxJitter
return time.Duration(jitter)
}
// Sleep waits for the next backoff duration.
func (b *Backoff) Sleep() {
time.Sleep(b.Next())
}
+5 -3
View File
@@ -158,10 +158,12 @@ func TestBackoff_ExponentialProgression(t *testing.T) {
// We allow for jitter by checking a range
for i := 1; i < len(delays)-1; i++ {
// Each delay should be roughly double the previous (accounting for jitter)
// With 10% jitter on each value, worst case: (2.0 * 1.1) / 0.9 = 2.44
// We use 1.7x to 2.5x as a reasonable range with 10% jitter on each
// With 10% jitter on each value:
// Lower bound: (2.0 * 0.9) / 1.1 ≈ 1.636
// Upper bound: (2.0 * 1.1) / 0.9 ≈ 2.444
// We use 1.6x to 2.5x as a reasonable range to account for jitter variance
ratio := float64(delays[i]) / float64(delays[i-1])
assert.GreaterOrEqual(t, ratio, 1.7, "exponential growth should be ~2x")
assert.GreaterOrEqual(t, ratio, 1.6, "exponential growth should be ~2x")
assert.LessOrEqual(t, ratio, 2.5, "exponential growth should be ~2x")
}
}
+306 -45
View File
@@ -8,6 +8,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
// ForwardUpdateMsg is sent when a forward status changes
@@ -44,11 +45,34 @@ type BubbleTeaUI struct {
toggleCallback func(id string, enable bool)
version string
errors map[string]string // Track error messages by forward ID
// Update notification
updateAvailable bool
updateVersion string
updateURL string
// Modal wizard state
viewMode ViewMode
addWizard *AddWizardState
removeWizard *RemoveWizardState
// Delete confirmation state
deleteConfirming bool
deleteConfirmID string
deleteConfirmAlias string
deleteConfirmCursor int // 0 = Yes, 1 = No
// Dependencies for wizards
discovery *k8s.Discovery
mutator *config.Mutator
configPath string
}
// bubbletea model
type model struct {
ui *BubbleTeaUI
ui *BubbleTeaUI
termWidth int
termHeight int
}
// NewBubbleTeaUI creates a new bubbletea-based UI
@@ -61,11 +85,32 @@ func NewBubbleTeaUI(toggleCallback func(id string, enable bool), version string)
toggleCallback: toggleCallback,
version: version,
errors: make(map[string]string),
viewMode: ViewModeMain,
}
return ui
}
// SetWizardDependencies sets the dependencies needed for the add/remove wizards
func (ui *BubbleTeaUI) SetWizardDependencies(discovery *k8s.Discovery, mutator *config.Mutator, configPath string) {
ui.mu.Lock()
defer ui.mu.Unlock()
ui.discovery = discovery
ui.mutator = mutator
ui.configPath = configPath
}
// SetUpdateAvailable sets the update notification to be displayed
func (ui *BubbleTeaUI) SetUpdateAvailable(version, url string) {
ui.mu.Lock()
defer ui.mu.Unlock()
ui.updateAvailable = true
ui.updateVersion = version
ui.updateURL = url
}
// Start starts the bubbletea application
func (ui *BubbleTeaUI) Start() error {
m := model{ui: ui}
@@ -139,8 +184,9 @@ func (ui *BubbleTeaUI) UpdateStatus(id string, status string) {
if fwd, ok := ui.forwards[id]; ok {
fwd.Status = status
}
// Clear error if status is not Error
if status != "Error" {
// Only clear error when forward becomes Active again
// This keeps error visible during Reconnecting/Starting states
if status == "Active" {
delete(ui.errors, id)
}
ui.mu.Unlock()
@@ -187,45 +233,113 @@ func (m model) Init() tea.Cmd {
}
func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.ui.mu.RLock()
viewMode := m.ui.viewMode
m.ui.mu.RUnlock()
switch msg := msg.(type) {
case tea.WindowSizeMsg:
// Update terminal dimensions on resize
m.termWidth = msg.Width
m.termHeight = msg.Height
return m, nil
case tea.KeyMsg:
switch msg.String() {
case "ctrl+c", "q":
return m, tea.Quit
case "up", "k":
m.ui.moveSelection(-1)
case "down", "j":
m.ui.moveSelection(1)
case " ", "enter":
m.ui.toggleSelected()
// Route based on current view mode
switch viewMode {
case ViewModeMain:
return m.handleMainViewKeys(msg)
case ViewModeAddWizard:
return m.handleAddWizardKeys(msg)
case ViewModeRemoveWizard:
return m.handleRemoveWizardKeys(msg)
}
case ForwardAddMsg:
// Already handled in AddForward, just trigger re-render
// Forward management messages (always update main view data)
case ForwardAddMsg, ForwardUpdateMsg, ForwardErrorMsg, ForwardRemoveMsg:
return m, nil
case ForwardUpdateMsg:
// Already handled in UpdateStatus, just trigger re-render
return m, nil
case ForwardErrorMsg:
// Already handled in SetError, just trigger re-render
return m, nil
case ForwardRemoveMsg:
// Already handled in Remove, just trigger re-render
return m, nil
// Wizard-specific messages
case ContextsLoadedMsg:
return m.handleContextsLoaded(msg)
case NamespacesLoadedMsg:
return m.handleNamespacesLoaded(msg)
case PodsLoadedMsg:
return m.handlePodsLoaded(msg)
case ServicesLoadedMsg:
return m.handleServicesLoaded(msg)
case SelectorValidatedMsg:
return m.handleSelectorValidated(msg)
case PortCheckedMsg:
return m.handlePortChecked(msg)
case ForwardSavedMsg:
return m.handleForwardSaved(msg)
case ForwardsRemovedMsg:
return m.handleForwardsRemoved(msg)
case WizardCompleteMsg:
m.ui.mu.Lock()
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
m.ui.removeWizard = nil
m.ui.mu.Unlock()
return m, tea.ClearScreen
}
return m, nil
}
func (m model) View() string {
m.ui.mu.RLock()
viewMode := m.ui.viewMode
deleteConfirming := m.ui.deleteConfirming
m.ui.mu.RUnlock()
// Always render main view as base
mainView := m.renderMainView()
// Use actual terminal dimensions for proper centering
termWidth := m.termWidth
termHeight := m.termHeight
// Fallback to reasonable defaults if dimensions not yet received
if termWidth == 0 {
termWidth = 120
}
if termHeight == 0 {
termHeight = 40
}
// Overlay delete confirmation if active
if deleteConfirming {
modal := m.renderDeleteConfirmation()
return overlayContent(mainView, modal, termWidth, termHeight)
}
// Overlay wizard if active
switch viewMode {
case ViewModeAddWizard:
modal := m.renderAddWizard()
return overlayContent(mainView, modal, termWidth, termHeight)
case ViewModeRemoveWizard:
modal := m.renderRemoveWizard()
return overlayContent(mainView, modal, termWidth, termHeight)
default:
return mainView
}
}
func (m model) renderMainView() string {
m.ui.mu.RLock()
defer m.ui.mu.RUnlock()
var b strings.Builder
// Get terminal dimensions for proper sizing
termHeight := m.termHeight
if termHeight == 0 {
termHeight = 40 // Fallback
}
// Styles
titleStyle := lipgloss.NewStyle().
Bold(true).
@@ -258,6 +372,15 @@ func (m model) View() string {
// Title with version
title := fmt.Sprintf("kportal v%s - Port Forwarding Status", m.ui.version)
b.WriteString(titleStyle.Render(title))
// Show update notification if available
if m.ui.updateAvailable {
updateStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("42")). // Green
Bold(true)
updateMsg := fmt.Sprintf(" Update available: v%s", m.ui.updateVersion)
b.WriteString(updateStyle.Render(updateMsg))
}
b.WriteString("\n\n")
// Header
@@ -280,7 +403,7 @@ func (m model) View() string {
}
isSelected := (idx == m.ui.selectedIndex)
isDisabled := m.ui.disabledMap[id]
isDisabled := m.ui.disabledMap[id] || fwd.Status == "Disabled"
// Selection indicator
indicator := " "
@@ -350,21 +473,7 @@ func (m model) View() string {
}
}
// Footer
b.WriteString("\n")
footerStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
keyStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("220"))
footer := fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: Quit │ Total: %d",
keyStyle.Render("↑↓"),
keyStyle.Render("jk"),
keyStyle.Render("Space"),
keyStyle.Render("q"),
len(m.ui.forwardOrder))
b.WriteString(footerStyle.Render(footer))
// Display errors if any
// Display errors if any (before footer)
if len(m.ui.errors) > 0 {
b.WriteString("\n\n")
errorHeaderStyle := lipgloss.NewStyle().
@@ -374,20 +483,104 @@ func (m model) View() string {
b.WriteString(errorHeaderStyle.Render("Errors:"))
b.WriteString("\n")
errorLineStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("196")).
Width(118). // Slightly less than table width (120) for padding
MaxWidth(118)
for id, errMsg := range m.ui.errors {
// Find the forward to display its alias
if fwd, ok := m.ui.forwards[id]; ok {
errorLineStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("196"))
line := fmt.Sprintf(" • %s: %s", fwd.Alias, errMsg)
b.WriteString(errorLineStyle.Render(line))
b.WriteString("\n")
// Format: " • alias: error message"
prefix := fmt.Sprintf(" • %s: ", fwd.Alias)
// Wrap the error message if it's too long
// Max line length is 118, subtract prefix length
maxErrLen := 118 - len(prefix)
wrappedMsg := wrapText(errMsg, maxErrLen)
// Render first line with prefix
lines := strings.Split(wrappedMsg, "\n")
if len(lines) > 0 {
b.WriteString(errorLineStyle.Render(prefix + lines[0]))
b.WriteString("\n")
// Render subsequent lines with indentation
indent := strings.Repeat(" ", len(prefix))
for i := 1; i < len(lines); i++ {
b.WriteString(errorLineStyle.Render(indent + lines[i]))
b.WriteString("\n")
}
}
}
}
}
// Calculate current content height
currentContent := b.String()
currentLines := strings.Count(currentContent, "\n") + 1
// Footer styles
footerStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("240"))
keyStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("220"))
footer := fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: New %s: Edit %s: Delete %s: Quit │ Total: %d",
keyStyle.Render("↑↓"),
keyStyle.Render("jk"),
keyStyle.Render("Space"),
keyStyle.Render("n"),
keyStyle.Render("e"),
keyStyle.Render("d"),
keyStyle.Render("q"),
len(m.ui.forwardOrder))
// Fill space to push footer to bottom (reserve 2 lines: 1 for spacing, 1 for footer)
footerHeight := 2
remainingLines := termHeight - currentLines - footerHeight
if remainingLines > 0 {
b.WriteString(strings.Repeat("\n", remainingLines))
}
// Add footer at bottom
b.WriteString("\n")
b.WriteString(footerStyle.Render(footer))
return b.String()
}
// wrapText wraps text to the specified width, breaking at word boundaries
func wrapText(text string, width int) string {
if len(text) <= width {
return text
}
var result strings.Builder
var line strings.Builder
words := strings.Fields(text)
for i, word := range words {
// If adding this word would exceed width, start new line
if line.Len()+len(word)+1 > width && line.Len() > 0 {
result.WriteString(line.String())
result.WriteString("\n")
line.Reset()
}
// Add space before word (except first word on line)
if line.Len() > 0 {
line.WriteString(" ")
}
line.WriteString(word)
// Last word - flush the line
if i == len(words)-1 {
result.WriteString(line.String())
}
}
return result.String()
}
// moveSelection moves the selection up or down
func (ui *BubbleTeaUI) moveSelection(delta int) {
ui.mu.Lock()
@@ -406,6 +599,74 @@ func (ui *BubbleTeaUI) moveSelection(delta int) {
}
}
// resetDeleteConfirmation resets the delete confirmation dialog state.
// Caller must hold ui.mu lock.
func (ui *BubbleTeaUI) resetDeleteConfirmation() {
ui.deleteConfirming = false
ui.deleteConfirmID = ""
ui.deleteConfirmAlias = ""
ui.deleteConfirmCursor = 0
}
// renderDeleteConfirmation renders the delete confirmation dialog
func (m model) renderDeleteConfirmation() string {
m.ui.mu.RLock()
defer m.ui.mu.RUnlock()
var b strings.Builder
// Use wizard color palette for consistency
titleStyle := lipgloss.NewStyle().
Bold(true).
Foreground(warningColor). // Yellow for warning (delete action)
Padding(0, 1)
buttonSelectedStyle := lipgloss.NewStyle().
Background(primaryColor). // Pink/Magenta background
Foreground(lipgloss.Color("230")). // Light yellow text
Bold(true).
Padding(0, 1)
buttonUnselectedStyle := lipgloss.NewStyle().
Foreground(mutedColor). // Gray
Padding(0, 1)
deleteInfoStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252")). // Light gray for info text
Italic(true)
// Title
b.WriteString(titleStyle.Render("⚠ Delete Port Forward"))
b.WriteString("\n\n")
// Message
b.WriteString("Are you sure you want to delete:\n\n")
b.WriteString(deleteInfoStyle.Render(" " + m.ui.deleteConfirmAlias))
b.WriteString("\n\n")
// Buttons
if m.ui.deleteConfirmCursor == 0 {
b.WriteString(buttonSelectedStyle.Render(" Yes "))
b.WriteString(" ")
b.WriteString(buttonUnselectedStyle.Render(" No "))
} else {
b.WriteString(buttonUnselectedStyle.Render(" Yes "))
b.WriteString(" ")
b.WriteString(buttonSelectedStyle.Render(" No "))
}
b.WriteString("\n\n")
b.WriteString(helpStyle.Render("←/→: Navigate Enter: Confirm Esc: Cancel"))
// Wrap in a box using wizard style
boxStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor). // Purple border like other wizards
Padding(1, 2)
return boxStyle.Render(b.String())
}
// toggleSelected toggles the selected forward on/off
func (ui *BubbleTeaUI) toggleSelected() {
ui.mu.Lock()
-181
View File
@@ -1,181 +0,0 @@
package ui
import (
"fmt"
"os"
"sync"
"golang.org/x/term"
)
// InteractiveController handles keyboard input and selection state
type InteractiveController struct {
mu sync.RWMutex
selectedIndex int
forwardIDs []string // Ordered list of forward IDs
disabledMap map[string]bool // Tracks which forwards are disabled
toggleCallback func(id string, enable bool)
enabled bool
oldTermState *term.State
}
// NewInteractiveController creates a new interactive controller
func NewInteractiveController(toggleCallback func(id string, enable bool)) *InteractiveController {
return &InteractiveController{
selectedIndex: 0,
forwardIDs: make([]string, 0),
disabledMap: make(map[string]bool),
toggleCallback: toggleCallback,
enabled: false,
}
}
// Enable puts the terminal in raw mode for keyboard input
func (ic *InteractiveController) Enable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.enabled {
return nil
}
// Save current terminal state
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return fmt.Errorf("failed to enable raw mode: %w", err)
}
ic.oldTermState = oldState
ic.enabled = true
return nil
}
// Disable restores the terminal to normal mode
func (ic *InteractiveController) Disable() error {
ic.mu.Lock()
defer ic.mu.Unlock()
if !ic.enabled {
return nil
}
if ic.oldTermState != nil {
if err := term.Restore(int(os.Stdin.Fd()), ic.oldTermState); err != nil {
return fmt.Errorf("failed to restore terminal: %w", err)
}
}
ic.enabled = false
return nil
}
// UpdateForwardsList updates the list of forwards for navigation
func (ic *InteractiveController) UpdateForwardsList(ids []string) {
ic.mu.Lock()
defer ic.mu.Unlock()
ic.forwardIDs = ids
// Ensure selected index is valid
if ic.selectedIndex >= len(ic.forwardIDs) {
ic.selectedIndex = len(ic.forwardIDs) - 1
}
if ic.selectedIndex < 0 && len(ic.forwardIDs) > 0 {
ic.selectedIndex = 0
}
}
// MoveUp moves selection up
func (ic *InteractiveController) MoveUp() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex > 0 {
ic.selectedIndex--
}
}
// MoveDown moves selection down
func (ic *InteractiveController) MoveDown() {
ic.mu.Lock()
defer ic.mu.Unlock()
if ic.selectedIndex < len(ic.forwardIDs)-1 {
ic.selectedIndex++
}
}
// ToggleSelected toggles the enable/disable state of the selected forward
func (ic *InteractiveController) ToggleSelected() {
ic.mu.Lock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
ic.mu.Unlock()
return
}
selectedID := ic.forwardIDs[ic.selectedIndex]
currentlyDisabled := ic.disabledMap[selectedID]
newState := !currentlyDisabled
ic.disabledMap[selectedID] = newState
ic.mu.Unlock()
// Call the toggle callback
if ic.toggleCallback != nil {
ic.toggleCallback(selectedID, !newState) // enable is inverse of disabled
}
}
// GetSelectedIndex returns the current selection index
func (ic *InteractiveController) GetSelectedIndex() int {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.selectedIndex
}
// IsDisabled returns whether a forward is disabled
func (ic *InteractiveController) IsDisabled(id string) bool {
ic.mu.RLock()
defer ic.mu.RUnlock()
return ic.disabledMap[id]
}
// GetSelectedID returns the ID of the currently selected forward
func (ic *InteractiveController) GetSelectedID() string {
ic.mu.RLock()
defer ic.mu.RUnlock()
if ic.selectedIndex < 0 || ic.selectedIndex >= len(ic.forwardIDs) {
return ""
}
return ic.forwardIDs[ic.selectedIndex]
}
// HandleKey processes keyboard input and returns true if should continue
func (ic *InteractiveController) HandleKey(b []byte) bool {
if len(b) == 0 {
return true
}
// Handle single byte keys
if len(b) == 1 {
switch b[0] {
case 'q', 'Q', 3: // q, Q, or Ctrl+C
return false
case ' ', '\r': // Space or Enter to toggle
ic.ToggleSelected()
return true
}
}
// Handle escape sequences (arrow keys)
if len(b) == 3 && b[0] == 27 && b[1] == 91 {
switch b[2] {
case 65: // Up arrow
ic.MoveUp()
case 66: // Down arrow
ic.MoveDown()
}
}
return true
}
+7 -48
View File
@@ -23,10 +23,9 @@ type ForwardStatus struct {
// TableUI manages the terminal table display
type TableUI struct {
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
interactive *InteractiveController
mu sync.RWMutex
forwards map[string]*ForwardStatus // key is forward ID
verbose bool
}
// NewTableUI creates a new table UI manager
@@ -37,13 +36,6 @@ func NewTableUI(verbose bool) *TableUI {
}
}
// SetInteractiveController sets the interactive controller
func (t *TableUI) SetInteractiveController(ic *InteractiveController) {
t.mu.Lock()
defer t.mu.Unlock()
t.interactive = ic
}
// AddForward registers a new forward for display
func (t *TableUI) AddForward(id string, fwd *config.Forward) {
t.mu.Lock()
@@ -126,27 +118,10 @@ func (t *TableUI) Render() {
}
}
// Update interactive controller with current forward IDs (in display order)
if t.interactive != nil {
ids := make([]string, len(entries))
for i, entry := range entries {
ids[i] = entry.id
}
t.interactive.UpdateForwardsList(ids)
}
// Print each forward
for i, entry := range entries {
for _, entry := range entries {
fwd := entry.fwd
// Check if this row is selected
isSelected := false
isDisabled := false
if t.interactive != nil {
isSelected = (i == t.interactive.GetSelectedIndex())
isDisabled = t.interactive.IsDisabled(entry.id)
}
// Truncate long names
alias := truncate(fwd.Alias, 25)
resource := truncate(fwd.Resource, 25)
@@ -154,8 +129,8 @@ func (t *TableUI) Render() {
// Color code status with indicator
statusStr := formatStatusWithIndicator(fwd.Status)
// Build the row content
rowContent := fmt.Sprintf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s",
// Print the row
fmt.Printf(" %-15s %-18s %-25s %-10s %-25s %-12d %-12d %s\n",
fwd.Context,
fwd.Namespace,
alias,
@@ -164,26 +139,10 @@ func (t *TableUI) Render() {
fwd.RemotePort,
fwd.LocalPort,
statusStr)
// Apply selection highlighting or disabled styling
if isSelected {
// Replace leading spaces with arrow, then apply reverse video to entire line
rowContent = "\033[7m> " + rowContent[2:] + "\033[0m"
} else if isDisabled {
// Apply dimmed styling to entire line
rowContent = "\033[2m" + rowContent + "\033[0m"
}
fmt.Println(rowContent)
}
fmt.Println(strings.Repeat("=", 130))
helpText := "Total forwards: %d | ↑↓: Navigate | Space: Toggle | q: Quit"
if !t.verbose {
fmt.Printf(helpText+"\n", len(t.forwards))
} else {
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
}
fmt.Printf("Total forwards: %d | Press Ctrl+C to stop\n", len(t.forwards))
// In verbose mode, add a newline to separate from logs
if t.verbose {
+239
View File
@@ -0,0 +1,239 @@
package ui
import (
"context"
"fmt"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
const (
k8sAPITimeout = 10 * time.Second
)
// Messages sent from async commands back to the update loop
// ContextsLoadedMsg is sent when contexts have been loaded
type ContextsLoadedMsg struct {
contexts []string
err error
}
// NamespacesLoadedMsg is sent when namespaces have been loaded
type NamespacesLoadedMsg struct {
namespaces []string
err error
}
// PodsLoadedMsg is sent when pods have been loaded
type PodsLoadedMsg struct {
pods []k8s.PodInfo
err error
}
// ServicesLoadedMsg is sent when services have been loaded
type ServicesLoadedMsg struct {
services []k8s.ServiceInfo
err error
}
// SelectorValidatedMsg is sent when a selector has been validated
type SelectorValidatedMsg struct {
valid bool
pods []k8s.PodInfo
err error
}
// PortCheckedMsg is sent when a port's availability has been checked
type PortCheckedMsg struct {
port int
available bool
message string
}
// ForwardSavedMsg is sent when a forward has been saved to config
type ForwardSavedMsg struct {
success bool
err error
}
// ForwardsRemovedMsg is sent when forwards have been removed from config
type ForwardsRemovedMsg struct {
success bool
count int
err error
}
// WizardCompleteMsg signals that the wizard has completed
type WizardCompleteMsg struct{}
// Command functions (return tea.Cmd)
// loadContextsCmd loads available Kubernetes contexts
func loadContextsCmd(discovery *k8s.Discovery) tea.Cmd {
return func() tea.Msg {
contexts, err := discovery.ListContexts()
if err != nil {
return ContextsLoadedMsg{err: err}
}
return ContextsLoadedMsg{contexts: contexts}
}
}
// loadNamespacesCmd loads namespaces for the given context
func loadNamespacesCmd(discovery *k8s.Discovery, contextName string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
namespaces, err := discovery.ListNamespaces(ctx, contextName)
if err != nil {
return NamespacesLoadedMsg{err: err}
}
return NamespacesLoadedMsg{namespaces: namespaces}
}
}
// loadPodsCmd loads pods for the given context and namespace
func loadPodsCmd(discovery *k8s.Discovery, contextName, namespace string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
pods, err := discovery.ListPods(ctx, contextName, namespace)
if err != nil {
return PodsLoadedMsg{err: err}
}
return PodsLoadedMsg{pods: pods}
}
}
// loadServicesCmd loads services for the given context and namespace
func loadServicesCmd(discovery *k8s.Discovery, contextName, namespace string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
services, err := discovery.ListServices(ctx, contextName, namespace)
if err != nil {
return ServicesLoadedMsg{err: err}
}
return ServicesLoadedMsg{services: services}
}
}
// validateSelectorCmd validates a label selector and returns matching pods
func validateSelectorCmd(discovery *k8s.Discovery, contextName, namespace, selector string) tea.Cmd {
return func() tea.Msg {
ctx, cancel := context.WithTimeout(context.Background(), k8sAPITimeout)
defer cancel()
pods, err := discovery.ListPodsWithSelector(ctx, contextName, namespace, selector)
if err != nil {
return SelectorValidatedMsg{valid: false, err: err}
}
return SelectorValidatedMsg{
valid: len(pods) > 0,
pods: pods,
}
}
}
// checkPortCmd checks if a local port is available
func checkPortCmd(port int, configPath string) tea.Cmd {
return func() tea.Msg {
// First check if port is already in the configuration
cfg, err := config.LoadConfig(configPath)
if err == nil {
// Check all forwards in config for this port
allForwards := cfg.GetAllForwards()
for _, fwd := range allForwards {
if fwd.LocalPort == port {
return PortCheckedMsg{
port: port,
available: false,
message: fmt.Sprintf("✗ Port %d already assigned to %s", port, fwd.ID()),
}
}
}
}
// Then check if port is available at OS level
available, processInfo, err := k8s.CheckPortAvailability(port)
msg := ""
if err != nil {
msg = fmt.Sprintf("✗ Error: %v", err)
} else if available {
msg = fmt.Sprintf("✓ Port %d available", port)
} else {
msg = fmt.Sprintf("✗ Port %d in use by %s", port, processInfo)
}
return PortCheckedMsg{
port: port,
available: available,
message: msg,
}
}
}
// saveForwardCmd saves a new forward to the configuration file
func saveForwardCmd(mutator *config.Mutator, contextName, namespace string, fwd config.Forward) tea.Cmd {
return func() tea.Msg {
err := mutator.AddForward(contextName, namespace, fwd)
return ForwardSavedMsg{
success: err == nil,
err: err,
}
}
}
// updateForwardCmd atomically updates an existing forward (used in edit mode)
func updateForwardCmd(mutator *config.Mutator, oldID, contextName, namespace string, fwd config.Forward) tea.Cmd {
return func() tea.Msg {
err := mutator.UpdateForward(oldID, contextName, namespace, fwd)
return ForwardSavedMsg{
success: err == nil,
err: err,
}
}
}
// removeForwardsCmd removes selected forwards from the configuration file
func removeForwardsCmd(mutator *config.Mutator, forwards []RemovableForward) tea.Cmd {
return func() tea.Msg {
// Create a map of IDs to remove
idsToRemove := make(map[string]bool)
for _, fwd := range forwards {
idsToRemove[fwd.ID] = true
}
// Remove forwards matching the IDs
err := mutator.RemoveForwards(func(ctx, ns string, fwd config.Forward) bool {
return idsToRemove[fwd.ID()]
})
return ForwardsRemovedMsg{
success: err == nil,
count: len(forwards),
err: err,
}
}
}
// removeForwardByIDCmd removes a single forward by its ID
func removeForwardByIDCmd(mutator *config.Mutator, id string) tea.Cmd {
return func() tea.Msg {
err := mutator.RemoveForwardByID(id)
return ForwardsRemovedMsg{
success: err == nil,
count: 1,
err: err,
}
}
}
+819
View File
@@ -0,0 +1,819 @@
package ui
import (
"fmt"
"strconv"
"strings"
tea "github.com/charmbracelet/bubbletea"
"github.com/nvm/kportal/internal/config"
"github.com/nvm/kportal/internal/k8s"
)
// isFilterableStep returns true if the step supports search/filter
func isFilterableStep(step AddWizardStep) bool {
switch step {
case StepSelectContext, StepSelectNamespace:
return true
case StepEnterResource:
// Only service selection is filterable (pod prefix and selector are text input)
return true // We'll check resource type in the handler
default:
return false
}
}
// handleMainViewKeys handles keyboard input in the main view
func (m model) handleMainViewKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
// If delete confirmation is showing, handle it separately
if m.ui.deleteConfirming {
return m.handleDeleteConfirmation(msg)
}
switch msg.String() {
case "ctrl+c", "q":
return m, tea.Quit
case "up", "k":
m.ui.moveSelection(-1)
case "down", "j":
m.ui.moveSelection(1)
case " ", "enter":
m.ui.toggleSelected()
case "n": // Enter add wizard
m.ui.mu.Lock()
if m.ui.discovery == nil || m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
m.ui.viewMode = ViewModeAddWizard
m.ui.addWizard = newAddWizardState()
m.ui.addWizard.loading = true
m.ui.mu.Unlock()
// Load contexts
return m, loadContextsCmd(m.ui.discovery)
case "e": // Edit selected forward
m.ui.mu.Lock()
if len(m.ui.forwardOrder) == 0 {
// No forwards to edit
m.ui.mu.Unlock()
return m, nil
}
if m.ui.discovery == nil || m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
// Get the currently selected forward
currentSelectedIndex := m.ui.selectedIndex
if currentSelectedIndex < 0 || currentSelectedIndex >= len(m.ui.forwardOrder) {
m.ui.mu.Unlock()
return m, nil
}
selectedID := m.ui.forwardOrder[currentSelectedIndex]
selectedForward, ok := m.ui.forwards[selectedID]
if !ok {
m.ui.mu.Unlock()
return m, nil
}
// Create an add wizard pre-filled with the current forward's values
m.ui.viewMode = ViewModeAddWizard
m.ui.addWizard = newAddWizardState()
// Pre-fill the wizard with current values
m.ui.addWizard.selectedContext = selectedForward.Context
m.ui.addWizard.selectedNamespace = selectedForward.Namespace
m.ui.addWizard.resourceValue = selectedForward.Resource
m.ui.addWizard.remotePort = selectedForward.RemotePort
m.ui.addWizard.localPort = selectedForward.LocalPort
m.ui.addWizard.alias = selectedForward.Alias
// Determine resource type from the resource string
if strings.HasPrefix(selectedForward.Type, "service") {
m.ui.addWizard.selectedResourceType = ResourceTypeService
} else {
m.ui.addWizard.selectedResourceType = ResourceTypePodPrefix
}
// Mark as edit mode and store original ID
m.ui.addWizard.isEditing = true
m.ui.addWizard.originalID = selectedID
// Start at the remote port step (skip context/namespace/resource selection)
m.ui.addWizard.step = StepEnterRemotePort
// Load resources to detect ports
m.ui.addWizard.loading = true
m.ui.mu.Unlock()
// Load pods or services to detect available ports
if m.ui.addWizard.selectedResourceType == ResourceTypeService {
return m, loadServicesCmd(m.ui.discovery, selectedForward.Context, selectedForward.Namespace)
}
return m, loadPodsCmd(m.ui.discovery, selectedForward.Context, selectedForward.Namespace)
case "d": // Delete currently selected forward - show confirmation
m.ui.mu.Lock()
if len(m.ui.forwardOrder) == 0 {
// No forwards to delete
m.ui.mu.Unlock()
return m, nil
}
if m.ui.mutator == nil {
// Dependencies not set up
m.ui.mu.Unlock()
return m, nil
}
// Get the currently selected forward
currentSelectedIndex := m.ui.selectedIndex
if currentSelectedIndex < 0 || currentSelectedIndex >= len(m.ui.forwardOrder) {
m.ui.mu.Unlock()
return m, nil
}
selectedID := m.ui.forwardOrder[currentSelectedIndex]
selectedForward, ok := m.ui.forwards[selectedID]
if !ok {
m.ui.mu.Unlock()
return m, nil
}
// Show confirmation dialog
m.ui.deleteConfirming = true
m.ui.deleteConfirmID = selectedID
m.ui.deleteConfirmAlias = selectedForward.Alias
m.ui.deleteConfirmCursor = 0 // Default to "No" for safety
m.ui.mu.Unlock()
return m, nil
}
return m, nil
}
// handleDeleteConfirmation handles keyboard input for delete confirmation dialog
func (m model) handleDeleteConfirmation(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
switch msg.String() {
case "ctrl+c", "esc":
// Cancel deletion
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, tea.ClearScreen
case "left", "h", "right", "l":
// Toggle between Yes/No
m.ui.deleteConfirmCursor = 1 - m.ui.deleteConfirmCursor
m.ui.mu.Unlock()
return m, nil
case "enter", "y":
// Confirm deletion (either Enter on Yes or pressing 'y')
if m.ui.deleteConfirmCursor == 0 || msg.String() == "y" {
id := m.ui.deleteConfirmID
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, removeForwardByIDCmd(m.ui.mutator, id)
}
// Enter on No = cancel
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, tea.ClearScreen
case "n":
// Quick 'n' for no
m.ui.resetDeleteConfirmation()
m.ui.mu.Unlock()
return m, tea.ClearScreen
}
m.ui.mu.Unlock()
return m, nil
}
// handleAddWizardKeys handles keyboard input in the add wizard
func (m model) handleAddWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
wizard := m.ui.addWizard
if wizard == nil {
return m, nil
}
switch msg.String() {
case "ctrl+c":
// Hard cancel
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
case "esc":
// If there's an active search filter, clear it instead of going back
if wizard.searchFilter != "" && isFilterableStep(wizard.step) {
wizard.clearSearchFilter()
return m, nil
}
// In edit mode, Esc always cancels (don't navigate back through skipped steps)
if wizard.isEditing {
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
// In add mode, go back or cancel
if wizard.step == StepSelectContext {
// On first step, cancel entirely
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
} else {
// Go back one step
wizard.step--
wizard.resetInput()
// Reset input mode based on the step we're going back to
switch wizard.step {
case StepSelectContext, StepSelectNamespace, StepSelectResourceType:
wizard.inputMode = InputModeList
case StepEnterResource:
if wizard.selectedResourceType == ResourceTypeService {
wizard.inputMode = InputModeList
} else {
wizard.inputMode = InputModeText
}
case StepEnterRemotePort, StepEnterLocalPort:
wizard.inputMode = InputModeText
case StepConfirmation:
wizard.inputMode = InputModeList
}
}
return m, nil
case "up", "k":
// In confirmation step, toggle between alias and buttons
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusButtons {
wizard.confirmationFocus = FocusAlias
}
} else {
wizard.moveCursor(-1)
}
case "down", "j":
// In confirmation step, toggle between alias and buttons
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
} else {
wizard.moveCursor(1) // Navigate between buttons
}
} else {
wizard.moveCursor(1)
}
case "tab":
// Tab moves between alias field and buttons in confirmation
if wizard.step == StepConfirmation {
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
} else {
wizard.confirmationFocus = FocusAlias
}
}
case "enter":
return m.handleAddWizardEnter()
case "backspace":
// Allow backspace in text input mode OR when focused on alias in confirmation OR when filtering
canBackspace := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step) && len(wizard.searchFilter) > 0)
if canBackspace {
if isFilterableStep(wizard.step) && wizard.inputMode == InputModeList && len(wizard.searchFilter) > 0 {
// Backspace in search filter
wizard.searchFilter = wizard.searchFilter[:len(wizard.searchFilter)-1]
wizard.cursor = 0
wizard.scrollOffset = 0
} else if len(wizard.textInput) > 0 {
wizard.textInput = wizard.textInput[:len(wizard.textInput)-1]
}
}
default:
// Handle text input
canTypeText := wizard.inputMode == InputModeText ||
(wizard.step == StepConfirmation && wizard.confirmationFocus == FocusAlias) ||
(wizard.inputMode == InputModeList && isFilterableStep(wizard.step))
if canTypeText && len(msg.String()) == 1 {
// If in list mode on filterable step, add to search filter instead of textInput
if wizard.inputMode == InputModeList && isFilterableStep(wizard.step) {
char := rune(msg.String()[0])
// Only allow printable characters
if char >= 32 && char < 127 {
wizard.searchFilter += string(char)
wizard.cursor = 0
wizard.scrollOffset = 0
}
} else {
wizard.handleTextInput(rune(msg.String()[0]))
// Trigger validation for selector
if wizard.step == StepEnterResource && wizard.selectedResourceType == ResourceTypePodSelector {
if len(wizard.textInput) > 0 {
wizard.loading = true
wizard.error = nil
return m, validateSelectorCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace, wizard.textInput)
}
}
}
}
}
return m, nil
}
// handleAddWizardEnter handles Enter key in the add wizard
func (m model) handleAddWizardEnter() (tea.Model, tea.Cmd) {
wizard := m.ui.addWizard
// Don't process Enter if we're currently loading
if wizard.loading {
return m, nil
}
switch wizard.step {
case StepSelectContext:
filteredContexts := wizard.getFilteredContexts()
if wizard.cursor >= 0 && wizard.cursor < len(filteredContexts) {
wizard.selectedContext = filteredContexts[wizard.cursor]
wizard.step = StepSelectNamespace
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.loading = true
return m, loadNamespacesCmd(m.ui.discovery, wizard.selectedContext)
}
case StepSelectNamespace:
filteredNamespaces := wizard.getFilteredNamespaces()
if wizard.cursor >= 0 && wizard.cursor < len(filteredNamespaces) {
wizard.selectedNamespace = filteredNamespaces[wizard.cursor]
wizard.step = StepSelectResourceType
wizard.cursor = 0
wizard.clearSearchFilter()
wizard.inputMode = InputModeList
}
case StepSelectResourceType:
if wizard.cursor >= 0 && wizard.cursor < 3 {
wizard.selectedResourceType = ResourceType(wizard.cursor)
wizard.step = StepEnterResource
wizard.cursor = 0
if wizard.selectedResourceType == ResourceTypeService {
wizard.inputMode = InputModeList
wizard.loading = true
return m, loadServicesCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace)
} else {
wizard.inputMode = InputModeText
wizard.loading = true
return m, loadPodsCmd(m.ui.discovery, wizard.selectedContext, wizard.selectedNamespace)
}
}
case StepEnterResource:
switch wizard.selectedResourceType {
case ResourceTypePodPrefix:
if wizard.textInput != "" {
wizard.resourceValue = wizard.textInput
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
// Detect ports from matching pods
wizard.detectedPorts = k8s.GetUniquePorts(wizard.pods)
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
case ResourceTypePodSelector:
if wizard.textInput != "" && len(wizard.matchingPods) > 0 {
wizard.resourceValue = "pod"
wizard.selector = wizard.textInput
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
// Detect ports from matching pods
wizard.detectedPorts = k8s.GetUniquePorts(wizard.matchingPods)
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
case ResourceTypeService:
filteredServices := wizard.getFilteredServices()
if wizard.cursor >= 0 && wizard.cursor < len(filteredServices) {
wizard.resourceValue = filteredServices[wizard.cursor].Name
// Get ports from selected service (must do this BEFORE clearing search filter)
wizard.detectedPorts = filteredServices[wizard.cursor].Ports
wizard.step = StepEnterRemotePort
wizard.clearTextInput()
wizard.clearSearchFilter()
if len(wizard.detectedPorts) > 0 {
wizard.inputMode = InputModeList
wizard.cursor = 0
} else {
wizard.inputMode = InputModeText
}
}
}
case StepEnterRemotePort:
if wizard.inputMode == InputModeList && len(wizard.detectedPorts) > 0 {
// List mode - user selected from detected ports
if wizard.cursor == len(wizard.detectedPorts) {
// Selected "Manual entry" option
wizard.inputMode = InputModeText
wizard.clearTextInput()
} else if wizard.cursor >= 0 && wizard.cursor < len(wizard.detectedPorts) {
// Selected a detected port
wizard.remotePort = int(wizard.detectedPorts[wizard.cursor].Port)
wizard.step = StepEnterLocalPort
wizard.clearTextInput()
wizard.inputMode = InputModeText
wizard.error = nil
}
} else {
// Text mode - manual entry
port, err := strconv.Atoi(wizard.textInput)
if err != nil || !config.IsValidPort(port) {
wizard.error = fmt.Errorf("invalid port number")
} else {
wizard.remotePort = port
wizard.step = StepEnterLocalPort
wizard.clearTextInput()
wizard.error = nil
}
}
case StepEnterLocalPort:
port, err := strconv.Atoi(wizard.textInput)
if err != nil || !config.IsValidPort(port) {
wizard.error = fmt.Errorf("invalid port number")
} else {
// Check port availability before proceeding
wizard.localPort = port
wizard.loading = true
wizard.error = nil
return m, checkPortCmd(port, m.ui.configPath)
}
case StepConfirmation:
// If focused on alias field, move to buttons
if wizard.confirmationFocus == FocusAlias {
wizard.confirmationFocus = FocusButtons
wizard.cursor = 0
return m, nil
}
// Handle button selection
if wizard.cursor == 0 {
// Check if port is available before saving
if !wizard.portAvailable {
wizard.error = fmt.Errorf("port %d is not available. Please choose a different port", wizard.localPort)
return m, nil
}
// Confirmed - save the forward
wizard.alias = wizard.textInput
// Build the forward config
fwd := config.Forward{
Protocol: "tcp",
Port: wizard.remotePort,
LocalPort: wizard.localPort,
Alias: wizard.alias,
}
if wizard.selectedResourceType == ResourceTypePodPrefix {
fwd.Resource = "pod/" + wizard.resourceValue
} else if wizard.selectedResourceType == ResourceTypePodSelector {
fwd.Resource = wizard.resourceValue
fwd.Selector = wizard.selector
} else if wizard.selectedResourceType == ResourceTypeService {
fwd.Resource = "service/" + wizard.resourceValue
}
wizard.loading = true
// If editing, use atomic update operation
if wizard.isEditing {
return m, updateForwardCmd(m.ui.mutator, wizard.originalID, wizard.selectedContext, wizard.selectedNamespace, fwd)
}
return m, saveForwardCmd(m.ui.mutator, wizard.selectedContext, wizard.selectedNamespace, fwd)
} else {
// Cancelled - return to main view with screen clear
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
case StepSuccess:
if wizard.cursor == 0 {
// Add another
m.ui.addWizard = newAddWizardState()
m.ui.addWizard.loading = true
return m, loadContextsCmd(m.ui.discovery)
} else {
// Return to main view with screen clear
m.ui.viewMode = ViewModeMain
m.ui.addWizard = nil
return m, tea.ClearScreen
}
}
return m, nil
}
// handleRemoveWizardKeys handles keyboard input in the remove wizard
func (m model) handleRemoveWizardKeys(msg tea.KeyMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
wizard := m.ui.removeWizard
if wizard == nil {
return m, nil
}
switch msg.String() {
case "ctrl+c":
// Hard cancel - always exit
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
return m, tea.ClearScreen
case "esc":
if wizard.confirming {
// In confirmation mode, Esc confirms the removal (same as pressing Yes)
selectedForwards := wizard.getSelectedForwards()
return m, removeForwardsCmd(m.ui.mutator, selectedForwards)
} else {
// Not confirming yet - cancel entirely
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
}
return m, tea.ClearScreen
case "up", "k":
wizard.moveCursor(-1)
case "down", "j":
wizard.moveCursor(1)
case " ":
if !wizard.confirming {
wizard.toggleSelection()
}
case "a":
wizard.selectAll()
case "n":
wizard.selectNone()
case "enter":
if !wizard.confirming {
if wizard.getSelectedCount() == 0 {
// Nothing selected
return m, nil
}
// Show confirmation
wizard.confirming = true
wizard.confirmCursor = 0
} else {
// Confirmed
if wizard.confirmCursor == 0 {
// Yes, remove
selectedForwards := wizard.getSelectedForwards()
return m, removeForwardsCmd(m.ui.mutator, selectedForwards)
} else {
// No, cancel
wizard.confirming = false
}
}
}
return m, nil
}
// Message handlers
func (m model) handleContextsLoaded(msg ContextsLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
// Get current context and move it to the top
currentCtx, err := m.ui.discovery.GetCurrentContext()
if err == nil && currentCtx != "" {
// Reorder contexts with current first
reordered := []string{currentCtx}
for _, ctx := range msg.contexts {
if ctx != currentCtx {
reordered = append(reordered, ctx)
}
}
m.ui.addWizard.contexts = reordered
} else {
m.ui.addWizard.contexts = msg.contexts
}
}
}
return m, nil
}
func (m model) handleNamespacesLoaded(msg NamespacesLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.namespaces = msg.namespaces
}
}
return m, nil
}
func (m model) handlePodsLoaded(msg PodsLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.pods = msg.pods
// If we're at the remote port step (edit mode), detect ports now
if m.ui.addWizard.step == StepEnterRemotePort {
m.ui.addWizard.detectedPorts = k8s.GetUniquePorts(msg.pods)
if len(m.ui.addWizard.detectedPorts) > 0 {
m.ui.addWizard.inputMode = InputModeList
m.ui.addWizard.cursor = 0
} else {
m.ui.addWizard.inputMode = InputModeText
m.ui.addWizard.textInput = fmt.Sprintf("%d", m.ui.addWizard.remotePort)
}
}
}
}
return m, nil
}
func (m model) handleServicesLoaded(msg ServicesLoadedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.err == nil {
m.ui.addWizard.services = msg.services
// If we're at the remote port step (edit mode), detect ports now
if m.ui.addWizard.step == StepEnterRemotePort {
// Find the service by name
for _, svc := range msg.services {
if svc.Name == m.ui.addWizard.resourceValue {
m.ui.addWizard.detectedPorts = svc.Ports
if len(m.ui.addWizard.detectedPorts) > 0 {
m.ui.addWizard.inputMode = InputModeList
m.ui.addWizard.cursor = 0
} else {
m.ui.addWizard.inputMode = InputModeText
m.ui.addWizard.textInput = fmt.Sprintf("%d", m.ui.addWizard.remotePort)
}
break
}
}
}
}
}
return m, nil
}
func (m model) handleSelectorValidated(msg SelectorValidatedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.error = msg.err
if msg.valid {
m.ui.addWizard.matchingPods = msg.pods
} else {
m.ui.addWizard.matchingPods = nil
}
}
return m, nil
}
func (m model) handlePortChecked(msg PortCheckedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
m.ui.addWizard.portAvailable = msg.available
m.ui.addWizard.portCheckMsg = msg.message
// Only proceed to confirmation if port is available
if msg.available {
m.ui.addWizard.step = StepConfirmation
m.ui.addWizard.clearTextInput()
m.ui.addWizard.cursor = 0
m.ui.addWizard.inputMode = InputModeList
} else {
// Port is not available - show error and stay on local port step
m.ui.addWizard.error = fmt.Errorf("port %d is in use, please choose another port", msg.port)
}
}
return m, nil
}
func (m model) handleForwardSaved(msg ForwardSavedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
if m.ui.addWizard != nil {
m.ui.addWizard.loading = false
if msg.success {
// Move to success step
m.ui.addWizard.step = StepSuccess
m.ui.addWizard.cursor = 0
m.ui.addWizard.inputMode = InputModeList
} else {
m.ui.addWizard.error = msg.err
}
}
return m, nil
}
func (m model) handleForwardsRemoved(msg ForwardsRemovedMsg) (tea.Model, tea.Cmd) {
m.ui.mu.Lock()
defer m.ui.mu.Unlock()
// Delete now happens directly without wizard
// Just ensure we're back in main view
m.ui.viewMode = ViewModeMain
m.ui.removeWizard = nil
// If there was an error, it will be logged but we don't show it in UI for now
// The config watcher will either reload (success) or keep old config (failure)
return m, tea.ClearScreen
}
+375
View File
@@ -0,0 +1,375 @@
package ui
import (
"strings"
"github.com/nvm/kportal/internal/k8s"
)
// filterStrings filters a slice of strings by a search filter (case-insensitive substring match)
func filterStrings(items []string, filter string) []string {
if filter == "" {
return items
}
filtered := []string{}
filterLower := strings.ToLower(filter)
for _, item := range items {
if strings.Contains(strings.ToLower(item), filterLower) {
filtered = append(filtered, item)
}
}
return filtered
}
// matchesFilter checks if a string matches the filter (case-insensitive substring match)
func matchesFilter(item, filter string) bool {
if filter == "" {
return true
}
return strings.Contains(strings.ToLower(item), strings.ToLower(filter))
}
// ViewMode represents the current view state of the UI
type ViewMode int
const (
ViewModeMain ViewMode = iota
ViewModeAddWizard
ViewModeRemoveWizard
)
// InputMode represents whether the wizard is in list selection or text input mode
type InputMode int
const (
InputModeList InputMode = iota
InputModeText
)
// AddWizardStep represents the current step in the add wizard flow
type AddWizardStep int
const (
StepSelectContext AddWizardStep = iota
StepSelectNamespace
StepSelectResourceType
StepEnterResource
StepEnterRemotePort
StepEnterLocalPort
StepConfirmation
StepSuccess
)
// ConfirmationFocus represents what the user is focused on in confirmation step
type ConfirmationFocus int
const (
FocusAlias ConfirmationFocus = iota
FocusButtons
)
// ResourceType represents the type of Kubernetes resource to forward to
type ResourceType int
const (
ResourceTypePodPrefix ResourceType = iota
ResourceTypePodSelector
ResourceTypeService
)
// String returns a human-readable name for the resource type
func (r ResourceType) String() string {
switch r {
case ResourceTypePodPrefix:
return "Pod (by name prefix)"
case ResourceTypePodSelector:
return "Pod (by label selector)"
case ResourceTypeService:
return "Service"
default:
return "Unknown"
}
}
// Description returns a description of the resource type
func (r ResourceType) Description() string {
switch r {
case ResourceTypePodPrefix:
return "Recommended for specific pod instances"
case ResourceTypePodSelector:
return "Flexible, survives pod restarts automatically"
case ResourceTypeService:
return "Most stable, load-balanced"
default:
return ""
}
}
// AddWizardState maintains the state for the add port forward wizard
type AddWizardState struct {
step AddWizardStep
inputMode InputMode
cursor int
scrollOffset int // For scrolling long lists
textInput string
searchFilter string // For filtering lists (contexts, namespaces, services)
loading bool
error error
// Selections made by user
selectedContext string
selectedNamespace string
selectedResourceType ResourceType
resourceValue string // pod prefix or service name
selector string // for pod selector type
remotePort int
localPort int
alias string
// Available options (loaded asynchronously from k8s)
contexts []string
namespaces []string
pods []k8s.PodInfo
services []k8s.ServiceInfo
// Validation state
portAvailable bool
portCheckMsg string
matchingPods []k8s.PodInfo
// Edit mode
isEditing bool
originalID string // ID of the forward being edited
// Detected ports from resources
detectedPorts []k8s.PortInfo
// Confirmation focus (alias field vs buttons)
confirmationFocus ConfirmationFocus
}
// newAddWizardState creates a new add wizard state initialized to the first step
func newAddWizardState() *AddWizardState {
return &AddWizardState{
step: StepSelectContext,
inputMode: InputModeList,
cursor: 0,
contexts: []string{},
}
}
// moveCursor moves the cursor up or down in list selection mode
func (w *AddWizardState) moveCursor(delta int) {
if w.inputMode != InputModeList {
return
}
var maxItems int
switch w.step {
case StepSelectContext:
maxItems = len(w.getFilteredContexts())
case StepSelectNamespace:
maxItems = len(w.getFilteredNamespaces())
case StepSelectResourceType:
maxItems = 3 // Three resource types
case StepEnterResource:
if w.selectedResourceType == ResourceTypeService {
maxItems = len(w.getFilteredServices())
}
case StepEnterRemotePort:
if len(w.detectedPorts) > 0 {
maxItems = len(w.detectedPorts) + 1 // +1 for "Manual entry" option
}
}
w.cursor += delta
if w.cursor < 0 {
w.cursor = 0
}
if w.cursor >= maxItems && maxItems > 0 {
w.cursor = maxItems - 1
}
// Adjust scroll offset to keep cursor visible
// Viewport shows max 20 items at a time
const viewportHeight = 20
// If cursor moved below visible area, scroll down
if w.cursor >= w.scrollOffset+viewportHeight {
w.scrollOffset = w.cursor - viewportHeight + 1
}
// If cursor moved above visible area, scroll up
if w.cursor < w.scrollOffset {
w.scrollOffset = w.cursor
}
// Ensure scroll offset is valid
if w.scrollOffset < 0 {
w.scrollOffset = 0
}
}
// handleTextInput handles a single character input in text mode
func (w *AddWizardState) handleTextInput(char rune) {
// Note: Caller already checks if text input is allowed (inputMode or confirmation step)
// so we don't need to check inputMode here
// Handle backspace
if char == 127 || char == 8 {
if len(w.textInput) > 0 {
w.textInput = w.textInput[:len(w.textInput)-1]
}
return
}
// Only allow printable characters
if char >= 32 && char < 127 {
w.textInput += string(char)
}
}
// clearTextInput clears the text input field
func (w *AddWizardState) clearTextInput() {
w.textInput = ""
}
// RemoveWizardState maintains the state for the remove port forward wizard
type RemoveWizardState struct {
forwards []RemovableForward
cursor int
selected map[int]bool
confirming bool
confirmCursor int // 0 = Yes, 1 = No
}
// RemovableForward represents a forward that can be removed
type RemovableForward struct {
ID string
Context string
Namespace string
Alias string
Resource string
Selector string
Port int
LocalPort int
}
// moveCursor moves the cursor up or down
func (w *RemoveWizardState) moveCursor(delta int) {
if w.confirming {
// Move between Yes/No in confirmation
w.confirmCursor += delta
if w.confirmCursor < 0 {
w.confirmCursor = 0
}
if w.confirmCursor > 1 {
w.confirmCursor = 1
}
} else {
// Move between forwards
w.cursor += delta
if w.cursor < 0 {
w.cursor = 0
}
if w.cursor >= len(w.forwards) {
w.cursor = len(w.forwards) - 1
}
}
}
// toggleSelection toggles the selection of the current forward
func (w *RemoveWizardState) toggleSelection() {
if w.confirming {
return
}
w.selected[w.cursor] = !w.selected[w.cursor]
}
// selectAll selects all forwards for removal
func (w *RemoveWizardState) selectAll() {
if w.confirming {
return
}
for i := range w.forwards {
w.selected[i] = true
}
}
// selectNone deselects all forwards
func (w *RemoveWizardState) selectNone() {
if w.confirming {
return
}
w.selected = make(map[int]bool)
}
// getSelectedCount returns the number of selected forwards
func (w *RemoveWizardState) getSelectedCount() int {
count := 0
for _, selected := range w.selected {
if selected {
count++
}
}
return count
}
// getSelectedForwards returns a list of selected forwards
func (w *RemoveWizardState) getSelectedForwards() []RemovableForward {
selected := make([]RemovableForward, 0)
for i, fwd := range w.forwards {
if w.selected[i] {
selected = append(selected, fwd)
}
}
return selected
}
// getFilteredContexts returns contexts filtered by search string
func (w *AddWizardState) getFilteredContexts() []string {
if w.searchFilter == "" {
return w.contexts
}
return filterStrings(w.contexts, w.searchFilter)
}
// getFilteredNamespaces returns namespaces filtered by search string
func (w *AddWizardState) getFilteredNamespaces() []string {
if w.searchFilter == "" {
return w.namespaces
}
return filterStrings(w.namespaces, w.searchFilter)
}
// getFilteredServices returns services filtered by search string
func (w *AddWizardState) getFilteredServices() []k8s.ServiceInfo {
if w.searchFilter == "" {
return w.services
}
filtered := []k8s.ServiceInfo{}
for _, svc := range w.services {
if matchesFilter(svc.Name, w.searchFilter) {
filtered = append(filtered, svc)
}
}
return filtered
}
// clearSearchFilter clears the search filter and resets cursor/scroll
func (w *AddWizardState) clearSearchFilter() {
w.searchFilter = ""
w.cursor = 0
w.scrollOffset = 0
}
// resetInput clears text input, search filter, and error state.
// Use this when navigating between wizard steps.
func (w *AddWizardState) resetInput() {
w.textInput = ""
w.searchFilter = ""
w.cursor = 0
w.scrollOffset = 0
w.error = nil
}
+350
View File
@@ -0,0 +1,350 @@
package ui
import (
"testing"
"github.com/nvm/kportal/internal/k8s"
"github.com/stretchr/testify/assert"
)
func TestFilterStrings(t *testing.T) {
tests := []struct {
name string
items []string
filter string
expected []string
}{
{
name: "empty filter returns all items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "",
expected: []string{"namespace-1", "namespace-2", "namespace-3"},
},
{
name: "filter matches multiple items",
items: []string{"prod-api", "prod-db", "staging-api", "dev-api"},
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter matches single item",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "2",
expected: []string{"namespace-2"},
},
{
name: "filter matches no items",
items: []string{"namespace-1", "namespace-2", "namespace-3"},
filter: "xyz",
expected: []string{},
},
{
name: "case insensitive matching",
items: []string{"Production", "Staging", "Development"},
filter: "prod",
expected: []string{"Production"},
},
{
name: "partial string matching",
items: []string{"my-app-frontend", "my-app-backend", "other-service"},
filter: "app",
expected: []string{"my-app-frontend", "my-app-backend"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterStrings(tt.items, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMatchesFilter(t *testing.T) {
tests := []struct {
name string
item string
filter string
expected bool
}{
{
name: "empty filter matches everything",
item: "namespace-1",
filter: "",
expected: true,
},
{
name: "exact match",
item: "namespace-1",
filter: "namespace-1",
expected: true,
},
{
name: "partial match",
item: "production-api",
filter: "prod",
expected: true,
},
{
name: "no match",
item: "namespace-1",
filter: "xyz",
expected: false,
},
{
name: "case insensitive match",
item: "Production",
filter: "prod",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchesFilter(tt.item, tt.filter)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredContexts(t *testing.T) {
wizard := &AddWizardState{
contexts: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-cluster"},
},
{
name: "filter by 'cluster'",
filter: "cluster",
expected: []string{"prod-cluster", "staging-cluster", "dev-cluster", "test-cluster"},
},
{
name: "filter by 'staging'",
filter: "staging",
expected: []string{"staging-cluster"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredContexts()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredNamespaces(t *testing.T) {
wizard := &AddWizardState{
namespaces: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{
"kube-system", "kube-public", "default",
"prod-api", "prod-db", "staging-api", "staging-db",
"monitoring", "logging",
},
},
{
name: "filter by 'prod'",
filter: "prod",
expected: []string{"prod-api", "prod-db"},
},
{
name: "filter by 'kube'",
filter: "kube",
expected: []string{"kube-system", "kube-public"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"prod-api", "staging-api"},
},
{
name: "filter by 'ing' (partial match)",
filter: "ing",
expected: []string{"staging-api", "staging-db", "monitoring", "logging"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredNamespaces()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetFilteredServices(t *testing.T) {
wizard := &AddWizardState{
services: []k8s.ServiceInfo{
{Name: "api-gateway"},
{Name: "api-backend"},
{Name: "database"},
{Name: "redis-cache"},
{Name: "postgres-db"},
},
}
tests := []struct {
name string
filter string
expected []string
}{
{
name: "no filter returns all",
filter: "",
expected: []string{"api-gateway", "api-backend", "database", "redis-cache", "postgres-db"},
},
{
name: "filter by 'api'",
filter: "api",
expected: []string{"api-gateway", "api-backend"},
},
{
name: "filter by 'db'",
filter: "db",
expected: []string{"postgres-db"},
},
{
name: "filter by 'base'",
filter: "base",
expected: []string{"database"},
},
{
name: "filter by 'redis'",
filter: "redis",
expected: []string{"redis-cache"},
},
{
name: "filter with no matches",
filter: "xyz",
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard.searchFilter = tt.filter
result := wizard.getFilteredServices()
resultNames := make([]string, len(result))
for i, svc := range result {
resultNames[i] = svc.Name
}
assert.Equal(t, tt.expected, resultNames)
})
}
}
func TestClearSearchFilter(t *testing.T) {
wizard := &AddWizardState{
searchFilter: "test",
cursor: 5,
scrollOffset: 10,
}
wizard.clearSearchFilter()
assert.Equal(t, "", wizard.searchFilter, "searchFilter should be cleared")
assert.Equal(t, 0, wizard.cursor, "cursor should be reset to 0")
assert.Equal(t, 0, wizard.scrollOffset, "scrollOffset should be reset to 0")
}
func TestMoveCursorWithFilteredLists(t *testing.T) {
tests := []struct {
name string
step AddWizardStep
contexts []string
namespaces []string
searchFilter string
initialCursor int
delta int
expectedCursor int
}{
{
name: "move down in filtered contexts",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 0,
delta: 1,
expectedCursor: 1,
},
{
name: "cannot move beyond filtered list",
step: StepSelectContext,
contexts: []string{"prod-1", "prod-2", "staging-1", "dev-1"},
searchFilter: "prod",
initialCursor: 1,
delta: 1,
expectedCursor: 1, // Should stay at 1 (last item in filtered list)
},
{
name: "move up in filtered list",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3", "other"},
searchFilter: "ns",
initialCursor: 2,
delta: -1,
expectedCursor: 1,
},
{
name: "cannot move above 0",
step: StepSelectNamespace,
namespaces: []string{"ns-1", "ns-2", "ns-3"},
searchFilter: "ns",
initialCursor: 0,
delta: -1,
expectedCursor: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wizard := &AddWizardState{
step: tt.step,
inputMode: InputModeList,
cursor: tt.initialCursor,
contexts: tt.contexts,
namespaces: tt.namespaces,
searchFilter: tt.searchFilter,
}
wizard.moveCursor(tt.delta)
assert.Equal(t, tt.expectedCursor, wizard.cursor)
})
}
}
+211
View File
@@ -0,0 +1,211 @@
package ui
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
)
// Color palette for wizards
var (
primaryColor = lipgloss.Color("205") // Pink/Magenta
successColor = lipgloss.Color("42") // Green
errorColor = lipgloss.Color("196") // Red
warningColor = lipgloss.Color("220") // Yellow
mutedColor = lipgloss.Color("241") // Gray
accentColor = lipgloss.Color("63") // Purple
highlightColor = lipgloss.Color("117") // Light blue
)
// Text styles
var (
wizardHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(primaryColor).
MarginBottom(0)
wizardStepStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true)
breadcrumbStyle = lipgloss.NewStyle().
Foreground(highlightColor).
Bold(true)
selectedStyle = lipgloss.NewStyle().
Foreground(primaryColor).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
errorStyle = lipgloss.NewStyle().
Foreground(errorColor).
Bold(true)
warningStyle = lipgloss.NewStyle().
Foreground(warningColor).
Bold(true)
mutedStyle = lipgloss.NewStyle().
Foreground(mutedColor)
helpStyle = lipgloss.NewStyle().
Foreground(mutedColor).
Italic(true)
spinnerStyle = lipgloss.NewStyle().
Foreground(accentColor).
Bold(true)
)
// Input styles
var (
inputStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
validInputStyle = lipgloss.NewStyle().
Foreground(successColor)
)
// Checkbox styles
var (
checkedBoxStyle = lipgloss.NewStyle().
Foreground(successColor).
Bold(true)
uncheckedBoxStyle = lipgloss.NewStyle().
Foreground(mutedColor)
)
// Container styles
var (
wizardBoxStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(accentColor).
Padding(1, 2).
Width(60)
)
// Helper functions for rendering
// renderProgress returns a step indicator like "Step 2/7"
func renderProgress(current, total int) string {
return wizardStepStyle.Render(fmt.Sprintf("Step %d/%d", current, total))
}
// renderHeader returns a formatted header with title and progress
func renderHeader(title, progress string) string {
header := wizardHeaderStyle.Render(title)
if progress != "" {
header += " " + progress
}
return header + "\n\n"
}
// renderBreadcrumb returns a formatted breadcrumb path
func renderBreadcrumb(parts ...string) string {
return breadcrumbStyle.Render(strings.Join(parts, " / "))
}
// renderList renders a list of items with cursor selection and viewport scrolling
func renderList(items []string, cursor int, prefix string, scrollOffset int) string {
var b strings.Builder
const viewportHeight = 20
totalItems := len(items)
// Show scroll up indicator if there are items above the viewport
if scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := scrollOffset
end := scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render visible items
for i := start; i < end; i++ {
cursorPrefix := prefix
if i == cursor {
cursorPrefix = "▸ "
b.WriteString(selectedStyle.Render(cursorPrefix + items[i]))
} else {
b.WriteString(cursorPrefix + items[i])
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
return b.String()
}
// renderTextInput renders a text input field with a cursor
func renderTextInput(label, value string, valid bool) string {
var b strings.Builder
b.WriteString(label)
inputText := value + "█"
if valid {
b.WriteString(validInputStyle.Render(inputText))
} else {
b.WriteString(inputStyle.Render(inputText))
}
return b.String()
}
// overlayContent overlays modal content centered on the base view
func overlayContent(base, modal string, termWidth, termHeight int) string {
baseLines := strings.Split(base, "\n")
modalLines := strings.Split(modal, "\n")
// Ensure base has enough lines
for len(baseLines) < termHeight {
baseLines = append(baseLines, "")
}
modalHeight := len(modalLines)
modalWidth := 0
for _, line := range modalLines {
w := lipgloss.Width(line)
if w > modalWidth {
modalWidth = w
}
}
// Calculate center position
startRow := (termHeight - modalHeight) / 2
if startRow < 0 {
startRow = 0
}
// Create result with modal overlaid
result := make([]string, len(baseLines))
copy(result, baseLines)
for i, modalLine := range modalLines {
row := startRow + i
if row >= 0 && row < len(result) {
// Center the modal line
padding := (termWidth - lipgloss.Width(modalLine)) / 2
if padding < 0 {
padding = 0
}
result[row] = strings.Repeat(" ", padding) + modalLine
}
}
return strings.Join(result, "\n")
}
+653
View File
@@ -0,0 +1,653 @@
package ui
import (
"fmt"
"strings"
)
// renderAddWizard renders the appropriate step of the add wizard
func (m model) renderAddWizard() string {
if m.ui.addWizard == nil {
return ""
}
wizard := m.ui.addWizard
var content string
switch wizard.step {
case StepSelectContext:
content = m.renderSelectContext()
case StepSelectNamespace:
content = m.renderSelectNamespace()
case StepSelectResourceType:
content = m.renderSelectResourceType()
case StepEnterResource:
content = m.renderEnterResource()
case StepEnterRemotePort:
content = m.renderEnterRemotePort()
case StepEnterLocalPort:
content = m.renderEnterLocalPort()
case StepConfirmation:
content = m.renderConfirmation()
case StepSuccess:
content = m.renderSuccess()
default:
content = "Unknown step"
}
return wizardBoxStyle.Render(content)
}
func (m model) renderSelectContext() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(1, 7)))
b.WriteString("Select Kubernetes Context:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading contexts..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Error: %v", wizard.error)))
} else if len(wizard.contexts) == 0 {
b.WriteString(mutedStyle.Render("No contexts found in kubeconfig"))
} else {
filteredContexts := wizard.getFilteredContexts()
if len(filteredContexts) == 0 {
b.WriteString(mutedStyle.Render("No matching contexts"))
} else {
const viewportHeight = 20
totalItems := len(filteredContexts)
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render visible contexts with (current) marker on first one (only if not filtered)
for i := start; i < end; i++ {
prefix := " "
text := filteredContexts[i]
// Only show (current) marker if no filter and this is the first item in original list
if wizard.searchFilter == "" && i == 0 {
text += mutedStyle.Render(" (current)")
}
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + text))
} else {
b.WriteString(prefix + text)
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
}
}
b.WriteString("\n")
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Cancel", len(wizard.getFilteredContexts()), len(wizard.contexts))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc/Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderSelectNamespace() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(2, 7)))
b.WriteString(fmt.Sprintf("Context: %s\n\n", breadcrumbStyle.Render(wizard.selectedContext)))
b.WriteString("Select Namespace:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading namespaces..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Error: %v\n", wizard.error)))
b.WriteString(mutedStyle.Render("\nCluster may be unreachable. Check context."))
} else if len(wizard.namespaces) == 0 {
b.WriteString(mutedStyle.Render("No namespaces found"))
} else {
filteredNamespaces := wizard.getFilteredNamespaces()
if len(filteredNamespaces) == 0 {
b.WriteString(mutedStyle.Render("No matching namespaces"))
} else {
b.WriteString(renderList(filteredNamespaces, wizard.cursor, " ", wizard.scrollOffset))
}
}
b.WriteString("\n")
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredNamespaces()), len(wizard.namespaces))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderSelectResourceType() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(3, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n\n")
b.WriteString("Select Resource Type:\n\n")
resourceTypes := []ResourceType{
ResourceTypePodPrefix,
ResourceTypePodSelector,
ResourceTypeService,
}
for i, rt := range resourceTypes {
prefix := " "
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + rt.String()))
b.WriteString("\n")
b.WriteString(mutedStyle.Render(" " + rt.Description()))
} else {
b.WriteString(prefix + rt.String())
}
b.WriteString("\n")
if i < len(resourceTypes)-1 {
b.WriteString("\n")
}
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
return b.String()
}
func (m model) renderEnterResource() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(4, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n\n")
switch wizard.selectedResourceType {
case ResourceTypePodPrefix:
b.WriteString("Enter pod name prefix:\n\n")
// Show running pods for reference
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading pods..."))
} else if len(wizard.pods) > 0 {
b.WriteString(mutedStyle.Render("Running pods:\n"))
showCount := 0
for _, pod := range wizard.pods {
if strings.HasPrefix(pod.Name, wizard.textInput) || wizard.textInput == "" {
if showCount < 5 { // Limit to 5 pods
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", pod.Name)))
showCount++
}
}
}
if showCount == 0 && wizard.textInput != "" {
b.WriteString(mutedStyle.Render(" (no matching pods)\n"))
} else if len(wizard.pods) > showCount {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" ... and %d more\n", len(wizard.pods)-showCount)))
}
b.WriteString("\n")
}
// Text input
b.WriteString(renderTextInput("Prefix: ", wizard.textInput, true))
b.WriteString("\n\n")
// Show match count
if wizard.textInput != "" {
matchCount := 0
for _, pod := range wizard.pods {
if strings.HasPrefix(pod.Name, wizard.textInput) {
matchCount++
}
}
if matchCount > 0 {
b.WriteString(successStyle.Render(fmt.Sprintf("✓ Matches %d pod(s)", matchCount)))
} else {
b.WriteString(warningStyle.Render("⚠ No matching pods (you can still proceed)"))
}
}
case ResourceTypePodSelector:
b.WriteString("Enter label selector:\n")
b.WriteString(mutedStyle.Render("Format: key=value,key2=value2\n\n"))
b.WriteString(renderTextInput("Selector: ", wizard.textInput, true))
b.WriteString("\n\n")
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Validating selector..."))
} else if len(wizard.matchingPods) > 0 {
b.WriteString(successStyle.Render(fmt.Sprintf("✓ Found %d matching pod(s):\n", len(wizard.matchingPods))))
showCount := 0
for _, pod := range wizard.matchingPods {
if showCount < 3 {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", pod.Name)))
showCount++
}
}
if len(wizard.matchingPods) > 3 {
b.WriteString(mutedStyle.Render(fmt.Sprintf(" ... and %d more\n", len(wizard.matchingPods)-3)))
}
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ Invalid selector: %v", wizard.error)))
}
case ResourceTypeService:
b.WriteString("Select service:\n\n")
// Show search input if there's a filter active
if wizard.searchFilter != "" {
b.WriteString(renderTextInput("Filter: ", wizard.searchFilter, true))
b.WriteString("\n\n")
}
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Loading services..."))
} else if len(wizard.services) == 0 {
b.WriteString(mutedStyle.Render("No services found"))
} else {
filteredServices := wizard.getFilteredServices()
if len(filteredServices) == 0 {
b.WriteString(mutedStyle.Render("No matching services"))
} else {
serviceNames := make([]string, len(filteredServices))
for i, svc := range filteredServices {
serviceNames[i] = svc.Name
}
b.WriteString(renderList(serviceNames, wizard.cursor, " ", wizard.scrollOffset))
}
}
}
b.WriteString("\n")
// Show appropriate help text based on resource type and filter state
if wizard.selectedResourceType == ResourceTypeService {
if wizard.searchFilter != "" {
b.WriteString(helpStyle.Render(fmt.Sprintf("↑/↓: Navigate Enter: Select Backspace: Clear filter (%d/%d) Esc: Back", len(wizard.getFilteredServices()), len(wizard.services))))
} else {
b.WriteString(helpStyle.Render("Type to filter ↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
}
} else {
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderEnterRemotePort() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(5, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n")
// Show resource selection
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("%s [%s]", wizard.resourceValue, wizard.selector)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf("Resource: %s\n\n", resourceInfo)))
// If we have detected ports and in list mode, show them as a list
if len(wizard.detectedPorts) > 0 && wizard.inputMode == InputModeList {
b.WriteString("Select remote port:\n\n")
const viewportHeight = 20
totalItems := len(wizard.detectedPorts) + 1 // +1 for manual entry option
// Show scroll up indicator if there are items above the viewport
if wizard.scrollOffset > 0 {
b.WriteString(mutedStyle.Render(" ↑ More above ↑") + "\n")
}
// Calculate visible range
start := wizard.scrollOffset
end := wizard.scrollOffset + viewportHeight
if end > totalItems {
end = totalItems
}
// Render detected ports within viewport
for i := start; i < end && i < len(wizard.detectedPorts); i++ {
port := wizard.detectedPorts[i]
portDesc := fmt.Sprintf("%d", port.Port)
if port.Name != "" {
portDesc += fmt.Sprintf(" (%s)", port.Name)
}
prefix := " "
if i == wizard.cursor {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + portDesc))
} else {
b.WriteString(prefix + portDesc)
}
b.WriteString("\n")
}
// Add "Manual entry" option if within viewport
manualIdx := len(wizard.detectedPorts)
if manualIdx >= start && manualIdx < end {
manualOption := "Manual entry (type port number)"
prefix := " "
if wizard.cursor == manualIdx {
prefix = "▸ "
b.WriteString(selectedStyle.Render(prefix + manualOption))
} else {
b.WriteString(prefix + mutedStyle.Render(manualOption))
}
b.WriteString("\n")
}
// Show scroll down indicator if there are items below the viewport
if end < totalItems {
b.WriteString(mutedStyle.Render(" ↓ More below ↓") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select Esc: Back Ctrl+C: Cancel"))
} else {
// Text input mode (no detected ports or user chose manual entry)
if len(wizard.detectedPorts) > 0 {
b.WriteString(mutedStyle.Render("Detected ports:\n"))
for _, port := range wizard.detectedPorts {
portDesc := fmt.Sprintf("%d", port.Port)
if port.Name != "" {
portDesc += fmt.Sprintf(" (%s)", port.Name)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf(" • %s\n", portDesc)))
}
b.WriteString("\n")
}
b.WriteString(renderTextInput("Remote port: ", wizard.textInput, wizard.error == nil))
b.WriteString("\n\n")
if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ %v", wizard.error)))
} else if wizard.textInput != "" {
b.WriteString(mutedStyle.Render("Press Enter to continue"))
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
}
return b.String()
}
func (m model) renderEnterLocalPort() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(6, 7)))
b.WriteString(renderBreadcrumb(wizard.selectedContext, wizard.selectedNamespace))
b.WriteString("\n")
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("%s [%s]", wizard.resourceValue, wizard.selector)
}
b.WriteString(mutedStyle.Render(fmt.Sprintf("Resource: %s\n", resourceInfo)))
b.WriteString(mutedStyle.Render(fmt.Sprintf("Remote port: %d\n\n", wizard.remotePort)))
b.WriteString(renderTextInput("Local port: ", wizard.textInput, wizard.error == nil))
b.WriteString("\n\n")
if wizard.loading {
b.WriteString(spinnerStyle.Render("⣾ Checking availability..."))
} else if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("✗ %v", wizard.error)))
} else if wizard.portCheckMsg != "" {
if wizard.portAvailable {
b.WriteString(successStyle.Render(wizard.portCheckMsg))
} else {
b.WriteString(errorStyle.Render(wizard.portCheckMsg))
}
} else if wizard.textInput != "" {
b.WriteString(mutedStyle.Render("Press Enter to check availability"))
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("Enter: Continue Esc: Back Ctrl+C: Cancel"))
return b.String()
}
func (m model) renderConfirmation() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(renderHeader("Add Port Forward", renderProgress(7, 7)))
b.WriteString("\n")
b.WriteString("Review Configuration:\n\n")
resourceInfo := wizard.resourceValue
if wizard.selector != "" {
resourceInfo = fmt.Sprintf("pod (selector: %s)", wizard.selector)
} else if wizard.selectedResourceType == ResourceTypePodPrefix {
resourceInfo = fmt.Sprintf("pod/%s", wizard.resourceValue)
} else if wizard.selectedResourceType == ResourceTypeService {
resourceInfo = fmt.Sprintf("service/%s", wizard.resourceValue)
}
b.WriteString(fmt.Sprintf(" Context: %s\n", wizard.selectedContext))
b.WriteString(fmt.Sprintf(" Namespace: %s\n", wizard.selectedNamespace))
b.WriteString(fmt.Sprintf(" Resource: %s\n", resourceInfo))
b.WriteString(fmt.Sprintf(" Remote Port: %d\n", wizard.remotePort))
b.WriteString(fmt.Sprintf(" Local Port: %d\n", wizard.localPort))
b.WriteString(" Protocol: tcp\n")
b.WriteString("\n")
// Show alias field with focus indicator
if wizard.confirmationFocus == FocusAlias {
b.WriteString(selectedStyle.Render("▸ Optional alias (friendly name):") + "\n")
b.WriteString(" Alias: " + validInputStyle.Render(wizard.textInput+"█") + "\n")
} else {
b.WriteString(mutedStyle.Render(" Optional alias (friendly name):") + "\n")
b.WriteString(mutedStyle.Render(" Alias: "+wizard.textInput) + "\n")
}
b.WriteString("\n")
// Show buttons with focus indicator
if wizard.confirmationFocus == FocusButtons {
if wizard.cursor == 0 {
b.WriteString(selectedStyle.Render("▸ Add to .kportal.yaml") + "\n")
b.WriteString(" Cancel\n")
} else {
b.WriteString(" Add to .kportal.yaml\n")
b.WriteString(selectedStyle.Render("▸ Cancel") + "\n")
}
} else {
b.WriteString(mutedStyle.Render(" Add to .kportal.yaml") + "\n")
b.WriteString(mutedStyle.Render(" Cancel") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓/Tab: Navigate Enter: Confirm Esc: Back"))
return b.String()
}
func (m model) renderSuccess() string {
wizard := m.ui.addWizard
var b strings.Builder
b.WriteString(successStyle.Render("Success! ✓"))
b.WriteString("\n\n")
if wizard.error != nil {
b.WriteString(errorStyle.Render(fmt.Sprintf("Error: %v", wizard.error)))
} else {
b.WriteString("Added to .kportal.yaml\n\n")
forwardDesc := fmt.Sprintf("localhost:%d → %s:%d",
wizard.localPort,
wizard.resourceValue,
wizard.remotePort)
if wizard.alias != "" {
forwardDesc = fmt.Sprintf("%s (%s)", wizard.alias, forwardDesc)
}
b.WriteString(successStyle.Render(forwardDesc))
b.WriteString("\n\n")
b.WriteString(mutedStyle.Render("The port forward will be active shortly."))
}
b.WriteString("\n\n")
b.WriteString("Would you like to:\n")
if wizard.cursor == 0 {
b.WriteString(selectedStyle.Render("▸ Add another port forward") + "\n")
b.WriteString(" Return to main view\n")
} else {
b.WriteString(" Add another port forward\n")
b.WriteString(selectedStyle.Render("▸ Return to main view") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Select"))
return b.String()
}
// renderRemoveWizard renders the remove wizard
func (m model) renderRemoveWizard() string {
if m.ui.removeWizard == nil {
return ""
}
wizard := m.ui.removeWizard
var content string
if wizard.confirming {
content = m.renderRemoveConfirmation()
} else {
content = m.renderRemoveSelection()
}
return wizardBoxStyle.Render(content)
}
func (m model) renderRemoveSelection() string {
wizard := m.ui.removeWizard
var b strings.Builder
b.WriteString(renderHeader("Remove Port Forwards", ""))
b.WriteString("\n")
b.WriteString("Select forwards to remove (Space to toggle):\n\n")
for i, fwd := range wizard.forwards {
isSelected := i == wizard.cursor
isChecked := wizard.selected[i]
line1 := fmt.Sprintf("%s:%d→%d", fwd.Alias, fwd.Port, fwd.LocalPort)
line2 := fmt.Sprintf(" %s/%s/%s", fwd.Context, fwd.Namespace, fwd.Resource)
checkbox := "[ ] "
if isChecked {
checkbox = "[✓] "
}
fullLine := checkbox + line1
if isSelected {
b.WriteString(selectedStyle.Render(fullLine))
} else {
if isChecked {
b.WriteString(checkedBoxStyle.Render(checkbox) + line1)
} else {
b.WriteString(uncheckedBoxStyle.Render(checkbox) + line1)
}
}
b.WriteString("\n")
b.WriteString(mutedStyle.Render(line2))
b.WriteString("\n\n")
}
selectedCount := wizard.getSelectedCount()
b.WriteString(fmt.Sprintf("%d of %d selected\n\n", selectedCount, len(wizard.forwards)))
b.WriteString(helpStyle.Render("Space: Toggle a: All n: None Enter: Remove Esc: Cancel"))
return b.String()
}
func (m model) renderRemoveConfirmation() string {
wizard := m.ui.removeWizard
var b strings.Builder
b.WriteString(renderHeader("Confirm Removal", ""))
b.WriteString("\n")
selectedCount := wizard.getSelectedCount()
b.WriteString(fmt.Sprintf("Remove %d port forward(s)?\n\n", selectedCount))
selectedForwards := wizard.getSelectedForwards()
for _, fwd := range selectedForwards {
b.WriteString(errorStyle.Render(fmt.Sprintf(" • %s:%d→%d\n", fwd.Alias, fwd.Port, fwd.LocalPort)))
b.WriteString(mutedStyle.Render(fmt.Sprintf(" %s/%s/%s\n", fwd.Context, fwd.Namespace, fwd.Resource)))
}
b.WriteString("\n")
b.WriteString(warningStyle.Render("This action cannot be undone."))
b.WriteString("\n\n")
// Yes/No buttons
if wizard.confirmCursor == 0 {
b.WriteString(selectedStyle.Render("▸ Yes, remove them") + "\n")
b.WriteString(" Cancel\n")
} else {
b.WriteString(" Yes, remove them\n")
b.WriteString(selectedStyle.Render("▸ Cancel") + "\n")
}
b.WriteString("\n")
b.WriteString(helpStyle.Render("↑/↓: Navigate Enter: Confirm Esc: Cancel"))
return b.String()
}
+158
View File
@@ -0,0 +1,158 @@
package version
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
const (
// GitHubAPIURL is the GitHub API endpoint for releases
githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest"
// requestTimeout is the timeout for HTTP requests
requestTimeout = 5 * time.Second
)
// ReleaseInfo contains information about a GitHub release
type ReleaseInfo struct {
TagName string `json:"tag_name"`
HTMLURL string `json:"html_url"`
Name string `json:"name"`
}
// UpdateInfo contains information about an available update
type UpdateInfo struct {
CurrentVersion string
LatestVersion string
ReleaseURL string
ReleaseName string
}
// Checker checks for new versions on GitHub
type Checker struct {
owner string
repo string
current string
client *http.Client
}
// NewChecker creates a new version checker
func NewChecker(owner, repo, currentVersion string) *Checker {
return &Checker{
owner: owner,
repo: repo,
current: normalizeVersion(currentVersion),
client: &http.Client{
Timeout: requestTimeout,
},
}
}
// CheckForUpdate checks if a newer version is available.
// Returns nil if current version is up to date or if check fails.
// This is designed to fail silently - network errors should not impact the user.
func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo {
release, err := c.fetchLatestRelease(ctx)
if err != nil {
return nil
}
latestVersion := normalizeVersion(release.TagName)
if isNewerVersion(latestVersion, c.current) {
return &UpdateInfo{
CurrentVersion: c.current,
LatestVersion: latestVersion,
ReleaseURL: release.HTMLURL,
ReleaseName: release.Name,
}
}
return nil
}
// fetchLatestRelease fetches the latest release info from GitHub API
func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) {
url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "kportal-version-checker")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
}
var release ReleaseInfo
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
// normalizeVersion removes 'v' or 'V' prefix and trims whitespace
func normalizeVersion(v string) string {
v = strings.TrimSpace(v)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
return v
}
// isNewerVersion compares two semver-like versions.
// Returns true if latest is newer than current.
func isNewerVersion(latest, current string) bool {
latestParts := parseVersion(latest)
currentParts := parseVersion(current)
// Compare each part
for i := 0; i < len(latestParts) && i < len(currentParts); i++ {
if latestParts[i] > currentParts[i] {
return true
}
if latestParts[i] < currentParts[i] {
return false
}
}
// If all compared parts are equal, longer version is newer
// e.g., 1.0.1 > 1.0
return len(latestParts) > len(currentParts)
}
// parseVersion splits a version string into numeric parts
func parseVersion(v string) []int {
// Remove any suffix like -beta, -rc1, etc.
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.Split(v, ".")
result := make([]int, 0, len(parts))
for _, p := range parts {
var num int
fmt.Sscanf(p, "%d", &num)
result = append(result, num)
}
return result
}
// FormatUpdateMessage formats a user-friendly update notification
func (u *UpdateInfo) FormatUpdateMessage() string {
return fmt.Sprintf("New version available: %s (current: %s) - %s",
u.LatestVersion, u.CurrentVersion, u.ReleaseURL)
}
+90
View File
@@ -0,0 +1,90 @@
package version
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNormalizeVersion(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"v1.0.0", "1.0.0"},
{"1.0.0", "1.0.0"},
{" v2.1.3 ", "2.1.3"},
{"V1.0.0", "1.0.0"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseVersion(t *testing.T) {
tests := []struct {
input string
expected []int
}{
{"1.0.0", []int{1, 0, 0}},
{"2.1.3", []int{2, 1, 3}},
{"1.0", []int{1, 0}},
{"10.20.30", []int{10, 20, 30}},
{"1.0.0-beta", []int{1, 0, 0}},
{"1.0.0-rc1", []int{1, 0, 0}},
{"1.0.0+build123", []int{1, 0, 0}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := parseVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
latest string
current string
expected bool
}{
{"major version bump", "2.0.0", "1.0.0", true},
{"minor version bump", "1.1.0", "1.0.0", true},
{"patch version bump", "1.0.1", "1.0.0", true},
{"same version", "1.0.0", "1.0.0", false},
{"current is newer major", "1.0.0", "2.0.0", false},
{"current is newer minor", "1.0.0", "1.1.0", false},
{"current is newer patch", "1.0.0", "1.0.1", false},
{"multi-digit versions", "1.10.0", "1.9.0", true},
{"longer version is newer", "1.0.1", "1.0", true},
{"shorter version is older", "1.0", "1.0.1", false},
{"complex comparison", "2.1.3", "2.1.2", true},
{"real world example", "0.2.0", "0.1.0", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNewerVersion(tt.latest, tt.current)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUpdateInfo_FormatUpdateMessage(t *testing.T) {
info := &UpdateInfo{
CurrentVersion: "0.1.0",
LatestVersion: "0.2.0",
ReleaseURL: "https://github.com/nvm/kportal/releases/tag/v0.2.0",
}
msg := info.FormatUpdateMessage()
assert.Contains(t, msg, "0.2.0")
assert.Contains(t, msg, "0.1.0")
assert.Contains(t, msg, "https://github.com/nvm/kportal/releases/tag/v0.2.0")
}
+1 -2
View File
@@ -2,7 +2,7 @@ version: 1
force:
major: 0
minor: 1
minor: 2
patch: 0
blacklist:
@@ -19,5 +19,4 @@ wording:
- "major"
- "BREAKING CHANGE"
release:
- "rc"
- "release candidate"