commit 22552aec99c17c48e9df2ddddde8fe590cbb032a Author: Lukasz Raczylo Date: Fri Nov 28 02:50:25 2025 +0000 Initial commit. diff --git a/.github/workflows/autoupdate.yaml b/.github/workflows/autoupdate.yaml new file mode 100644 index 0000000..0937faa --- /dev/null +++ b/.github/workflows/autoupdate.yaml @@ -0,0 +1,47 @@ +name: AutoUpdate + +on: + schedule: + - cron: "0 3 * * *" + workflow_dispatch: + +jobs: + prepare: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.24" + cache: true + + - name: Install dependencies + run: go get ./... + + test: + needs: prepare + runs-on: ubuntu-latest + container: golang:1 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install dependencies + run: apt-get update && apt-get install -y ca-certificates make + + - name: Tidy and update modules + run: | + go mod tidy + go get -u -v ./... + + - name: Run tests + run: CI_RUN=${CI} go test -v ./... + + - name: Commit changes + uses: stefanzweifel/git-auto-commit-action@v5 + with: + commit_message: "Update go.mod and go.sum" + file_pattern: "go.mod go.sum" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..0b1a6ea --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,83 @@ +name: Release + +on: + push: + branches: + - main + paths: + - "**.go" + - "go.mod" + - "go.sum" + workflow_dispatch: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.24" + + - name: Run tests + run: go test -race -v ./... + + version: + needs: test + runs-on: ubuntu-latest + outputs: + version: ${{ steps.semver.outputs.version }} + version_tag: ${{ steps.semver.outputs.version_tag }} + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Calculate version + id: semver + uses: lukaszraczylo/semver-generator@v1 + with: + config_file: semver.yaml + repository_local: true + + - name: Print version + run: | + echo "Version: ${{ steps.semver.outputs.version }}" + echo "Version tag: ${{ steps.semver.outputs.version_tag }}" + + release: + needs: version + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version: "1.24" + + - name: Create and push tag + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git tag -a ${{ needs.version.outputs.version_tag }} -m "Release ${{ needs.version.outputs.version }}" + git push origin ${{ needs.version.outputs.version_tag }} + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + with: + distribution: goreleaser + version: "~> v2" + args: release --clean + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + HOMEBREW_TAP_TOKEN: ${{ secrets.HOMEBREW_TAP_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2f3c7e2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +CLAUDE.md +build diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..953a22a --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,62 @@ +version: 2 + +before: + hooks: + - go mod tidy + +builds: + - id: lolcathost + main: ./cmd/lolcathost + binary: lolcathost + env: + - CGO_ENABLED=0 + goos: + - linux + - darwin + goarch: + - amd64 + - arm64 + ldflags: + - -s -w + - -X main.appVersion={{.Version}} + +archives: + - id: lolcathost + format: tar.gz + name_template: "lolcathost-{{ .Version }}-{{ .Os }}-{{ .Arch }}" + files: + - LICENSE + - README.md + +checksum: + name_template: "lolcathost-{{ .Version }}-checksums.txt" + algorithm: sha256 + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" + - "^Merge" + - "^WIP" + +release: + github: + owner: lukaszraczylo + name: lolcathost + name_template: "Release {{.Version}}" + draft: false + prerelease: auto + +brews: + - repository: + owner: lukaszraczylo + name: brew-taps + token: "{{ .Env.HOMEBREW_TAP_TOKEN }}" + directory: Formula + homepage: https://github.com/lukaszraczylo/lolcathost + description: "Dynamic host management tool for macOS and Linux with TUI" + license: MIT + test: | + system "#{bin}/lolcathost", "--version" diff --git a/Formula/lolcathost.rb b/Formula/lolcathost.rb new file mode 100644 index 0000000..e23b1be --- /dev/null +++ b/Formula/lolcathost.rb @@ -0,0 +1,57 @@ +class Lolcathost < Formula + desc "Dynamic host management tool for macOS and Linux with TUI" + homepage "https://github.com/lukaszraczylo/lolcathost" + license "MIT" + + version "0.1.0" + + on_macos do + on_arm do + url "https://github.com/lukaszraczylo/lolcathost/releases/download/v#{version}/lolcathost-#{version}-darwin-arm64.tar.gz" + sha256 "PLACEHOLDER_SHA256_DARWIN_ARM64" + end + + on_intel do + url "https://github.com/lukaszraczylo/lolcathost/releases/download/v#{version}/lolcathost-#{version}-darwin-amd64.tar.gz" + sha256 "PLACEHOLDER_SHA256_DARWIN_AMD64" + end + end + + on_linux do + on_arm do + url "https://github.com/lukaszraczylo/lolcathost/releases/download/v#{version}/lolcathost-#{version}-linux-arm64.tar.gz" + sha256 "PLACEHOLDER_SHA256_LINUX_ARM64" + end + + on_intel do + url "https://github.com/lukaszraczylo/lolcathost/releases/download/v#{version}/lolcathost-#{version}-linux-amd64.tar.gz" + sha256 "PLACEHOLDER_SHA256_LINUX_AMD64" + end + end + + def install + bin.install "lolcathost" + end + + def caveats + <<~EOS + lolcathost requires root access for the daemon to modify /etc/hosts. + + After installation: + 1. Run: sudo lolcathost --install + This will install the LaunchDaemon (macOS) or systemd service (Linux) + + 2. Create a config file at ~/.config/lolcathost/config.yaml + + 3. Run: lolcathost + This launches the TUI for managing host entries + + For more information: + https://github.com/lukaszraczylo/lolcathost + EOS + end + + test do + assert_match version.to_s, shell_output("#{bin}/lolcathost --version") + end +end diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..3e48e96 --- /dev/null +++ b/Makefile @@ -0,0 +1,120 @@ +.PHONY: build test lint vet staticcheck clean install uninstall fmt + +# Build variables +BINARY_NAME=lolcathost +VERSION?=1.0.0 +BUILD_DIR=./build +LDFLAGS=-ldflags "-s -w -X main.appVersion=$(VERSION)" + +# Go commands +GOCMD=go +GOBUILD=$(GOCMD) build +GOTEST=$(GOCMD) test +GOVET=$(GOCMD) vet +GOFMT=$(GOCMD) fmt +GOMOD=$(GOCMD) mod + +# Default target +all: lint test build + +# Build the binary +build: + $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME) ./cmd/lolcathost + +# Build for all platforms +build-all: build-darwin-arm64 build-darwin-amd64 build-linux-arm64 build-linux-amd64 + +build-darwin-arm64: + GOOS=darwin GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./cmd/lolcathost + +build-darwin-amd64: + GOOS=darwin GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 ./cmd/lolcathost + +build-linux-arm64: + GOOS=linux GOARCH=arm64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./cmd/lolcathost + +build-linux-amd64: + GOOS=linux GOARCH=amd64 $(GOBUILD) $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./cmd/lolcathost + +# Run tests +test: + $(GOTEST) -v ./... + +# Run tests with coverage +test-coverage: + $(GOTEST) -v -coverprofile=coverage.out ./... + $(GOCMD) tool cover -html=coverage.out -o coverage.html + +# Run single test +test-run: + $(GOTEST) -v -run $(TEST) ./... + +# Run benchmarks +bench: + $(GOTEST) -bench=. -benchmem ./... + +# Linting +lint: vet staticcheck + +vet: + $(GOVET) ./... + +staticcheck: + @command -v staticcheck >/dev/null 2>&1 || { echo "Installing staticcheck..."; go install honnef.co/go/tools/cmd/staticcheck@latest; } + staticcheck ./... + +# Format code +fmt: + $(GOFMT) ./... + +# Tidy dependencies +tidy: + $(GOMOD) tidy + +# Clean build artifacts +clean: + rm -rf $(BUILD_DIR) + rm -f coverage.out coverage.html + +# Install locally (for development) +install: build + sudo cp $(BUILD_DIR)/$(BINARY_NAME) /usr/local/bin/ + @echo "Installed to /usr/local/bin/$(BINARY_NAME)" + @echo "Run 'sudo lolcathost --install' to set up the daemon" + +# Uninstall +uninstall: + sudo rm -f /usr/local/bin/$(BINARY_NAME) + @echo "Removed /usr/local/bin/$(BINARY_NAME)" + @echo "Note: Run 'sudo lolcathost --uninstall' first to remove the daemon" + +# Development helpers +dev: fmt lint test build + +# Run the TUI (requires daemon to be installed) +run: build + $(BUILD_DIR)/$(BINARY_NAME) + +# Run as daemon (requires sudo) +run-daemon: build + sudo $(BUILD_DIR)/$(BINARY_NAME) --daemon + +# Show help +help: + @echo "Available targets:" + @echo " all - Lint, test, and build" + @echo " build - Build the binary" + @echo " build-all - Build for all platforms" + @echo " test - Run tests" + @echo " test-coverage - Run tests with coverage report" + @echo " test-run - Run specific test (use TEST=TestName)" + @echo " bench - Run benchmarks" + @echo " lint - Run linters (vet + staticcheck)" + @echo " fmt - Format code" + @echo " tidy - Tidy go.mod" + @echo " clean - Clean build artifacts" + @echo " install - Install binary to /usr/local/bin" + @echo " uninstall - Remove binary from /usr/local/bin" + @echo " dev - Format, lint, test, and build" + @echo " run - Run the TUI" + @echo " run-daemon - Run as daemon (requires sudo)" diff --git a/README.md b/README.md new file mode 100644 index 0000000..b7b02c4 --- /dev/null +++ b/README.md @@ -0,0 +1,332 @@ +

+ lolcathost logo +

+ +

+ lolcathost +

+ +

+ Release + License + Go Report Card +

+ +

+ Dynamic hosts file manager with interactive terminal UI +

+ +lolcathost manages your `/etc/hosts` file with an interactive terminal interface. It provides real-time management, automatic backups, group organization, presets, and a secure daemon-based architecture. + +## Features + +- **Interactive TUI** - Terminal interface with keyboard navigation +- **Live management** - Add, edit, and delete host entries without restarting +- **Groups** - Organize hosts into logical groups +- **Presets** - Save and apply preset configurations with a single command +- **Auto-backup** - Automatic backups before every change with rollback support +- **Secure daemon** - Privileged daemon handles file access via Unix socket IPC +- **Domain blocking** - Configurable blocklist to prevent dangerous entries +- **Cross-platform** - Works on macOS (LaunchDaemon) and Linux (systemd) +- **CLI & TUI** - Both command-line and interactive modes for flexibility +- **Auto-update check** - Notifies you when a new version is available + +## Comparison with Other Tools + +| Feature | lolcathost | [HostsMan](https://hostsfileman.github.io/) | [Gas Mask](https://github.com/2ndalpha/gasmask) | Manual editing | +|---------|------------|---------------------------------------------|------------------------------------------------|----------------| +| **Platform** | macOS/Linux | Windows | macOS only | All | +| **Interface** | Terminal TUI | Desktop GUI | Desktop GUI | Text editor | +| **Daemon architecture** | Yes (secure) | No | No | N/A | +| **Real-time sync** | Yes | No | Manual | Manual | +| **Groups** | Yes | Yes | Yes | Manual | +| **Presets** | Yes | Yes | Yes | No | +| **Auto-backup** | 10 rolling | Manual | Manual | No | +| **Rollback** | Yes | No | No | No | +| **CLI automation** | Yes | Limited | No | Yes | +| **Rate limiting** | Yes | No | No | N/A | +| **Domain blocking** | Yes | No | No | No | +| **Auto-update check** | Yes | No | No | N/A | + +## Installation + +### Homebrew (macOS/Linux) + +```bash +brew install lukaszraczylo/brew-taps/lolcathost +``` + +After Homebrew installation, run: + +```bash +sudo lolcathost --install +``` + +### Quick Install + +```bash +curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/lolcathost/main/install.sh | bash +``` + +### Manual Download + +Download binaries from the [releases page](https://github.com/lukaszraczylo/lolcathost/releases). + +### Build from Source + +```bash +git clone https://github.com/lukaszraczylo/lolcathost.git +cd lolcathost +make build +sudo ./build/lolcathost --install +``` + +### Post-Installation + +The installer will: +- Install the binary to `/usr/local/bin/lolcathost` +- Create a LaunchDaemon (macOS) or systemd service (Linux) +- Start the daemon automatically +- Create the default config at `~/.config/lolcathost/config.yaml` + +## Quick Start + +After installation, open a **new terminal** and run: + +```bash +lolcathost +``` + +### Keyboard Controls + +| Key | Action | +|-----|--------| +| `↑↓` / `j/k` | Navigate entries | +| `Space` / `Enter` | Toggle entry enabled/disabled | +| `n` | Add new host entry | +| `e` | Edit selected entry | +| `d` | Delete selected entry | +| `p` | Open preset picker | +| `g` | Open group manager | +| `/` | Search | +| `r` | Refresh list | +| `?` | Show help | +| `q` | Quit | + +## Configuration + +### Config File Location + +Default: `~/.config/lolcathost/config.yaml` + +### Example Configuration + +```yaml +# Groups for organizing host entries +groups: + - name: development + hosts: + - domain: myapp.local + ip: 127.0.0.1 + enabled: true + - domain: api.myapp.local + ip: 127.0.0.1 + enabled: true + + - name: staging + hosts: + - domain: staging.example.com + ip: 192.168.1.100 + enabled: false + +# Presets for quick configuration switching +presets: + - name: work + enable: + - myapp-local + - api-myapp-local + disable: + - staging-example-com + + - name: testing + enable: + - staging-example-com + disable: + - myapp-local + +# Domain blocklist (prevent adding these domains) +blocklist: + - google.com + - facebook.com + - github.com +``` + +### Host Entry Fields + +| Field | Required | Description | +|-------|----------|-------------| +| `domain` | Yes | The hostname (e.g., myapp.local) | +| `ip` | Yes | IP address to resolve to | +| `enabled` | No | Whether entry is active (default: false) | + +Note: Aliases are auto-generated from domain names (e.g., `myapp.local` becomes `myapp-local`). + +## CLI Commands + +```bash +lolcathost # Launch TUI +lolcathost list # List all entries +lolcathost on # Enable entry +lolcathost off # Disable entry +lolcathost preset # Apply preset +lolcathost status # Show daemon status +``` + +### Version & Updates + +```bash +lolcathost --version # Show current version +lolcathost --update # Check for updates +``` + +### Installation Commands + +```bash +sudo lolcathost --install # Install daemon +sudo lolcathost --uninstall # Uninstall daemon +``` + +## Status Indicators + +| Indicator | Description | +|-----------|-------------| +| `● Active` | Entry is enabled and in /etc/hosts | +| `○ Disabled` | Entry is disabled | +| `◐ Pending` | Operation in progress | +| `✗ Error` | Operation failed | + +## Architecture + +lolcathost uses a daemon-based architecture for security: + +``` +┌─────────────────┐ ┌─────────────────────┐ +│ lolcathost │ JSON │ Daemon │ +│ CLI / TUI │◄───────►│ (runs as root) │ +│ (runs as user) │ Unix │ │ +└─────────────────┘ Socket └──────────┬──────────┘ + │ + ┌────────▼────────┐ + │ /etc/hosts │ + └─────────────────┘ +``` + +**Daemon** (runs as root): +- Handles `/etc/hosts` modifications +- Creates automatic backups (10 rolling) +- Validates inputs (domain, IP) +- Rate limiting protection (100 req/min per PID) +- Flushes DNS cache automatically + +**Client** (CLI/TUI, runs as user): +- Connects via Unix socket +- JSON protocol for commands +- No sudo required for operations +- Real-time status updates + +Socket: `/var/run/lolcathost.sock` +Backups: `/var/backups/lolcathost/` + +## Troubleshooting + +### "daemon not running (socket not found)" + +The daemon isn't running. Install or reinstall: + +```bash +sudo lolcathost --uninstall +sudo lolcathost --install +``` + +Then open a **new terminal** for group membership to take effect. + +### Check Daemon Status + +```bash +# macOS +sudo launchctl list | grep lolcathost + +# Linux +sudo systemctl status lolcathost +``` + +### View Daemon Logs + +```bash +# macOS/Linux +cat /var/log/lolcathost/daemon.log +cat /var/log/lolcathost/daemon.err +``` + +### DNS Cache Not Flushing + +lolcathost automatically flushes the DNS cache after changes: + +- **macOS**: Uses `dscacheutil -flushcache` and `killall -HUP mDNSResponder` +- **Linux**: Uses `systemd-resolve --flush-caches` or `nscd -i hosts` + +If changes don't take effect, manually flush: + +```bash +# macOS +sudo dscacheutil -flushcache && sudo killall -HUP mDNSResponder + +# Linux (systemd) +sudo systemd-resolve --flush-caches +``` + +## Development + +### Prerequisites + +- Go 1.24+ +- macOS or Linux + +### Build + +```bash +make build # Build binary +make test # Run tests +make test-coverage # Tests with coverage +make lint # Run linters +make dev # Format, lint, test, build +``` + +### Project Structure + +``` +cmd/lolcathost/ - Entry point, CLI commands +internal/ + protocol/ - JSON message types (Unix socket) + config/ - YAML config parsing, hot-reload + daemon/ - Socket server, /etc/hosts management + client/ - Socket client library + installer/ - --install/--uninstall logic + tui/ - Bubble Tea TUI + version/ - Update checker +``` + +## License + +MIT License - see [LICENSE](LICENSE). + +## Acknowledgments + +- [Bubble Tea](https://github.com/charmbracelet/bubbletea) - Terminal UI framework +- [Lipgloss](https://github.com/charmbracelet/lipgloss) - Terminal styling + +## Links + +- [Website](https://lukaszraczylo.github.io/lolcathost) +- [Issues](https://github.com/lukaszraczylo/lolcathost/issues) +- [Releases](https://github.com/lukaszraczylo/lolcathost/releases) diff --git a/cmd/lolcathost/main.go b/cmd/lolcathost/main.go new file mode 100644 index 0000000..89a71e3 --- /dev/null +++ b/cmd/lolcathost/main.go @@ -0,0 +1,307 @@ +// Package main provides the entry point for the lolcathost application. +package main + +import ( + "context" + "flag" + "fmt" + "os" + "text/tabwriter" + "time" + + "github.com/lukaszraczylo/lolcathost/internal/client" + "github.com/lukaszraczylo/lolcathost/internal/config" + "github.com/lukaszraczylo/lolcathost/internal/daemon" + "github.com/lukaszraczylo/lolcathost/internal/installer" + "github.com/lukaszraczylo/lolcathost/internal/protocol" + "github.com/lukaszraczylo/lolcathost/internal/tui" + "github.com/lukaszraczylo/lolcathost/internal/version" +) + +// version is set at compile time via ldflags +var appVersion = "dev" + +const ( + githubOwner = "lukaszraczylo" + githubRepo = "lolcathost" +) + +func main() { + // Flags + daemonMode := flag.Bool("daemon", false, "Run as daemon (called by LaunchDaemon/systemd)") + installFlag := flag.Bool("install", false, "Install the daemon service (requires sudo)") + uninstallFlag := flag.Bool("uninstall", false, "Uninstall the daemon service (requires sudo)") + versionFlag := flag.Bool("version", false, "Show version") + updateFlag := flag.Bool("update", false, "Check for updates") + configPath := flag.String("config", config.DefaultConfigPath(), "Path to config file") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, "lolcathost - Dynamic Host Management\n\n") + fmt.Fprintf(os.Stderr, "Usage:\n") + fmt.Fprintf(os.Stderr, " lolcathost Launch TUI\n") + fmt.Fprintf(os.Stderr, " lolcathost list List all entries\n") + fmt.Fprintf(os.Stderr, " lolcathost on Enable entry\n") + fmt.Fprintf(os.Stderr, " lolcathost off Disable entry\n") + fmt.Fprintf(os.Stderr, " lolcathost preset Apply preset\n") + fmt.Fprintf(os.Stderr, " lolcathost status Show daemon status\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "Installation:\n") + fmt.Fprintf(os.Stderr, " sudo lolcathost --install Install daemon\n") + fmt.Fprintf(os.Stderr, " sudo lolcathost --uninstall Uninstall daemon\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "Options:\n") + flag.PrintDefaults() + } + + flag.Parse() + + // Version + if *versionFlag { + fmt.Printf("lolcathost version %s\n", appVersion) + os.Exit(0) + } + + // Update check + if *updateFlag { + checkForUpdates() + os.Exit(0) + } + + // Install/Uninstall + if *installFlag { + runInstall() + return + } + + if *uninstallFlag { + runUninstall() + return + } + + // Daemon mode + if *daemonMode { + runDaemon(*configPath) + return + } + + // Parse subcommand + args := flag.Args() + if len(args) == 0 { + // No subcommand - launch TUI + runTUI(*configPath) + return + } + + // Handle subcommands + switch args[0] { + case "list": + runList() + case "on": + if len(args) < 2 { + fmt.Fprintln(os.Stderr, "Usage: lolcathost on ") + os.Exit(1) + } + runOn(args[1]) + case "off": + if len(args) < 2 { + fmt.Fprintln(os.Stderr, "Usage: lolcathost off ") + os.Exit(1) + } + runOff(args[1]) + case "preset": + if len(args) < 2 { + fmt.Fprintln(os.Stderr, "Usage: lolcathost preset ") + os.Exit(1) + } + runPreset(args[1]) + case "status": + runStatus() + default: + fmt.Fprintf(os.Stderr, "Unknown command: %s\n", args[0]) + flag.Usage() + os.Exit(1) + } +} + +func runInstall() { + inst, err := installer.New() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if err := inst.Install(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func runUninstall() { + inst, err := installer.New() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if err := inst.Uninstall(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func runDaemon(configPath string) { + daemon.Version = appVersion + d, err := daemon.New(configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create daemon: %v\n", err) + os.Exit(1) + } + + if err := d.Run(); err != nil { + fmt.Fprintf(os.Stderr, "Daemon error: %v\n", err) + os.Exit(1) + } +} + +func runTUI(configPath string) { + // Check installation + if err := installer.CheckInstallation(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + fmt.Fprintln(os.Stderr, "\nTo install, run: sudo lolcathost --install") + os.Exit(1) + } + + if err := tui.RunWithVersion(protocol.SocketPath, configPath, appVersion, githubOwner, githubRepo); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } +} + +func runList() { + c := connectClient() + defer c.Close() + + entries, err := c.List() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + if len(entries) == 0 { + fmt.Println("No entries configured.") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "STATUS\tDOMAIN\tIP\tALIAS\tGROUP") + fmt.Fprintln(w, "------\t------\t--\t-----\t-----") + + for _, e := range entries { + status := "○" + if e.Enabled { + status = "●" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", status, e.Domain, e.IP, e.Alias, e.Group) + } + + w.Flush() +} + +func runOn(alias string) { + c := connectClient() + defer c.Close() + + data, err := c.Enable(alias) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("✓ Enabled: %s → %s\n", alias, data.Domain) +} + +func runOff(alias string) { + c := connectClient() + defer c.Close() + + data, err := c.Disable(alias) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("✓ Disabled: %s → %s\n", alias, data.Domain) +} + +func runPreset(name string) { + c := connectClient() + defer c.Close() + + if err := c.ApplyPreset(name); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("✓ Applied preset: %s\n", name) +} + +func runStatus() { + c := connectClient() + defer c.Close() + + status, err := c.Status() + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Status: %s\n", greenIf("running", status.Running)) + fmt.Printf("Version: %s\n", status.Version) + fmt.Printf("Uptime: %d seconds\n", status.Uptime) + fmt.Printf("Active entries: %d\n", status.ActiveCount) + fmt.Printf("Total requests: %d\n", status.RequestCount) +} + +func connectClient() *client.Client { + // Check installation first + if err := installer.CheckInstallation(); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + fmt.Fprintln(os.Stderr, "\nTo install, run: sudo lolcathost --install") + os.Exit(1) + } + + c := client.New(protocol.SocketPath) + if err := c.Connect(); err != nil { + fmt.Fprintf(os.Stderr, "Failed to connect to daemon: %v\n", err) + os.Exit(1) + } + + return c +} + +func greenIf(s string, condition bool) string { + if condition { + return "\033[32m" + s + "\033[0m" + } + return "\033[31mnot " + s + "\033[0m" +} + +func checkForUpdates() { + fmt.Printf("lolcathost version %s\n", appVersion) + fmt.Println("Checking for updates...") + + checker := version.NewChecker(githubOwner, githubRepo, appVersion) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + update := checker.CheckForUpdate(ctx) + if update == nil { + fmt.Println("You are running the latest version.") + return + } + + fmt.Printf("\n\033[32mUpdate available: v%s\033[0m\n", update.LatestVersion) + fmt.Printf("Download: %s\n", update.ReleaseURL) + fmt.Println("\nTo update, download the latest release from the URL above") + fmt.Println("or use your package manager (e.g., 'brew upgrade lolcathost').") +} diff --git a/docs/index.html b/docs/index.html new file mode 100644 index 0000000..a1d2a64 --- /dev/null +++ b/docs/index.html @@ -0,0 +1,670 @@ + + + + + + lolcathost - Dynamic Hosts File Manager + + + + + + + + + + + + + + + +
+
+
+
+
+ +
+
+
+ lolcathost logo +
+
+
lolcathost
+
+

+ Dynamic Hosts File
Manager +

+

+ Terminal interface for managing your /etc/hosts file with automatic backups, groups, presets, and a secure daemon architecture. +

+ +
+ Version + License + Go Report +
+
+
+
+ + +
+
+
+

Features

+

Everything you need for managing local host entries

+
+
+
+
+
+ +
+
+

Interactive TUI

+

Beautiful terminal interface with real-time updates and keyboard navigation

+
+
+
+
+
+
+ +
+
+

Live Management

+

Add, edit, delete, and toggle host entries without restarting

+
+
+
+
+
+
+ +
+
+

Groups

+

Organize hosts into groups for better management

+
+
+
+
+
+
+ +
+
+

Presets

+

Save and apply preset configurations with a single command

+
+
+
+
+
+
+ +
+
+

Auto-Backup

+

Automatic backups before every change with rollback support

+
+
+
+
+
+
+ +
+
+

Secure Daemon

+

Privileged daemon handles file access via Unix socket IPC

+
+
+
+
+
+
+ +
+
+

Domain Blocking

+

Configurable domain blocklist to prevent dangerous entries

+
+
+
+
+
+
+ +
+
+

Cross-Platform

+

Works on macOS (LaunchDaemon) and Linux (systemd)

+
+
+
+
+
+
+ +
+
+

CLI & TUI

+

Both command-line and interactive modes for flexibility

+
+
+
+
+
+
+ + +
+
+
+

Installation

+

Get started in under a minute

+
+
+
+

+ + Quick Install +

+
curl -fsSL https://raw.githubusercontent.com/lukaszraczylo/lolcathost/main/install.sh | bash
+
+
+

+ + Build from Source +

+
git clone https://github.com/lukaszraczylo/lolcathost.git
+cd lolcathost
+make build
+sudo ./build/lolcathost --install
+
+
+

+ + Post-Installation +

+

The installer will:

+
    +
  • Install the binary to /usr/local/bin/lolcathost
  • +
  • Create a LaunchDaemon (macOS) or systemd service (Linux)
  • +
  • Start the daemon automatically
  • +
  • Create the default config at ~/.config/lolcathost/config.yaml
  • +
+
+
+
+
+ + +
+
+
+

Usage

+

Simple and intuitive interface

+
+
+
+

Interactive Mode (TUI)

+
lolcathost
+

Keyboard Controls

+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
KeyAction
up/down or j/kNavigate entries
Space or EnterToggle entry enabled/disabled
nAdd new host entry
eEdit selected entry
dDelete selected entry
pOpen preset picker
/Search
rRefresh list
?Show help
qQuit
+
+
+
+

Status Indicators

+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
IndicatorDescription
● ActiveEntry is enabled and in /etc/hosts
○ DisabledEntry is disabled
◔ PendingOperation in progress
✗ ErrorOperation failed
+
+
+
+
+
+ + +
+
+
+

Configuration

+

Flexible YAML-based configuration

+
+
+
+

Config File Location

+

Default: ~/.config/lolcathost/config.yaml

+
+
+

Example Configuration

+
# Groups for organizing host entries
+groups:
+  - name: development
+    hosts:
+      - domain: myapp.local
+        ip: 127.0.0.1
+        enabled: true
+      - domain: api.myapp.local
+        ip: 127.0.0.1
+        enabled: true
+
+  - name: staging
+    hosts:
+      - domain: staging.example.com
+        ip: 192.168.1.100
+        enabled: false
+
+# Presets for quick configuration switching
+presets:
+  - name: work
+    enable:
+      - myapp.local
+      - api.myapp.local
+    disable:
+      - staging.example.com
+
+  - name: testing
+    enable:
+      - staging.example.com
+    disable:
+      - myapp.local
+
+# Domain blocklist (prevent adding these domains)
+blocklist:
+  - google.com
+  - facebook.com
+  - github.com
+
+
+

Host Entry Fields

+
+ + + + + + + + + + + + + + + + + + + + + + + + + +
FieldRequiredDescription
domainYesThe hostname (e.g., myapp.local)
ipYesIP address to resolve to
enabledNoWhether entry is active (default: false)
+
+
+
+
+
+ + +
+
+
+

CLI Commands

+

Scriptable command-line interface

+
+
+
+
+ lolcathost list +
+

List all host entries

+
+
+
+ lolcathost enable <alias> +
+

Enable a host entry by its alias

+
+
+
+ lolcathost disable <alias> +
+

Disable a host entry by its alias

+
+
+
+ lolcathost add -d <domain> -i <ip> -g <group> +
+

Add a new host entry

+
+
+
+ lolcathost delete <alias> +
+

Delete a host entry by its alias

+
+
+
+ lolcathost preset <name> +
+

Apply a named preset

+
+
+
+ lolcathost rollback <backup> +
+

Restore from a backup

+
+
+
+ lolcathost status +
+

Show daemon status

+
+
+
+ sudo lolcathost --install +
+

Install the daemon service

+
+
+
+ sudo lolcathost --uninstall +
+

Uninstall the daemon service

+
+
+
+
+ + +
+
+
+

Architecture

+

Secure daemon-based design

+
+
+
+
+
+

+ + Daemon +

+
    +
  • • Runs as root (LaunchDaemon/systemd)
  • +
  • • Handles /etc/hosts modifications
  • +
  • • Creates automatic backups
  • +
  • • Validates inputs (domain, IP)
  • +
  • • Rate limiting protection
  • +
+
+
+

+ + Client (CLI/TUI) +

+
    +
  • • Runs as regular user
  • +
  • • Connects via Unix socket
  • +
  • • JSON protocol for commands
  • +
  • • No sudo required for operations
  • +
  • • Real-time status updates
  • +
+
+
+
+

+ + Socket: /var/run/lolcathost.sock +

+
+
+
+
+
+ + + + + + + diff --git a/docs/lolcathost.png b/docs/lolcathost.png new file mode 100644 index 0000000..a0f58d0 Binary files /dev/null and b/docs/lolcathost.png differ diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7fa90a3 --- /dev/null +++ b/go.mod @@ -0,0 +1,38 @@ +module github.com/lukaszraczylo/lolcathost + +go 1.24.2 + +require ( + github.com/charmbracelet/bubbles v0.21.0 + github.com/charmbracelet/bubbletea v1.3.10 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.9.0 + github.com/stretchr/testify v1.11.1 + golang.org/x/sys v0.38.0 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + github.com/atotto/clipboard v0.1.4 // indirect + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/colorprofile v0.3.3 // indirect + github.com/charmbracelet/x/ansi v0.11.2 // indirect + github.com/charmbracelet/x/cellbuf v0.0.14 // indirect + github.com/charmbracelet/x/term v0.2.2 // indirect + github.com/clipperhouse/displaywidth v0.6.0 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/lucasb-eyer/go-colorful v1.3.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + golang.org/x/text v0.31.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8a36fb8 --- /dev/null +++ b/go.sum @@ -0,0 +1,64 @@ +github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= +github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs= +github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg= +github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw= +github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4= +github.com/charmbracelet/colorprofile v0.3.3 h1:DjJzJtLP6/NZ8p7Cgjno0CKGr7wwRJGxWUwh2IyhfAI= +github.com/charmbracelet/colorprofile v0.3.3/go.mod h1:nB1FugsAbzq284eJcjfah2nhdSLppN2NqvfotkfRYP4= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.11.2 h1:XAG3FSjiVtFvgEgGrNBkCNNYrsucAt8c6bfxHyROLLs= +github.com/charmbracelet/x/ansi v0.11.2/go.mod h1:9tY2bzX5SiJCU0iWyskjBeI2BRQfvPqI+J760Mjf+Rg= +github.com/charmbracelet/x/cellbuf v0.0.14 h1:iUEMryGyFTelKW3THW4+FfPgi4fkmKnnaLOXuc+/Kj4= +github.com/charmbracelet/x/cellbuf v0.0.14/go.mod h1:P447lJl49ywBbil/KjCk2HexGh4tEY9LH0/1QrZZ9rA= +github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= +github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= +github.com/clipperhouse/displaywidth v0.6.0 h1:k32vueaksef9WIKCNcoqRNyKbyvkvkysNYnAWz2fN4s= +github.com/clipperhouse/displaywidth v0.6.0/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= +github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag= +github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/client/client.go b/internal/client/client.go new file mode 100644 index 0000000..4db93bd --- /dev/null +++ b/internal/client/client.go @@ -0,0 +1,427 @@ +// Package client provides a client library for communicating with the lolcathost daemon. +package client + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "sync" + "time" + + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// Client is a client for the lolcathost daemon. +type Client struct { + socketPath string + conn net.Conn + reader *bufio.Reader + timeout time.Duration + mu sync.Mutex +} + +// New creates a new client. +func New(socketPath string) *Client { + return &Client{ + socketPath: socketPath, + timeout: 5 * time.Second, + } +} + +// NewWithTimeout creates a new client with a custom timeout. +func NewWithTimeout(socketPath string, timeout time.Duration) *Client { + return &Client{ + socketPath: socketPath, + timeout: timeout, + } +} + +// Connect establishes a connection to the daemon. +func (c *Client) Connect() error { + c.mu.Lock() + defer c.mu.Unlock() + + // Close existing connection if any + if c.conn != nil { + c.conn.Close() + c.conn = nil + c.reader = nil + } + + conn, err := net.DialTimeout("unix", c.socketPath, c.timeout) + if err != nil { + return fmt.Errorf("failed to connect to daemon: %w", err) + } + + c.conn = conn + c.reader = bufio.NewReader(conn) + return nil +} + +// Close closes the connection. +func (c *Client) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + err := c.conn.Close() + c.conn = nil + c.reader = nil + return err + } + return nil +} + +// send sends a request and receives a response. +func (c *Client) send(req *protocol.Request) (*protocol.Response, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return nil, fmt.Errorf("not connected") + } + + // Set deadline + c.conn.SetDeadline(time.Now().Add(c.timeout)) + + // Send request + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + data = append(data, '\n') + + if _, err := c.conn.Write(data); err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // Read response + line, err := c.reader.ReadBytes('\n') + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var resp protocol.Response + if err := json.Unmarshal(line, &resp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return &resp, nil +} + +// Ping checks if the daemon is responsive. +func (c *Client) Ping() error { + req, _ := protocol.NewRequest(protocol.RequestPing, nil) + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("ping failed: %s", resp.Message) + } + return nil +} + +// Status returns the daemon's status. +func (c *Client) Status() (*protocol.StatusData, error) { + req, _ := protocol.NewRequest(protocol.RequestStatus, nil) + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("status failed: %s", resp.Message) + } + + var data protocol.StatusData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return &data, nil +} + +// List returns all host entries. +func (c *Client) List() ([]protocol.HostEntry, error) { + req, _ := protocol.NewRequest(protocol.RequestList, nil) + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("list failed: %s", resp.Message) + } + + var data protocol.ListData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return data.Entries, nil +} + +// Set enables or disables a host entry by alias. +func (c *Client) Set(alias string, enabled bool, force bool) (*protocol.SetData, error) { + req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{ + Alias: alias, + Enabled: enabled, + Force: force, + }) + + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + + var data protocol.SetData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return &data, nil +} + +// Enable enables a host entry by alias. +func (c *Client) Enable(alias string) (*protocol.SetData, error) { + return c.Set(alias, true, false) +} + +// Disable disables a host entry by alias. +func (c *Client) Disable(alias string) (*protocol.SetData, error) { + return c.Set(alias, false, false) +} + +// Add adds a new host entry. +func (c *Client) Add(domain, ip, alias, group string, enabled bool) (*protocol.SetData, error) { + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: domain, + IP: ip, + Alias: alias, + Group: group, + Enabled: enabled, + }) + + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + + var data protocol.SetData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return &data, nil +} + +// Delete removes a host entry by alias. +func (c *Client) Delete(alias string) error { + req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{ + Alias: alias, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// AddGroup adds a new group. +func (c *Client) AddGroup(name string) error { + req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{ + Name: name, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// DeleteGroup removes a group and all its hosts. +func (c *Client) DeleteGroup(name string) error { + req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{ + Name: name, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// ListGroups returns all group names. +func (c *Client) ListGroups() ([]string, error) { + req, _ := protocol.NewRequest(protocol.RequestListGroups, nil) + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + + var data protocol.GroupsData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return data.Groups, nil +} + +// Sync synchronizes the config to the hosts file. +func (c *Client) Sync() error { + req, _ := protocol.NewRequest(protocol.RequestSync, nil) + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("sync failed: %s", resp.Message) + } + return nil +} + +// ApplyPreset applies a named preset. +func (c *Client) ApplyPreset(name string) error { + req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{ + Name: name, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("preset failed: %s", resp.Message) + } + return nil +} + +// Rollback restores a backup by name. +func (c *Client) Rollback(backupName string) error { + req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{ + BackupName: backupName, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("rollback failed: %s", resp.Message) + } + return nil +} + +// ListBackups returns available backups. +func (c *Client) ListBackups() ([]protocol.BackupInfo, error) { + req, _ := protocol.NewRequest(protocol.RequestBackups, nil) + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("backups failed: %s", resp.Message) + } + + var data protocol.BackupsData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return data.Backups, nil +} + +// RenameGroup renames a group. +func (c *Client) RenameGroup(oldName, newName string) error { + req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{ + OldName: oldName, + NewName: newName, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// AddPreset adds a new preset. +func (c *Client) AddPreset(name string, enable, disable []string) error { + req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{ + Name: name, + Enable: enable, + Disable: disable, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// DeletePreset removes a preset by name. +func (c *Client) DeletePreset(name string) error { + req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{ + Name: name, + }) + + resp, err := c.send(req) + if err != nil { + return err + } + if !resp.IsOK() { + return fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + return nil +} + +// ListPresets returns all presets. +func (c *Client) ListPresets() ([]protocol.PresetInfo, error) { + req, _ := protocol.NewRequest(protocol.RequestListPresets, nil) + resp, err := c.send(req) + if err != nil { + return nil, err + } + if !resp.IsOK() { + return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message) + } + + var data protocol.PresetsData + if err := resp.ParseData(&data); err != nil { + return nil, err + } + return data.Presets, nil +} + +// IsConnected checks if the daemon is reachable. +func IsConnected(socketPath string) bool { + client := New(socketPath) + if err := client.Connect(); err != nil { + return false + } + defer client.Close() + + return client.Ping() == nil +} diff --git a/internal/client/client_test.go b/internal/client/client_test.go new file mode 100644 index 0000000..ac9fd15 --- /dev/null +++ b/internal/client/client_test.go @@ -0,0 +1,516 @@ +package client + +import ( + "bufio" + "encoding/json" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// mockServer creates a mock Unix socket server for testing +type mockServer struct { + listener net.Listener + path string + handler func(req *protocol.Request) *protocol.Response +} + +func newMockServer(t *testing.T) *mockServer { + // Use /tmp directly to avoid long paths (Unix socket paths have ~104 char limit on macOS) + tmpDir, err := os.MkdirTemp("/tmp", "lolcat") + require.NoError(t, err) + t.Cleanup(func() { os.RemoveAll(tmpDir) }) + + socketPath := filepath.Join(tmpDir, "s.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + ms := &mockServer{ + listener: listener, + path: socketPath, + } + + go ms.serve() + + return ms +} + +func (ms *mockServer) serve() { + for { + conn, err := ms.listener.Accept() + if err != nil { + return + } + go ms.handleConn(conn) + } +} + +func (ms *mockServer) handleConn(conn net.Conn) { + defer conn.Close() + + reader := bufio.NewReader(conn) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + return + } + + var req protocol.Request + if err := json.Unmarshal(line, &req); err != nil { + continue + } + + var resp *protocol.Response + if ms.handler != nil { + resp = ms.handler(&req) + } else { + resp, _ = protocol.NewOKResponse(nil) + } + + data, _ := json.Marshal(resp) + conn.Write(append(data, '\n')) + } +} + +func (ms *mockServer) close() { + ms.listener.Close() + os.Remove(ms.path) +} + +func TestClient_Connect(t *testing.T) { + t.Run("success", func(t *testing.T) { + server := newMockServer(t) + defer server.close() + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + assert.NotNil(t, client.conn) + assert.NotNil(t, client.reader) + }) + + t.Run("failure - socket not found", func(t *testing.T) { + client := New("/nonexistent/socket.sock") + err := client.Connect() + assert.Error(t, err) + }) +} + +func TestClient_Ping(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestPing { + resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected request") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.Ping() + assert.NoError(t, err) +} + +func TestClient_Status(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestStatus { + resp, _ := protocol.NewOKResponse(protocol.StatusData{ + Running: true, + Version: "1.0.0", + Uptime: 3600, + ActiveCount: 5, + RequestCount: 100, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + status, err := client.Status() + require.NoError(t, err) + + assert.True(t, status.Running) + assert.Equal(t, "1.0.0", status.Version) + assert.Equal(t, int64(3600), status.Uptime) + assert.Equal(t, 5, status.ActiveCount) + assert.Equal(t, int64(100), status.RequestCount) +} + +func TestClient_List(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestList { + resp, _ := protocol.NewOKResponse(protocol.ListData{ + Entries: []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + }, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + entries, err := client.List() + require.NoError(t, err) + + assert.Len(t, entries, 2) + assert.Equal(t, "a.com", entries[0].Domain) + assert.True(t, entries[0].Enabled) + assert.Equal(t, "b.com", entries[1].Domain) + assert.False(t, entries[1].Enabled) +} + +func TestClient_Set(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestSet { + var payload protocol.SetPayload + req.ParsePayload(&payload) + + resp, _ := protocol.NewOKResponse(protocol.SetData{ + Domain: "example.com", + Applied: true, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + data, err := client.Set("test", true, false) + require.NoError(t, err) + + assert.Equal(t, "example.com", data.Domain) + assert.True(t, data.Applied) +} + +func TestClient_Enable(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestSet { + var payload protocol.SetPayload + req.ParsePayload(&payload) + assert.True(t, payload.Enabled) + + resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + _, err = client.Enable("test") + assert.NoError(t, err) +} + +func TestClient_Disable(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestSet { + var payload protocol.SetPayload + req.ParsePayload(&payload) + assert.False(t, payload.Enabled) + + resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + _, err = client.Disable("test") + assert.NoError(t, err) +} + +func TestClient_Sync(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestSync { + resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.Sync() + assert.NoError(t, err) +} + +func TestClient_ApplyPreset(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestPreset { + var payload protocol.PresetPayload + req.ParsePayload(&payload) + assert.Equal(t, "local", payload.Name) + + resp, _ := protocol.NewOKResponse(map[string]string{"preset": "local"}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.ApplyPreset("local") + assert.NoError(t, err) +} + +func TestClient_Rollback(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestRollback { + var payload protocol.RollbackPayload + req.ParsePayload(&payload) + assert.Equal(t, "hosts.backup.bak", payload.BackupName) + + resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.Rollback("hosts.backup.bak") + assert.NoError(t, err) +} + +func TestClient_ListBackups(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestBackups { + resp, _ := protocol.NewOKResponse(protocol.BackupsData{ + Backups: []protocol.BackupInfo{ + {Name: "hosts.20231201.bak", Timestamp: 1701432000, Size: 1024}, + {Name: "hosts.20231130.bak", Timestamp: 1701345600, Size: 1000}, + }, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + backups, err := client.ListBackups() + require.NoError(t, err) + + assert.Len(t, backups, 2) + assert.Equal(t, "hosts.20231201.bak", backups[0].Name) +} + +func TestClient_ErrorResponse(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, "domain is blocked") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + _, err = client.Set("test", true, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "domain is blocked") +} + +func TestClient_NotConnected(t *testing.T) { + client := New("/nonexistent/socket.sock") + + _, err := client.Status() + assert.Error(t, err) + assert.Contains(t, err.Error(), "not connected") +} + +func TestClient_Timeout(t *testing.T) { + client := NewWithTimeout("/nonexistent.sock", 100*time.Millisecond) + assert.Equal(t, 100*time.Millisecond, client.timeout) +} + +func TestIsConnected(t *testing.T) { + t.Run("connected", func(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + resp, _ := protocol.NewOKResponse(nil) + return resp + } + + connected := IsConnected(server.path) + assert.True(t, connected) + }) + + t.Run("not connected", func(t *testing.T) { + connected := IsConnected("/nonexistent/socket.sock") + assert.False(t, connected) + }) +} + +// Matrix test for request types +func TestClient_RequestTypes_Matrix(t *testing.T) { + types := []struct { + name string + reqType protocol.RequestType + call func(*Client) error + }{ + {"ping", protocol.RequestPing, func(c *Client) error { return c.Ping() }}, + {"status", protocol.RequestStatus, func(c *Client) error { _, err := c.Status(); return err }}, + {"list", protocol.RequestList, func(c *Client) error { _, err := c.List(); return err }}, + {"sync", protocol.RequestSync, func(c *Client) error { return c.Sync() }}, + {"preset", protocol.RequestPreset, func(c *Client) error { return c.ApplyPreset("test") }}, + {"backups", protocol.RequestBackups, func(c *Client) error { _, err := c.ListBackups(); return err }}, + } + + for _, tt := range types { + t.Run(tt.name, func(t *testing.T) { + server := newMockServer(t) + defer server.close() + + receivedType := protocol.RequestType("") + server.handler = func(req *protocol.Request) *protocol.Response { + receivedType = req.Type + + switch req.Type { + case protocol.RequestStatus: + resp, _ := protocol.NewOKResponse(protocol.StatusData{}) + return resp + case protocol.RequestList: + resp, _ := protocol.NewOKResponse(protocol.ListData{}) + return resp + case protocol.RequestBackups: + resp, _ := protocol.NewOKResponse(protocol.BackupsData{}) + return resp + default: + resp, _ := protocol.NewOKResponse(nil) + return resp + } + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + _ = tt.call(client) + assert.Equal(t, tt.reqType, receivedType) + }) + } +} + +func BenchmarkClient_Ping(b *testing.B) { + tmpDir := b.TempDir() + socketPath := filepath.Join(tmpDir, "bench.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(b, err) + defer listener.Close() + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + go func(c net.Conn) { + defer c.Close() + reader := bufio.NewReader(c) + for { + _, err := reader.ReadBytes('\n') + if err != nil { + return + } + resp, _ := protocol.NewOKResponse(nil) + data, _ := json.Marshal(resp) + c.Write(append(data, '\n')) + } + }(conn) + } + }() + + client := New(socketPath) + err = client.Connect() + require.NoError(b, err) + defer client.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = client.Ping() + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..8fecbb8 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,541 @@ +// Package config handles YAML configuration parsing and hot-reload. +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/fsnotify/fsnotify" + "gopkg.in/yaml.v3" +) + +// SystemConfigDir is the system-wide config directory for the daemon. +const SystemConfigDir = "/etc/lolcathost" + +// SystemConfigPath is the system-wide config file path for the daemon. +const SystemConfigPath = "/etc/lolcathost/config.yaml" + +// DefaultConfigDir returns the default config directory path for users. +func DefaultConfigDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".config", "lolcathost") +} + +// DefaultConfigPath returns the default config file path for users. +func DefaultConfigPath() string { + return filepath.Join(DefaultConfigDir(), "config.yaml") +} + +// FlushMethod defines DNS cache flush methods. +type FlushMethod string + +const ( + FlushMethodAuto FlushMethod = "auto" + FlushMethodDscacheutil FlushMethod = "dscacheutil" + FlushMethodKillall FlushMethod = "killall" + FlushMethodBoth FlushMethod = "both" +) + +// Settings holds global configuration settings. +type Settings struct { + AutoApply bool `yaml:"autoApply"` + FlushMethod FlushMethod `yaml:"flushMethod"` +} + +// Host represents a single host entry in configuration. +type Host struct { + Domain string `yaml:"domain"` + IP string `yaml:"ip"` + Alias string `yaml:"alias"` + Enabled bool `yaml:"enabled"` +} + +// Group represents a group of host entries. +type Group struct { + Name string `yaml:"name"` + Hosts []Host `yaml:"hosts"` +} + +// Preset defines a named preset that enables/disables specific aliases. +type Preset struct { + Name string `yaml:"name"` + Enable []string `yaml:"enable,omitempty"` + Disable []string `yaml:"disable,omitempty"` +} + +// Config represents the complete configuration. +type Config struct { + Settings Settings `yaml:"settings"` + Groups []Group `yaml:"groups"` + Presets []Preset `yaml:"presets"` +} + +// Manager handles configuration loading and watching. +type Manager struct { + path string + config *Config + mu sync.RWMutex + watcher *fsnotify.Watcher + onChange func(*Config) + stopCh chan struct{} +} + +// NewManager creates a new config manager. +func NewManager(path string) *Manager { + return &Manager{ + path: path, + stopCh: make(chan struct{}), + } +} + +// Load reads and parses the configuration file. +func (m *Manager) Load() error { + data, err := os.ReadFile(m.path) + if err != nil { + return fmt.Errorf("failed to read config file: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return fmt.Errorf("failed to parse config file: %w", err) + } + + if err := ValidateConfig(&cfg); err != nil { + return fmt.Errorf("invalid config: %w", err) + } + + m.mu.Lock() + m.config = &cfg + m.mu.Unlock() + + return nil +} + +// Get returns the current configuration. +func (m *Manager) Get() *Config { + m.mu.RLock() + defer m.mu.RUnlock() + return m.config +} + +// Watch starts watching the config file for changes. +func (m *Manager) Watch(onChange func(*Config)) error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to create watcher: %w", err) + } + + m.watcher = watcher + m.onChange = onChange + + go m.watchLoop() + + if err := watcher.Add(m.path); err != nil { + return fmt.Errorf("failed to watch config file: %w", err) + } + + return nil +} + +func (m *Manager) watchLoop() { + for { + select { + case event, ok := <-m.watcher.Events: + if !ok { + return + } + if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) { + if err := m.Load(); err == nil && m.onChange != nil { + m.onChange(m.Get()) + } + } + case <-m.watcher.Errors: + // Ignore watcher errors + case <-m.stopCh: + return + } + } +} + +// Stop stops watching the config file. +func (m *Manager) Stop() { + close(m.stopCh) + if m.watcher != nil { + m.watcher.Close() + } +} + +// GetAllHosts returns all hosts from all groups. +func (c *Config) GetAllHosts() []Host { + var hosts []Host + for _, g := range c.Groups { + hosts = append(hosts, g.Hosts...) + } + return hosts +} + +// FindHostByAlias finds a host by its alias. +func (c *Config) FindHostByAlias(alias string) (*Host, *Group) { + for i := range c.Groups { + for j := range c.Groups[i].Hosts { + if c.Groups[i].Hosts[j].Alias == alias { + return &c.Groups[i].Hosts[j], &c.Groups[i] + } + } + } + return nil, nil +} + +// FindPreset finds a preset by name. +func (c *Config) FindPreset(name string) *Preset { + for i := range c.Presets { + if c.Presets[i].Name == name { + return &c.Presets[i] + } + } + return nil +} + +// SetHostEnabled sets the enabled state of a host by alias. +func (c *Config) SetHostEnabled(alias string, enabled bool) bool { + for i := range c.Groups { + for j := range c.Groups[i].Hosts { + if c.Groups[i].Hosts[j].Alias == alias { + c.Groups[i].Hosts[j].Enabled = enabled + return true + } + } + } + return false +} + +// GenerateAlias creates a unique alias from a domain name. +func (c *Config) GenerateAlias(domain string) string { + // Convert domain to alias format: example.com -> example-com + alias := strings.ReplaceAll(domain, ".", "-") + alias = strings.ReplaceAll(alias, "_", "-") + alias = strings.ToLower(alias) + + // Check if alias exists, if so append a number + baseAlias := alias + counter := 1 + for { + if existing, _ := c.FindHostByAlias(alias); existing == nil { + break + } + counter++ + alias = fmt.Sprintf("%s-%d", baseAlias, counter) + } + + return alias +} + +// AddHost adds a new host to the configuration. +func (c *Config) AddHost(domain, ip, alias, groupName string, enabled bool) error { + // Auto-generate alias if empty + if alias == "" { + alias = c.GenerateAlias(domain) + } else { + // Check for duplicate alias + if existing, _ := c.FindHostByAlias(alias); existing != nil { + return fmt.Errorf("alias already exists: %s", alias) + } + } + + host := Host{ + Domain: domain, + IP: ip, + Alias: alias, + Enabled: enabled, + } + + // Find or create group + for i := range c.Groups { + if c.Groups[i].Name == groupName { + c.Groups[i].Hosts = append(c.Groups[i].Hosts, host) + return nil + } + } + + // Create new group + c.Groups = append(c.Groups, Group{ + Name: groupName, + Hosts: []Host{host}, + }) + return nil +} + +// AddGroup adds a new empty group. +func (c *Config) AddGroup(name string) error { + // Check if group already exists + for _, g := range c.Groups { + if g.Name == name { + return fmt.Errorf("group already exists: %s", name) + } + } + + c.Groups = append(c.Groups, Group{ + Name: name, + Hosts: []Host{}, + }) + return nil +} + +// DeleteGroup removes a group and all its hosts. +func (c *Config) DeleteGroup(name string) error { + for i, g := range c.Groups { + if g.Name == name { + c.Groups = append(c.Groups[:i], c.Groups[i+1:]...) + return nil + } + } + return fmt.Errorf("group not found: %s", name) +} + +// RenameGroup renames an existing group. +func (c *Config) RenameGroup(oldName, newName string) error { + // Check if new name already exists + for _, g := range c.Groups { + if g.Name == newName { + return fmt.Errorf("group already exists: %s", newName) + } + } + + for i := range c.Groups { + if c.Groups[i].Name == oldName { + c.Groups[i].Name = newName + return nil + } + } + return fmt.Errorf("group not found: %s", oldName) +} + +// GetGroups returns all group names. +func (c *Config) GetGroups() []string { + names := make([]string, len(c.Groups)) + for i, g := range c.Groups { + names[i] = g.Name + } + return names +} + +// DeleteHost removes a host by alias. +func (c *Config) DeleteHost(alias string) bool { + for i := range c.Groups { + for j := range c.Groups[i].Hosts { + if c.Groups[i].Hosts[j].Alias == alias { + c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...) + return true + } + } + } + return false +} + +// UpdateHost updates an existing host by alias. +func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error { + // Find the host + var foundGroup int = -1 + var foundHost int = -1 + for i := range c.Groups { + for j := range c.Groups[i].Hosts { + if c.Groups[i].Hosts[j].Alias == oldAlias { + foundGroup = i + foundHost = j + break + } + } + if foundHost >= 0 { + break + } + } + + if foundHost < 0 { + return fmt.Errorf("alias not found: %s", oldAlias) + } + + // Check for duplicate alias if alias is changing + if oldAlias != newAlias { + if existing, _ := c.FindHostByAlias(newAlias); existing != nil { + return fmt.Errorf("alias already exists: %s", newAlias) + } + } + + // Get current enabled state + enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled + + // If group is changing, move to new group + if c.Groups[foundGroup].Name != groupName { + // Remove from old group + c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...) + + // Add to new group + host := Host{ + Domain: domain, + IP: ip, + Alias: newAlias, + Enabled: enabled, + } + + // Find or create target group + found := false + for i := range c.Groups { + if c.Groups[i].Name == groupName { + c.Groups[i].Hosts = append(c.Groups[i].Hosts, host) + found = true + break + } + } + if !found { + c.Groups = append(c.Groups, Group{ + Name: groupName, + Hosts: []Host{host}, + }) + } + } else { + // Update in place + c.Groups[foundGroup].Hosts[foundHost].Domain = domain + c.Groups[foundGroup].Hosts[foundHost].IP = ip + c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias + } + + return nil +} + +// ApplyPreset applies a preset to the configuration. +func (c *Config) ApplyPreset(name string) error { + preset := c.FindPreset(name) + if preset == nil { + return fmt.Errorf("preset not found: %s", name) + } + + for _, alias := range preset.Enable { + c.SetHostEnabled(alias, true) + } + for _, alias := range preset.Disable { + c.SetHostEnabled(alias, false) + } + return nil +} + +// AddPreset adds a new preset. +func (c *Config) AddPreset(name string, enable, disable []string) error { + // Check if preset already exists + for _, p := range c.Presets { + if p.Name == name { + return fmt.Errorf("preset already exists: %s", name) + } + } + + c.Presets = append(c.Presets, Preset{ + Name: name, + Enable: enable, + Disable: disable, + }) + return nil +} + +// DeletePreset removes a preset by name. +func (c *Config) DeletePreset(name string) error { + for i, p := range c.Presets { + if p.Name == name { + c.Presets = append(c.Presets[:i], c.Presets[i+1:]...) + return nil + } + } + return fmt.Errorf("preset not found: %s", name) +} + +// GetPresets returns all presets. +func (c *Config) GetPresets() []Preset { + return c.Presets +} + +// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed. +func (c *Config) EnsureDefaultGroup() { + if len(c.Groups) == 0 { + c.Groups = append(c.Groups, Group{ + Name: "default", + Hosts: []Host{}, + }) + } +} + +// Save writes the configuration to the file. +func (m *Manager) Save() error { + m.mu.RLock() + cfg := m.config + m.mu.RUnlock() + + if cfg == nil { + return fmt.Errorf("no config loaded") + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + if err := os.WriteFile(m.path, data, 0644); err != nil { + return fmt.Errorf("failed to write config: %w", err) + } + + return nil +} + +// CreateDefault creates a default configuration file. +func CreateDefault(path string) error { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + cfg := &Config{ + Settings: Settings{ + AutoApply: true, + FlushMethod: FlushMethodAuto, + }, + Groups: []Group{ + { + Name: "development", + Hosts: []Host{ + { + Domain: "example.local", + IP: "127.0.0.1", + Alias: "example-local", + Enabled: false, + }, + }, + }, + }, + Presets: []Preset{ + { + Name: "local", + Enable: []string{"example-local"}, + Disable: []string{}, + }, + { + Name: "clear", + Enable: []string{}, + Disable: []string{"example-local"}, + }, + }, + } + + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("failed to marshal default config: %w", err) + } + + if err := os.WriteFile(path, data, 0644); err != nil { + return fmt.Errorf("failed to write default config: %w", err) + } + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..ab58dfe --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,267 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfig_GetAllHosts(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false}, + }, + }, + { + Name: "staging", + Hosts: []Host{ + {Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true}, + }, + }, + }, + } + + hosts := cfg.GetAllHosts() + assert.Len(t, hosts, 3) + assert.Equal(t, "a.com", hosts[0].Domain) + assert.Equal(t, "b.com", hosts[1].Domain) + assert.Equal(t, "c.com", hosts[2].Domain) +} + +func TestConfig_FindHostByAlias(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true}, + }, + }, + }, + } + + t.Run("found", func(t *testing.T) { + host, group := cfg.FindHostByAlias("example") + require.NotNil(t, host) + require.NotNil(t, group) + assert.Equal(t, "example.com", host.Domain) + assert.Equal(t, "dev", group.Name) + }) + + t.Run("not found", func(t *testing.T) { + host, group := cfg.FindHostByAlias("nonexistent") + assert.Nil(t, host) + assert.Nil(t, group) + }) +} + +func TestConfig_FindPreset(t *testing.T) { + cfg := &Config{ + Presets: []Preset{ + {Name: "local", Enable: []string{"a"}, Disable: []string{"b"}}, + {Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}}, + }, + } + + t.Run("found", func(t *testing.T) { + preset := cfg.FindPreset("local") + require.NotNil(t, preset) + assert.Equal(t, "local", preset.Name) + assert.Equal(t, []string{"a"}, preset.Enable) + }) + + t.Run("not found", func(t *testing.T) { + preset := cfg.FindPreset("nonexistent") + assert.Nil(t, preset) + }) +} + +func TestConfig_SetHostEnabled(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: false}, + }, + }, + }, + } + + t.Run("enable existing", func(t *testing.T) { + result := cfg.SetHostEnabled("example", true) + assert.True(t, result) + assert.True(t, cfg.Groups[0].Hosts[0].Enabled) + }) + + t.Run("disable existing", func(t *testing.T) { + result := cfg.SetHostEnabled("example", false) + assert.True(t, result) + assert.False(t, cfg.Groups[0].Hosts[0].Enabled) + }) + + t.Run("nonexistent alias", func(t *testing.T) { + result := cfg.SetHostEnabled("nonexistent", true) + assert.False(t, result) + }) +} + +func TestConfig_ApplyPreset(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true}, + }, + }, + }, + Presets: []Preset{ + {Name: "swap", Enable: []string{"a"}, Disable: []string{"b"}}, + }, + } + + t.Run("valid preset", func(t *testing.T) { + err := cfg.ApplyPreset("swap") + require.NoError(t, err) + assert.True(t, cfg.Groups[0].Hosts[0].Enabled) + assert.False(t, cfg.Groups[0].Hosts[1].Enabled) + }) + + t.Run("nonexistent preset", func(t *testing.T) { + err := cfg.ApplyPreset("nonexistent") + assert.Error(t, err) + }) +} + +func TestManager_LoadAndGet(t *testing.T) { + // Create temp config file + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + configContent := ` +settings: + autoApply: true + flushMethod: auto +groups: + - name: development + hosts: + - domain: example.com + ip: 127.0.0.1 + alias: example-local + enabled: true +presets: + - name: local + enable: [example-local] + disable: [] +` + err := os.WriteFile(configPath, []byte(configContent), 0644) + require.NoError(t, err) + + manager := NewManager(configPath) + err = manager.Load() + require.NoError(t, err) + + cfg := manager.Get() + require.NotNil(t, cfg) + + assert.True(t, cfg.Settings.AutoApply) + assert.Equal(t, FlushMethodAuto, cfg.Settings.FlushMethod) + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "development", cfg.Groups[0].Name) + assert.Len(t, cfg.Groups[0].Hosts, 1) + assert.Equal(t, "example.com", cfg.Groups[0].Hosts[0].Domain) +} + +func TestManager_Save(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + // Create initial config + err := CreateDefault(configPath) + require.NoError(t, err) + + // Load and modify + manager := NewManager(configPath) + err = manager.Load() + require.NoError(t, err) + + cfg := manager.Get() + cfg.Groups[0].Hosts[0].Enabled = true + + // Save + err = manager.Save() + require.NoError(t, err) + + // Reload and verify + manager2 := NewManager(configPath) + err = manager2.Load() + require.NoError(t, err) + + cfg2 := manager2.Get() + assert.True(t, cfg2.Groups[0].Hosts[0].Enabled) +} + +func TestCreateDefault(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "subdir", "config.yaml") + + err := CreateDefault(configPath) + require.NoError(t, err) + + // Verify file exists + _, err = os.Stat(configPath) + require.NoError(t, err) + + // Verify content is valid + manager := NewManager(configPath) + err = manager.Load() + require.NoError(t, err) + + cfg := manager.Get() + require.NotNil(t, cfg) + assert.True(t, cfg.Settings.AutoApply) + assert.Len(t, cfg.Groups, 1) + assert.Len(t, cfg.Presets, 2) +} + +func TestManager_Load_InvalidYAML(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0644) + require.NoError(t, err) + + manager := NewManager(configPath) + err = manager.Load() + assert.Error(t, err) +} + +func TestManager_Load_FileNotFound(t *testing.T) { + manager := NewManager("/nonexistent/path/config.yaml") + err := manager.Load() + assert.Error(t, err) +} + +func TestFlushMethod(t *testing.T) { + methods := []FlushMethod{ + FlushMethodAuto, + FlushMethodDscacheutil, + FlushMethodKillall, + FlushMethodBoth, + } + + for _, m := range methods { + t.Run(string(m), func(t *testing.T) { + assert.NotEmpty(t, string(m)) + }) + } +} diff --git a/internal/config/validation.go b/internal/config/validation.go new file mode 100644 index 0000000..4a1059e --- /dev/null +++ b/internal/config/validation.go @@ -0,0 +1,211 @@ +// Package config provides validation functions for configuration. +package config + +import ( + "fmt" + "net" + "regexp" + "strings" +) + +// domainRegex validates domain names. +var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$|^localhost$`) + +// aliasRegex validates alias names. +var aliasRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`) + +// blockedDomains contains domains that cannot be modified. +var blockedDomains = map[string]bool{ + "apple.com": true, + "icloud.com": true, + "icloud-content.com": true, + "apple-dns.cn": true, + "apple-dns.net": true, + "mzstatic.com": true, + "itunes.apple.com": true, + "updates.apple.com": true, +} + +// ValidationError represents a configuration validation error. +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return fmt.Sprintf("%s: %s", e.Field, e.Message) +} + +// ValidateConfig validates the entire configuration. +func ValidateConfig(cfg *Config) error { + if cfg == nil { + return &ValidationError{Field: "config", Message: "config is nil"} + } + + if err := validateSettings(&cfg.Settings); err != nil { + return err + } + + // Track aliases for uniqueness + aliases := make(map[string]bool) + + for i, g := range cfg.Groups { + if err := validateGroup(&g, i, aliases); err != nil { + return err + } + } + + for i, p := range cfg.Presets { + if err := validatePreset(&p, i, aliases); err != nil { + return err + } + } + + return nil +} + +func validateSettings(s *Settings) error { + switch s.FlushMethod { + case FlushMethodAuto, FlushMethodDscacheutil, FlushMethodKillall, FlushMethodBoth, "": + // Valid + default: + return &ValidationError{ + Field: "settings.flushMethod", + Message: fmt.Sprintf("invalid flush method: %s", s.FlushMethod), + } + } + return nil +} + +func validateGroup(g *Group, index int, aliases map[string]bool) error { + if strings.TrimSpace(g.Name) == "" { + return &ValidationError{ + Field: fmt.Sprintf("groups[%d].name", index), + Message: "group name is required", + } + } + + for i, h := range g.Hosts { + if err := validateHost(&h, index, i, aliases); err != nil { + return err + } + } + + return nil +} + +func validateHost(h *Host, groupIndex, hostIndex int, aliases map[string]bool) error { + fieldPrefix := fmt.Sprintf("groups[%d].hosts[%d]", groupIndex, hostIndex) + + // Validate domain + if !ValidateDomain(h.Domain) { + return &ValidationError{ + Field: fieldPrefix + ".domain", + Message: fmt.Sprintf("invalid domain: %s", h.Domain), + } + } + + // Check blocked domains + if IsBlockedDomain(h.Domain) { + return &ValidationError{ + Field: fieldPrefix + ".domain", + Message: fmt.Sprintf("domain is blocked: %s", h.Domain), + } + } + + // Validate IP + if !ValidateIP(h.IP) { + return &ValidationError{ + Field: fieldPrefix + ".ip", + Message: fmt.Sprintf("invalid IP address: %s", h.IP), + } + } + + // Validate alias + if !ValidateAlias(h.Alias) { + return &ValidationError{ + Field: fieldPrefix + ".alias", + Message: fmt.Sprintf("invalid alias: %s", h.Alias), + } + } + + // Check alias uniqueness + if aliases[h.Alias] { + return &ValidationError{ + Field: fieldPrefix + ".alias", + Message: fmt.Sprintf("duplicate alias: %s", h.Alias), + } + } + aliases[h.Alias] = true + + return nil +} + +func validatePreset(p *Preset, index int, aliases map[string]bool) error { + fieldPrefix := fmt.Sprintf("presets[%d]", index) + + if strings.TrimSpace(p.Name) == "" { + return &ValidationError{ + Field: fieldPrefix + ".name", + Message: "preset name is required", + } + } + + // Note: We don't validate preset aliases strictly anymore. + // Unknown aliases in presets will simply be skipped when applying the preset. + // This allows presets to survive when hosts are removed from the config. + + return nil +} + +// ValidateDomain checks if a domain name is valid. +func ValidateDomain(domain string) bool { + if domain == "" { + return false + } + return domainRegex.MatchString(domain) +} + +// ValidateIP checks if an IP address is valid (IPv4 or IPv6). +func ValidateIP(ip string) bool { + if ip == "" { + return false + } + return net.ParseIP(ip) != nil +} + +// ValidateAlias checks if an alias is valid. +func ValidateAlias(alias string) bool { + if alias == "" { + return false + } + return aliasRegex.MatchString(alias) +} + +// IsBlockedDomain checks if a domain is in the blocklist. +func IsBlockedDomain(domain string) bool { + domain = strings.ToLower(domain) + + // Check exact match + if blockedDomains[domain] { + return true + } + + // Check if it's a subdomain of a blocked domain + for blocked := range blockedDomains { + if strings.HasSuffix(domain, "."+blocked) { + return true + } + } + + return false +} + +// GetBlockedDomains returns a copy of the blocked domains list. +func GetBlockedDomains() []string { + domains := make([]string, 0, len(blockedDomains)) + for d := range blockedDomains { + domains = append(domains, d) + } + return domains +} diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go new file mode 100644 index 0000000..36830e6 --- /dev/null +++ b/internal/config/validation_test.go @@ -0,0 +1,436 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidateDomain(t *testing.T) { + tests := []struct { + domain string + valid bool + }{ + {"example.com", true}, + {"sub.example.com", true}, + {"my-app.example.com", true}, + {"localhost", true}, + {"a.b.c.d.example.com", true}, + {"example123.com", true}, + + {"", false}, + {"-example.com", false}, + {"example-.com", false}, + {"example.c", false}, // TLD too short + {"example", false}, // No TLD + {".example.com", false}, + {"example..com", false}, + } + + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + result := ValidateDomain(tt.domain) + assert.Equal(t, tt.valid, result, "domain: %s", tt.domain) + }) + } +} + +func TestValidateIP(t *testing.T) { + tests := []struct { + ip string + valid bool + }{ + // Valid IPv4 + {"127.0.0.1", true}, + {"192.168.1.1", true}, + {"0.0.0.0", true}, + {"255.255.255.255", true}, + + // Valid IPv6 + {"::1", true}, + {"2001:db8::1", true}, + {"fe80::1", true}, + {"::ffff:192.168.1.1", true}, + + // Invalid + {"", false}, + {"256.0.0.1", false}, + {"192.168.1", false}, + {"not-an-ip", false}, + {"192.168.1.1.1", false}, + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + result := ValidateIP(tt.ip) + assert.Equal(t, tt.valid, result, "ip: %s", tt.ip) + }) + } +} + +func TestValidateAlias(t *testing.T) { + tests := []struct { + alias string + valid bool + }{ + {"my-alias", true}, + {"myalias", true}, + {"my_alias", true}, + {"alias123", true}, + {"a", true}, + {"a-b_c-d", true}, + + {"", false}, + {"-startswithdash", false}, + {"_startswithunderscore", false}, + {"has spaces", false}, + {"has.dot", false}, + } + + for _, tt := range tests { + t.Run(tt.alias, func(t *testing.T) { + result := ValidateAlias(tt.alias) + assert.Equal(t, tt.valid, result, "alias: %s", tt.alias) + }) + } +} + +func TestIsBlockedDomain(t *testing.T) { + tests := []struct { + domain string + blocked bool + }{ + // Blocked domains + {"apple.com", true}, + {"icloud.com", true}, + {"sub.apple.com", true}, + {"deep.sub.icloud.com", true}, + {"APPLE.COM", true}, // Case insensitive + + // Allowed domains + {"example.com", false}, + {"myapp.com", false}, + {"applestore.com", false}, // Not a subdomain + {"notapple.com", false}, + } + + for _, tt := range tests { + t.Run(tt.domain, func(t *testing.T) { + result := IsBlockedDomain(tt.domain) + assert.Equal(t, tt.blocked, result, "domain: %s", tt.domain) + }) + } +} + +func TestGetBlockedDomains(t *testing.T) { + domains := GetBlockedDomains() + assert.NotEmpty(t, domains) + assert.Contains(t, domains, "apple.com") + assert.Contains(t, domains, "icloud.com") +} + +func TestValidateConfig(t *testing.T) { + t.Run("valid config", func(t *testing.T) { + cfg := &Config{ + Settings: Settings{ + AutoApply: true, + FlushMethod: FlushMethodAuto, + }, + Groups: []Group{ + { + Name: "development", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true}, + }, + }, + }, + Presets: []Preset{ + {Name: "local", Enable: []string{"example"}, Disable: []string{}}, + }, + } + + err := ValidateConfig(cfg) + assert.NoError(t, err) + }) + + t.Run("nil config", func(t *testing.T) { + err := ValidateConfig(nil) + assert.Error(t, err) + }) + + t.Run("invalid flush method", func(t *testing.T) { + cfg := &Config{ + Settings: Settings{FlushMethod: "invalid"}, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("empty group name", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{{Name: "", Hosts: []Host{}}}, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("invalid domain", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "invalid", IP: "127.0.0.1", Alias: "test", Enabled: true}, + }, + }, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("blocked domain", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "apple.com", IP: "127.0.0.1", Alias: "test", Enabled: true}, + }, + }, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("invalid IP", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "invalid", Alias: "test", Enabled: true}, + }, + }, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("invalid alias", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "-invalid", Enabled: true}, + }, + }, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("duplicate alias", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "same", Enabled: true}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "same", Enabled: true}, + }, + }, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("empty preset name", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true}, + }, + }, + }, + Presets: []Preset{ + {Name: "", Enable: []string{}}, + }, + } + err := ValidateConfig(cfg) + assert.Error(t, err) + }) + + t.Run("preset with unknown alias is allowed", func(t *testing.T) { + // Unknown aliases in presets are now allowed (they're simply skipped when applied) + // This allows presets to survive when hosts are removed from the config + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true}, + }, + }, + }, + Presets: []Preset{ + {Name: "local", Enable: []string{"unknown"}}, + }, + } + err := ValidateConfig(cfg) + assert.NoError(t, err) + }) +} + +func TestValidationError(t *testing.T) { + err := &ValidationError{Field: "test.field", Message: "test message"} + assert.Equal(t, "test.field: test message", err.Error()) +} + +func TestValidateSettings(t *testing.T) { + tests := []struct { + name string + method FlushMethod + wantErr bool + }{ + {"auto", FlushMethodAuto, false}, + {"dscacheutil", FlushMethodDscacheutil, false}, + {"killall", FlushMethodKillall, false}, + {"both", FlushMethodBoth, false}, + {"empty", "", false}, + {"invalid", "invalid", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings := &Settings{FlushMethod: tt.method} + err := validateSettings(settings) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Matrix testing for domain validation +func TestValidateDomain_Matrix(t *testing.T) { + prefixes := []string{"", "sub.", "a.b."} + domains := []string{"example", "my-app", "test123"} + tlds := []string{".com", ".io", ".co.uk", ".dev"} + + for _, prefix := range prefixes { + for _, domain := range domains { + for _, tld := range tlds { + fullDomain := prefix + domain + tld + t.Run(fullDomain, func(t *testing.T) { + result := ValidateDomain(fullDomain) + assert.True(t, result, "expected %s to be valid", fullDomain) + }) + } + } + } +} + +// Matrix testing for IP validation +func TestValidateIP_Matrix(t *testing.T) { + octets := []string{"0", "127", "192", "255"} + + for _, o1 := range octets { + for _, o2 := range octets { + for _, o3 := range octets { + for _, o4 := range octets { + ip := o1 + "." + o2 + "." + o3 + "." + o4 + t.Run(ip, func(t *testing.T) { + result := ValidateIP(ip) + assert.True(t, result, "expected %s to be valid", ip) + }) + } + } + } + } +} + +// Benchmark tests +func BenchmarkValidateDomain(b *testing.B) { + domains := []string{ + "example.com", + "sub.example.com", + "very.long.subdomain.chain.example.com", + } + + for _, domain := range domains { + b.Run(domain, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ValidateDomain(domain) + } + }) + } +} + +func BenchmarkValidateIP(b *testing.B) { + ips := []string{ + "127.0.0.1", + "192.168.1.1", + "::1", + "2001:db8::1", + } + + for _, ip := range ips { + b.Run(ip, func(b *testing.B) { + for i := 0; i < b.N; i++ { + ValidateIP(ip) + } + }) + } +} + +func BenchmarkIsBlockedDomain(b *testing.B) { + domains := []string{ + "example.com", // not blocked + "apple.com", // blocked + "sub.icloud.com", // blocked subdomain + } + + for _, domain := range domains { + b.Run(domain, func(b *testing.B) { + for i := 0; i < b.N; i++ { + IsBlockedDomain(domain) + } + }) + } +} + +func BenchmarkValidateConfig(b *testing.B) { + cfg := &Config{ + Settings: Settings{AutoApply: true, FlushMethod: FlushMethodAuto}, + Groups: []Group{ + { + Name: "development", + Hosts: []Host{ + {Domain: "a.example.com", IP: "127.0.0.1", Alias: "a", Enabled: true}, + {Domain: "b.example.com", IP: "127.0.0.1", Alias: "b", Enabled: true}, + {Domain: "c.example.com", IP: "127.0.0.1", Alias: "c", Enabled: false}, + }, + }, + }, + Presets: []Preset{ + {Name: "local", Enable: []string{"a", "b"}, Disable: []string{"c"}}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := ValidateConfig(cfg) + require.NoError(b, err) + } +} diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go new file mode 100644 index 0000000..7fa32a9 --- /dev/null +++ b/internal/daemon/daemon.go @@ -0,0 +1,133 @@ +// Package daemon provides the main daemon loop and lifecycle management. +package daemon + +import ( + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/lukaszraczylo/lolcathost/internal/config" + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// Daemon represents the lolcathost daemon. +type Daemon struct { + server *Server + config *config.Manager + stopCh chan struct{} + cleanupCh chan struct{} +} + +// New creates a new daemon instance. +func New(configPath string) (*Daemon, error) { + cfgManager := config.NewManager(configPath) + + // Try to load config, create default if it doesn't exist + if err := cfgManager.Load(); err != nil { + if os.IsNotExist(err) { + if err := config.CreateDefault(configPath); err != nil { + return nil, fmt.Errorf("failed to create default config: %w", err) + } + if err := cfgManager.Load(); err != nil { + return nil, fmt.Errorf("failed to load default config: %w", err) + } + } else { + return nil, fmt.Errorf("failed to load config: %w", err) + } + } + + // Ensure at least one group exists + cfg := cfgManager.Get() + if cfg != nil { + cfg.EnsureDefaultGroup() + // Save if we added a default group + if len(cfg.Groups) == 1 && cfg.Groups[0].Name == "default" && len(cfg.Groups[0].Hosts) == 0 { + cfgManager.Save() + } + } + + server := NewServer(protocol.SocketPath, cfgManager) + + return &Daemon{ + server: server, + config: cfgManager, + stopCh: make(chan struct{}), + cleanupCh: make(chan struct{}), + }, nil +} + +// Run starts the daemon and blocks until stopped. +func (d *Daemon) Run() error { + // Verify we're running as root + if os.Geteuid() != 0 { + return fmt.Errorf("daemon must run as root") + } + + // Start the server + if err := d.server.Start(); err != nil { + return fmt.Errorf("failed to start server: %w", err) + } + + // Watch config for changes + if err := d.config.Watch(d.onConfigChange); err != nil { + fmt.Fprintf(os.Stderr, "warning: failed to watch config: %v\n", err) + } + + // Start cleanup goroutine + go d.cleanupLoop() + + // Wait for shutdown signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + select { + case <-sigCh: + fmt.Println("Received shutdown signal") + case <-d.stopCh: + fmt.Println("Shutdown requested") + } + + return d.shutdown() +} + +// Stop signals the daemon to stop. +func (d *Daemon) Stop() { + close(d.stopCh) +} + +func (d *Daemon) shutdown() error { + close(d.cleanupCh) + d.config.Stop() + + if err := d.server.Stop(); err != nil { + return fmt.Errorf("failed to stop server: %w", err) + } + + return nil +} + +func (d *Daemon) onConfigChange(cfg *config.Config) { + fmt.Println("Config changed, syncing hosts file...") + // The server will use the updated config on next request + // We could trigger a sync here if autoApply is enabled + if cfg != nil && cfg.Settings.AutoApply { + // Sync hosts file with new config + // This is handled by the server internally + } +} + +func (d *Daemon) cleanupLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + d.server.rateLimiter.Cleanup() + case <-d.cleanupCh: + return + } + } +} diff --git a/internal/daemon/dns.go b/internal/daemon/dns.go new file mode 100644 index 0000000..0397576 --- /dev/null +++ b/internal/daemon/dns.go @@ -0,0 +1,142 @@ +// Package daemon provides DNS cache flushing functionality. +package daemon + +import ( + "fmt" + "os/exec" + "runtime" +) + +// DNSFlusher handles DNS cache flushing. +type DNSFlusher struct { + method FlushMethod +} + +// FlushMethod defines the DNS flush method to use. +type FlushMethod string + +const ( + FlushMethodAuto FlushMethod = "auto" + FlushMethodDscacheutil FlushMethod = "dscacheutil" + FlushMethodKillall FlushMethod = "killall" + FlushMethodBoth FlushMethod = "both" + FlushMethodSystemd FlushMethod = "systemd" + FlushMethodNscd FlushMethod = "nscd" +) + +// NewDNSFlusher creates a new DNS flusher. +func NewDNSFlusher(method FlushMethod) *DNSFlusher { + return &DNSFlusher{method: method} +} + +// Flush flushes the DNS cache using the configured method. +func (f *DNSFlusher) Flush() error { + method := f.method + if method == FlushMethodAuto || method == "" { + method = f.detectMethod() + } + + switch runtime.GOOS { + case "darwin": + return f.flushDarwin(method) + case "linux": + return f.flushLinux(method) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +func (f *DNSFlusher) detectMethod() FlushMethod { + switch runtime.GOOS { + case "darwin": + return FlushMethodBoth + case "linux": + // Check for systemd-resolve first + if _, err := exec.LookPath("systemd-resolve"); err == nil { + return FlushMethodSystemd + } + if _, err := exec.LookPath("resolvectl"); err == nil { + return FlushMethodSystemd + } + // Fall back to nscd + if _, err := exec.LookPath("nscd"); err == nil { + return FlushMethodNscd + } + return FlushMethodAuto + default: + return FlushMethodAuto + } +} + +func (f *DNSFlusher) flushDarwin(method FlushMethod) error { + var errs []error + + switch method { + case FlushMethodDscacheutil: + if err := runCommand("dscacheutil", "-flushcache"); err != nil { + return fmt.Errorf("dscacheutil failed: %w", err) + } + case FlushMethodKillall: + if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil { + return fmt.Errorf("killall mDNSResponder failed: %w", err) + } + case FlushMethodBoth: + if err := runCommand("dscacheutil", "-flushcache"); err != nil { + errs = append(errs, fmt.Errorf("dscacheutil failed: %w", err)) + } + if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil { + errs = append(errs, fmt.Errorf("killall mDNSResponder failed: %w", err)) + } + if len(errs) == 2 { + return fmt.Errorf("all DNS flush methods failed: %v, %v", errs[0], errs[1]) + } + default: + // Auto - try both + _ = runCommand("dscacheutil", "-flushcache") + _ = runCommand("killall", "-HUP", "mDNSResponder") + } + + return nil +} + +func (f *DNSFlusher) flushLinux(method FlushMethod) error { + switch method { + case FlushMethodSystemd: + // Try resolvectl first (newer), then systemd-resolve (older) + if err := runCommand("resolvectl", "flush-caches"); err != nil { + if err := runCommand("systemd-resolve", "--flush-caches"); err != nil { + return fmt.Errorf("systemd DNS flush failed: %w", err) + } + } + case FlushMethodNscd: + // Try to restart nscd + if err := runCommand("nscd", "-i", "hosts"); err != nil { + // Try service restart as fallback + if err := runCommand("service", "nscd", "restart"); err != nil { + return fmt.Errorf("nscd flush failed: %w", err) + } + } + default: + // Auto - try all methods + // Try systemd first + if err := runCommand("resolvectl", "flush-caches"); err == nil { + return nil + } + if err := runCommand("systemd-resolve", "--flush-caches"); err == nil { + return nil + } + // Try nscd + if err := runCommand("nscd", "-i", "hosts"); err == nil { + return nil + } + // On many Linux systems, no explicit flush is needed as /etc/hosts is read directly + // So we return nil here + } + + return nil +} + +func runCommand(name string, args ...string) error { + cmd := exec.Command(name, args...) + return cmd.Run() +} diff --git a/internal/daemon/dns_test.go b/internal/daemon/dns_test.go new file mode 100644 index 0000000..2d647f4 --- /dev/null +++ b/internal/daemon/dns_test.go @@ -0,0 +1,108 @@ +package daemon + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewDNSFlusher(t *testing.T) { + tests := []FlushMethod{ + FlushMethodAuto, + FlushMethodDscacheutil, + FlushMethodKillall, + FlushMethodBoth, + FlushMethodSystemd, + FlushMethodNscd, + } + + for _, method := range tests { + t.Run(string(method), func(t *testing.T) { + flusher := NewDNSFlusher(method) + assert.NotNil(t, flusher) + assert.Equal(t, method, flusher.method) + }) + } +} + +func TestDNSFlusher_DetectMethod(t *testing.T) { + flusher := NewDNSFlusher(FlushMethodAuto) + + method := flusher.detectMethod() + + switch runtime.GOOS { + case "darwin": + assert.Equal(t, FlushMethodBoth, method) + case "linux": + // Could be systemd, nscd, or auto depending on system + assert.Contains(t, []FlushMethod{FlushMethodSystemd, FlushMethodNscd, FlushMethodAuto}, method) + } +} + +func TestFlushMethod_String(t *testing.T) { + methods := map[FlushMethod]string{ + FlushMethodAuto: "auto", + FlushMethodDscacheutil: "dscacheutil", + FlushMethodKillall: "killall", + FlushMethodBoth: "both", + FlushMethodSystemd: "systemd", + FlushMethodNscd: "nscd", + } + + for method, expected := range methods { + t.Run(expected, func(t *testing.T) { + assert.Equal(t, expected, string(method)) + }) + } +} + +// Note: Actually testing DNS flush requires root and modifies system state, +// so we skip those tests in unit tests. They would be integration tests. + +func TestDNSFlusher_Flush_UnsupportedOS(t *testing.T) { + // This test only makes sense if we're not on darwin or linux + if runtime.GOOS == "darwin" || runtime.GOOS == "linux" { + t.Skip("Test only applicable on unsupported OS") + } + + flusher := NewDNSFlusher(FlushMethodAuto) + err := flusher.Flush() + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported operating system") +} + +// Matrix test for flush methods +func TestFlushMethod_Matrix(t *testing.T) { + methods := []FlushMethod{ + FlushMethodAuto, + FlushMethodDscacheutil, + FlushMethodKillall, + FlushMethodBoth, + FlushMethodSystemd, + FlushMethodNscd, + } + + platforms := []string{"darwin", "linux"} + + for _, method := range methods { + for _, platform := range platforms { + t.Run(string(method)+"_"+platform, func(t *testing.T) { + flusher := NewDNSFlusher(method) + assert.NotNil(t, flusher) + + // Just verify no panic when checking method + _ = flusher.method + }) + } + } +} + +func BenchmarkDNSFlusher_DetectMethod(b *testing.B) { + flusher := NewDNSFlusher(FlushMethodAuto) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = flusher.detectMethod() + } +} diff --git a/internal/daemon/hosts.go b/internal/daemon/hosts.go new file mode 100644 index 0000000..c43fe28 --- /dev/null +++ b/internal/daemon/hosts.go @@ -0,0 +1,319 @@ +// Package daemon implements the privileged daemon that manages /etc/hosts. +package daemon + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + "time" +) + +const ( + // HostsPath is the path to the system hosts file. + HostsPath = "/etc/hosts" + // BackupDir is the directory for hosts file backups. + BackupDir = "/var/backups/lolcathost" + // MaxBackups is the maximum number of backups to keep. + MaxBackups = 10 + + // Markers for the managed section. + markerStart = "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========" + markerEnd = "# ========== END LOLCATHOST ==========" +) + +// HostEntry represents a single entry in the hosts file. +type HostEntry struct { + IP string + Domain string + Alias string + Enabled bool +} + +// HostsManager handles reading and writing the hosts file. +type HostsManager struct { + hostsPath string + backupDir string +} + +// NewHostsManager creates a new hosts manager. +func NewHostsManager() *HostsManager { + return &HostsManager{ + hostsPath: HostsPath, + backupDir: BackupDir, + } +} + +// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing). +func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager { + return &HostsManager{ + hostsPath: hostsPath, + backupDir: backupDir, + } +} + +// ReadManagedEntries reads the lolcathost-managed entries from the hosts file. +func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) { + file, err := os.Open(m.hostsPath) + if err != nil { + return nil, fmt.Errorf("failed to open hosts file: %w", err) + } + defer file.Close() + + var entries []HostEntry + inManagedSection := false + scanner := bufio.NewScanner(file) + entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + if line == markerStart { + inManagedSection = true + continue + } + if line == markerEnd { + inManagedSection = false + continue + } + + if inManagedSection && !strings.HasPrefix(line, "#") && line != "" { + matches := entryRegex.FindStringSubmatch(line) + if len(matches) == 4 { + entries = append(entries, HostEntry{ + IP: matches[1], + Domain: matches[2], + Alias: matches[3], + Enabled: true, + }) + } + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("failed to read hosts file: %w", err) + } + + return entries, nil +} + +// WriteManagedEntries writes the managed entries to the hosts file. +func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error { + // Create backup first + if err := m.CreateBackup(); err != nil { + return fmt.Errorf("failed to create backup: %w", err) + } + + // Read existing content + content, err := os.ReadFile(m.hostsPath) + if err != nil { + return fmt.Errorf("failed to read hosts file: %w", err) + } + + // Remove existing managed section + newContent := m.removeManagedSection(string(content)) + + // Build new managed section + managedSection := m.buildManagedSection(entries) + + // Append managed section + newContent = strings.TrimRight(newContent, "\n") + "\n\n" + managedSection + + // Write atomically + if err := m.writeAtomic(newContent); err != nil { + return fmt.Errorf("failed to write hosts file: %w", err) + } + + return nil +} + +func (m *HostsManager) removeManagedSection(content string) string { + lines := strings.Split(content, "\n") + var result []string + inManagedSection := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == markerStart { + inManagedSection = true + continue + } + if trimmed == markerEnd { + inManagedSection = false + continue + } + if !inManagedSection { + result = append(result, line) + } + } + + // Remove trailing empty lines + for len(result) > 0 && strings.TrimSpace(result[len(result)-1]) == "" { + result = result[:len(result)-1] + } + + return strings.Join(result, "\n") +} + +func (m *HostsManager) buildManagedSection(entries []HostEntry) string { + var sb strings.Builder + sb.WriteString(markerStart) + sb.WriteString("\n") + + for _, entry := range entries { + if entry.Enabled { + sb.WriteString(fmt.Sprintf("%s\t%s\t# lolcathost:%s\n", entry.IP, entry.Domain, entry.Alias)) + } + } + + sb.WriteString(markerEnd) + sb.WriteString("\n") + + return sb.String() +} + +func (m *HostsManager) writeAtomic(content string) error { + // Write to temp file first + tmpFile := m.hostsPath + ".tmp" + if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil { + return err + } + + // Rename atomically + if err := os.Rename(tmpFile, m.hostsPath); err != nil { + os.Remove(tmpFile) + return err + } + + return nil +} + +// CreateBackup creates a backup of the current hosts file. +func (m *HostsManager) CreateBackup() error { + if err := os.MkdirAll(m.backupDir, 0755); err != nil { + return fmt.Errorf("failed to create backup directory: %w", err) + } + + content, err := os.ReadFile(m.hostsPath) + if err != nil { + return fmt.Errorf("failed to read hosts file: %w", err) + } + + timestamp := time.Now().Format("20060102-150405") + backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp)) + + if err := os.WriteFile(backupPath, content, 0644); err != nil { + return fmt.Errorf("failed to write backup: %w", err) + } + + // Cleanup old backups + if err := m.cleanupBackups(); err != nil { + // Log but don't fail + fmt.Fprintf(os.Stderr, "warning: failed to cleanup backups: %v\n", err) + } + + return nil +} + +func (m *HostsManager) cleanupBackups() error { + entries, err := os.ReadDir(m.backupDir) + if err != nil { + return err + } + + var backups []os.DirEntry + for _, entry := range entries { + if !entry.IsDir() && strings.HasPrefix(entry.Name(), "hosts.") && strings.HasSuffix(entry.Name(), ".bak") { + backups = append(backups, entry) + } + } + + if len(backups) <= MaxBackups { + return nil + } + + // Sort by name (timestamp) descending + sort.Slice(backups, func(i, j int) bool { + return backups[i].Name() > backups[j].Name() + }) + + // Remove oldest backups + for i := MaxBackups; i < len(backups); i++ { + path := filepath.Join(m.backupDir, backups[i].Name()) + os.Remove(path) + } + + return nil +} + +// ListBackups returns a list of available backups. +func (m *HostsManager) ListBackups() ([]BackupInfo, error) { + entries, err := os.ReadDir(m.backupDir) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + var backups []BackupInfo + for _, entry := range entries { + if entry.IsDir() || !strings.HasPrefix(entry.Name(), "hosts.") || !strings.HasSuffix(entry.Name(), ".bak") { + continue + } + + info, err := entry.Info() + if err != nil { + continue + } + + backups = append(backups, BackupInfo{ + Name: entry.Name(), + Timestamp: info.ModTime().Unix(), + Size: info.Size(), + }) + } + + // Sort by timestamp descending + sort.Slice(backups, func(i, j int) bool { + return backups[i].Timestamp > backups[j].Timestamp + }) + + return backups, nil +} + +// BackupInfo holds information about a backup file. +type BackupInfo struct { + Name string + Timestamp int64 + Size int64 +} + +// RestoreBackup restores a backup by name. +func (m *HostsManager) RestoreBackup(name string) error { + backupPath := filepath.Join(m.backupDir, name) + + // Validate backup name to prevent path traversal + if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") { + return fmt.Errorf("invalid backup name") + } + + content, err := os.ReadFile(backupPath) + if err != nil { + return fmt.Errorf("failed to read backup: %w", err) + } + + // Create a backup of current state before restoring + if err := m.CreateBackup(); err != nil { + return fmt.Errorf("failed to create backup before restore: %w", err) + } + + if err := m.writeAtomic(string(content)); err != nil { + return fmt.Errorf("failed to restore backup: %w", err) + } + + return nil +} diff --git a/internal/daemon/hosts_test.go b/internal/daemon/hosts_test.go new file mode 100644 index 0000000..97b57ea --- /dev/null +++ b/internal/daemon/hosts_test.go @@ -0,0 +1,422 @@ +package daemon + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHostsManager_ReadManagedEntries(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + + hostsContent := `127.0.0.1 localhost +255.255.255.255 broadcasthost +::1 localhost + +# ========== LOLCATHOST MANAGED - DO NOT EDIT ========== +127.0.0.1 example.com # lolcathost:example-local +192.168.1.1 api.example.com # lolcathost:api-local +# ========== END LOLCATHOST ========== +` + err := os.WriteFile(hostsPath, []byte(hostsContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups")) + entries, err := manager.ReadManagedEntries() + require.NoError(t, err) + + assert.Len(t, entries, 2) + assert.Equal(t, "127.0.0.1", entries[0].IP) + assert.Equal(t, "example.com", entries[0].Domain) + assert.Equal(t, "example-local", entries[0].Alias) + assert.Equal(t, "192.168.1.1", entries[1].IP) + assert.Equal(t, "api.example.com", entries[1].Domain) + assert.Equal(t, "api-local", entries[1].Alias) +} + +func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + + hostsContent := `127.0.0.1 localhost +255.255.255.255 broadcasthost +` + err := os.WriteFile(hostsPath, []byte(hostsContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups")) + entries, err := manager.ReadManagedEntries() + require.NoError(t, err) + + assert.Empty(t, entries) +} + +func TestHostsManager_WriteManagedEntries(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + // Create initial hosts file + initialContent := `127.0.0.1 localhost +255.255.255.255 broadcasthost +` + err := os.WriteFile(hostsPath, []byte(initialContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + entries := []HostEntry{ + {IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true}, + {IP: "127.0.0.1", Domain: "api.myapp.com", Alias: "api-local", Enabled: true}, + {IP: "192.168.1.1", Domain: "staging.myapp.com", Alias: "staging", Enabled: false}, + } + + err = manager.WriteManagedEntries(entries) + require.NoError(t, err) + + // Read back + content, err := os.ReadFile(hostsPath) + require.NoError(t, err) + + contentStr := string(content) + assert.Contains(t, contentStr, "127.0.0.1\tlocalhost") + assert.Contains(t, contentStr, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========") + assert.Contains(t, contentStr, "127.0.0.1\tmyapp.com\t# lolcathost:myapp-local") + assert.Contains(t, contentStr, "127.0.0.1\tapi.myapp.com\t# lolcathost:api-local") + assert.NotContains(t, contentStr, "staging.myapp.com") // disabled + assert.Contains(t, contentStr, "# ========== END LOLCATHOST ==========") +} + +func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + // Create hosts file with existing managed section + initialContent := `127.0.0.1 localhost + +# ========== LOLCATHOST MANAGED - DO NOT EDIT ========== +127.0.0.1 old.com # lolcathost:old +# ========== END LOLCATHOST ========== +` + err := os.WriteFile(hostsPath, []byte(initialContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + entries := []HostEntry{ + {IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true}, + } + + err = manager.WriteManagedEntries(entries) + require.NoError(t, err) + + content, err := os.ReadFile(hostsPath) + require.NoError(t, err) + + contentStr := string(content) + assert.Contains(t, contentStr, "127.0.0.1\tlocalhost") + assert.Contains(t, contentStr, "new.com") + assert.NotContains(t, contentStr, "old.com") +} + +func TestHostsManager_CreateBackup(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + hostsContent := "127.0.0.1\tlocalhost\n" + err := os.WriteFile(hostsPath, []byte(hostsContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + err = manager.CreateBackup() + require.NoError(t, err) + + // Verify backup exists + entries, err := os.ReadDir(backupDir) + require.NoError(t, err) + assert.Len(t, entries, 1) + assert.True(t, strings.HasPrefix(entries[0].Name(), "hosts.")) + assert.True(t, strings.HasSuffix(entries[0].Name(), ".bak")) + + // Verify backup content + backupContent, err := os.ReadFile(filepath.Join(backupDir, entries[0].Name())) + require.NoError(t, err) + assert.Equal(t, hostsContent, string(backupContent)) +} + +func TestHostsManager_ListBackups(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + // Create hosts file + err := os.WriteFile(hostsPath, []byte("localhost"), 0644) + require.NoError(t, err) + + // Manually create backup files with different timestamps + err = os.MkdirAll(backupDir, 0755) + require.NoError(t, err) + + backupNames := []string{ + "hosts.20231201-120000.bak", + "hosts.20231201-120001.bak", + "hosts.20231201-120002.bak", + } + for _, name := range backupNames { + err = os.WriteFile(filepath.Join(backupDir, name), []byte("backup"), 0644) + require.NoError(t, err) + } + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + backups, err := manager.ListBackups() + require.NoError(t, err) + assert.Len(t, backups, 3) +} + +func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "nonexistent") + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + backups, err := manager.ListBackups() + require.NoError(t, err) + assert.Empty(t, backups) +} + +func TestHostsManager_RestoreBackup(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + // Create initial hosts file + initialContent := "initial content" + err := os.WriteFile(hostsPath, []byte(initialContent), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + // Create backup + err = manager.CreateBackup() + require.NoError(t, err) + + // Modify hosts file + err = os.WriteFile(hostsPath, []byte("modified content"), 0644) + require.NoError(t, err) + + // Get backup name + backups, err := manager.ListBackups() + require.NoError(t, err) + require.Len(t, backups, 1) + + // Restore + err = manager.RestoreBackup(backups[0].Name) + require.NoError(t, err) + + // Verify content restored + content, err := os.ReadFile(hostsPath) + require.NoError(t, err) + assert.Equal(t, initialContent, string(content)) +} + +func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) { + tmpDir := t.TempDir() + manager := NewHostsManagerWithPaths( + filepath.Join(tmpDir, "hosts"), + filepath.Join(tmpDir, "backups"), + ) + + tests := []string{ + "../../../etc/passwd", + "hosts.bak", // Missing timestamp + "notahosts.backup", // Wrong format + "", + } + + for _, name := range tests { + t.Run(name, func(t *testing.T) { + err := manager.RestoreBackup(name) + assert.Error(t, err) + }) + } +} + +func TestHostsManager_CleanupBackups(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + err := os.WriteFile(hostsPath, []byte("localhost"), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + // Create more than MaxBackups + for i := 0; i < MaxBackups+5; i++ { + err = manager.CreateBackup() + require.NoError(t, err) + } + + // Verify only MaxBackups remain + backups, err := manager.ListBackups() + require.NoError(t, err) + assert.LessOrEqual(t, len(backups), MaxBackups) +} + +func TestHostsManager_RemoveManagedSection(t *testing.T) { + manager := &HostsManager{} + + tests := []struct { + name string + input string + expected string + }{ + { + name: "with managed section", + input: `127.0.0.1 localhost + +# ========== LOLCATHOST MANAGED - DO NOT EDIT ========== +127.0.0.1 example.com # lolcathost:test +# ========== END LOLCATHOST ========== +`, + expected: "127.0.0.1\tlocalhost", + }, + { + name: "without managed section", + input: "127.0.0.1\tlocalhost\n", + expected: "127.0.0.1\tlocalhost", + }, + { + name: "multiple managed sections", + input: `127.0.0.1 localhost +# ========== LOLCATHOST MANAGED - DO NOT EDIT ========== +entry1 +# ========== END LOLCATHOST ========== +more content +# ========== LOLCATHOST MANAGED - DO NOT EDIT ========== +entry2 +# ========== END LOLCATHOST ========== +`, + expected: "127.0.0.1\tlocalhost\nmore content", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := manager.removeManagedSection(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHostsManager_BuildManagedSection(t *testing.T) { + manager := &HostsManager{} + + entries := []HostEntry{ + {IP: "127.0.0.1", Domain: "a.com", Alias: "a", Enabled: true}, + {IP: "192.168.1.1", Domain: "b.com", Alias: "b", Enabled: true}, + {IP: "10.0.0.1", Domain: "c.com", Alias: "c", Enabled: false}, + } + + result := manager.buildManagedSection(entries) + + assert.Contains(t, result, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========") + assert.Contains(t, result, "127.0.0.1\ta.com\t# lolcathost:a") + assert.Contains(t, result, "192.168.1.1\tb.com\t# lolcathost:b") + assert.NotContains(t, result, "c.com") // disabled + assert.Contains(t, result, "# ========== END LOLCATHOST ==========") +} + +// Matrix tests for hosts file parsing +func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) { + ips := []string{"127.0.0.1", "192.168.1.1", "::1"} + domains := []string{"example.com", "sub.example.com", "my-app.test"} + aliases := []string{"test", "my-alias", "app-1"} + + for _, ip := range ips { + for _, domain := range domains { + for _, alias := range aliases { + t.Run(ip+"/"+domain+"/"+alias, func(t *testing.T) { + tmpDir := t.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + + content := "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n" + content += ip + "\t" + domain + "\t# lolcathost:" + alias + "\n" + content += "# ========== END LOLCATHOST ==========\n" + + err := os.WriteFile(hostsPath, []byte(content), 0644) + require.NoError(t, err) + + manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups")) + entries, err := manager.ReadManagedEntries() + require.NoError(t, err) + require.Len(t, entries, 1) + + assert.Equal(t, ip, entries[0].IP) + assert.Equal(t, domain, entries[0].Domain) + assert.Equal(t, alias, entries[0].Alias) + }) + } + } + } +} + +func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) { + tmpDir := b.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + + // Create a hosts file with many entries + var content strings.Builder + content.WriteString("127.0.0.1\tlocalhost\n") + content.WriteString("# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n") + for i := 0; i < 100; i++ { + content.WriteString("127.0.0.1\texample" + string(rune('a'+i%26)) + ".com\t# lolcathost:alias" + string(rune('a'+i%26)) + "\n") + } + content.WriteString("# ========== END LOLCATHOST ==========\n") + + err := os.WriteFile(hostsPath, []byte(content.String()), 0644) + require.NoError(b, err) + + manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups")) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = manager.ReadManagedEntries() + } +} + +func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) { + tmpDir := b.TempDir() + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644) + require.NoError(b, err) + + manager := NewHostsManagerWithPaths(hostsPath, backupDir) + + entries := make([]HostEntry, 50) + for i := range entries { + entries[i] = HostEntry{ + IP: "127.0.0.1", + Domain: "example" + string(rune('a'+i%26)) + ".com", + Alias: "alias" + string(rune('a'+i%26)), + Enabled: i%2 == 0, + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = manager.WriteManagedEntries(entries) + } +} diff --git a/internal/daemon/peercred_darwin.go b/internal/daemon/peercred_darwin.go new file mode 100644 index 0000000..c9100c7 --- /dev/null +++ b/internal/daemon/peercred_darwin.go @@ -0,0 +1,57 @@ +//go:build darwin + +package daemon + +import ( + "net" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +// getPeerCredentials extracts peer credentials from a Unix socket connection on macOS. +// Note: macOS Xucred doesn't include PID, so we use LOCAL_PEERPID separately. +func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials { + unixConn, ok := conn.(*net.UnixConn) + if !ok { + return nil + } + + rawConn, err := unixConn.SyscallConn() + if err != nil { + return nil + } + + var creds *PeerCredentials + rawConn.Control(func(fd uintptr) { + xucred, err := unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED) + if err != nil { + return + } + + // Get PID separately using LOCAL_PEERPID + var pid int32 + pidLen := uint32(unsafe.Sizeof(pid)) + _, _, errno := syscall.Syscall6( + syscall.SYS_GETSOCKOPT, + fd, + unix.SOL_LOCAL, + 0x002, // LOCAL_PEERPID + uintptr(unsafe.Pointer(&pid)), + uintptr(unsafe.Pointer(&pidLen)), + 0, + ) + if errno != 0 { + pid = 0 + } + + creds = &PeerCredentials{ + UID: xucred.Uid, + GID: xucred.Groups[0], + PID: pid, + } + }) + + return creds +} diff --git a/internal/daemon/peercred_linux.go b/internal/daemon/peercred_linux.go new file mode 100644 index 0000000..ad9f21f --- /dev/null +++ b/internal/daemon/peercred_linux.go @@ -0,0 +1,37 @@ +//go:build linux + +package daemon + +import ( + "net" + + "golang.org/x/sys/unix" +) + +// getPeerCredentials extracts peer credentials from a Unix socket connection on Linux. +func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials { + unixConn, ok := conn.(*net.UnixConn) + if !ok { + return nil + } + + rawConn, err := unixConn.SyscallConn() + if err != nil { + return nil + } + + var creds *PeerCredentials + rawConn.Control(func(fd uintptr) { + ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED) + if err != nil { + return + } + creds = &PeerCredentials{ + UID: ucred.Uid, + GID: ucred.Gid, + PID: ucred.Pid, + } + }) + + return creds +} diff --git a/internal/daemon/security.go b/internal/daemon/security.go new file mode 100644 index 0000000..7659dad --- /dev/null +++ b/internal/daemon/security.go @@ -0,0 +1,196 @@ +// Package daemon provides security functions including rate limiting and audit logging. +package daemon + +import ( + "encoding/json" + "fmt" + "os" + "os/user" + "sync" + "time" +) + +const ( + // AuditLogPath is the path to the audit log file. + AuditLogPath = "/var/log/lolcathost/audit.log" + // RateLimit is the maximum requests per minute per PID. + RateLimit = 100 + // RateLimitWindow is the time window for rate limiting. + RateLimitWindow = time.Minute +) + +// RateLimiter implements per-PID rate limiting. +type RateLimiter struct { + mu sync.Mutex + requests map[int32][]time.Time + limit int + window time.Duration +} + +// NewRateLimiter creates a new rate limiter. +func NewRateLimiter(limit int, window time.Duration) *RateLimiter { + return &RateLimiter{ + requests: make(map[int32][]time.Time), + limit: limit, + window: window, + } +} + +// Allow checks if a request from the given PID should be allowed. +func (r *RateLimiter) Allow(pid int32) bool { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-r.window) + + // Get existing requests for this PID + reqs := r.requests[pid] + + // Filter out old requests + var validReqs []time.Time + for _, t := range reqs { + if t.After(cutoff) { + validReqs = append(validReqs, t) + } + } + + // Check if under limit + if len(validReqs) >= r.limit { + r.requests[pid] = validReqs + return false + } + + // Add new request + validReqs = append(validReqs, now) + r.requests[pid] = validReqs + + return true +} + +// Cleanup removes old entries from the rate limiter. +func (r *RateLimiter) Cleanup() { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-r.window) + + for pid, reqs := range r.requests { + var validReqs []time.Time + for _, t := range reqs { + if t.After(cutoff) { + validReqs = append(validReqs, t) + } + } + if len(validReqs) == 0 { + delete(r.requests, pid) + } else { + r.requests[pid] = validReqs + } + } +} + +// AuditLogger handles audit logging. +type AuditLogger struct { + mu sync.Mutex + file *os.File + path string + encoder *json.Encoder +} + +// AuditEntry represents a single audit log entry. +type AuditEntry struct { + Timestamp string `json:"timestamp"` + UID uint32 `json:"uid"` + PID int32 `json:"pid"` + Action string `json:"action"` + Details any `json:"details,omitempty"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// NewAuditLogger creates a new audit logger. +func NewAuditLogger(path string) (*AuditLogger, error) { + // Ensure directory exists + dir := path[:len(path)-len("/audit.log")] + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create log directory: %w", err) + } + + file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + return nil, fmt.Errorf("failed to open audit log: %w", err) + } + + return &AuditLogger{ + file: file, + path: path, + encoder: json.NewEncoder(file), + }, nil +} + +// Log writes an audit entry. +func (a *AuditLogger) Log(uid uint32, pid int32, action string, details any, success bool, errMsg string) { + a.mu.Lock() + defer a.mu.Unlock() + + entry := AuditEntry{ + Timestamp: time.Now().UTC().Format(time.RFC3339), + UID: uid, + PID: pid, + Action: action, + Details: details, + Success: success, + Error: errMsg, + } + + // Ignore encoding errors - audit logging should not fail the operation + _ = a.encoder.Encode(entry) +} + +// Close closes the audit logger. +func (a *AuditLogger) Close() error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.file != nil { + err := a.file.Close() + a.file = nil // Prevent double close + return err + } + return nil +} + +// PeerCredentials holds the credentials of a connected peer. +type PeerCredentials struct { + UID uint32 + GID uint32 + PID int32 +} + +// isUserInGroup checks if a user (by UID) is a member of a group (by GID). +// This checks supplementary groups, not just the primary GID. +func isUserInGroup(uid uint32, targetGID uint32) bool { + // Look up user by UID + u, err := user.LookupId(fmt.Sprintf("%d", uid)) + if err != nil { + return false + } + + // Get user's group IDs + groupIDs, err := u.GroupIds() + if err != nil { + return false + } + + // Check if target GID is in the list + targetGIDStr := fmt.Sprintf("%d", targetGID) + for _, gid := range groupIDs { + if gid == targetGIDStr { + return true + } + } + + return false +} diff --git a/internal/daemon/security_test.go b/internal/daemon/security_test.go new file mode 100644 index 0000000..4bf1ca4 --- /dev/null +++ b/internal/daemon/security_test.go @@ -0,0 +1,206 @@ +package daemon + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRateLimiter_Allow(t *testing.T) { + t.Run("under limit", func(t *testing.T) { + rl := NewRateLimiter(5, time.Minute) + + for i := 0; i < 5; i++ { + assert.True(t, rl.Allow(123), "request %d should be allowed", i) + } + }) + + t.Run("over limit", func(t *testing.T) { + rl := NewRateLimiter(3, time.Minute) + + for i := 0; i < 3; i++ { + assert.True(t, rl.Allow(123)) + } + + // 4th request should be blocked + assert.False(t, rl.Allow(123)) + }) + + t.Run("different PIDs", func(t *testing.T) { + rl := NewRateLimiter(2, time.Minute) + + // PID 1 + assert.True(t, rl.Allow(1)) + assert.True(t, rl.Allow(1)) + assert.False(t, rl.Allow(1)) + + // PID 2 should have its own limit + assert.True(t, rl.Allow(2)) + assert.True(t, rl.Allow(2)) + assert.False(t, rl.Allow(2)) + }) + + t.Run("window expiration", func(t *testing.T) { + rl := NewRateLimiter(2, 10*time.Millisecond) + + assert.True(t, rl.Allow(123)) + assert.True(t, rl.Allow(123)) + assert.False(t, rl.Allow(123)) + + // Wait for window to expire + time.Sleep(15 * time.Millisecond) + + // Should be allowed again + assert.True(t, rl.Allow(123)) + }) +} + +func TestRateLimiter_Cleanup(t *testing.T) { + rl := NewRateLimiter(10, 10*time.Millisecond) + + // Add requests from multiple PIDs + for pid := int32(1); pid <= 5; pid++ { + rl.Allow(pid) + } + + assert.Len(t, rl.requests, 5) + + // Wait for expiration + time.Sleep(15 * time.Millisecond) + + // Cleanup + rl.Cleanup() + + assert.Empty(t, rl.requests) +} + +func TestAuditLogger_Log(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + logger, err := NewAuditLogger(logPath) + require.NoError(t, err) + defer logger.Close() + + logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "") + logger.Log(1000, 12345, "sync", nil, false, "sync failed") + + // Read log file + content, err := os.ReadFile(logPath) + require.NoError(t, err) + + contentStr := string(content) + assert.Contains(t, contentStr, `"action":"set"`) + assert.Contains(t, contentStr, `"uid":1000`) + assert.Contains(t, contentStr, `"pid":12345`) + assert.Contains(t, contentStr, `"success":true`) + assert.Contains(t, contentStr, `"action":"sync"`) + assert.Contains(t, contentStr, `"success":false`) + assert.Contains(t, contentStr, `"error":"sync failed"`) +} + +func TestAuditLogger_CreatesDirectory(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "subdir", "audit.log") + + logger, err := NewAuditLogger(logPath) + require.NoError(t, err) + defer logger.Close() + + // Verify directory was created + _, err = os.Stat(filepath.Dir(logPath)) + assert.NoError(t, err) +} + +func TestAuditLogger_Close(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + logger, err := NewAuditLogger(logPath) + require.NoError(t, err) + + err = logger.Close() + assert.NoError(t, err) + + // Closing again should not error + err = logger.Close() + assert.NoError(t, err) +} + +func TestPeerCredentials(t *testing.T) { + creds := &PeerCredentials{ + UID: 501, + GID: 20, + PID: 12345, + } + + assert.Equal(t, uint32(501), creds.UID) + assert.Equal(t, uint32(20), creds.GID) + assert.Equal(t, int32(12345), creds.PID) +} + +// Matrix test for rate limiting +func TestRateLimiter_Matrix(t *testing.T) { + limits := []int{1, 5, 10, 100} + windows := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, time.Second} + + for _, limit := range limits { + for _, window := range windows { + t.Run( + "limit="+string(rune('0'+limit))+"_window="+window.String(), + func(t *testing.T) { + rl := NewRateLimiter(limit, window) + + // Should allow exactly 'limit' requests + for i := 0; i < limit; i++ { + assert.True(t, rl.Allow(1)) + } + + // Next should be blocked + assert.False(t, rl.Allow(1)) + }, + ) + } + } +} + +func BenchmarkRateLimiter_Allow(b *testing.B) { + rl := NewRateLimiter(RateLimit, RateLimitWindow) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rl.Allow(int32(i % 100)) + } +} + +func BenchmarkRateLimiter_Cleanup(b *testing.B) { + rl := NewRateLimiter(RateLimit, RateLimitWindow) + + // Pre-populate with requests + for i := 0; i < 1000; i++ { + rl.Allow(int32(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rl.Cleanup() + } +} + +func BenchmarkAuditLogger_Log(b *testing.B) { + tmpDir := b.TempDir() + logPath := filepath.Join(tmpDir, "audit.log") + + logger, err := NewAuditLogger(logPath) + require.NoError(b, err) + defer logger.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "") + } +} diff --git a/internal/daemon/server.go b/internal/daemon/server.go new file mode 100644 index 0000000..f1aa6cd --- /dev/null +++ b/internal/daemon/server.go @@ -0,0 +1,803 @@ +// Package daemon provides the Unix socket server for the daemon. +package daemon + +import ( + "bufio" + "encoding/json" + "fmt" + "net" + "os" + "sync" + "time" + + "github.com/lukaszraczylo/lolcathost/internal/config" + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// Version is set by the main package at startup +var Version = "dev" + +// Server is the daemon's Unix socket server. +type Server struct { + socketPath string + listener net.Listener + config *config.Manager + hosts *HostsManager + flusher *DNSFlusher + rateLimiter *RateLimiter + auditLogger *AuditLogger + mu sync.RWMutex + running bool + stopCh chan struct{} + requestCount int64 + startTime int64 +} + +// NewServer creates a new daemon server. +func NewServer(socketPath string, cfgManager *config.Manager) *Server { + return &Server{ + socketPath: socketPath, + config: cfgManager, + hosts: NewHostsManager(), + flusher: NewDNSFlusher(FlushMethodAuto), + rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow), + stopCh: make(chan struct{}), + } +} + +// Start starts the server. +func (s *Server) Start() error { + // Remove existing socket + os.Remove(s.socketPath) + + listener, err := net.Listen("unix", s.socketPath) + if err != nil { + return fmt.Errorf("failed to listen on socket: %w", err) + } + + // Set socket permissions: 0660 root:lolcathost + if err := os.Chmod(s.socketPath, 0660); err != nil { + listener.Close() + return fmt.Errorf("failed to set socket permissions: %w", err) + } + + // Set socket group to lolcathost (GID 850) + if err := os.Chown(s.socketPath, 0, 850); err != nil { + listener.Close() + return fmt.Errorf("failed to set socket ownership: %w", err) + } + + s.listener = listener + s.running = true + s.startTime = currentTimeUnix() + + // Try to create audit logger, but don't fail if it doesn't work + if logger, err := NewAuditLogger(AuditLogPath); err == nil { + s.auditLogger = logger + } + + go s.acceptLoop() + + return nil +} + +func currentTimeUnix() int64 { + return time.Now().Unix() +} + +// Stop stops the server. +func (s *Server) Stop() error { + s.mu.Lock() + s.running = false + s.mu.Unlock() + + close(s.stopCh) + + if s.listener != nil { + s.listener.Close() + } + + os.Remove(s.socketPath) + + if s.auditLogger != nil { + s.auditLogger.Close() + } + + return nil +} + +func (s *Server) acceptLoop() { + for { + conn, err := s.listener.Accept() + if err != nil { + select { + case <-s.stopCh: + return + default: + continue + } + } + + go s.handleConnection(conn) + } +} + +// LolcathostGID is the group ID for the lolcathost group. +const LolcathostGID = 850 + +func (s *Server) handleConnection(conn net.Conn) { + defer conn.Close() + + // Get peer credentials + creds := s.getPeerCredentials(conn) + + // Authorization check: verify peer is authorized + if !s.isAuthorized(creds) { + s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group")) + if s.auditLogger != nil { + var uid uint32 + var pid int32 + if creds != nil { + uid = creds.UID + pid = creds.PID + } + s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt") + } + return + } + + reader := bufio.NewReader(conn) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + return + } + + var req protocol.Request + if err := json.Unmarshal(line, &req); err != nil { + s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON")) + continue + } + + // Rate limiting + if creds != nil && !s.rateLimiter.Allow(creds.PID) { + s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded")) + continue + } + + s.mu.Lock() + s.requestCount++ + s.mu.Unlock() + + resp := s.handleRequest(&req, creds) + s.writeResponse(conn, resp) + } +} + +// isAuthorized checks if the peer is authorized to access the daemon. +// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850). +func (s *Server) isAuthorized(creds *PeerCredentials) bool { + if creds == nil { + // Can't verify credentials - deny by default + return false + } + + // Root is always authorized + if creds.UID == 0 { + return true + } + + // Check if user's primary GID is lolcathost + if creds.GID == LolcathostGID { + return true + } + + // Check supplementary groups (user might be in lolcathost as secondary group) + // This requires looking up the user's groups from the system + return isUserInGroup(creds.UID, LolcathostGID) +} + +func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) { + data, _ := json.Marshal(resp) + data = append(data, '\n') + conn.Write(data) +} + +func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response { + var uid uint32 + var pid int32 + if creds != nil { + uid = creds.UID + pid = creds.PID + } + + switch req.Type { + case protocol.RequestPing: + return s.handlePing() + + case protocol.RequestStatus: + return s.handleStatus() + + case protocol.RequestList: + return s.handleList() + + case protocol.RequestSet: + resp := s.handleSet(req) + if s.auditLogger != nil { + var payload protocol.SetPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestSync: + resp := s.handleSync() + if s.auditLogger != nil { + s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestPreset: + resp := s.handlePreset(req) + if s.auditLogger != nil { + var payload protocol.PresetPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestRollback: + resp := s.handleRollback(req) + if s.auditLogger != nil { + var payload protocol.RollbackPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestBackups: + return s.handleBackups() + + case protocol.RequestAdd: + resp := s.handleAdd(req) + if s.auditLogger != nil { + var payload protocol.AddPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestDelete: + resp := s.handleDelete(req) + if s.auditLogger != nil { + var payload protocol.DeletePayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestAddGroup: + resp := s.handleAddGroup(req) + if s.auditLogger != nil { + var payload protocol.GroupPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestDeleteGroup: + resp := s.handleDeleteGroup(req) + if s.auditLogger != nil { + var payload protocol.GroupPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestListGroups: + return s.handleListGroups() + + case protocol.RequestRenameGroup: + resp := s.handleRenameGroup(req) + if s.auditLogger != nil { + var payload protocol.RenameGroupPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestAddPreset: + resp := s.handleAddPreset(req) + if s.auditLogger != nil { + var payload protocol.AddPresetPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestDeletePreset: + resp := s.handleDeletePreset(req) + if s.auditLogger != nil { + var payload protocol.PresetPayload + _ = req.ParsePayload(&payload) + s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message) + } + return resp + + case protocol.RequestListPresets: + return s.handleListPresets() + + default: + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type)) + } +} + +func (s *Server) handlePing() *protocol.Response { + resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"}) + return resp +} + +func (s *Server) handleStatus() *protocol.Response { + s.mu.RLock() + reqCount := s.requestCount + startTime := s.startTime + s.mu.RUnlock() + + cfg := s.config.Get() + var activeCount int + if cfg != nil { + for _, h := range cfg.GetAllHosts() { + if h.Enabled { + activeCount++ + } + } + } + + data := protocol.StatusData{ + Running: true, + Version: Version, + Uptime: nowUnix() - startTime, + ActiveCount: activeCount, + RequestCount: reqCount, + } + + resp, _ := protocol.NewOKResponse(data) + return resp +} + +func nowUnix() int64 { + return time.Now().Unix() +} + +func (s *Server) handleList() *protocol.Response { + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + var entries []protocol.HostEntry + for _, g := range cfg.Groups { + for _, h := range g.Hosts { + entries = append(entries, protocol.HostEntry{ + Domain: h.Domain, + IP: h.IP, + Alias: h.Alias, + Enabled: h.Enabled, + Group: g.Name, + }) + } + } + + resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries}) + return resp +} + +func (s *Server) handleSet(req *protocol.Request) *protocol.Response { + var payload protocol.SetPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + host, _ := cfg.FindHostByAlias(payload.Alias) + if host == nil { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias)) + } + + // Check for conflicts if enabling + if payload.Enabled && !payload.Force { + for _, g := range cfg.Groups { + for _, h := range g.Hosts { + if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled { + return protocol.NewErrorResponse(protocol.ErrCodeConflict, + fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias)) + } + } + } + } + + // Update config + cfg.SetHostEnabled(payload.Alias, payload.Enabled) + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + // Sync to hosts file + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err)) + } + + resp, _ := protocol.NewOKResponse(protocol.SetData{ + Domain: host.Domain, + Applied: true, + }) + return resp +} + +func (s *Server) handleSync() *protocol.Response { + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true}) + return resp +} + +func (s *Server) handlePreset(req *protocol.Request) *protocol.Response { + var payload protocol.PresetPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.ApplyPreset(payload.Name); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + // Sync to hosts file + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"}) + return resp +} + +func (s *Server) handleRollback(req *protocol.Request) *protocol.Response { + var payload protocol.RollbackPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if err := s.hosts.RestoreBackup(payload.BackupName); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err)) + } + + // Flush DNS after restore + s.flusher.Flush() + + resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName}) + return resp +} + +func (s *Server) handleBackups() *protocol.Response { + backups, err := s.hosts.ListBackups() + if err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err)) + } + + var infos []protocol.BackupInfo + for _, b := range backups { + infos = append(infos, protocol.BackupInfo{ + Name: b.Name, + Timestamp: b.Timestamp, + Size: b.Size, + }) + } + + resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos}) + return resp +} + +func (s *Server) handleAdd(req *protocol.Request) *protocol.Response { + var payload protocol.AddPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + // Validate domain + if payload.Domain == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required") + } + + // Validate IP + if payload.IP == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required") + } + + // Validate group + if payload.Group == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required") + } + + // Check blocked domains + if config.IsBlockedDomain(payload.Domain) { + return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain)) + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + // Add to config (alias will be auto-generated if empty) + if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + // Sync to hosts file + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err)) + } + + resp, _ := protocol.NewOKResponse(protocol.SetData{ + Domain: payload.Domain, + Applied: true, + }) + return resp +} + +func (s *Server) handleDelete(req *protocol.Request) *protocol.Response { + var payload protocol.DeletePayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.Alias == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + // Delete from config + if !cfg.DeleteHost(payload.Alias) { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias)) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + // Sync to hosts file + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias}) + return resp +} + +func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response { + var payload protocol.GroupPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.Name == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.AddGroup(payload.Name); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name}) + return resp +} + +func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response { + var payload protocol.GroupPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.Name == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.DeleteGroup(payload.Name); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + // Sync to hosts file + if err := s.syncHostsFile(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name}) + return resp +} + +func (s *Server) handleListGroups() *protocol.Response { + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()}) + return resp +} + +func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response { + var payload protocol.RenameGroupPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.OldName == "" || payload.NewName == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName}) + return resp +} + +func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response { + var payload protocol.AddPresetPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.Name == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name}) + return resp +} + +func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response { + var payload protocol.PresetPayload + if err := req.ParsePayload(&payload); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload") + } + + if payload.Name == "" { + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required") + } + + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + if err := cfg.DeletePreset(payload.Name); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error()) + } + + // Save config + if err := s.config.Save(); err != nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err)) + } + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name}) + return resp +} + +func (s *Server) handleListPresets() *protocol.Response { + cfg := s.config.Get() + if cfg == nil { + return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded") + } + + presets := cfg.GetPresets() + infos := make([]protocol.PresetInfo, len(presets)) + for i, p := range presets { + infos[i] = protocol.PresetInfo{ + Name: p.Name, + Enable: p.Enable, + Disable: p.Disable, + } + } + + resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos}) + return resp +} + +func (s *Server) syncHostsFile() error { + cfg := s.config.Get() + if cfg == nil { + return fmt.Errorf("no configuration loaded") + } + + var entries []HostEntry + for _, g := range cfg.Groups { + for _, h := range g.Hosts { + entries = append(entries, HostEntry{ + IP: h.IP, + Domain: h.Domain, + Alias: h.Alias, + Enabled: h.Enabled, + }) + } + } + + if err := s.hosts.WriteManagedEntries(entries); err != nil { + return err + } + + // Flush DNS cache + return s.flusher.Flush() +} diff --git a/internal/installer/installer.go b/internal/installer/installer.go new file mode 100644 index 0000000..9f44c87 --- /dev/null +++ b/internal/installer/installer.go @@ -0,0 +1,474 @@ +// Package installer handles installation and uninstallation of the lolcathost daemon. +package installer + +import ( + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strconv" + "strings" + + "github.com/lukaszraczylo/lolcathost/internal/config" +) + +const ( + // GroupName is the name of the lolcathost group. + GroupName = "lolcathost" + // GroupGID is the GID for the lolcathost group (macOS). + GroupGID = 850 + + // Paths + LogDir = "/var/log/lolcathost" + BackupDir = "/var/backups/lolcathost" + SocketPath = "/var/run/lolcathost.sock" + LaunchDaemonDir = "/Library/LaunchDaemons" + SystemdDir = "/etc/systemd/system" +) + +// LaunchDaemonPlist is the macOS LaunchDaemon plist template. +const LaunchDaemonPlist = ` + + + + Label + com.lolcathost.daemon + ProgramArguments + + %s + --daemon + --config + /etc/lolcathost/config.yaml + + RunAtLoad + + KeepAlive + + StandardOutPath + /var/log/lolcathost/daemon.log + StandardErrorPath + /var/log/lolcathost/daemon.err + + +` + +// SystemdUnit is the Linux systemd unit template. +const SystemdUnit = `[Unit] +Description=lolcathost - Dynamic Host Management Daemon +After=network.target + +[Service] +Type=simple +ExecStart=%s --daemon --config /etc/lolcathost/config.yaml +Restart=always +RestartSec=5 +User=root +Group=root + +[Install] +WantedBy=multi-user.target +` + +// Installer handles installation and uninstallation. +type Installer struct { + binaryPath string + verbose bool +} + +// New creates a new installer. +func New() (*Installer, error) { + binaryPath, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("failed to get executable path: %w", err) + } + + // Resolve symlinks + binaryPath, err = filepath.EvalSymlinks(binaryPath) + if err != nil { + return nil, fmt.Errorf("failed to resolve executable path: %w", err) + } + + return &Installer{ + binaryPath: binaryPath, + verbose: true, + }, nil +} + +// Install performs the full installation. +func (i *Installer) Install() error { + if os.Geteuid() != 0 { + return fmt.Errorf("--install requires sudo") + } + + i.log("Installing lolcathost...") + + // Create group + if err := i.createGroup(); err != nil { + return fmt.Errorf("failed to create group: %w", err) + } + + // Add current user to group + if err := i.addCurrentUserToGroup(); err != nil { + return fmt.Errorf("failed to add user to group: %w", err) + } + + // Create directories + if err := i.createDirectories(); err != nil { + return fmt.Errorf("failed to create directories: %w", err) + } + + // Create system config for daemon + if err := i.createSystemConfig(); err != nil { + return fmt.Errorf("failed to create system config: %w", err) + } + + // Install service + if runtime.GOOS == "darwin" { + if err := i.installLaunchDaemon(); err != nil { + return fmt.Errorf("failed to install LaunchDaemon: %w", err) + } + } else if runtime.GOOS == "linux" { + if err := i.installSystemdService(); err != nil { + return fmt.Errorf("failed to install systemd service: %w", err) + } + } + + // Create default config for the invoking user + if err := i.createDefaultConfig(); err != nil { + i.log("Warning: failed to create default config: %v", err) + } + + i.log("") + i.log("✓ Installed successfully!") + i.log("") + i.log("Next steps:") + i.log(" 1. Open a NEW terminal (for group membership to take effect)") + i.log(" 2. Run 'lolcathost' to start the TUI") + i.log("") + + return nil +} + +// Uninstall removes the installation. +func (i *Installer) Uninstall() error { + if os.Geteuid() != 0 { + return fmt.Errorf("--uninstall requires sudo") + } + + i.log("Uninstalling lolcathost...") + + // Stop and remove service + if runtime.GOOS == "darwin" { + i.uninstallLaunchDaemon() + } else if runtime.GOOS == "linux" { + i.uninstallSystemdService() + } + + // Remove socket + os.Remove(SocketPath) + + // Note: We don't remove the group, logs, or backups + // The user may want to keep these + + i.log("") + i.log("✓ Uninstalled successfully!") + i.log("") + i.log("Note: Log files, backups, and the group were preserved.") + i.log("To fully remove, manually delete:") + i.log(" - /var/log/lolcathost/") + i.log(" - /var/backups/lolcathost/") + i.log(" - ~/.config/lolcathost/") + if runtime.GOOS == "darwin" { + i.log(" - Remove group: sudo dscl . -delete /Groups/%s", GroupName) + } else { + i.log(" - Remove group: sudo groupdel %s", GroupName) + } + i.log("") + + return nil +} + +func (i *Installer) log(format string, args ...any) { + if i.verbose { + fmt.Printf(format+"\n", args...) + } +} + +func (i *Installer) createGroup() error { + switch runtime.GOOS { + case "darwin": + return i.createGroupDarwin() + case "linux": + return i.createGroupLinux() + default: + return fmt.Errorf("unsupported OS: %s", runtime.GOOS) + } +} + +func (i *Installer) createGroupDarwin() error { + // Check if group exists + if _, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName).Output(); err == nil { + i.log(" Group '%s' already exists", GroupName) + return nil + } + + i.log(" Creating group '%s' (GID %d)...", GroupName, GroupGID) + + // Create group + cmds := [][]string{ + {"dscl", ".", "-create", "/Groups/" + GroupName}, + {"dscl", ".", "-create", "/Groups/" + GroupName, "PrimaryGroupID", strconv.Itoa(GroupGID)}, + {"dscl", ".", "-create", "/Groups/" + GroupName, "RealName", "lolcathost users"}, + } + + for _, args := range cmds { + if err := exec.Command(args[0], args[1:]...).Run(); err != nil { + return fmt.Errorf("command %v failed: %w", args, err) + } + } + + return nil +} + +func (i *Installer) createGroupLinux() error { + // Check if group exists + if _, err := exec.Command("getent", "group", GroupName).Output(); err == nil { + i.log(" Group '%s' already exists", GroupName) + return nil + } + + i.log(" Creating group '%s'...", GroupName) + + if err := exec.Command("groupadd", "-r", GroupName).Run(); err != nil { + return fmt.Errorf("groupadd failed: %w", err) + } + + return nil +} + +func (i *Installer) addCurrentUserToGroup() error { + // Get the real user (not root) + username := os.Getenv("SUDO_USER") + if username == "" { + // Fall back to current user + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + username = u.Username + } + + if username == "root" { + i.log(" Skipping adding root to group") + return nil + } + + switch runtime.GOOS { + case "darwin": + return i.addUserToGroupDarwin(username) + case "linux": + return i.addUserToGroupLinux(username) + default: + return fmt.Errorf("unsupported OS: %s", runtime.GOOS) + } +} + +func (i *Installer) addUserToGroupDarwin(username string) error { + // Check if user is already in group + output, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName, "GroupMembership").Output() + if err == nil && strings.Contains(string(output), username) { + i.log(" User '%s' already in group '%s'", username, GroupName) + return nil + } + + i.log(" Adding user '%s' to group '%s'...", username, GroupName) + + if err := exec.Command("dscl", ".", "-append", "/Groups/"+GroupName, "GroupMembership", username).Run(); err != nil { + return fmt.Errorf("failed to add user to group: %w", err) + } + + return nil +} + +func (i *Installer) addUserToGroupLinux(username string) error { + // Check if user is already in group + output, err := exec.Command("id", "-nG", username).Output() + if err == nil && strings.Contains(string(output), GroupName) { + i.log(" User '%s' already in group '%s'", username, GroupName) + return nil + } + + i.log(" Adding user '%s' to group '%s'...", username, GroupName) + + if err := exec.Command("usermod", "-aG", GroupName, username).Run(); err != nil { + return fmt.Errorf("failed to add user to group: %w", err) + } + + return nil +} + +func (i *Installer) createDirectories() error { + dirs := []string{LogDir, BackupDir, config.SystemConfigDir} + + for _, dir := range dirs { + i.log(" Creating directory '%s'...", dir) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("failed to create %s: %w", dir, err) + } + } + + return nil +} + +func (i *Installer) createSystemConfig() error { + // Check if system config already exists + if _, err := os.Stat(config.SystemConfigPath); err == nil { + i.log(" System config already exists at %s", config.SystemConfigPath) + return nil + } + + i.log(" Creating system config at %s...", config.SystemConfigPath) + return config.CreateDefault(config.SystemConfigPath) +} + +func (i *Installer) installLaunchDaemon() error { + plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist") + plistContent := fmt.Sprintf(LaunchDaemonPlist, i.binaryPath) + + i.log(" Writing LaunchDaemon plist...") + if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil { + return fmt.Errorf("failed to write plist: %w", err) + } + + // Unload if already loaded + exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run() + + // Bootstrap the daemon + i.log(" Starting daemon...") + if err := exec.Command("launchctl", "bootstrap", "system", plistPath).Run(); err != nil { + return fmt.Errorf("failed to bootstrap daemon: %w", err) + } + + return nil +} + +func (i *Installer) uninstallLaunchDaemon() { + plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist") + + i.log(" Stopping daemon...") + exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run() + + i.log(" Removing LaunchDaemon plist...") + os.Remove(plistPath) +} + +func (i *Installer) installSystemdService() error { + unitPath := filepath.Join(SystemdDir, "lolcathost.service") + unitContent := fmt.Sprintf(SystemdUnit, i.binaryPath) + + i.log(" Writing systemd unit...") + if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil { + return fmt.Errorf("failed to write unit file: %w", err) + } + + // Reload systemd + i.log(" Reloading systemd...") + if err := exec.Command("systemctl", "daemon-reload").Run(); err != nil { + return fmt.Errorf("failed to reload systemd: %w", err) + } + + // Enable and start the service + i.log(" Enabling and starting service...") + if err := exec.Command("systemctl", "enable", "--now", "lolcathost.service").Run(); err != nil { + return fmt.Errorf("failed to enable service: %w", err) + } + + return nil +} + +func (i *Installer) uninstallSystemdService() { + i.log(" Stopping and disabling service...") + exec.Command("systemctl", "disable", "--now", "lolcathost.service").Run() + + i.log(" Removing systemd unit...") + os.Remove(filepath.Join(SystemdDir, "lolcathost.service")) + + exec.Command("systemctl", "daemon-reload").Run() +} + +func (i *Installer) createDefaultConfig() error { + // Get the real user's home directory + username := os.Getenv("SUDO_USER") + if username == "" { + return nil // Can't determine user + } + + u, err := user.Lookup(username) + if err != nil { + return fmt.Errorf("failed to lookup user: %w", err) + } + + configPath := filepath.Join(u.HomeDir, ".config", "lolcathost", "config.yaml") + + // Check if config already exists + if _, err := os.Stat(configPath); err == nil { + i.log(" Config already exists at %s", configPath) + return nil + } + + i.log(" Creating default config at %s...", configPath) + + if err := config.CreateDefault(configPath); err != nil { + return err + } + + // Change ownership to the real user + uid, _ := strconv.Atoi(u.Uid) + gid, _ := strconv.Atoi(u.Gid) + + configDir := filepath.Dir(configPath) + os.Chown(configDir, uid, gid) + os.Chown(filepath.Dir(configDir), uid, gid) + os.Chown(configPath, uid, gid) + + return nil +} + +// CheckInstallation checks if the daemon is properly installed. +func CheckInstallation() error { + // Check if socket exists + if _, err := os.Stat(SocketPath); os.IsNotExist(err) { + return fmt.Errorf("daemon not running (socket not found)") + } + + // Check if user is in group + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + + groups, err := u.GroupIds() + if err != nil { + return fmt.Errorf("failed to get user groups: %w", err) + } + + inGroup := false + for _, gid := range groups { + g, err := user.LookupGroupId(gid) + if err != nil { + continue + } + if g.Name == GroupName { + inGroup = true + break + } + } + + if !inGroup { + return fmt.Errorf("user '%s' is not in group '%s'. Run 'sudo lolcathost --install' and open a new terminal", u.Username, GroupName) + } + + return nil +} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go new file mode 100644 index 0000000..00b540a --- /dev/null +++ b/internal/protocol/protocol.go @@ -0,0 +1,226 @@ +// Package protocol defines shared message types for client-daemon communication. +package protocol + +import ( + "encoding/json" + "fmt" +) + +// SocketPath is the Unix socket path for daemon communication. +const SocketPath = "/var/run/lolcathost.sock" + +// RequestType defines the type of request. +type RequestType string + +const ( + RequestPing RequestType = "ping" + RequestStatus RequestType = "status" + RequestList RequestType = "list" + RequestSet RequestType = "set" + RequestAdd RequestType = "add" + RequestDelete RequestType = "delete" + RequestSync RequestType = "sync" + RequestPreset RequestType = "preset" + RequestRollback RequestType = "rollback" + RequestBackups RequestType = "backups" + RequestAddGroup RequestType = "add_group" + RequestDeleteGroup RequestType = "delete_group" + RequestRenameGroup RequestType = "rename_group" + RequestListGroups RequestType = "list_groups" + RequestAddPreset RequestType = "add_preset" + RequestDeletePreset RequestType = "delete_preset" + RequestListPresets RequestType = "list_presets" +) + +// ErrorCode defines standard error codes. +type ErrorCode string + +const ( + ErrCodeInvalidRequest ErrorCode = "INVALID_REQUEST" + ErrCodeInvalidDomain ErrorCode = "INVALID_DOMAIN" + ErrCodeInvalidIP ErrorCode = "INVALID_IP" + ErrCodeBlockedDomain ErrorCode = "BLOCKED_DOMAIN" + ErrCodeRateLimited ErrorCode = "RATE_LIMITED" + ErrCodeUnauthorized ErrorCode = "UNAUTHORIZED" + ErrCodeNotFound ErrorCode = "NOT_FOUND" + ErrCodeConflict ErrorCode = "CONFLICT" + ErrCodeInternalError ErrorCode = "INTERNAL_ERROR" + ErrCodePermissionError ErrorCode = "PERMISSION_ERROR" +) + +// Request represents a client request to the daemon. +type Request struct { + Type RequestType `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +// SetPayload is the payload for set requests. +type SetPayload struct { + Alias string `json:"alias"` + Enabled bool `json:"enabled"` + Force bool `json:"force,omitempty"` +} + +// PresetPayload is the payload for preset requests. +type PresetPayload struct { + Name string `json:"name"` +} + +// RollbackPayload is the payload for rollback requests. +type RollbackPayload struct { + BackupName string `json:"backup_name"` +} + +// AddPayload is the payload for add requests. +type AddPayload struct { + Domain string `json:"domain"` + IP string `json:"ip"` + Alias string `json:"alias"` + Group string `json:"group"` + Enabled bool `json:"enabled"` +} + +// DeletePayload is the payload for delete requests. +type DeletePayload struct { + Alias string `json:"alias"` +} + +// GroupPayload is the payload for group add/delete requests. +type GroupPayload struct { + Name string `json:"name"` +} + +// RenameGroupPayload is the payload for rename_group requests. +type RenameGroupPayload struct { + OldName string `json:"old_name"` + NewName string `json:"new_name"` +} + +// GroupsData is the data for list_groups responses. +type GroupsData struct { + Groups []string `json:"groups"` +} + +// AddPresetPayload is the payload for add_preset requests. +type AddPresetPayload struct { + Name string `json:"name"` + Enable []string `json:"enable"` + Disable []string `json:"disable"` +} + +// PresetInfo represents a preset with its configuration. +type PresetInfo struct { + Name string `json:"name"` + Enable []string `json:"enable"` + Disable []string `json:"disable"` +} + +// PresetsData is the data for list_presets responses. +type PresetsData struct { + Presets []PresetInfo `json:"presets"` +} + +// Response represents a daemon response. +type Response struct { + Status string `json:"status"` + Data json.RawMessage `json:"data,omitempty"` + Message string `json:"message,omitempty"` + Code ErrorCode `json:"code,omitempty"` +} + +// StatusData is the data for status responses. +type StatusData struct { + Running bool `json:"running"` + Version string `json:"version"` + Uptime int64 `json:"uptime_seconds"` + ActiveCount int `json:"active_count"` + RequestCount int64 `json:"request_count"` +} + +// HostEntry represents a single host entry. +type HostEntry struct { + Domain string `json:"domain"` + IP string `json:"ip"` + Alias string `json:"alias"` + Enabled bool `json:"enabled"` + Group string `json:"group"` +} + +// ListData is the data for list responses. +type ListData struct { + Entries []HostEntry `json:"entries"` +} + +// SetData is the data for set responses. +type SetData struct { + Domain string `json:"domain"` + Applied bool `json:"applied"` +} + +// BackupsData is the data for backups responses. +type BackupsData struct { + Backups []BackupInfo `json:"backups"` +} + +// BackupInfo represents a backup file. +type BackupInfo struct { + Name string `json:"name"` + Timestamp int64 `json:"timestamp"` + Size int64 `json:"size"` +} + +// NewRequest creates a new request with the given type and payload. +func NewRequest(reqType RequestType, payload interface{}) (*Request, error) { + req := &Request{Type: reqType} + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal payload: %w", err) + } + req.Payload = data + } + return req, nil +} + +// NewOKResponse creates a success response with optional data. +func NewOKResponse(data interface{}) (*Response, error) { + resp := &Response{Status: "ok"} + if data != nil { + dataBytes, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal data: %w", err) + } + resp.Data = dataBytes + } + return resp, nil +} + +// NewErrorResponse creates an error response. +func NewErrorResponse(code ErrorCode, message string) *Response { + return &Response{ + Status: "error", + Code: code, + Message: message, + } +} + +// ParsePayload unmarshals the request payload into the given target. +func (r *Request) ParsePayload(target interface{}) error { + if r.Payload == nil { + return fmt.Errorf("no payload in request") + } + return json.Unmarshal(r.Payload, target) +} + +// ParseData unmarshals the response data into the given target. +func (r *Response) ParseData(target interface{}) error { + if r.Data == nil { + return fmt.Errorf("no data in response") + } + return json.Unmarshal(r.Data, target) +} + +// IsOK returns true if the response indicates success. +func (r *Response) IsOK() bool { + return r.Status == "ok" +} diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go new file mode 100644 index 0000000..fec06d0 --- /dev/null +++ b/internal/protocol/protocol_test.go @@ -0,0 +1,227 @@ +package protocol + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRequest(t *testing.T) { + tests := []struct { + name string + reqType RequestType + payload interface{} + wantErr bool + }{ + { + name: "ping request without payload", + reqType: RequestPing, + payload: nil, + wantErr: false, + }, + { + name: "set request with payload", + reqType: RequestSet, + payload: SetPayload{Alias: "test", Enabled: true}, + wantErr: false, + }, + { + name: "preset request with payload", + reqType: RequestPreset, + payload: PresetPayload{Name: "local"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, err := NewRequest(tt.reqType, tt.payload) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.reqType, req.Type) + if tt.payload != nil { + assert.NotNil(t, req.Payload) + } + }) + } +} + +func TestRequest_ParsePayload(t *testing.T) { + t.Run("valid payload", func(t *testing.T) { + payload := SetPayload{Alias: "test-alias", Enabled: true, Force: false} + req, err := NewRequest(RequestSet, payload) + require.NoError(t, err) + + var parsed SetPayload + err = req.ParsePayload(&parsed) + require.NoError(t, err) + assert.Equal(t, "test-alias", parsed.Alias) + assert.True(t, parsed.Enabled) + assert.False(t, parsed.Force) + }) + + t.Run("nil payload", func(t *testing.T) { + req := &Request{Type: RequestPing} + var parsed SetPayload + err := req.ParsePayload(&parsed) + assert.Error(t, err) + }) +} + +func TestNewOKResponse(t *testing.T) { + t.Run("with data", func(t *testing.T) { + data := StatusData{ + Running: true, + Version: "1.0.0", + Uptime: 3600, + ActiveCount: 5, + RequestCount: 100, + } + + resp, err := NewOKResponse(data) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Status) + assert.NotNil(t, resp.Data) + assert.True(t, resp.IsOK()) + }) + + t.Run("without data", func(t *testing.T) { + resp, err := NewOKResponse(nil) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Status) + assert.Nil(t, resp.Data) + }) +} + +func TestNewErrorResponse(t *testing.T) { + resp := NewErrorResponse(ErrCodeBlockedDomain, "domain is blocked") + + assert.Equal(t, "error", resp.Status) + assert.Equal(t, ErrCodeBlockedDomain, resp.Code) + assert.Equal(t, "domain is blocked", resp.Message) + assert.False(t, resp.IsOK()) +} + +func TestResponse_ParseData(t *testing.T) { + t.Run("valid data", func(t *testing.T) { + data := ListData{ + Entries: []HostEntry{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"}, + }, + } + resp, err := NewOKResponse(data) + require.NoError(t, err) + + var parsed ListData + err = resp.ParseData(&parsed) + require.NoError(t, err) + assert.Len(t, parsed.Entries, 1) + assert.Equal(t, "example.com", parsed.Entries[0].Domain) + }) + + t.Run("nil data", func(t *testing.T) { + resp := &Response{Status: "ok"} + var parsed ListData + err := resp.ParseData(&parsed) + assert.Error(t, err) + }) +} + +func TestRequestTypes(t *testing.T) { + types := []RequestType{ + RequestPing, + RequestStatus, + RequestList, + RequestSet, + RequestSync, + RequestPreset, + RequestRollback, + RequestBackups, + } + + for _, rt := range types { + t.Run(string(rt), func(t *testing.T) { + req, err := NewRequest(rt, nil) + require.NoError(t, err) + assert.Equal(t, rt, req.Type) + + // Verify JSON marshaling works + data, err := json.Marshal(req) + require.NoError(t, err) + assert.Contains(t, string(data), string(rt)) + }) + } +} + +func TestErrorCodes(t *testing.T) { + codes := []ErrorCode{ + ErrCodeInvalidRequest, + ErrCodeInvalidDomain, + ErrCodeInvalidIP, + ErrCodeBlockedDomain, + ErrCodeRateLimited, + ErrCodeNotFound, + ErrCodeConflict, + ErrCodeInternalError, + ErrCodePermissionError, + } + + for _, code := range codes { + t.Run(string(code), func(t *testing.T) { + resp := NewErrorResponse(code, "test error") + assert.Equal(t, code, resp.Code) + + // Verify JSON marshaling works + data, err := json.Marshal(resp) + require.NoError(t, err) + assert.Contains(t, string(data), string(code)) + }) + } +} + +func TestHostEntry(t *testing.T) { + entry := HostEntry{ + Domain: "example.com", + IP: "127.0.0.1", + Alias: "example-local", + Enabled: true, + Group: "development", + } + + data, err := json.Marshal(entry) + require.NoError(t, err) + + var parsed HostEntry + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, entry.Domain, parsed.Domain) + assert.Equal(t, entry.IP, parsed.IP) + assert.Equal(t, entry.Alias, parsed.Alias) + assert.Equal(t, entry.Enabled, parsed.Enabled) + assert.Equal(t, entry.Group, parsed.Group) +} + +func TestBackupInfo(t *testing.T) { + info := BackupInfo{ + Name: "hosts.20231201-120000.bak", + Timestamp: 1701432000, + Size: 1024, + } + + data, err := json.Marshal(info) + require.NoError(t, err) + + var parsed BackupInfo + err = json.Unmarshal(data, &parsed) + require.NoError(t, err) + + assert.Equal(t, info.Name, parsed.Name) + assert.Equal(t, info.Timestamp, parsed.Timestamp) + assert.Equal(t, info.Size, parsed.Size) +} diff --git a/internal/tui/app.go b/internal/tui/app.go new file mode 100644 index 0000000..3e9375f --- /dev/null +++ b/internal/tui/app.go @@ -0,0 +1,904 @@ +// Package tui provides the main Bubble Tea application. +package tui + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + + "github.com/lukaszraczylo/lolcathost/internal/client" + "github.com/lukaszraczylo/lolcathost/internal/config" + "github.com/lukaszraczylo/lolcathost/internal/protocol" + "github.com/lukaszraczylo/lolcathost/internal/version" +) + +// ViewMode represents the current view mode. +type ViewMode int + +const ( + ViewList ViewMode = iota + ViewForm + ViewPresets + ViewGroups + ViewHelp + ViewSearch +) + +// Model is the main Bubble Tea model. +type Model struct { + // Client + client *client.Client + connected bool + + // Config + configPath string + config *config.Manager + + // Views + mode ViewMode + list *ListView + form *Form + presetPicker *PresetPicker + groupPicker *GroupPicker + searchInput textinput.Model + + // State + width int + height int + message string + messageStyle string // "error" or "success" + messageTime time.Time + searchTerm string + allGroups []string // All groups including empty ones + + // Update notification + updateAvailable bool + updateVersion string + updateURL string + + // Version info for update checking + version string + githubOwner string + githubRepo string +} + +// Message types +type ( + connectMsg struct{ err error } + refreshMsg struct { + entries []protocol.HostEntry + err error + } + toggleMsg struct { + alias string + err error + } + presetMsg struct { + name string + err error + } + addMsg struct { + domain string + err error + } + deleteMsg struct { + alias string + err error + } + addPresetMsg struct { + name string + err error + } + deletePresetMsg struct { + name string + err error + } + refreshPresetsMsg struct { + presets []protocol.PresetInfo + err error + } + addGroupMsg struct { + name string + err error + } + renameGroupMsg struct { + name string + err error + } + deleteGroupMsg struct { + name string + err error + } + refreshGroupsMsg struct { + groups []string + err error + } + clearMsgMsg struct{} + tickMsg struct{} + updateMsg struct { + version string + url string + } +) + +// NewModel creates a new TUI model. +func NewModel(socketPath, configPath string) *Model { + searchInput := textinput.New() + searchInput.Placeholder = "Search..." + searchInput.CharLimit = 100 + searchInput.Width = 50 + + return &Model{ + client: client.New(socketPath), + configPath: configPath, + config: config.NewManager(configPath), + list: NewListView(), + form: NewForm(), + presetPicker: NewPresetPicker(), + groupPicker: NewGroupPicker(), + searchInput: searchInput, + mode: ViewList, + } +} + +// Init initializes the model. +func (m *Model) Init() tea.Cmd { + return tea.Batch( + m.connect(), + tea.SetWindowTitle("lolcathost"), + m.tick(), + m.checkForUpdate(), + ) +} + +func (m *Model) connect() tea.Cmd { + return func() tea.Msg { + if err := m.client.Connect(); err != nil { + return connectMsg{err: err} + } + return connectMsg{err: nil} + } +} + +func (m *Model) refresh() tea.Cmd { + return func() tea.Msg { + entries, err := m.client.List() + if err != nil { + return refreshMsg{entries: nil, err: err} + } + return refreshMsg{entries: entries, err: nil} + } +} + +func (m *Model) toggle(alias string, enabled bool) tea.Cmd { + return func() tea.Msg { + _, err := m.client.Set(alias, enabled, false) + return toggleMsg{alias: alias, err: err} + } +} + +func (m *Model) applyPreset(name string) tea.Cmd { + return func() tea.Msg { + err := m.client.ApplyPreset(name) + return presetMsg{name: name, err: err} + } +} + +func (m *Model) addHost(domain, ip, alias, group string) tea.Cmd { + return func() tea.Msg { + _, err := m.client.Add(domain, ip, alias, group, false) + return addMsg{domain: domain, err: err} + } +} + +func (m *Model) deleteHost(alias string) tea.Cmd { + return func() tea.Msg { + err := m.client.Delete(alias) + return deleteMsg{alias: alias, err: err} + } +} + +func (m *Model) addPreset(name string, enable, disable []string) tea.Cmd { + return func() tea.Msg { + err := m.client.AddPreset(name, enable, disable) + return addPresetMsg{name: name, err: err} + } +} + +func (m *Model) deletePreset(name string) tea.Cmd { + return func() tea.Msg { + err := m.client.DeletePreset(name) + return deletePresetMsg{name: name, err: err} + } +} + +func (m *Model) refreshPresets() tea.Cmd { + return func() tea.Msg { + presets, err := m.client.ListPresets() + return refreshPresetsMsg{presets: presets, err: err} + } +} + +func (m *Model) addGroup(name string) tea.Cmd { + return func() tea.Msg { + err := m.client.AddGroup(name) + return addGroupMsg{name: name, err: err} + } +} + +func (m *Model) renameGroup(oldName, newName string) tea.Cmd { + return func() tea.Msg { + err := m.client.RenameGroup(oldName, newName) + return renameGroupMsg{name: newName, err: err} + } +} + +func (m *Model) deleteGroup(name string) tea.Cmd { + return func() tea.Msg { + err := m.client.DeleteGroup(name) + return deleteGroupMsg{name: name, err: err} + } +} + +func (m *Model) refreshGroups() tea.Cmd { + return func() tea.Msg { + groups, err := m.client.ListGroups() + return refreshGroupsMsg{groups: groups, err: err} + } +} + +func (m *Model) tick() tea.Cmd { + return tea.Tick(time.Second*3, func(t time.Time) tea.Msg { + return tickMsg{} + }) +} + +func (m *Model) clearMsg() tea.Cmd { + return tea.Tick(time.Second*3, func(t time.Time) tea.Msg { + return clearMsgMsg{} + }) +} + +func (m *Model) checkForUpdate() tea.Cmd { + if m.githubOwner == "" || m.githubRepo == "" { + return nil + } + return func() tea.Msg { + checker := version.NewChecker(m.githubOwner, m.githubRepo, m.version) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if update := checker.CheckForUpdate(ctx); update != nil { + return updateMsg{version: update.LatestVersion, url: update.ReleaseURL} + } + return nil + } +} + +// Update handles messages. +func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + var cmds []tea.Cmd + + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.list.SetSize(msg.Width, msg.Height-10) + m.form.SetSize(msg.Width, msg.Height) + m.presetPicker.SetSize(msg.Width, msg.Height) + m.groupPicker.SetSize(msg.Width, msg.Height) + // Set search input width + searchWidth := msg.Width - 20 + if searchWidth > 60 { + searchWidth = 60 + } + m.searchInput.Width = searchWidth + + case tea.KeyMsg: + cmd := m.handleKey(msg) + if cmd != nil { + cmds = append(cmds, cmd) + } + + case connectMsg: + if msg.err != nil { + m.connected = false + m.setError(fmt.Sprintf("Failed to connect: %v", msg.err)) + } else { + m.connected = true + cmds = append(cmds, m.refresh()) + cmds = append(cmds, m.refreshPresets()) + cmds = append(cmds, m.refreshGroups()) + m.loadConfig() + } + + case refreshMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Refresh failed: %v", msg.err)) + // Mark as disconnected to trigger reconnect + m.connected = false + m.client.Close() + } else if msg.entries != nil { + m.list.SetItems(msg.entries) + } + + case toggleMsg: + if msg.err != nil { + m.list.SetError(msg.alias, true) + m.setError(fmt.Sprintf("Toggle failed: %v", msg.err)) + } else { + m.list.SetPending(msg.alias, false) + cmds = append(cmds, m.refresh()) + m.setSuccess("Entry toggled") + } + + case presetMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Preset failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refresh()) + m.setSuccess(fmt.Sprintf("Applied preset: %s", msg.name)) + } + m.mode = ViewList + + case addMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Add failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refresh()) + m.setSuccess(fmt.Sprintf("Added host: %s", msg.domain)) + } + m.mode = ViewList + + case deleteMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Delete failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refresh()) + m.setSuccess(fmt.Sprintf("Deleted: %s", msg.alias)) + } + + case addPresetMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Add preset failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refreshPresets()) + m.setSuccess(fmt.Sprintf("Added preset: %s", msg.name)) + } + m.presetPicker.CancelForm() + + case deletePresetMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Delete preset failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refreshPresets()) + m.setSuccess(fmt.Sprintf("Deleted preset: %s", msg.name)) + } + m.presetPicker.CancelForm() + + case refreshPresetsMsg: + if msg.err == nil && msg.presets != nil { + m.presetPicker.SetPresetsWithInfo(msg.presets) + } + + case addGroupMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Add group failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refreshGroups()) + cmds = append(cmds, m.refresh()) // Refresh list to show new group + m.setSuccess(fmt.Sprintf("Added group: %s", msg.name)) + } + m.groupPicker.CancelForm() + + case renameGroupMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Rename group failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refreshGroups()) + cmds = append(cmds, m.refresh()) + m.setSuccess(fmt.Sprintf("Renamed group to: %s", msg.name)) + } + m.groupPicker.CancelForm() + + case deleteGroupMsg: + if msg.err != nil { + m.setError(fmt.Sprintf("Delete group failed: %v", msg.err)) + } else { + cmds = append(cmds, m.refreshGroups()) + cmds = append(cmds, m.refresh()) + m.setSuccess(fmt.Sprintf("Deleted group: %s", msg.name)) + } + m.groupPicker.CancelForm() + + case refreshGroupsMsg: + if msg.err == nil && msg.groups != nil { + m.allGroups = msg.groups + m.groupPicker.SetGroups(msg.groups) + } + + case clearMsgMsg: + if time.Since(m.messageTime) >= time.Second*3 { + m.message = "" + } + + case tickMsg: + // Reconnect if disconnected + if !m.connected { + cmds = append(cmds, m.connect()) + } + cmds = append(cmds, m.tick()) + + case updateMsg: + if msg.version != "" { + m.updateAvailable = true + m.updateVersion = msg.version + m.updateURL = msg.url + } + } + + return m, tea.Batch(cmds...) +} + +func (m *Model) handleKey(msg tea.KeyMsg) tea.Cmd { + // Global keys + switch msg.String() { + case "ctrl+c": + return tea.Quit + } + + // Mode-specific keys + switch m.mode { + case ViewList: + return m.handleListKey(msg) + case ViewForm: + return m.handleFormKey(msg) + case ViewPresets: + return m.handlePresetKey(msg) + case ViewGroups: + return m.handleGroupKey(msg) + case ViewHelp: + return m.handleHelpKey(msg) + case ViewSearch: + return m.handleSearchKey(msg) + } + + return nil +} + +func (m *Model) handleListKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "q": + return tea.Quit + case "esc": + // Clear search if active + if m.searchTerm != "" { + m.searchTerm = "" + m.searchInput.Reset() + } + case "up", "k": + m.list.MoveUp() + case "down", "j": + m.list.MoveDown() + case " ", "enter": + return m.toggleSelected() + case "n": + m.mode = ViewForm + m.form.SetGroups(m.allGroups) + m.form.Init() + case "e": + if item := m.list.Selected(); item != nil { + m.mode = ViewForm + m.form.SetGroups(m.allGroups) + m.form.InitEdit(item.Entry.Domain, item.Entry.IP, item.Entry.Alias, item.Entry.Group) + } + case "d": + if item := m.list.Selected(); item != nil { + return m.deleteHost(item.Entry.Alias) + } + case "p": + m.mode = ViewPresets + case "g": + m.mode = ViewGroups + return m.refreshGroups() + case "/": + m.mode = ViewSearch + m.searchInput.Focus() + case "?": + m.mode = ViewHelp + case "r": + return m.refresh() + } + return nil +} + +func (m *Model) handleFormKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.mode = ViewList + return nil + case "enter": + if errMsg := m.form.Validate(); errMsg != "" { + m.setError(errMsg) + return m.clearMsg() + } + domain, ip, group := m.form.Values() + if m.form.IsEdit() { + // For edit, delete old and add new (simple approach) + oldAlias := m.form.EditAlias() + return tea.Sequence( + func() tea.Msg { + m.client.Delete(oldAlias) + return nil + }, + m.addHost(domain, ip, "", group), // Empty alias = auto-generate + ) + } + return m.addHost(domain, ip, "", group) // Empty alias = auto-generate + } + + return m.form.Update(msg) +} + +func (m *Model) handlePresetKey(msg tea.KeyMsg) tea.Cmd { + // Handle based on preset picker mode + switch m.presetPicker.Mode() { + case PresetModeSelect: + return m.handlePresetSelectKey(msg) + case PresetModeAdd, PresetModeEdit: + return m.handlePresetFormKey(msg) + case PresetModeConfirmDelete: + return m.handlePresetDeleteKey(msg) + } + return nil +} + +func (m *Model) handlePresetSelectKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc", "q": + m.mode = ViewList + case "up", "k": + m.presetPicker.MoveUp() + case "down", "j": + m.presetPicker.MoveDown() + case "enter": + if preset := m.presetPicker.Selected(); preset != "" { + return m.applyPreset(preset) + } + case "n": + m.presetPicker.InitAdd() + case "e": + m.presetPicker.InitEdit() + case "d": + m.presetPicker.InitDelete() + } + return nil +} + +func (m *Model) handlePresetFormKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.presetPicker.CancelForm() + return nil + case "enter": + if errMsg := m.presetPicker.ValidateForm(); errMsg != "" { + m.setError(errMsg) + return m.clearMsg() + } + name, enable, disable := m.presetPicker.FormValues() + if m.presetPicker.IsEdit() { + // For edit, delete old and add new + oldName := m.presetPicker.EditName() + return tea.Sequence( + func() tea.Msg { + m.client.DeletePreset(oldName) + return nil + }, + m.addPreset(name, enable, disable), + ) + } + return m.addPreset(name, enable, disable) + } + return m.presetPicker.Update(msg) +} + +func (m *Model) handlePresetDeleteKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "y", "Y": + if preset := m.presetPicker.Selected(); preset != "" { + return m.deletePreset(preset) + } + m.presetPicker.CancelForm() + case "n", "N", "esc": + m.presetPicker.CancelForm() + } + return nil +} + +func (m *Model) handleGroupKey(msg tea.KeyMsg) tea.Cmd { + // Handle based on group picker mode + switch m.groupPicker.Mode() { + case GroupModeSelect: + return m.handleGroupSelectKey(msg) + case GroupModeAdd, GroupModeRename: + return m.handleGroupFormKey(msg) + case GroupModeConfirmDelete: + return m.handleGroupDeleteKey(msg) + } + return nil +} + +func (m *Model) handleGroupSelectKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc", "q": + m.mode = ViewList + case "up", "k": + m.groupPicker.MoveUp() + case "down", "j": + m.groupPicker.MoveDown() + case "n": + m.groupPicker.InitAdd() + case "r": + m.groupPicker.InitRename() + case "d": + m.groupPicker.InitDelete() + } + return nil +} + +func (m *Model) handleGroupFormKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.groupPicker.CancelForm() + return nil + case "enter": + if errMsg := m.groupPicker.ValidateForm(); errMsg != "" { + m.setError(errMsg) + return m.clearMsg() + } + name := m.groupPicker.FormValue() + if m.groupPicker.IsRename() { + oldName := m.groupPicker.EditName() + return m.renameGroup(oldName, name) + } + return m.addGroup(name) + } + return m.groupPicker.Update(msg) +} + +func (m *Model) handleGroupDeleteKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "y", "Y": + if group := m.groupPicker.Selected(); group != "" { + return m.deleteGroup(group) + } + m.groupPicker.CancelForm() + case "n", "N", "esc": + m.groupPicker.CancelForm() + } + return nil +} + +func (m *Model) handleHelpKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc", "q", "?": + m.mode = ViewList + } + return nil +} + +func (m *Model) handleSearchKey(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "esc": + m.mode = ViewList + m.searchTerm = "" + m.searchInput.Reset() + return nil + case "enter": + m.searchTerm = m.searchInput.Value() + m.mode = ViewList + return nil + } + + var cmd tea.Cmd + m.searchInput, cmd = m.searchInput.Update(msg) + return cmd +} + +func (m *Model) toggleSelected() tea.Cmd { + item := m.list.Selected() + if item == nil { + return nil + } + + m.list.SetPending(item.Entry.Alias, true) + return m.toggle(item.Entry.Alias, !item.Entry.Enabled) +} + +func (m *Model) loadConfig() { + if err := m.config.Load(); err != nil { + return + } + + cfg := m.config.Get() + if cfg == nil { + return + } + + var presetNames []string + for _, p := range cfg.Presets { + presetNames = append(presetNames, p.Name) + } + m.presetPicker.SetPresets(presetNames) +} + +func (m *Model) setError(msg string) { + m.message = msg + m.messageStyle = "error" + m.messageTime = time.Now() +} + +func (m *Model) setSuccess(msg string) { + m.message = msg + m.messageStyle = "success" + m.messageTime = time.Now() +} + +// View renders the UI. +func (m *Model) View() string { + var sb strings.Builder + + // Title with version + title := titleStyle.Render("lolcathost - Host Management") + sb.WriteString(title) + + // Update notification + if m.updateAvailable { + sb.WriteString(" ") + sb.WriteString(updateStyle.Render(fmt.Sprintf("Update available: v%s", m.updateVersion))) + } + + sb.WriteString("\n\n") + + // Main content based on mode + switch m.mode { + case ViewList: + sb.WriteString(m.list.ViewFiltered(m.searchTerm)) + case ViewForm: + sb.WriteString(m.form.View()) + case ViewPresets: + sb.WriteString(m.presetPicker.View()) + case ViewGroups: + sb.WriteString(m.groupPicker.View()) + case ViewHelp: + sb.WriteString(m.helpView()) + case ViewSearch: + sb.WriteString(m.searchView()) + } + + // Message + if m.message != "" { + sb.WriteString("\n") + if m.messageStyle == "error" { + sb.WriteString(errorMsgStyle.Render(m.message)) + } else { + sb.WriteString(successMsgStyle.Render(m.message)) + } + } + + // Calculate remaining space for footer positioning + currentContent := sb.String() + currentLines := strings.Count(currentContent, "\n") + 1 + + // Fill space to push footer to bottom (reserve 3 lines for footer) + footerHeight := 3 + remainingLines := m.height - currentLines - footerHeight + if remainingLines > 0 { + sb.WriteString(strings.Repeat("\n", remainingLines)) + } + + // Footer (help bar + status bar) + if m.mode == ViewList { + sb.WriteString("\n") + sb.WriteString(m.helpBar()) + } + sb.WriteString("\n") + sb.WriteString(m.statusBar()) + + return sb.String() +} + +func (m *Model) helpBar() string { + return helpBarStyle.Render(fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: New %s: Edit %s: Delete %s: Presets %s: Groups %s: Search %s: Help %s: Quit", + helpKeyStyle.Render("↑↓"), + helpKeyStyle.Render("jk"), + helpKeyStyle.Render("Space"), + helpKeyStyle.Render("n"), + helpKeyStyle.Render("e"), + helpKeyStyle.Render("d"), + helpKeyStyle.Render("p"), + helpKeyStyle.Render("g"), + helpKeyStyle.Render("/"), + helpKeyStyle.Render("?"), + helpKeyStyle.Render("q"))) +} + +func (m *Model) statusBar() string { + var status string + if m.connected { + status = connectedStyle.String() + } else { + status = disconnectedStyle.String() + } + + active := fmt.Sprintf("%d active", m.list.ActiveCount()) + total := fmt.Sprintf("%d total", m.list.Len()) + + return statusBarStyle.Render(fmt.Sprintf("%s | %s | %s", status, active, total)) +} + +func (m *Model) helpView() string { + var sb strings.Builder + + sb.WriteString(titleStyle.Render("Help")) + sb.WriteString("\n\n") + + help := []struct{ key, desc string }{ + {"↑/↓ or j/k", "Navigate up/down"}, + {"Space/Enter", "Toggle entry on/off"}, + {"n", "Add new entry"}, + {"e", "Edit selected entry"}, + {"d", "Delete selected entry"}, + {"p", "Open preset manager"}, + {"g", "Open group manager"}, + {"/", "Search"}, + {"r", "Refresh list"}, + {"?", "Toggle this help"}, + {"q", "Quit"}, + } + + for _, h := range help { + sb.WriteString(fmt.Sprintf(" %s %s\n", + helpKeyStyle.Width(15).Render(h.key), + helpDescStyle.Render(h.desc))) + } + + sb.WriteString("\n") + sb.WriteString(helpDescStyle.Render("Press ? or Esc to close")) + + return dialogStyle.Render(sb.String()) +} + +func (m *Model) searchView() string { + var sb strings.Builder + + sb.WriteString(titleStyle.Render("Search")) + sb.WriteString("\n\n") + + sb.WriteString(inputFocusStyle.Render(m.searchInput.View())) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("Enter to search • Esc to cancel")) + + return dialogStyle.Render(sb.String()) +} + +// Run starts the TUI application. +func Run(socketPath, configPath string) error { + return RunWithVersion(socketPath, configPath, "dev", "", "") +} + +// RunWithVersion starts the TUI application with version info for update checking. +func RunWithVersion(socketPath, configPath, version, githubOwner, githubRepo string) error { + m := NewModel(socketPath, configPath) + m.version = version + m.githubOwner = githubOwner + m.githubRepo = githubRepo + p := tea.NewProgram(m, tea.WithAltScreen()) + + _, err := p.Run() + return err +} diff --git a/internal/tui/form.go b/internal/tui/form.go new file mode 100644 index 0000000..aa7cc69 --- /dev/null +++ b/internal/tui/form.go @@ -0,0 +1,336 @@ +// Package tui provides the form component for adding/editing entries. +package tui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" +) + +// FormMode represents the form mode. +type FormMode int + +const ( + FormModeAdd FormMode = iota + FormModeEdit +) + +// FormField represents a form field index. +type FormField int + +const ( + FieldDomain FormField = iota + FieldIP + FieldGroup + FieldCount +) + +// Form handles the add/edit entry form. +type Form struct { + mode FormMode + fields []textinput.Model + focus FormField + width int + height int + editAlias string // Original alias when editing + + // Group dropdown + groups []string + groupCursor int + groupFocused bool +} + +// NewForm creates a new form. +func NewForm() *Form { + fields := make([]textinput.Model, FieldCount) + + // Domain field + fields[FieldDomain] = textinput.New() + fields[FieldDomain].Placeholder = "example.com" + fields[FieldDomain].CharLimit = 253 + + // IP field + fields[FieldIP] = textinput.New() + fields[FieldIP].Placeholder = "127.0.0.1" + fields[FieldIP].CharLimit = 45 // IPv6 max + + // Group field (not used as text input, but kept for compatibility) + fields[FieldGroup] = textinput.New() + fields[FieldGroup].Placeholder = "development" + fields[FieldGroup].CharLimit = 63 + + return &Form{ + fields: fields, + focus: FieldDomain, + groups: []string{"default"}, + } +} + +// SetGroups sets the available groups for the dropdown. +func (f *Form) SetGroups(groups []string) { + if len(groups) == 0 { + f.groups = []string{"default"} + } else { + f.groups = groups + } + // Reset cursor if out of bounds + if f.groupCursor >= len(f.groups) { + f.groupCursor = 0 + } +} + +// Init initializes the form for adding a new entry. +func (f *Form) Init() { + f.mode = FormModeAdd + f.editAlias = "" + + for i := range f.fields { + f.fields[i].Reset() + } + + f.fields[FieldIP].SetValue("127.0.0.1") + f.groupCursor = 0 + f.groupFocused = false + f.focus = FieldDomain + f.fields[FieldDomain].Focus() +} + +// InitEdit initializes the form for editing an existing entry. +func (f *Form) InitEdit(domain, ip, alias, group string) { + f.mode = FormModeEdit + f.editAlias = alias + + f.fields[FieldDomain].SetValue(domain) + f.fields[FieldIP].SetValue(ip) + + // Find the group in the list + f.groupCursor = 0 + for i, g := range f.groups { + if g == group { + f.groupCursor = i + break + } + } + + f.groupFocused = false + f.focus = FieldDomain + f.fields[FieldDomain].Focus() +} + +// SetSize sets the form dimensions. +func (f *Form) SetSize(width, height int) { + f.width = width + f.height = height + + inputWidth := min(50, width-10) + for i := range f.fields { + f.fields[i].Width = inputWidth + } +} + +// Update handles input events. +func (f *Form) Update(msg tea.Msg) tea.Cmd { + switch msg := msg.(type) { + case tea.KeyMsg: + // Handle group dropdown navigation + if f.focus == FieldGroup { + switch msg.String() { + case "tab": + f.nextField() + return nil + case "shift+tab": + f.prevField() + return nil + case "up", "k": + if f.groupCursor > 0 { + f.groupCursor-- + } + return nil + case "down", "j": + if f.groupCursor < len(f.groups)-1 { + f.groupCursor++ + } + return nil + case "left": + if f.groupCursor > 0 { + f.groupCursor-- + } + return nil + case "right": + if f.groupCursor < len(f.groups)-1 { + f.groupCursor++ + } + return nil + } + return nil + } + + // Handle text input fields + switch msg.String() { + case "tab", "down": + f.nextField() + return nil + case "shift+tab", "up": + f.prevField() + return nil + } + } + + // Update the focused text field (only for Domain and IP) + if f.focus != FieldGroup { + var cmd tea.Cmd + f.fields[f.focus], cmd = f.fields[f.focus].Update(msg) + return cmd + } + + return nil +} + +func (f *Form) nextField() { + if f.focus != FieldGroup { + f.fields[f.focus].Blur() + } + f.focus = (f.focus + 1) % FieldCount + if f.focus != FieldGroup { + f.fields[f.focus].Focus() + } +} + +func (f *Form) prevField() { + if f.focus != FieldGroup { + f.fields[f.focus].Blur() + } + f.focus = (f.focus - 1 + FieldCount) % FieldCount + if f.focus != FieldGroup { + f.fields[f.focus].Focus() + } +} + +// Values returns the form values (domain, ip, group). +func (f *Form) Values() (domain, ip, group string) { + group = "" + if f.groupCursor < len(f.groups) { + group = f.groups[f.groupCursor] + } + return strings.TrimSpace(f.fields[FieldDomain].Value()), + strings.TrimSpace(f.fields[FieldIP].Value()), + group +} + +// EditAlias returns the original alias when editing. +func (f *Form) EditAlias() string { + return f.editAlias +} + +// IsEdit returns true if in edit mode. +func (f *Form) IsEdit() bool { + return f.mode == FormModeEdit +} + +// Validate validates the form values. +func (f *Form) Validate() string { + domain, ip, group := f.Values() + + if domain == "" { + return "Domain is required" + } + if ip == "" { + return "IP address is required" + } + if group == "" { + return "Group is required" + } + + return "" +} + +// View renders the form. +func (f *Form) View() string { + var sb strings.Builder + + title := "Add New Entry" + if f.mode == FormModeEdit { + title = "Edit Entry" + } + + sb.WriteString(titleStyle.Render(title)) + sb.WriteString("\n\n") + + // Domain field + sb.WriteString(inputLabelStyle.Render("Domain:")) + sb.WriteString("\n") + style := inputStyle + if f.focus == FieldDomain { + style = inputFocusStyle + } + sb.WriteString(style.Render(f.fields[FieldDomain].View())) + sb.WriteString("\n\n") + + // IP field + sb.WriteString(inputLabelStyle.Render("IP Address:")) + sb.WriteString("\n") + style = inputStyle + if f.focus == FieldIP { + style = inputFocusStyle + } + sb.WriteString(style.Render(f.fields[FieldIP].View())) + sb.WriteString("\n\n") + + // Group dropdown + sb.WriteString(inputLabelStyle.Render("Group:")) + sb.WriteString("\n") + sb.WriteString(f.renderGroupDropdown()) + sb.WriteString("\n\n") + + sb.WriteString("\n") + sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • ←→ select group • Enter save • Esc cancel")) + + return dialogStyle.Render(sb.String()) +} + +func (f *Form) renderGroupDropdown() string { + isFocused := f.focus == FieldGroup + + // Get current group name + currentGroup := "default" + if f.groupCursor < len(f.groups) { + currentGroup = f.groups[f.groupCursor] + } + + // Build the selector content: ◀ group_name ▶ + var content string + if isFocused { + // Show arrows when focused + leftArrow := "◀" + rightArrow := "▶" + if f.groupCursor == 0 { + leftArrow = " " // dim or hide left arrow at start + } + if f.groupCursor >= len(f.groups)-1 { + rightArrow = " " // dim or hide right arrow at end + } + content = leftArrow + " " + currentGroup + " " + rightArrow + } else { + content = " " + currentGroup + " " + } + + // Show position indicator if multiple groups + if len(f.groups) > 1 { + content += fmt.Sprintf(" (%d/%d)", f.groupCursor+1, len(f.groups)) + } + + // Apply border style + if isFocused { + return inputFocusStyle.Render(content) + } + return inputStyle.Render(content) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/tui/groups.go b/internal/tui/groups.go new file mode 100644 index 0000000..55463e9 --- /dev/null +++ b/internal/tui/groups.go @@ -0,0 +1,232 @@ +// Package tui provides the group management component. +package tui + +import ( + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" +) + +// GroupMode represents the group view mode. +type GroupMode int + +const ( + GroupModeSelect GroupMode = iota + GroupModeAdd + GroupModeRename + GroupModeConfirmDelete +) + +// GroupPicker handles the group selection and management UI. +type GroupPicker struct { + groups []string + cursor int + width int + height int + mode GroupMode + input textinput.Model + editName string // Original name when renaming +} + +// NewGroupPicker creates a new group picker. +func NewGroupPicker() *GroupPicker { + input := textinput.New() + input.Placeholder = "group-name" + input.CharLimit = 63 + + return &GroupPicker{ + input: input, + mode: GroupModeSelect, + } +} + +// SetGroups updates the available groups. +func (g *GroupPicker) SetGroups(groups []string) { + g.groups = groups + if g.cursor >= len(groups) { + g.cursor = max(0, len(groups)-1) + } +} + +// SetSize sets the picker dimensions. +func (g *GroupPicker) SetSize(width, height int) { + g.width = width + g.height = height + g.input.Width = min(50, width-10) +} + +// MoveUp moves the cursor up. +func (g *GroupPicker) MoveUp() { + if g.cursor > 0 { + g.cursor-- + } +} + +// MoveDown moves the cursor down. +func (g *GroupPicker) MoveDown() { + if g.cursor < len(g.groups)-1 { + g.cursor++ + } +} + +// Selected returns the currently selected group. +func (g *GroupPicker) Selected() string { + if g.cursor >= 0 && g.cursor < len(g.groups) { + return g.groups[g.cursor] + } + return "" +} + +// Len returns the number of groups. +func (g *GroupPicker) Len() int { + return len(g.groups) +} + +// Mode returns the current mode. +func (g *GroupPicker) Mode() GroupMode { + return g.mode +} + +// InitAdd initializes the form for adding a new group. +func (g *GroupPicker) InitAdd() { + g.mode = GroupModeAdd + g.editName = "" + g.input.Reset() + g.input.Focus() +} + +// InitRename initializes the form for renaming an existing group. +func (g *GroupPicker) InitRename() { + selected := g.Selected() + if selected == "" { + return + } + + g.mode = GroupModeRename + g.editName = selected + g.input.SetValue(selected) + g.input.Focus() +} + +// InitDelete starts delete confirmation. +func (g *GroupPicker) InitDelete() { + if g.Selected() == "" { + return + } + g.mode = GroupModeConfirmDelete +} + +// CancelForm cancels the current form operation. +func (g *GroupPicker) CancelForm() { + g.mode = GroupModeSelect + g.editName = "" + g.input.Reset() + g.input.Blur() +} + +// Update handles input events for form mode. +func (g *GroupPicker) Update(msg tea.KeyMsg) tea.Cmd { + var cmd tea.Cmd + g.input, cmd = g.input.Update(msg) + return cmd +} + +// FormValue returns the form input value. +func (g *GroupPicker) FormValue() string { + return strings.TrimSpace(g.input.Value()) +} + +// EditName returns the original name when renaming. +func (g *GroupPicker) EditName() string { + return g.editName +} + +// IsRename returns true if in rename mode. +func (g *GroupPicker) IsRename() bool { + return g.mode == GroupModeRename +} + +// ValidateForm validates the form value. +func (g *GroupPicker) ValidateForm() string { + value := g.FormValue() + if value == "" { + return "Group name is required" + } + return "" +} + +// View renders the group picker. +func (g *GroupPicker) View() string { + switch g.mode { + case GroupModeAdd, GroupModeRename: + return g.formView() + case GroupModeConfirmDelete: + return g.deleteView() + default: + return g.selectView() + } +} + +func (g *GroupPicker) selectView() string { + var sb strings.Builder + + sb.WriteString(titleStyle.Render("Groups")) + sb.WriteString("\n\n") + + if len(g.groups) == 0 { + sb.WriteString(helpDescStyle.Render("No groups configured.")) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("Press 'n' to create one")) + } else { + for i, group := range g.groups { + if i == g.cursor { + sb.WriteString(presetSelectedStyle.Render("▸ " + group)) + } else { + sb.WriteString(presetItemStyle.Render(" " + group)) + } + sb.WriteString("\n") + } + } + + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("↑↓ navigate • n new • r rename • d delete • Esc back")) + + return dialogStyle.Render(sb.String()) +} + +func (g *GroupPicker) formView() string { + var sb strings.Builder + + title := "Add New Group" + if g.mode == GroupModeRename { + title = "Rename Group" + } + + sb.WriteString(titleStyle.Render(title)) + sb.WriteString("\n\n") + + sb.WriteString(inputLabelStyle.Render("Name:")) + sb.WriteString("\n") + sb.WriteString(inputFocusStyle.Render(g.input.View())) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("Enter save • Esc cancel")) + + return dialogStyle.Render(sb.String()) +} + +func (g *GroupPicker) deleteView() string { + var sb strings.Builder + + groupName := g.Selected() + + sb.WriteString(titleStyle.Render("Delete Group")) + sb.WriteString("\n\n") + sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete group '" + groupName + "'?")) + sb.WriteString("\n") + sb.WriteString(helpDescStyle.Render("This will remove all hosts in this group!")) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel")) + + return dialogStyle.Render(sb.String()) +} diff --git a/internal/tui/list.go b/internal/tui/list.go new file mode 100644 index 0000000..72a0188 --- /dev/null +++ b/internal/tui/list.go @@ -0,0 +1,429 @@ +// Package tui provides the list view component. +package tui + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/lipgloss/table" + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// EntryItem represents a displayable host entry. +type EntryItem struct { + Entry protocol.HostEntry + Pending bool + HasError bool +} + +// ListView handles the list of host entries. +type ListView struct { + items []EntryItem + groups map[string][]int // group name -> indices in items + groupOrder []string // ordered group names + cursor int + width int + height int +} + +// NewListView creates a new list view. +func NewListView() *ListView { + return &ListView{ + groups: make(map[string][]int), + } +} + +// SetItems updates the list items. +func (l *ListView) SetItems(entries []protocol.HostEntry) { + l.items = make([]EntryItem, len(entries)) + l.groups = make(map[string][]int) + l.groupOrder = nil + + groupSeen := make(map[string]bool) + + for i, e := range entries { + l.items[i] = EntryItem{Entry: e} + + if !groupSeen[e.Group] { + groupSeen[e.Group] = true + l.groupOrder = append(l.groupOrder, e.Group) + } + + l.groups[e.Group] = append(l.groups[e.Group], i) + } + + // Reset cursor if out of bounds + if l.cursor >= len(l.items) { + l.cursor = max(0, len(l.items)-1) + } +} + +// SetSize sets the view dimensions. +func (l *ListView) SetSize(width, height int) { + l.width = width + l.height = height +} + +// MoveUp moves the cursor up. +func (l *ListView) MoveUp() { + if l.cursor > 0 { + l.cursor-- + } +} + +// MoveDown moves the cursor down. +func (l *ListView) MoveDown() { + if l.cursor < len(l.items)-1 { + l.cursor++ + } +} + +// Selected returns the currently selected item. +func (l *ListView) Selected() *EntryItem { + if l.cursor >= 0 && l.cursor < len(l.items) { + return &l.items[l.cursor] + } + return nil +} + +// SelectedAlias returns the alias of the selected item. +func (l *ListView) SelectedAlias() string { + if item := l.Selected(); item != nil { + return item.Entry.Alias + } + return "" +} + +// SetPending marks an item as pending. +func (l *ListView) SetPending(alias string, pending bool) { + for i := range l.items { + if l.items[i].Entry.Alias == alias { + l.items[i].Pending = pending + break + } + } +} + +// SetError marks an item as having an error. +func (l *ListView) SetError(alias string, hasError bool) { + for i := range l.items { + if l.items[i].Entry.Alias == alias { + l.items[i].HasError = hasError + break + } + } +} + +// UpdateEntry updates an entry's enabled state. +func (l *ListView) UpdateEntry(alias string, enabled bool) { + for i := range l.items { + if l.items[i].Entry.Alias == alias { + l.items[i].Entry.Enabled = enabled + l.items[i].Pending = false + l.items[i].HasError = false + break + } + } +} + +// Len returns the number of items. +func (l *ListView) Len() int { + return len(l.items) +} + +// ActiveCount returns the number of enabled entries. +func (l *ListView) ActiveCount() int { + count := 0 + for _, item := range l.items { + if item.Entry.Enabled { + count++ + } + } + return count +} + +// FindByAlias finds an item by alias. +func (l *ListView) FindByAlias(alias string) *EntryItem { + for i := range l.items { + if l.items[i].Entry.Alias == alias { + return &l.items[i] + } + } + return nil +} + +// Filter filters items by search term. +func (l *ListView) Filter(term string) []EntryItem { + if term == "" { + return l.items + } + + term = strings.ToLower(term) + var filtered []EntryItem + for _, item := range l.items { + if strings.Contains(strings.ToLower(item.Entry.Domain), term) || + strings.Contains(strings.ToLower(item.Entry.Alias), term) || + strings.Contains(strings.ToLower(item.Entry.IP), term) || + strings.Contains(strings.ToLower(item.Entry.Group), term) { + filtered = append(filtered, item) + } + } + return filtered +} + +// ViewFiltered renders the list filtered by search term. +func (l *ListView) ViewFiltered(searchTerm string) string { + if searchTerm == "" { + return l.View() + } + + filtered := l.Filter(searchTerm) + if len(filtered) == 0 { + emptyStyle := lipgloss.NewStyle().Foreground(colorMuted) + return "\n" + emptyStyle.Render(fmt.Sprintf(" No results for '%s'. Press Esc to clear search.", searchTerm)) + "\n" + } + + var sb strings.Builder + + // Show search indicator + searchIndicator := lipgloss.NewStyle(). + Foreground(colorWarning). + Bold(true). + Render(fmt.Sprintf(" Search: %s (%d results)", searchTerm, len(filtered))) + sb.WriteString(searchIndicator) + sb.WriteString("\n") + + // Group header style - bright colors for dark terminals + groupHeaderStyle := lipgloss.NewStyle(). + Bold(true). + Foreground(colorGroupHeader). + Background(lipgloss.Color("238")). + Padding(0, 1). + MarginTop(1) + + // Organize filtered items by group + groupItems := make(map[string][]EntryItem) + var groupOrder []string + groupSeen := make(map[string]bool) + + for _, item := range filtered { + group := item.Entry.Group + if !groupSeen[group] { + groupSeen[group] = true + groupOrder = append(groupOrder, group) + } + groupItems[group] = append(groupItems[group], item) + } + + for _, groupName := range groupOrder { + items := groupItems[groupName] + if len(items) == 0 { + continue + } + + // Group header + headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(items)) + sb.WriteString(groupHeaderStyle.Render(headerText)) + sb.WriteString("\n") + + // Build rows for this group's table + var rows [][]string + for _, item := range items { + status := l.getStatusString(item) + rows = append(rows, []string{ + truncate(item.Entry.Domain, 30), + truncate(item.Entry.IP, 15), + status, + }) + } + + // Create table for this group + t := table.New(). + Border(lipgloss.HiddenBorder()). + Headers("DOMAIN", "IP ADDRESS", "STATUS"). + Rows(rows...). + StyleFunc(func(row, col int) lipgloss.Style { + // Header row + if row == table.HeaderRow { + return lipgloss.NewStyle(). + Bold(true). + Foreground(colorHeader). + Padding(0, 1) + } + + baseStyle := lipgloss.NewStyle().Padding(0, 1) + + if row >= 0 && row < len(items) { + item := items[row] + + // Disabled rows are muted + if !item.Entry.Enabled && !item.Pending && !item.HasError { + return baseStyle.Foreground(colorMuted) + } + + // Status column gets colored based on status + if col == 2 { // STATUS column + if item.HasError { + return baseStyle.Foreground(colorError) + } + if item.Pending { + return baseStyle.Foreground(colorWarning) + } + if item.Entry.Enabled { + return baseStyle.Foreground(colorSuccess) + } + } + } + + return baseStyle + }) + + sb.WriteString(t.Render()) + sb.WriteString("\n") + } + + return sb.String() +} + +// GroupCount returns the number of groups. +func (l *ListView) GroupCount() int { + return len(l.groupOrder) +} + +// GetGroups returns all group names. +func (l *ListView) GetGroups() []string { + return l.groupOrder +} + +// View renders the list with groups as headers. +func (l *ListView) View() string { + if len(l.items) == 0 { + emptyStyle := lipgloss.NewStyle().Foreground(colorMuted) + return "\n" + emptyStyle.Render(" No host entries configured. Press 'n' to add a new entry.") + "\n" + } + + var sb strings.Builder + + // Group header style - bright colors for dark terminals + groupHeaderStyle := lipgloss.NewStyle(). + Bold(true). + Foreground(colorGroupHeader). + Background(lipgloss.Color("238")). + Padding(0, 1). + MarginTop(1) + + for _, groupName := range l.groupOrder { + indices := l.groups[groupName] + if len(indices) == 0 { + continue + } + + // Group header + headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(indices)) + sb.WriteString(groupHeaderStyle.Render(headerText)) + sb.WriteString("\n") + + // Build rows for this group's table + var rows [][]string + // Store actual item indices for cursor matching + itemIndices := make([]int, len(indices)) + copy(itemIndices, indices) + + for _, idx := range indices { + item := l.items[idx] + status := l.getStatusString(item) + rows = append(rows, []string{ + truncate(item.Entry.Domain, 30), + truncate(item.Entry.IP, 15), + status, + }) + } + + // Create table for this group + t := table.New(). + Border(lipgloss.HiddenBorder()). + Headers("DOMAIN", "IP ADDRESS", "STATUS"). + Rows(rows...). + StyleFunc(func(row, col int) lipgloss.Style { + // Header row + if row == table.HeaderRow { + return lipgloss.NewStyle(). + Bold(true). + Foreground(colorHeader). + Padding(0, 1) + } + + baseStyle := lipgloss.NewStyle().Padding(0, 1) + + // Check if this row is selected + if row >= 0 && row < len(itemIndices) { + actualItemIdx := itemIndices[row] + isSelected := actualItemIdx == l.cursor + item := l.items[actualItemIdx] + + // Selected row gets background highlight + if isSelected { + return baseStyle. + Background(colorSelectedBg). + Foreground(colorSelectedFg) + } + + // Disabled rows are muted + if !item.Entry.Enabled && !item.Pending && !item.HasError { + return baseStyle.Foreground(colorMuted) + } + + // Status column gets colored based on status + if col == 2 { // STATUS column + if item.HasError { + return baseStyle.Foreground(colorError) + } + if item.Pending { + return baseStyle.Foreground(colorWarning) + } + if item.Entry.Enabled { + return baseStyle.Foreground(colorSuccess) + } + } + } + + return baseStyle + }) + + sb.WriteString(t.Render()) + sb.WriteString("\n") + } + + return sb.String() +} + +func (l *ListView) getStatusString(item EntryItem) string { + if item.HasError { + return "✗ Error" + } + if item.Pending { + return "◐ Pending" + } + if item.Entry.Enabled { + return "● Active" + } + return "○ Disabled" +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/tui/list_test.go b/internal/tui/list_test.go new file mode 100644 index 0000000..d0577de --- /dev/null +++ b/internal/tui/list_test.go @@ -0,0 +1,409 @@ +package tui + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +func TestListView_SetItems(t *testing.T) { + lv := NewListView() + + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + {Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"}, + } + + lv.SetItems(entries) + + assert.Equal(t, 3, lv.Len()) + assert.Len(t, lv.groups, 2) + assert.Contains(t, lv.groupOrder, "dev") + assert.Contains(t, lv.groupOrder, "staging") +} + +func TestListView_Navigation(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + {Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"}, + } + lv.SetItems(entries) + + // Initial position + assert.Equal(t, 0, lv.cursor) + + // Move down + lv.MoveDown() + assert.Equal(t, 1, lv.cursor) + + lv.MoveDown() + assert.Equal(t, 2, lv.cursor) + + // Can't move past end + lv.MoveDown() + assert.Equal(t, 2, lv.cursor) + + // Move up + lv.MoveUp() + assert.Equal(t, 1, lv.cursor) + + lv.MoveUp() + assert.Equal(t, 0, lv.cursor) + + // Can't move before start + lv.MoveUp() + assert.Equal(t, 0, lv.cursor) +} + +func TestListView_Selected(t *testing.T) { + lv := NewListView() + + t.Run("empty list", func(t *testing.T) { + item := lv.Selected() + assert.Nil(t, item) + }) + + t.Run("with items", func(t *testing.T) { + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + } + lv.SetItems(entries) + + item := lv.Selected() + require.NotNil(t, item) + assert.Equal(t, "a.com", item.Entry.Domain) + + lv.MoveDown() + item = lv.Selected() + require.NotNil(t, item) + assert.Equal(t, "b.com", item.Entry.Domain) + }) +} + +func TestListView_SelectedAlias(t *testing.T) { + lv := NewListView() + + t.Run("empty list", func(t *testing.T) { + alias := lv.SelectedAlias() + assert.Empty(t, alias) + }) + + t.Run("with items", func(t *testing.T) { + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "my-alias", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + + alias := lv.SelectedAlias() + assert.Equal(t, "my-alias", alias) + }) +} + +func TestListView_SetPending(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + + assert.False(t, lv.items[0].Pending) + + lv.SetPending("a", true) + assert.True(t, lv.items[0].Pending) + + lv.SetPending("a", false) + assert.False(t, lv.items[0].Pending) + + // Non-existent alias should not panic + lv.SetPending("nonexistent", true) +} + +func TestListView_SetError(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + + assert.False(t, lv.items[0].HasError) + + lv.SetError("a", true) + assert.True(t, lv.items[0].HasError) + + lv.SetError("a", false) + assert.False(t, lv.items[0].HasError) +} + +func TestListView_UpdateEntry(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false, Group: "dev"}, + } + lv.SetItems(entries) + lv.items[0].Pending = true + lv.items[0].HasError = true + + lv.UpdateEntry("a", true) + + assert.True(t, lv.items[0].Entry.Enabled) + assert.False(t, lv.items[0].Pending) + assert.False(t, lv.items[0].HasError) +} + +func TestListView_ActiveCount(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + {Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"}, + } + lv.SetItems(entries) + + assert.Equal(t, 2, lv.ActiveCount()) +} + +func TestListView_FindByAlias(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"}, + } + lv.SetItems(entries) + + t.Run("found", func(t *testing.T) { + item := lv.FindByAlias("b") + require.NotNil(t, item) + assert.Equal(t, "b.com", item.Entry.Domain) + }) + + t.Run("not found", func(t *testing.T) { + item := lv.FindByAlias("nonexistent") + assert.Nil(t, item) + }) +} + +func TestListView_Filter(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "myapp.com", IP: "127.0.0.1", Alias: "myapp", Enabled: true, Group: "dev"}, + {Domain: "api.myapp.com", IP: "127.0.0.1", Alias: "api", Enabled: false, Group: "dev"}, + {Domain: "other.com", IP: "192.168.1.1", Alias: "other", Enabled: true, Group: "staging"}, + } + lv.SetItems(entries) + + t.Run("empty term", func(t *testing.T) { + filtered := lv.Filter("") + assert.Len(t, filtered, 3) + }) + + t.Run("by domain", func(t *testing.T) { + filtered := lv.Filter("myapp") + assert.Len(t, filtered, 2) + }) + + t.Run("by alias", func(t *testing.T) { + filtered := lv.Filter("api") + assert.Len(t, filtered, 1) + assert.Equal(t, "api.myapp.com", filtered[0].Entry.Domain) + }) + + t.Run("by IP", func(t *testing.T) { + filtered := lv.Filter("192.168") + assert.Len(t, filtered, 1) + assert.Equal(t, "other.com", filtered[0].Entry.Domain) + }) + + t.Run("case insensitive", func(t *testing.T) { + filtered := lv.Filter("MYAPP") + assert.Len(t, filtered, 2) + }) + + t.Run("no match", func(t *testing.T) { + filtered := lv.Filter("nonexistent") + assert.Empty(t, filtered) + }) +} + +func TestListView_View(t *testing.T) { + t.Run("empty list", func(t *testing.T) { + lv := NewListView() + view := lv.View() + assert.Contains(t, view, "No host entries") + }) + + t.Run("with items", func(t *testing.T) { + lv := NewListView() + entries := []protocol.HostEntry{ + {Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + + view := lv.View() + // Group header is shown as section title (uppercase) + assert.Contains(t, view, "DEV") + // Table headers + assert.Contains(t, view, "DOMAIN") + assert.Contains(t, view, "IP ADDRESS") + assert.Contains(t, view, "STATUS") + // Data is in the view + assert.Contains(t, view, "example.com") + assert.Contains(t, view, "127.0.0.1") + assert.Contains(t, view, "Active") + }) +} + +func TestListView_SetSize(t *testing.T) { + lv := NewListView() + lv.SetSize(80, 24) + + assert.Equal(t, 80, lv.width) + assert.Equal(t, 24, lv.height) +} + +func TestListView_CursorBounds(t *testing.T) { + lv := NewListView() + + // Set items + entries := []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + {Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + lv.cursor = 1 + + // Set fewer items - cursor should be adjusted + entries = []protocol.HostEntry{ + {Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"}, + } + lv.SetItems(entries) + + assert.Equal(t, 0, lv.cursor) +} + +func TestTruncate(t *testing.T) { + tests := []struct { + input string + maxLen int + expected string + }{ + {"short", 10, "short"}, + {"exactly10!", 10, "exactly10!"}, + {"this is too long", 10, "this is..."}, + {"", 5, ""}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestMax(t *testing.T) { + assert.Equal(t, 5, max(3, 5)) + assert.Equal(t, 5, max(5, 3)) + assert.Equal(t, 5, max(5, 5)) + assert.Equal(t, 0, max(0, -1)) +} + +// Matrix test for navigation +func TestListView_Navigation_Matrix(t *testing.T) { + sizes := []int{1, 5, 10, 100} + + for _, size := range sizes { + t.Run("size="+string(rune('0'+size)), func(t *testing.T) { + lv := NewListView() + + entries := make([]protocol.HostEntry, size) + for i := range entries { + entries[i] = protocol.HostEntry{ + Domain: "domain" + string(rune('a'+i%26)) + ".com", + IP: "127.0.0.1", + Alias: "alias" + string(rune('a'+i%26)), + Enabled: true, + Group: "dev", + } + } + lv.SetItems(entries) + + // Move to end + for i := 0; i < size*2; i++ { + lv.MoveDown() + } + assert.Equal(t, size-1, lv.cursor) + + // Move to start + for i := 0; i < size*2; i++ { + lv.MoveUp() + } + assert.Equal(t, 0, lv.cursor) + }) + } +} + +func BenchmarkListView_SetItems(b *testing.B) { + entries := make([]protocol.HostEntry, 100) + for i := range entries { + entries[i] = protocol.HostEntry{ + Domain: "domain.com", + IP: "127.0.0.1", + Alias: "alias", + Enabled: true, + Group: "dev", + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + lv := NewListView() + lv.SetItems(entries) + } +} + +func BenchmarkListView_Filter(b *testing.B) { + lv := NewListView() + entries := make([]protocol.HostEntry, 100) + for i := range entries { + entries[i] = protocol.HostEntry{ + Domain: "domain" + string(rune('a'+i%26)) + ".com", + IP: "127.0.0.1", + Alias: "alias" + string(rune('a'+i%26)), + Enabled: true, + Group: "dev", + } + } + lv.SetItems(entries) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = lv.Filter("domain") + } +} + +func BenchmarkListView_View(b *testing.B) { + lv := NewListView() + entries := make([]protocol.HostEntry, 50) + for i := range entries { + entries[i] = protocol.HostEntry{ + Domain: "domain.com", + IP: "127.0.0.1", + Alias: "alias", + Enabled: i%2 == 0, + Group: "group" + string(rune('a'+i%5)), + } + } + lv.SetItems(entries) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = lv.View() + } +} diff --git a/internal/tui/presets.go b/internal/tui/presets.go new file mode 100644 index 0000000..c79df6f --- /dev/null +++ b/internal/tui/presets.go @@ -0,0 +1,356 @@ +// Package tui provides the preset picker component. +package tui + +import ( + "strings" + + "github.com/charmbracelet/bubbles/textinput" + tea "github.com/charmbracelet/bubbletea" + "github.com/lukaszraczylo/lolcathost/internal/protocol" +) + +// PresetMode represents the preset view mode. +type PresetMode int + +const ( + PresetModeSelect PresetMode = iota + PresetModeAdd + PresetModeEdit + PresetModeConfirmDelete +) + +// PresetFormField represents a form field index. +type PresetFormField int + +const ( + PresetFieldName PresetFormField = iota + PresetFieldEnable + PresetFieldDisable + PresetFieldCount +) + +// PresetPicker handles the preset selection and management UI. +type PresetPicker struct { + presets []protocol.PresetInfo + cursor int + width int + height int + mode PresetMode + fields []textinput.Model + focus PresetFormField + editName string // Original name when editing +} + +// NewPresetPicker creates a new preset picker. +func NewPresetPicker() *PresetPicker { + fields := make([]textinput.Model, PresetFieldCount) + + // Name field + fields[PresetFieldName] = textinput.New() + fields[PresetFieldName].Placeholder = "preset-name" + fields[PresetFieldName].CharLimit = 63 + + // Enable field + fields[PresetFieldEnable] = textinput.New() + fields[PresetFieldEnable].Placeholder = "alias1,alias2,alias3" + fields[PresetFieldEnable].CharLimit = 500 + + // Disable field + fields[PresetFieldDisable] = textinput.New() + fields[PresetFieldDisable].Placeholder = "alias1,alias2,alias3" + fields[PresetFieldDisable].CharLimit = 500 + + return &PresetPicker{ + fields: fields, + mode: PresetModeSelect, + } +} + +// SetPresets updates the available presets (legacy method for compatibility). +func (p *PresetPicker) SetPresets(presets []string) { + p.presets = make([]protocol.PresetInfo, len(presets)) + for i, name := range presets { + p.presets[i] = protocol.PresetInfo{Name: name} + } + if p.cursor >= len(presets) { + p.cursor = max(0, len(presets)-1) + } +} + +// SetPresetsWithInfo updates the available presets with full info. +func (p *PresetPicker) SetPresetsWithInfo(presets []protocol.PresetInfo) { + p.presets = presets + if p.cursor >= len(presets) { + p.cursor = max(0, len(presets)-1) + } +} + +// SetSize sets the picker dimensions. +func (p *PresetPicker) SetSize(width, height int) { + p.width = width + p.height = height + + inputWidth := min(60, width-10) + for i := range p.fields { + p.fields[i].Width = inputWidth + } +} + +// MoveUp moves the cursor up. +func (p *PresetPicker) MoveUp() { + if p.cursor > 0 { + p.cursor-- + } +} + +// MoveDown moves the cursor down. +func (p *PresetPicker) MoveDown() { + if p.cursor < len(p.presets)-1 { + p.cursor++ + } +} + +// Selected returns the currently selected preset name. +func (p *PresetPicker) Selected() string { + if p.cursor >= 0 && p.cursor < len(p.presets) { + return p.presets[p.cursor].Name + } + return "" +} + +// SelectedInfo returns the currently selected preset info. +func (p *PresetPicker) SelectedInfo() *protocol.PresetInfo { + if p.cursor >= 0 && p.cursor < len(p.presets) { + return &p.presets[p.cursor] + } + return nil +} + +// Len returns the number of presets. +func (p *PresetPicker) Len() int { + return len(p.presets) +} + +// Mode returns the current mode. +func (p *PresetPicker) Mode() PresetMode { + return p.mode +} + +// SetMode sets the mode. +func (p *PresetPicker) SetMode(mode PresetMode) { + p.mode = mode +} + +// InitAdd initializes the form for adding a new preset. +func (p *PresetPicker) InitAdd() { + p.mode = PresetModeAdd + p.editName = "" + for i := range p.fields { + p.fields[i].Reset() + } + p.focus = PresetFieldName + p.fields[PresetFieldName].Focus() +} + +// InitEdit initializes the form for editing an existing preset. +func (p *PresetPicker) InitEdit() { + preset := p.SelectedInfo() + if preset == nil { + return + } + + p.mode = PresetModeEdit + p.editName = preset.Name + + p.fields[PresetFieldName].SetValue(preset.Name) + p.fields[PresetFieldEnable].SetValue(strings.Join(preset.Enable, ",")) + p.fields[PresetFieldDisable].SetValue(strings.Join(preset.Disable, ",")) + + p.focus = PresetFieldName + p.fields[PresetFieldName].Focus() +} + +// InitDelete starts delete confirmation. +func (p *PresetPicker) InitDelete() { + if p.SelectedInfo() == nil { + return + } + p.mode = PresetModeConfirmDelete +} + +// CancelForm cancels the current form operation. +func (p *PresetPicker) CancelForm() { + p.mode = PresetModeSelect + p.editName = "" + for i := range p.fields { + p.fields[i].Reset() + p.fields[i].Blur() + } +} + +// Update handles input events for form mode. +func (p *PresetPicker) Update(msg tea.KeyMsg) tea.Cmd { + switch msg.String() { + case "tab", "down": + p.nextField() + return nil + case "shift+tab", "up": + p.prevField() + return nil + } + + // Update the focused field + var cmd tea.Cmd + p.fields[p.focus], cmd = p.fields[p.focus].Update(msg) + return cmd +} + +func (p *PresetPicker) nextField() { + p.fields[p.focus].Blur() + p.focus = (p.focus + 1) % PresetFieldCount + p.fields[p.focus].Focus() +} + +func (p *PresetPicker) prevField() { + p.fields[p.focus].Blur() + p.focus = (p.focus - 1 + PresetFieldCount) % PresetFieldCount + p.fields[p.focus].Focus() +} + +// FormValues returns the form values (name, enable list, disable list). +func (p *PresetPicker) FormValues() (name string, enable, disable []string) { + name = strings.TrimSpace(p.fields[PresetFieldName].Value()) + + enableStr := strings.TrimSpace(p.fields[PresetFieldEnable].Value()) + if enableStr != "" { + for _, s := range strings.Split(enableStr, ",") { + if trimmed := strings.TrimSpace(s); trimmed != "" { + enable = append(enable, trimmed) + } + } + } + + disableStr := strings.TrimSpace(p.fields[PresetFieldDisable].Value()) + if disableStr != "" { + for _, s := range strings.Split(disableStr, ",") { + if trimmed := strings.TrimSpace(s); trimmed != "" { + disable = append(disable, trimmed) + } + } + } + + return name, enable, disable +} + +// EditName returns the original name when editing. +func (p *PresetPicker) EditName() string { + return p.editName +} + +// IsEdit returns true if in edit mode. +func (p *PresetPicker) IsEdit() bool { + return p.mode == PresetModeEdit +} + +// ValidateForm validates the form values. +func (p *PresetPicker) ValidateForm() string { + name, enable, disable := p.FormValues() + + if name == "" { + return "Preset name is required" + } + if len(enable) == 0 && len(disable) == 0 { + return "At least one alias to enable or disable is required" + } + + return "" +} + +// View renders the preset picker. +func (p *PresetPicker) View() string { + switch p.mode { + case PresetModeAdd, PresetModeEdit: + return p.formView() + case PresetModeConfirmDelete: + return p.deleteView() + default: + return p.selectView() + } +} + +func (p *PresetPicker) selectView() string { + var sb strings.Builder + + sb.WriteString(titleStyle.Render("Presets")) + sb.WriteString("\n\n") + + if len(p.presets) == 0 { + sb.WriteString(helpDescStyle.Render("No presets configured.")) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("Press 'n' to create one")) + } else { + for i, preset := range p.presets { + if i == p.cursor { + sb.WriteString(presetSelectedStyle.Render("▸ " + preset.Name)) + } else { + sb.WriteString(presetItemStyle.Render(" " + preset.Name)) + } + sb.WriteString("\n") + } + } + + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("↑↓ navigate • Enter apply • n new • e edit • d delete • Esc cancel")) + + return dialogStyle.Render(sb.String()) +} + +func (p *PresetPicker) formView() string { + var sb strings.Builder + + title := "Add New Preset" + if p.mode == PresetModeEdit { + title = "Edit Preset" + } + + sb.WriteString(titleStyle.Render(title)) + sb.WriteString("\n\n") + + labels := []string{"Name:", "Enable aliases (comma-separated):", "Disable aliases (comma-separated):"} + + for i, label := range labels { + sb.WriteString(inputLabelStyle.Render(label)) + sb.WriteString("\n") + + style := inputStyle + if PresetFormField(i) == p.focus { + style = inputFocusStyle + } + + sb.WriteString(style.Render(p.fields[i].View())) + sb.WriteString("\n\n") + } + + sb.WriteString("\n") + sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • Enter save • Esc cancel")) + + return dialogStyle.Render(sb.String()) +} + +func (p *PresetPicker) deleteView() string { + var sb strings.Builder + + preset := p.SelectedInfo() + presetName := "" + if preset != nil { + presetName = preset.Name + } + + sb.WriteString(titleStyle.Render("Delete Preset")) + sb.WriteString("\n\n") + sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete preset '" + presetName + "'?")) + sb.WriteString("\n\n") + sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel")) + + return dialogStyle.Render(sb.String()) +} diff --git a/internal/tui/styles.go b/internal/tui/styles.go new file mode 100644 index 0000000..ffa2bb6 --- /dev/null +++ b/internal/tui/styles.go @@ -0,0 +1,150 @@ +// Package tui provides the terminal user interface. +package tui + +import ( + "github.com/charmbracelet/lipgloss" +) + +// Colors - matching kportal style, optimized for dark terminals +var ( + colorPrimary = lipgloss.Color("205") // Pink/Magenta + colorSuccess = lipgloss.Color("42") // Green + colorWarning = lipgloss.Color("220") // Yellow + colorError = lipgloss.Color("196") // Red + colorMuted = lipgloss.Color("245") // Gray (brighter for dark terminals) + colorAccent = lipgloss.Color("141") // Light purple (brighter for dark terminals) + colorHeader = lipgloss.Color("220") // Yellow for headers + colorSelectedBg = lipgloss.Color("236") // Gray background for selection + colorSelectedFg = lipgloss.Color("255") // White foreground for selection + colorGroupHeader = lipgloss.Color("213") // Light pink for group headers +) + +// Title and header styles +var ( + titleStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(colorHeader). + Padding(0, 1) +) + +// Status indicators +var ( + enabledStyle = lipgloss.NewStyle(). + Foreground(colorSuccess). + Bold(true) + + disabledStyle = lipgloss.NewStyle(). + Foreground(colorMuted) + + pendingStyle = lipgloss.NewStyle(). + Foreground(colorWarning) + + errorIndicatorStyle = lipgloss.NewStyle(). + Foreground(colorError) +) + +// Status bar and help +var ( + statusBarStyle = lipgloss.NewStyle(). + Foreground(colorMuted) + + connectedStyle = lipgloss.NewStyle(). + Foreground(colorSuccess). + SetString("Connected") + + disconnectedStyle = lipgloss.NewStyle(). + Foreground(colorError). + SetString("Disconnected") + + helpBarStyle = lipgloss.NewStyle(). + Foreground(colorMuted) + + helpKeyStyle = lipgloss.NewStyle(). + Foreground(colorHeader). + Bold(true) + + helpDescStyle = lipgloss.NewStyle(). + Foreground(colorMuted) +) + +// Message styles +var ( + errorMsgStyle = lipgloss.NewStyle(). + Foreground(colorError). + Bold(true). + MarginTop(1) + + successMsgStyle = lipgloss.NewStyle(). + Foreground(colorSuccess). + MarginTop(1) + + updateStyle = lipgloss.NewStyle(). + Foreground(colorSuccess). + Bold(true) +) + +// Form styles +var ( + inputLabelStyle = lipgloss.NewStyle(). + Foreground(colorPrimary). + Bold(true) + + inputStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorMuted). + Padding(0, 1) + + inputFocusStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorPrimary). + Padding(0, 1) +) + +// Dialog/modal styles +var ( + dialogStyle = lipgloss.NewStyle(). + Border(lipgloss.RoundedBorder()). + BorderForeground(colorAccent). + Padding(1, 2) + + presetItemStyle = lipgloss.NewStyle(). + Padding(0, 1) + + presetSelectedStyle = lipgloss.NewStyle(). + Background(colorSelectedBg). + Foreground(colorSelectedFg). + Padding(0, 1) +) + +// Indicator returns the appropriate status indicator string. +func Indicator(enabled bool, pending bool, hasError bool) string { + if hasError { + return errorIndicatorStyle.Render("✗") + } + if pending { + return pendingStyle.Render("◐") + } + if enabled { + return enabledStyle.Render("●") + } + return disabledStyle.Render("○") +} + +// StatusText returns the status text with appropriate styling +func StatusText(enabled bool, pending bool, hasError bool) string { + if hasError { + return errorIndicatorStyle.Render("✗ Error") + } + if pending { + return pendingStyle.Render("◐ Pending") + } + if enabled { + return enabledStyle.Render("● Active") + } + return disabledStyle.Render("○ Disabled") +} + +// HelpItem formats a help item. +func HelpItem(key, desc string) string { + return helpKeyStyle.Render(key) + " " + helpDescStyle.Render(desc) +} diff --git a/internal/version/checker.go b/internal/version/checker.go new file mode 100644 index 0000000..24519bf --- /dev/null +++ b/internal/version/checker.go @@ -0,0 +1,159 @@ +// Package version provides version checking against GitHub releases. +package version + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +const ( + // githubReleasesURL is the GitHub API endpoint for latest release + githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest" + // requestTimeout is the timeout for HTTP requests + requestTimeout = 5 * time.Second +) + +// ReleaseInfo contains information about a GitHub release +type ReleaseInfo struct { + TagName string `json:"tag_name"` + HTMLURL string `json:"html_url"` + Name string `json:"name"` +} + +// UpdateInfo contains information about an available update +type UpdateInfo struct { + CurrentVersion string + LatestVersion string + ReleaseURL string + ReleaseName string +} + +// Checker checks for new versions on GitHub +type Checker struct { + owner string + repo string + current string + client *http.Client +} + +// NewChecker creates a new version checker +func NewChecker(owner, repo, currentVersion string) *Checker { + return &Checker{ + owner: owner, + repo: repo, + current: normalizeVersion(currentVersion), + client: &http.Client{ + Timeout: requestTimeout, + }, + } +} + +// CheckForUpdate checks if a newer version is available. +// Returns nil if current version is up to date or if check fails. +// This is designed to fail silently - network errors should not impact the user. +func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo { + release, err := c.fetchLatestRelease(ctx) + if err != nil { + return nil + } + + latestVersion := normalizeVersion(release.TagName) + if isNewerVersion(latestVersion, c.current) { + return &UpdateInfo{ + CurrentVersion: c.current, + LatestVersion: latestVersion, + ReleaseURL: release.HTMLURL, + ReleaseName: release.Name, + } + } + + return nil +} + +// fetchLatestRelease fetches the latest release info from GitHub API +func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) { + url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("User-Agent", "lolcathost-version-checker") + + resp, err := c.client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode) + } + + var release ReleaseInfo + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, err + } + + return &release, nil +} + +// normalizeVersion removes 'v' or 'V' prefix and trims whitespace +func normalizeVersion(v string) string { + v = strings.TrimSpace(v) + v = strings.TrimPrefix(v, "v") + v = strings.TrimPrefix(v, "V") + return v +} + +// isNewerVersion compares two semver-like versions. +// Returns true if latest is newer than current. +func isNewerVersion(latest, current string) bool { + latestParts := parseVersion(latest) + currentParts := parseVersion(current) + + // Compare each part + for i := 0; i < len(latestParts) && i < len(currentParts); i++ { + if latestParts[i] > currentParts[i] { + return true + } + if latestParts[i] < currentParts[i] { + return false + } + } + + // If all compared parts are equal, longer version is newer + // e.g., 1.0.1 > 1.0 + return len(latestParts) > len(currentParts) +} + +// parseVersion splits a version string into numeric parts +func parseVersion(v string) []int { + // Remove any suffix like -beta, -rc1, etc. + if idx := strings.IndexAny(v, "-+"); idx != -1 { + v = v[:idx] + } + + parts := strings.Split(v, ".") + result := make([]int, 0, len(parts)) + + for _, p := range parts { + var num int + fmt.Sscanf(p, "%d", &num) + result = append(result, num) + } + + return result +} + +// FormatUpdateMessage formats a user-friendly update notification +func (u *UpdateInfo) FormatUpdateMessage() string { + return fmt.Sprintf("New version available: %s (current: %s) - %s", + u.LatestVersion, u.CurrentVersion, u.ReleaseURL) +} diff --git a/internal/version/checker_test.go b/internal/version/checker_test.go new file mode 100644 index 0000000..a7ba712 --- /dev/null +++ b/internal/version/checker_test.go @@ -0,0 +1,99 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeVersion(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"v1.0.0", "1.0.0"}, + {"1.0.0", "1.0.0"}, + {" v2.1.3 ", "2.1.3"}, + {"V1.0.0", "1.0.0"}, + {"v0.1.0", "0.1.0"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := normalizeVersion(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestParseVersion(t *testing.T) { + tests := []struct { + input string + expected []int + }{ + {"1.0.0", []int{1, 0, 0}}, + {"2.1.3", []int{2, 1, 3}}, + {"1.0", []int{1, 0}}, + {"10.20.30", []int{10, 20, 30}}, + {"1.0.0-beta", []int{1, 0, 0}}, + {"1.0.0-rc1", []int{1, 0, 0}}, + {"1.0.0+build123", []int{1, 0, 0}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := parseVersion(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsNewerVersion(t *testing.T) { + tests := []struct { + name string + latest string + current string + expected bool + }{ + {"major version bump", "2.0.0", "1.0.0", true}, + {"minor version bump", "1.1.0", "1.0.0", true}, + {"patch version bump", "1.0.1", "1.0.0", true}, + {"same version", "1.0.0", "1.0.0", false}, + {"current is newer major", "1.0.0", "2.0.0", false}, + {"current is newer minor", "1.0.0", "1.1.0", false}, + {"current is newer patch", "1.0.0", "1.0.1", false}, + {"longer version is newer", "1.0.1", "1.0", true}, + {"shorter version is older", "1.0", "1.0.1", false}, + {"double digit versions", "10.0.0", "9.0.0", true}, + {"with prerelease suffix", "1.1.0", "1.0.0-beta", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isNewerVersion(tt.latest, tt.current) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestUpdateInfo_FormatUpdateMessage(t *testing.T) { + info := &UpdateInfo{ + CurrentVersion: "1.0.0", + LatestVersion: "1.1.0", + ReleaseURL: "https://github.com/lukaszraczylo/lolcathost/releases/tag/v1.1.0", + } + + msg := info.FormatUpdateMessage() + assert.Contains(t, msg, "1.0.0") + assert.Contains(t, msg, "1.1.0") + assert.Contains(t, msg, "https://github.com") +} + +func TestNewChecker(t *testing.T) { + checker := NewChecker("lukaszraczylo", "lolcathost", "v1.0.0") + + assert.Equal(t, "lukaszraczylo", checker.owner) + assert.Equal(t, "lolcathost", checker.repo) + assert.Equal(t, "1.0.0", checker.current) // Should be normalized + assert.NotNil(t, checker.client) +} diff --git a/semver.yaml b/semver.yaml new file mode 100644 index 0000000..95d3a1d --- /dev/null +++ b/semver.yaml @@ -0,0 +1,22 @@ +version: 1 + +force: + major: 0 + minor: 1 + patch: 0 + +blacklist: + - "Merge branch" + - "Merge pull request" + - "WIP" + +wording: + minor: + - "feat" + - "feature" + major: + - "breaking" + - "major" + - "BREAKING CHANGE" + release: + - "release candidate"