Initial commit.

This commit is contained in:
2025-11-28 02:50:25 +00:00
commit 22552aec99
41 changed files with 10626 additions and 0 deletions
+427
View File
@@ -0,0 +1,427 @@
// Package client provides a client library for communicating with the lolcathost daemon.
package client
import (
"bufio"
"encoding/json"
"fmt"
"net"
"sync"
"time"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// Client is a client for the lolcathost daemon.
type Client struct {
socketPath string
conn net.Conn
reader *bufio.Reader
timeout time.Duration
mu sync.Mutex
}
// New creates a new client.
func New(socketPath string) *Client {
return &Client{
socketPath: socketPath,
timeout: 5 * time.Second,
}
}
// NewWithTimeout creates a new client with a custom timeout.
func NewWithTimeout(socketPath string, timeout time.Duration) *Client {
return &Client{
socketPath: socketPath,
timeout: timeout,
}
}
// Connect establishes a connection to the daemon.
func (c *Client) Connect() error {
c.mu.Lock()
defer c.mu.Unlock()
// Close existing connection if any
if c.conn != nil {
c.conn.Close()
c.conn = nil
c.reader = nil
}
conn, err := net.DialTimeout("unix", c.socketPath, c.timeout)
if err != nil {
return fmt.Errorf("failed to connect to daemon: %w", err)
}
c.conn = conn
c.reader = bufio.NewReader(conn)
return nil
}
// Close closes the connection.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
err := c.conn.Close()
c.conn = nil
c.reader = nil
return err
}
return nil
}
// send sends a request and receives a response.
func (c *Client) send(req *protocol.Request) (*protocol.Response, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return nil, fmt.Errorf("not connected")
}
// Set deadline
c.conn.SetDeadline(time.Now().Add(c.timeout))
// Send request
data, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
data = append(data, '\n')
if _, err := c.conn.Write(data); err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
// Read response
line, err := c.reader.ReadBytes('\n')
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
var resp protocol.Response
if err := json.Unmarshal(line, &resp); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
return &resp, nil
}
// Ping checks if the daemon is responsive.
func (c *Client) Ping() error {
req, _ := protocol.NewRequest(protocol.RequestPing, nil)
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("ping failed: %s", resp.Message)
}
return nil
}
// Status returns the daemon's status.
func (c *Client) Status() (*protocol.StatusData, error) {
req, _ := protocol.NewRequest(protocol.RequestStatus, nil)
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("status failed: %s", resp.Message)
}
var data protocol.StatusData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return &data, nil
}
// List returns all host entries.
func (c *Client) List() ([]protocol.HostEntry, error) {
req, _ := protocol.NewRequest(protocol.RequestList, nil)
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("list failed: %s", resp.Message)
}
var data protocol.ListData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return data.Entries, nil
}
// Set enables or disables a host entry by alias.
func (c *Client) Set(alias string, enabled bool, force bool) (*protocol.SetData, error) {
req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{
Alias: alias,
Enabled: enabled,
Force: force,
})
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
var data protocol.SetData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return &data, nil
}
// Enable enables a host entry by alias.
func (c *Client) Enable(alias string) (*protocol.SetData, error) {
return c.Set(alias, true, false)
}
// Disable disables a host entry by alias.
func (c *Client) Disable(alias string) (*protocol.SetData, error) {
return c.Set(alias, false, false)
}
// Add adds a new host entry.
func (c *Client) Add(domain, ip, alias, group string, enabled bool) (*protocol.SetData, error) {
req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{
Domain: domain,
IP: ip,
Alias: alias,
Group: group,
Enabled: enabled,
})
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
var data protocol.SetData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return &data, nil
}
// Delete removes a host entry by alias.
func (c *Client) Delete(alias string) error {
req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{
Alias: alias,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// AddGroup adds a new group.
func (c *Client) AddGroup(name string) error {
req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{
Name: name,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// DeleteGroup removes a group and all its hosts.
func (c *Client) DeleteGroup(name string) error {
req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{
Name: name,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// ListGroups returns all group names.
func (c *Client) ListGroups() ([]string, error) {
req, _ := protocol.NewRequest(protocol.RequestListGroups, nil)
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
var data protocol.GroupsData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return data.Groups, nil
}
// Sync synchronizes the config to the hosts file.
func (c *Client) Sync() error {
req, _ := protocol.NewRequest(protocol.RequestSync, nil)
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("sync failed: %s", resp.Message)
}
return nil
}
// ApplyPreset applies a named preset.
func (c *Client) ApplyPreset(name string) error {
req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{
Name: name,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("preset failed: %s", resp.Message)
}
return nil
}
// Rollback restores a backup by name.
func (c *Client) Rollback(backupName string) error {
req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{
BackupName: backupName,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("rollback failed: %s", resp.Message)
}
return nil
}
// ListBackups returns available backups.
func (c *Client) ListBackups() ([]protocol.BackupInfo, error) {
req, _ := protocol.NewRequest(protocol.RequestBackups, nil)
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("backups failed: %s", resp.Message)
}
var data protocol.BackupsData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return data.Backups, nil
}
// RenameGroup renames a group.
func (c *Client) RenameGroup(oldName, newName string) error {
req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{
OldName: oldName,
NewName: newName,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// AddPreset adds a new preset.
func (c *Client) AddPreset(name string, enable, disable []string) error {
req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{
Name: name,
Enable: enable,
Disable: disable,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// DeletePreset removes a preset by name.
func (c *Client) DeletePreset(name string) error {
req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{
Name: name,
})
resp, err := c.send(req)
if err != nil {
return err
}
if !resp.IsOK() {
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
return nil
}
// ListPresets returns all presets.
func (c *Client) ListPresets() ([]protocol.PresetInfo, error) {
req, _ := protocol.NewRequest(protocol.RequestListPresets, nil)
resp, err := c.send(req)
if err != nil {
return nil, err
}
if !resp.IsOK() {
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
}
var data protocol.PresetsData
if err := resp.ParseData(&data); err != nil {
return nil, err
}
return data.Presets, nil
}
// IsConnected checks if the daemon is reachable.
func IsConnected(socketPath string) bool {
client := New(socketPath)
if err := client.Connect(); err != nil {
return false
}
defer client.Close()
return client.Ping() == nil
}
+516
View File
@@ -0,0 +1,516 @@
package client
import (
"bufio"
"encoding/json"
"net"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// mockServer creates a mock Unix socket server for testing
type mockServer struct {
listener net.Listener
path string
handler func(req *protocol.Request) *protocol.Response
}
func newMockServer(t *testing.T) *mockServer {
// Use /tmp directly to avoid long paths (Unix socket paths have ~104 char limit on macOS)
tmpDir, err := os.MkdirTemp("/tmp", "lolcat")
require.NoError(t, err)
t.Cleanup(func() { os.RemoveAll(tmpDir) })
socketPath := filepath.Join(tmpDir, "s.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(t, err)
ms := &mockServer{
listener: listener,
path: socketPath,
}
go ms.serve()
return ms
}
func (ms *mockServer) serve() {
for {
conn, err := ms.listener.Accept()
if err != nil {
return
}
go ms.handleConn(conn)
}
}
func (ms *mockServer) handleConn(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
return
}
var req protocol.Request
if err := json.Unmarshal(line, &req); err != nil {
continue
}
var resp *protocol.Response
if ms.handler != nil {
resp = ms.handler(&req)
} else {
resp, _ = protocol.NewOKResponse(nil)
}
data, _ := json.Marshal(resp)
conn.Write(append(data, '\n'))
}
}
func (ms *mockServer) close() {
ms.listener.Close()
os.Remove(ms.path)
}
func TestClient_Connect(t *testing.T) {
t.Run("success", func(t *testing.T) {
server := newMockServer(t)
defer server.close()
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
assert.NotNil(t, client.conn)
assert.NotNil(t, client.reader)
})
t.Run("failure - socket not found", func(t *testing.T) {
client := New("/nonexistent/socket.sock")
err := client.Connect()
assert.Error(t, err)
})
}
func TestClient_Ping(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestPing {
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected request")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
err = client.Ping()
assert.NoError(t, err)
}
func TestClient_Status(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestStatus {
resp, _ := protocol.NewOKResponse(protocol.StatusData{
Running: true,
Version: "1.0.0",
Uptime: 3600,
ActiveCount: 5,
RequestCount: 100,
})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
status, err := client.Status()
require.NoError(t, err)
assert.True(t, status.Running)
assert.Equal(t, "1.0.0", status.Version)
assert.Equal(t, int64(3600), status.Uptime)
assert.Equal(t, 5, status.ActiveCount)
assert.Equal(t, int64(100), status.RequestCount)
}
func TestClient_List(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestList {
resp, _ := protocol.NewOKResponse(protocol.ListData{
Entries: []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
},
})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
entries, err := client.List()
require.NoError(t, err)
assert.Len(t, entries, 2)
assert.Equal(t, "a.com", entries[0].Domain)
assert.True(t, entries[0].Enabled)
assert.Equal(t, "b.com", entries[1].Domain)
assert.False(t, entries[1].Enabled)
}
func TestClient_Set(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestSet {
var payload protocol.SetPayload
req.ParsePayload(&payload)
resp, _ := protocol.NewOKResponse(protocol.SetData{
Domain: "example.com",
Applied: true,
})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
data, err := client.Set("test", true, false)
require.NoError(t, err)
assert.Equal(t, "example.com", data.Domain)
assert.True(t, data.Applied)
}
func TestClient_Enable(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestSet {
var payload protocol.SetPayload
req.ParsePayload(&payload)
assert.True(t, payload.Enabled)
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
_, err = client.Enable("test")
assert.NoError(t, err)
}
func TestClient_Disable(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestSet {
var payload protocol.SetPayload
req.ParsePayload(&payload)
assert.False(t, payload.Enabled)
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
_, err = client.Disable("test")
assert.NoError(t, err)
}
func TestClient_Sync(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestSync {
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
err = client.Sync()
assert.NoError(t, err)
}
func TestClient_ApplyPreset(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestPreset {
var payload protocol.PresetPayload
req.ParsePayload(&payload)
assert.Equal(t, "local", payload.Name)
resp, _ := protocol.NewOKResponse(map[string]string{"preset": "local"})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
err = client.ApplyPreset("local")
assert.NoError(t, err)
}
func TestClient_Rollback(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestRollback {
var payload protocol.RollbackPayload
req.ParsePayload(&payload)
assert.Equal(t, "hosts.backup.bak", payload.BackupName)
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
err = client.Rollback("hosts.backup.bak")
assert.NoError(t, err)
}
func TestClient_ListBackups(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
if req.Type == protocol.RequestBackups {
resp, _ := protocol.NewOKResponse(protocol.BackupsData{
Backups: []protocol.BackupInfo{
{Name: "hosts.20231201.bak", Timestamp: 1701432000, Size: 1024},
{Name: "hosts.20231130.bak", Timestamp: 1701345600, Size: 1000},
},
})
return resp
}
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
backups, err := client.ListBackups()
require.NoError(t, err)
assert.Len(t, backups, 2)
assert.Equal(t, "hosts.20231201.bak", backups[0].Name)
}
func TestClient_ErrorResponse(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, "domain is blocked")
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
_, err = client.Set("test", true, false)
assert.Error(t, err)
assert.Contains(t, err.Error(), "domain is blocked")
}
func TestClient_NotConnected(t *testing.T) {
client := New("/nonexistent/socket.sock")
_, err := client.Status()
assert.Error(t, err)
assert.Contains(t, err.Error(), "not connected")
}
func TestClient_Timeout(t *testing.T) {
client := NewWithTimeout("/nonexistent.sock", 100*time.Millisecond)
assert.Equal(t, 100*time.Millisecond, client.timeout)
}
func TestIsConnected(t *testing.T) {
t.Run("connected", func(t *testing.T) {
server := newMockServer(t)
defer server.close()
server.handler = func(req *protocol.Request) *protocol.Response {
resp, _ := protocol.NewOKResponse(nil)
return resp
}
connected := IsConnected(server.path)
assert.True(t, connected)
})
t.Run("not connected", func(t *testing.T) {
connected := IsConnected("/nonexistent/socket.sock")
assert.False(t, connected)
})
}
// Matrix test for request types
func TestClient_RequestTypes_Matrix(t *testing.T) {
types := []struct {
name string
reqType protocol.RequestType
call func(*Client) error
}{
{"ping", protocol.RequestPing, func(c *Client) error { return c.Ping() }},
{"status", protocol.RequestStatus, func(c *Client) error { _, err := c.Status(); return err }},
{"list", protocol.RequestList, func(c *Client) error { _, err := c.List(); return err }},
{"sync", protocol.RequestSync, func(c *Client) error { return c.Sync() }},
{"preset", protocol.RequestPreset, func(c *Client) error { return c.ApplyPreset("test") }},
{"backups", protocol.RequestBackups, func(c *Client) error { _, err := c.ListBackups(); return err }},
}
for _, tt := range types {
t.Run(tt.name, func(t *testing.T) {
server := newMockServer(t)
defer server.close()
receivedType := protocol.RequestType("")
server.handler = func(req *protocol.Request) *protocol.Response {
receivedType = req.Type
switch req.Type {
case protocol.RequestStatus:
resp, _ := protocol.NewOKResponse(protocol.StatusData{})
return resp
case protocol.RequestList:
resp, _ := protocol.NewOKResponse(protocol.ListData{})
return resp
case protocol.RequestBackups:
resp, _ := protocol.NewOKResponse(protocol.BackupsData{})
return resp
default:
resp, _ := protocol.NewOKResponse(nil)
return resp
}
}
client := New(server.path)
err := client.Connect()
require.NoError(t, err)
defer client.Close()
_ = tt.call(client)
assert.Equal(t, tt.reqType, receivedType)
})
}
}
func BenchmarkClient_Ping(b *testing.B) {
tmpDir := b.TempDir()
socketPath := filepath.Join(tmpDir, "bench.sock")
listener, err := net.Listen("unix", socketPath)
require.NoError(b, err)
defer listener.Close()
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer c.Close()
reader := bufio.NewReader(c)
for {
_, err := reader.ReadBytes('\n')
if err != nil {
return
}
resp, _ := protocol.NewOKResponse(nil)
data, _ := json.Marshal(resp)
c.Write(append(data, '\n'))
}
}(conn)
}
}()
client := New(socketPath)
err = client.Connect()
require.NoError(b, err)
defer client.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = client.Ping()
}
}
+541
View File
@@ -0,0 +1,541 @@
// Package config handles YAML configuration parsing and hot-reload.
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/fsnotify/fsnotify"
"gopkg.in/yaml.v3"
)
// SystemConfigDir is the system-wide config directory for the daemon.
const SystemConfigDir = "/etc/lolcathost"
// SystemConfigPath is the system-wide config file path for the daemon.
const SystemConfigPath = "/etc/lolcathost/config.yaml"
// DefaultConfigDir returns the default config directory path for users.
func DefaultConfigDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "lolcathost")
}
// DefaultConfigPath returns the default config file path for users.
func DefaultConfigPath() string {
return filepath.Join(DefaultConfigDir(), "config.yaml")
}
// FlushMethod defines DNS cache flush methods.
type FlushMethod string
const (
FlushMethodAuto FlushMethod = "auto"
FlushMethodDscacheutil FlushMethod = "dscacheutil"
FlushMethodKillall FlushMethod = "killall"
FlushMethodBoth FlushMethod = "both"
)
// Settings holds global configuration settings.
type Settings struct {
AutoApply bool `yaml:"autoApply"`
FlushMethod FlushMethod `yaml:"flushMethod"`
}
// Host represents a single host entry in configuration.
type Host struct {
Domain string `yaml:"domain"`
IP string `yaml:"ip"`
Alias string `yaml:"alias"`
Enabled bool `yaml:"enabled"`
}
// Group represents a group of host entries.
type Group struct {
Name string `yaml:"name"`
Hosts []Host `yaml:"hosts"`
}
// Preset defines a named preset that enables/disables specific aliases.
type Preset struct {
Name string `yaml:"name"`
Enable []string `yaml:"enable,omitempty"`
Disable []string `yaml:"disable,omitempty"`
}
// Config represents the complete configuration.
type Config struct {
Settings Settings `yaml:"settings"`
Groups []Group `yaml:"groups"`
Presets []Preset `yaml:"presets"`
}
// Manager handles configuration loading and watching.
type Manager struct {
path string
config *Config
mu sync.RWMutex
watcher *fsnotify.Watcher
onChange func(*Config)
stopCh chan struct{}
}
// NewManager creates a new config manager.
func NewManager(path string) *Manager {
return &Manager{
path: path,
stopCh: make(chan struct{}),
}
}
// Load reads and parses the configuration file.
func (m *Manager) Load() error {
data, err := os.ReadFile(m.path)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
if err := ValidateConfig(&cfg); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
m.mu.Lock()
m.config = &cfg
m.mu.Unlock()
return nil
}
// Get returns the current configuration.
func (m *Manager) Get() *Config {
m.mu.RLock()
defer m.mu.RUnlock()
return m.config
}
// Watch starts watching the config file for changes.
func (m *Manager) Watch(onChange func(*Config)) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
m.watcher = watcher
m.onChange = onChange
go m.watchLoop()
if err := watcher.Add(m.path); err != nil {
return fmt.Errorf("failed to watch config file: %w", err)
}
return nil
}
func (m *Manager) watchLoop() {
for {
select {
case event, ok := <-m.watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
if err := m.Load(); err == nil && m.onChange != nil {
m.onChange(m.Get())
}
}
case <-m.watcher.Errors:
// Ignore watcher errors
case <-m.stopCh:
return
}
}
}
// Stop stops watching the config file.
func (m *Manager) Stop() {
close(m.stopCh)
if m.watcher != nil {
m.watcher.Close()
}
}
// GetAllHosts returns all hosts from all groups.
func (c *Config) GetAllHosts() []Host {
var hosts []Host
for _, g := range c.Groups {
hosts = append(hosts, g.Hosts...)
}
return hosts
}
// FindHostByAlias finds a host by its alias.
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
return &c.Groups[i].Hosts[j], &c.Groups[i]
}
}
}
return nil, nil
}
// FindPreset finds a preset by name.
func (c *Config) FindPreset(name string) *Preset {
for i := range c.Presets {
if c.Presets[i].Name == name {
return &c.Presets[i]
}
}
return nil
}
// SetHostEnabled sets the enabled state of a host by alias.
func (c *Config) SetHostEnabled(alias string, enabled bool) bool {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
c.Groups[i].Hosts[j].Enabled = enabled
return true
}
}
}
return false
}
// GenerateAlias creates a unique alias from a domain name.
func (c *Config) GenerateAlias(domain string) string {
// Convert domain to alias format: example.com -> example-com
alias := strings.ReplaceAll(domain, ".", "-")
alias = strings.ReplaceAll(alias, "_", "-")
alias = strings.ToLower(alias)
// Check if alias exists, if so append a number
baseAlias := alias
counter := 1
for {
if existing, _ := c.FindHostByAlias(alias); existing == nil {
break
}
counter++
alias = fmt.Sprintf("%s-%d", baseAlias, counter)
}
return alias
}
// AddHost adds a new host to the configuration.
func (c *Config) AddHost(domain, ip, alias, groupName string, enabled bool) error {
// Auto-generate alias if empty
if alias == "" {
alias = c.GenerateAlias(domain)
} else {
// Check for duplicate alias
if existing, _ := c.FindHostByAlias(alias); existing != nil {
return fmt.Errorf("alias already exists: %s", alias)
}
}
host := Host{
Domain: domain,
IP: ip,
Alias: alias,
Enabled: enabled,
}
// Find or create group
for i := range c.Groups {
if c.Groups[i].Name == groupName {
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
return nil
}
}
// Create new group
c.Groups = append(c.Groups, Group{
Name: groupName,
Hosts: []Host{host},
})
return nil
}
// AddGroup adds a new empty group.
func (c *Config) AddGroup(name string) error {
// Check if group already exists
for _, g := range c.Groups {
if g.Name == name {
return fmt.Errorf("group already exists: %s", name)
}
}
c.Groups = append(c.Groups, Group{
Name: name,
Hosts: []Host{},
})
return nil
}
// DeleteGroup removes a group and all its hosts.
func (c *Config) DeleteGroup(name string) error {
for i, g := range c.Groups {
if g.Name == name {
c.Groups = append(c.Groups[:i], c.Groups[i+1:]...)
return nil
}
}
return fmt.Errorf("group not found: %s", name)
}
// RenameGroup renames an existing group.
func (c *Config) RenameGroup(oldName, newName string) error {
// Check if new name already exists
for _, g := range c.Groups {
if g.Name == newName {
return fmt.Errorf("group already exists: %s", newName)
}
}
for i := range c.Groups {
if c.Groups[i].Name == oldName {
c.Groups[i].Name = newName
return nil
}
}
return fmt.Errorf("group not found: %s", oldName)
}
// GetGroups returns all group names.
func (c *Config) GetGroups() []string {
names := make([]string, len(c.Groups))
for i, g := range c.Groups {
names[i] = g.Name
}
return names
}
// DeleteHost removes a host by alias.
func (c *Config) DeleteHost(alias string) bool {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...)
return true
}
}
}
return false
}
// UpdateHost updates an existing host by alias.
func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error {
// Find the host
var foundGroup int = -1
var foundHost int = -1
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == oldAlias {
foundGroup = i
foundHost = j
break
}
}
if foundHost >= 0 {
break
}
}
if foundHost < 0 {
return fmt.Errorf("alias not found: %s", oldAlias)
}
// Check for duplicate alias if alias is changing
if oldAlias != newAlias {
if existing, _ := c.FindHostByAlias(newAlias); existing != nil {
return fmt.Errorf("alias already exists: %s", newAlias)
}
}
// Get current enabled state
enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled
// If group is changing, move to new group
if c.Groups[foundGroup].Name != groupName {
// Remove from old group
c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...)
// Add to new group
host := Host{
Domain: domain,
IP: ip,
Alias: newAlias,
Enabled: enabled,
}
// Find or create target group
found := false
for i := range c.Groups {
if c.Groups[i].Name == groupName {
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
found = true
break
}
}
if !found {
c.Groups = append(c.Groups, Group{
Name: groupName,
Hosts: []Host{host},
})
}
} else {
// Update in place
c.Groups[foundGroup].Hosts[foundHost].Domain = domain
c.Groups[foundGroup].Hosts[foundHost].IP = ip
c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias
}
return nil
}
// ApplyPreset applies a preset to the configuration.
func (c *Config) ApplyPreset(name string) error {
preset := c.FindPreset(name)
if preset == nil {
return fmt.Errorf("preset not found: %s", name)
}
for _, alias := range preset.Enable {
c.SetHostEnabled(alias, true)
}
for _, alias := range preset.Disable {
c.SetHostEnabled(alias, false)
}
return nil
}
// AddPreset adds a new preset.
func (c *Config) AddPreset(name string, enable, disable []string) error {
// Check if preset already exists
for _, p := range c.Presets {
if p.Name == name {
return fmt.Errorf("preset already exists: %s", name)
}
}
c.Presets = append(c.Presets, Preset{
Name: name,
Enable: enable,
Disable: disable,
})
return nil
}
// DeletePreset removes a preset by name.
func (c *Config) DeletePreset(name string) error {
for i, p := range c.Presets {
if p.Name == name {
c.Presets = append(c.Presets[:i], c.Presets[i+1:]...)
return nil
}
}
return fmt.Errorf("preset not found: %s", name)
}
// GetPresets returns all presets.
func (c *Config) GetPresets() []Preset {
return c.Presets
}
// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed.
func (c *Config) EnsureDefaultGroup() {
if len(c.Groups) == 0 {
c.Groups = append(c.Groups, Group{
Name: "default",
Hosts: []Host{},
})
}
}
// Save writes the configuration to the file.
func (m *Manager) Save() error {
m.mu.RLock()
cfg := m.config
m.mu.RUnlock()
if cfg == nil {
return fmt.Errorf("no config loaded")
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(m.path, data, 0644); err != nil {
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
// CreateDefault creates a default configuration file.
func CreateDefault(path string) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
cfg := &Config{
Settings: Settings{
AutoApply: true,
FlushMethod: FlushMethodAuto,
},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{
Domain: "example.local",
IP: "127.0.0.1",
Alias: "example-local",
Enabled: false,
},
},
},
},
Presets: []Preset{
{
Name: "local",
Enable: []string{"example-local"},
Disable: []string{},
},
{
Name: "clear",
Enable: []string{},
Disable: []string{"example-local"},
},
},
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal default config: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("failed to write default config: %w", err)
}
return nil
}
+267
View File
@@ -0,0 +1,267 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfig_GetAllHosts(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false},
},
},
{
Name: "staging",
Hosts: []Host{
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true},
},
},
},
}
hosts := cfg.GetAllHosts()
assert.Len(t, hosts, 3)
assert.Equal(t, "a.com", hosts[0].Domain)
assert.Equal(t, "b.com", hosts[1].Domain)
assert.Equal(t, "c.com", hosts[2].Domain)
}
func TestConfig_FindHostByAlias(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
},
},
},
}
t.Run("found", func(t *testing.T) {
host, group := cfg.FindHostByAlias("example")
require.NotNil(t, host)
require.NotNil(t, group)
assert.Equal(t, "example.com", host.Domain)
assert.Equal(t, "dev", group.Name)
})
t.Run("not found", func(t *testing.T) {
host, group := cfg.FindHostByAlias("nonexistent")
assert.Nil(t, host)
assert.Nil(t, group)
})
}
func TestConfig_FindPreset(t *testing.T) {
cfg := &Config{
Presets: []Preset{
{Name: "local", Enable: []string{"a"}, Disable: []string{"b"}},
{Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}},
},
}
t.Run("found", func(t *testing.T) {
preset := cfg.FindPreset("local")
require.NotNil(t, preset)
assert.Equal(t, "local", preset.Name)
assert.Equal(t, []string{"a"}, preset.Enable)
})
t.Run("not found", func(t *testing.T) {
preset := cfg.FindPreset("nonexistent")
assert.Nil(t, preset)
})
}
func TestConfig_SetHostEnabled(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: false},
},
},
},
}
t.Run("enable existing", func(t *testing.T) {
result := cfg.SetHostEnabled("example", true)
assert.True(t, result)
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
})
t.Run("disable existing", func(t *testing.T) {
result := cfg.SetHostEnabled("example", false)
assert.True(t, result)
assert.False(t, cfg.Groups[0].Hosts[0].Enabled)
})
t.Run("nonexistent alias", func(t *testing.T) {
result := cfg.SetHostEnabled("nonexistent", true)
assert.False(t, result)
})
}
func TestConfig_ApplyPreset(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "swap", Enable: []string{"a"}, Disable: []string{"b"}},
},
}
t.Run("valid preset", func(t *testing.T) {
err := cfg.ApplyPreset("swap")
require.NoError(t, err)
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
assert.False(t, cfg.Groups[0].Hosts[1].Enabled)
})
t.Run("nonexistent preset", func(t *testing.T) {
err := cfg.ApplyPreset("nonexistent")
assert.Error(t, err)
})
}
func TestManager_LoadAndGet(t *testing.T) {
// Create temp config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
configContent := `
settings:
autoApply: true
flushMethod: auto
groups:
- name: development
hosts:
- domain: example.com
ip: 127.0.0.1
alias: example-local
enabled: true
presets:
- name: local
enable: [example-local]
disable: []
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
require.NoError(t, err)
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
require.NotNil(t, cfg)
assert.True(t, cfg.Settings.AutoApply)
assert.Equal(t, FlushMethodAuto, cfg.Settings.FlushMethod)
assert.Len(t, cfg.Groups, 1)
assert.Equal(t, "development", cfg.Groups[0].Name)
assert.Len(t, cfg.Groups[0].Hosts, 1)
assert.Equal(t, "example.com", cfg.Groups[0].Hosts[0].Domain)
}
func TestManager_Save(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
// Create initial config
err := CreateDefault(configPath)
require.NoError(t, err)
// Load and modify
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
cfg.Groups[0].Hosts[0].Enabled = true
// Save
err = manager.Save()
require.NoError(t, err)
// Reload and verify
manager2 := NewManager(configPath)
err = manager2.Load()
require.NoError(t, err)
cfg2 := manager2.Get()
assert.True(t, cfg2.Groups[0].Hosts[0].Enabled)
}
func TestCreateDefault(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "subdir", "config.yaml")
err := CreateDefault(configPath)
require.NoError(t, err)
// Verify file exists
_, err = os.Stat(configPath)
require.NoError(t, err)
// Verify content is valid
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
require.NotNil(t, cfg)
assert.True(t, cfg.Settings.AutoApply)
assert.Len(t, cfg.Groups, 1)
assert.Len(t, cfg.Presets, 2)
}
func TestManager_Load_InvalidYAML(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0644)
require.NoError(t, err)
manager := NewManager(configPath)
err = manager.Load()
assert.Error(t, err)
}
func TestManager_Load_FileNotFound(t *testing.T) {
manager := NewManager("/nonexistent/path/config.yaml")
err := manager.Load()
assert.Error(t, err)
}
func TestFlushMethod(t *testing.T) {
methods := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
}
for _, m := range methods {
t.Run(string(m), func(t *testing.T) {
assert.NotEmpty(t, string(m))
})
}
}
+211
View File
@@ -0,0 +1,211 @@
// Package config provides validation functions for configuration.
package config
import (
"fmt"
"net"
"regexp"
"strings"
)
// domainRegex validates domain names.
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$|^localhost$`)
// aliasRegex validates alias names.
var aliasRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`)
// blockedDomains contains domains that cannot be modified.
var blockedDomains = map[string]bool{
"apple.com": true,
"icloud.com": true,
"icloud-content.com": true,
"apple-dns.cn": true,
"apple-dns.net": true,
"mzstatic.com": true,
"itunes.apple.com": true,
"updates.apple.com": true,
}
// ValidationError represents a configuration validation error.
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return fmt.Sprintf("%s: %s", e.Field, e.Message)
}
// ValidateConfig validates the entire configuration.
func ValidateConfig(cfg *Config) error {
if cfg == nil {
return &ValidationError{Field: "config", Message: "config is nil"}
}
if err := validateSettings(&cfg.Settings); err != nil {
return err
}
// Track aliases for uniqueness
aliases := make(map[string]bool)
for i, g := range cfg.Groups {
if err := validateGroup(&g, i, aliases); err != nil {
return err
}
}
for i, p := range cfg.Presets {
if err := validatePreset(&p, i, aliases); err != nil {
return err
}
}
return nil
}
func validateSettings(s *Settings) error {
switch s.FlushMethod {
case FlushMethodAuto, FlushMethodDscacheutil, FlushMethodKillall, FlushMethodBoth, "":
// Valid
default:
return &ValidationError{
Field: "settings.flushMethod",
Message: fmt.Sprintf("invalid flush method: %s", s.FlushMethod),
}
}
return nil
}
func validateGroup(g *Group, index int, aliases map[string]bool) error {
if strings.TrimSpace(g.Name) == "" {
return &ValidationError{
Field: fmt.Sprintf("groups[%d].name", index),
Message: "group name is required",
}
}
for i, h := range g.Hosts {
if err := validateHost(&h, index, i, aliases); err != nil {
return err
}
}
return nil
}
func validateHost(h *Host, groupIndex, hostIndex int, aliases map[string]bool) error {
fieldPrefix := fmt.Sprintf("groups[%d].hosts[%d]", groupIndex, hostIndex)
// Validate domain
if !ValidateDomain(h.Domain) {
return &ValidationError{
Field: fieldPrefix + ".domain",
Message: fmt.Sprintf("invalid domain: %s", h.Domain),
}
}
// Check blocked domains
if IsBlockedDomain(h.Domain) {
return &ValidationError{
Field: fieldPrefix + ".domain",
Message: fmt.Sprintf("domain is blocked: %s", h.Domain),
}
}
// Validate IP
if !ValidateIP(h.IP) {
return &ValidationError{
Field: fieldPrefix + ".ip",
Message: fmt.Sprintf("invalid IP address: %s", h.IP),
}
}
// Validate alias
if !ValidateAlias(h.Alias) {
return &ValidationError{
Field: fieldPrefix + ".alias",
Message: fmt.Sprintf("invalid alias: %s", h.Alias),
}
}
// Check alias uniqueness
if aliases[h.Alias] {
return &ValidationError{
Field: fieldPrefix + ".alias",
Message: fmt.Sprintf("duplicate alias: %s", h.Alias),
}
}
aliases[h.Alias] = true
return nil
}
func validatePreset(p *Preset, index int, aliases map[string]bool) error {
fieldPrefix := fmt.Sprintf("presets[%d]", index)
if strings.TrimSpace(p.Name) == "" {
return &ValidationError{
Field: fieldPrefix + ".name",
Message: "preset name is required",
}
}
// Note: We don't validate preset aliases strictly anymore.
// Unknown aliases in presets will simply be skipped when applying the preset.
// This allows presets to survive when hosts are removed from the config.
return nil
}
// ValidateDomain checks if a domain name is valid.
func ValidateDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}
// ValidateIP checks if an IP address is valid (IPv4 or IPv6).
func ValidateIP(ip string) bool {
if ip == "" {
return false
}
return net.ParseIP(ip) != nil
}
// ValidateAlias checks if an alias is valid.
func ValidateAlias(alias string) bool {
if alias == "" {
return false
}
return aliasRegex.MatchString(alias)
}
// IsBlockedDomain checks if a domain is in the blocklist.
func IsBlockedDomain(domain string) bool {
domain = strings.ToLower(domain)
// Check exact match
if blockedDomains[domain] {
return true
}
// Check if it's a subdomain of a blocked domain
for blocked := range blockedDomains {
if strings.HasSuffix(domain, "."+blocked) {
return true
}
}
return false
}
// GetBlockedDomains returns a copy of the blocked domains list.
func GetBlockedDomains() []string {
domains := make([]string, 0, len(blockedDomains))
for d := range blockedDomains {
domains = append(domains, d)
}
return domains
}
+436
View File
@@ -0,0 +1,436 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateDomain(t *testing.T) {
tests := []struct {
domain string
valid bool
}{
{"example.com", true},
{"sub.example.com", true},
{"my-app.example.com", true},
{"localhost", true},
{"a.b.c.d.example.com", true},
{"example123.com", true},
{"", false},
{"-example.com", false},
{"example-.com", false},
{"example.c", false}, // TLD too short
{"example", false}, // No TLD
{".example.com", false},
{"example..com", false},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
result := ValidateDomain(tt.domain)
assert.Equal(t, tt.valid, result, "domain: %s", tt.domain)
})
}
}
func TestValidateIP(t *testing.T) {
tests := []struct {
ip string
valid bool
}{
// Valid IPv4
{"127.0.0.1", true},
{"192.168.1.1", true},
{"0.0.0.0", true},
{"255.255.255.255", true},
// Valid IPv6
{"::1", true},
{"2001:db8::1", true},
{"fe80::1", true},
{"::ffff:192.168.1.1", true},
// Invalid
{"", false},
{"256.0.0.1", false},
{"192.168.1", false},
{"not-an-ip", false},
{"192.168.1.1.1", false},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := ValidateIP(tt.ip)
assert.Equal(t, tt.valid, result, "ip: %s", tt.ip)
})
}
}
func TestValidateAlias(t *testing.T) {
tests := []struct {
alias string
valid bool
}{
{"my-alias", true},
{"myalias", true},
{"my_alias", true},
{"alias123", true},
{"a", true},
{"a-b_c-d", true},
{"", false},
{"-startswithdash", false},
{"_startswithunderscore", false},
{"has spaces", false},
{"has.dot", false},
}
for _, tt := range tests {
t.Run(tt.alias, func(t *testing.T) {
result := ValidateAlias(tt.alias)
assert.Equal(t, tt.valid, result, "alias: %s", tt.alias)
})
}
}
func TestIsBlockedDomain(t *testing.T) {
tests := []struct {
domain string
blocked bool
}{
// Blocked domains
{"apple.com", true},
{"icloud.com", true},
{"sub.apple.com", true},
{"deep.sub.icloud.com", true},
{"APPLE.COM", true}, // Case insensitive
// Allowed domains
{"example.com", false},
{"myapp.com", false},
{"applestore.com", false}, // Not a subdomain
{"notapple.com", false},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
result := IsBlockedDomain(tt.domain)
assert.Equal(t, tt.blocked, result, "domain: %s", tt.domain)
})
}
}
func TestGetBlockedDomains(t *testing.T) {
domains := GetBlockedDomains()
assert.NotEmpty(t, domains)
assert.Contains(t, domains, "apple.com")
assert.Contains(t, domains, "icloud.com")
}
func TestValidateConfig(t *testing.T) {
t.Run("valid config", func(t *testing.T) {
cfg := &Config{
Settings: Settings{
AutoApply: true,
FlushMethod: FlushMethodAuto,
},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"example"}, Disable: []string{}},
},
}
err := ValidateConfig(cfg)
assert.NoError(t, err)
})
t.Run("nil config", func(t *testing.T) {
err := ValidateConfig(nil)
assert.Error(t, err)
})
t.Run("invalid flush method", func(t *testing.T) {
cfg := &Config{
Settings: Settings{FlushMethod: "invalid"},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("empty group name", func(t *testing.T) {
cfg := &Config{
Groups: []Group{{Name: "", Hosts: []Host{}}},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid domain", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "invalid", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("blocked domain", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "apple.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid IP", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "invalid", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid alias", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "-invalid", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("duplicate alias", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
{Domain: "b.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("empty preset name", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "", Enable: []string{}},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("preset with unknown alias is allowed", func(t *testing.T) {
// Unknown aliases in presets are now allowed (they're simply skipped when applied)
// This allows presets to survive when hosts are removed from the config
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"unknown"}},
},
}
err := ValidateConfig(cfg)
assert.NoError(t, err)
})
}
func TestValidationError(t *testing.T) {
err := &ValidationError{Field: "test.field", Message: "test message"}
assert.Equal(t, "test.field: test message", err.Error())
}
func TestValidateSettings(t *testing.T) {
tests := []struct {
name string
method FlushMethod
wantErr bool
}{
{"auto", FlushMethodAuto, false},
{"dscacheutil", FlushMethodDscacheutil, false},
{"killall", FlushMethodKillall, false},
{"both", FlushMethodBoth, false},
{"empty", "", false},
{"invalid", "invalid", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
settings := &Settings{FlushMethod: tt.method}
err := validateSettings(settings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// Matrix testing for domain validation
func TestValidateDomain_Matrix(t *testing.T) {
prefixes := []string{"", "sub.", "a.b."}
domains := []string{"example", "my-app", "test123"}
tlds := []string{".com", ".io", ".co.uk", ".dev"}
for _, prefix := range prefixes {
for _, domain := range domains {
for _, tld := range tlds {
fullDomain := prefix + domain + tld
t.Run(fullDomain, func(t *testing.T) {
result := ValidateDomain(fullDomain)
assert.True(t, result, "expected %s to be valid", fullDomain)
})
}
}
}
}
// Matrix testing for IP validation
func TestValidateIP_Matrix(t *testing.T) {
octets := []string{"0", "127", "192", "255"}
for _, o1 := range octets {
for _, o2 := range octets {
for _, o3 := range octets {
for _, o4 := range octets {
ip := o1 + "." + o2 + "." + o3 + "." + o4
t.Run(ip, func(t *testing.T) {
result := ValidateIP(ip)
assert.True(t, result, "expected %s to be valid", ip)
})
}
}
}
}
}
// Benchmark tests
func BenchmarkValidateDomain(b *testing.B) {
domains := []string{
"example.com",
"sub.example.com",
"very.long.subdomain.chain.example.com",
}
for _, domain := range domains {
b.Run(domain, func(b *testing.B) {
for i := 0; i < b.N; i++ {
ValidateDomain(domain)
}
})
}
}
func BenchmarkValidateIP(b *testing.B) {
ips := []string{
"127.0.0.1",
"192.168.1.1",
"::1",
"2001:db8::1",
}
for _, ip := range ips {
b.Run(ip, func(b *testing.B) {
for i := 0; i < b.N; i++ {
ValidateIP(ip)
}
})
}
}
func BenchmarkIsBlockedDomain(b *testing.B) {
domains := []string{
"example.com", // not blocked
"apple.com", // blocked
"sub.icloud.com", // blocked subdomain
}
for _, domain := range domains {
b.Run(domain, func(b *testing.B) {
for i := 0; i < b.N; i++ {
IsBlockedDomain(domain)
}
})
}
}
func BenchmarkValidateConfig(b *testing.B) {
cfg := &Config{
Settings: Settings{AutoApply: true, FlushMethod: FlushMethodAuto},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{Domain: "a.example.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
{Domain: "b.example.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
{Domain: "c.example.com", IP: "127.0.0.1", Alias: "c", Enabled: false},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"a", "b"}, Disable: []string{"c"}},
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := ValidateConfig(cfg)
require.NoError(b, err)
}
}
+133
View File
@@ -0,0 +1,133 @@
// Package daemon provides the main daemon loop and lifecycle management.
package daemon
import (
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/lukaszraczylo/lolcathost/internal/config"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// Daemon represents the lolcathost daemon.
type Daemon struct {
server *Server
config *config.Manager
stopCh chan struct{}
cleanupCh chan struct{}
}
// New creates a new daemon instance.
func New(configPath string) (*Daemon, error) {
cfgManager := config.NewManager(configPath)
// Try to load config, create default if it doesn't exist
if err := cfgManager.Load(); err != nil {
if os.IsNotExist(err) {
if err := config.CreateDefault(configPath); err != nil {
return nil, fmt.Errorf("failed to create default config: %w", err)
}
if err := cfgManager.Load(); err != nil {
return nil, fmt.Errorf("failed to load default config: %w", err)
}
} else {
return nil, fmt.Errorf("failed to load config: %w", err)
}
}
// Ensure at least one group exists
cfg := cfgManager.Get()
if cfg != nil {
cfg.EnsureDefaultGroup()
// Save if we added a default group
if len(cfg.Groups) == 1 && cfg.Groups[0].Name == "default" && len(cfg.Groups[0].Hosts) == 0 {
cfgManager.Save()
}
}
server := NewServer(protocol.SocketPath, cfgManager)
return &Daemon{
server: server,
config: cfgManager,
stopCh: make(chan struct{}),
cleanupCh: make(chan struct{}),
}, nil
}
// Run starts the daemon and blocks until stopped.
func (d *Daemon) Run() error {
// Verify we're running as root
if os.Geteuid() != 0 {
return fmt.Errorf("daemon must run as root")
}
// Start the server
if err := d.server.Start(); err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
// Watch config for changes
if err := d.config.Watch(d.onConfigChange); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to watch config: %v\n", err)
}
// Start cleanup goroutine
go d.cleanupLoop()
// Wait for shutdown signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
fmt.Println("Received shutdown signal")
case <-d.stopCh:
fmt.Println("Shutdown requested")
}
return d.shutdown()
}
// Stop signals the daemon to stop.
func (d *Daemon) Stop() {
close(d.stopCh)
}
func (d *Daemon) shutdown() error {
close(d.cleanupCh)
d.config.Stop()
if err := d.server.Stop(); err != nil {
return fmt.Errorf("failed to stop server: %w", err)
}
return nil
}
func (d *Daemon) onConfigChange(cfg *config.Config) {
fmt.Println("Config changed, syncing hosts file...")
// The server will use the updated config on next request
// We could trigger a sync here if autoApply is enabled
if cfg != nil && cfg.Settings.AutoApply {
// Sync hosts file with new config
// This is handled by the server internally
}
}
func (d *Daemon) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
d.server.rateLimiter.Cleanup()
case <-d.cleanupCh:
return
}
}
}
+142
View File
@@ -0,0 +1,142 @@
// Package daemon provides DNS cache flushing functionality.
package daemon
import (
"fmt"
"os/exec"
"runtime"
)
// DNSFlusher handles DNS cache flushing.
type DNSFlusher struct {
method FlushMethod
}
// FlushMethod defines the DNS flush method to use.
type FlushMethod string
const (
FlushMethodAuto FlushMethod = "auto"
FlushMethodDscacheutil FlushMethod = "dscacheutil"
FlushMethodKillall FlushMethod = "killall"
FlushMethodBoth FlushMethod = "both"
FlushMethodSystemd FlushMethod = "systemd"
FlushMethodNscd FlushMethod = "nscd"
)
// NewDNSFlusher creates a new DNS flusher.
func NewDNSFlusher(method FlushMethod) *DNSFlusher {
return &DNSFlusher{method: method}
}
// Flush flushes the DNS cache using the configured method.
func (f *DNSFlusher) Flush() error {
method := f.method
if method == FlushMethodAuto || method == "" {
method = f.detectMethod()
}
switch runtime.GOOS {
case "darwin":
return f.flushDarwin(method)
case "linux":
return f.flushLinux(method)
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
}
func (f *DNSFlusher) detectMethod() FlushMethod {
switch runtime.GOOS {
case "darwin":
return FlushMethodBoth
case "linux":
// Check for systemd-resolve first
if _, err := exec.LookPath("systemd-resolve"); err == nil {
return FlushMethodSystemd
}
if _, err := exec.LookPath("resolvectl"); err == nil {
return FlushMethodSystemd
}
// Fall back to nscd
if _, err := exec.LookPath("nscd"); err == nil {
return FlushMethodNscd
}
return FlushMethodAuto
default:
return FlushMethodAuto
}
}
func (f *DNSFlusher) flushDarwin(method FlushMethod) error {
var errs []error
switch method {
case FlushMethodDscacheutil:
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
return fmt.Errorf("dscacheutil failed: %w", err)
}
case FlushMethodKillall:
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
return fmt.Errorf("killall mDNSResponder failed: %w", err)
}
case FlushMethodBoth:
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
errs = append(errs, fmt.Errorf("dscacheutil failed: %w", err))
}
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
errs = append(errs, fmt.Errorf("killall mDNSResponder failed: %w", err))
}
if len(errs) == 2 {
return fmt.Errorf("all DNS flush methods failed: %v, %v", errs[0], errs[1])
}
default:
// Auto - try both
_ = runCommand("dscacheutil", "-flushcache")
_ = runCommand("killall", "-HUP", "mDNSResponder")
}
return nil
}
func (f *DNSFlusher) flushLinux(method FlushMethod) error {
switch method {
case FlushMethodSystemd:
// Try resolvectl first (newer), then systemd-resolve (older)
if err := runCommand("resolvectl", "flush-caches"); err != nil {
if err := runCommand("systemd-resolve", "--flush-caches"); err != nil {
return fmt.Errorf("systemd DNS flush failed: %w", err)
}
}
case FlushMethodNscd:
// Try to restart nscd
if err := runCommand("nscd", "-i", "hosts"); err != nil {
// Try service restart as fallback
if err := runCommand("service", "nscd", "restart"); err != nil {
return fmt.Errorf("nscd flush failed: %w", err)
}
}
default:
// Auto - try all methods
// Try systemd first
if err := runCommand("resolvectl", "flush-caches"); err == nil {
return nil
}
if err := runCommand("systemd-resolve", "--flush-caches"); err == nil {
return nil
}
// Try nscd
if err := runCommand("nscd", "-i", "hosts"); err == nil {
return nil
}
// On many Linux systems, no explicit flush is needed as /etc/hosts is read directly
// So we return nil here
}
return nil
}
func runCommand(name string, args ...string) error {
cmd := exec.Command(name, args...)
return cmd.Run()
}
+108
View File
@@ -0,0 +1,108 @@
package daemon
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewDNSFlusher(t *testing.T) {
tests := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
FlushMethodSystemd,
FlushMethodNscd,
}
for _, method := range tests {
t.Run(string(method), func(t *testing.T) {
flusher := NewDNSFlusher(method)
assert.NotNil(t, flusher)
assert.Equal(t, method, flusher.method)
})
}
}
func TestDNSFlusher_DetectMethod(t *testing.T) {
flusher := NewDNSFlusher(FlushMethodAuto)
method := flusher.detectMethod()
switch runtime.GOOS {
case "darwin":
assert.Equal(t, FlushMethodBoth, method)
case "linux":
// Could be systemd, nscd, or auto depending on system
assert.Contains(t, []FlushMethod{FlushMethodSystemd, FlushMethodNscd, FlushMethodAuto}, method)
}
}
func TestFlushMethod_String(t *testing.T) {
methods := map[FlushMethod]string{
FlushMethodAuto: "auto",
FlushMethodDscacheutil: "dscacheutil",
FlushMethodKillall: "killall",
FlushMethodBoth: "both",
FlushMethodSystemd: "systemd",
FlushMethodNscd: "nscd",
}
for method, expected := range methods {
t.Run(expected, func(t *testing.T) {
assert.Equal(t, expected, string(method))
})
}
}
// Note: Actually testing DNS flush requires root and modifies system state,
// so we skip those tests in unit tests. They would be integration tests.
func TestDNSFlusher_Flush_UnsupportedOS(t *testing.T) {
// This test only makes sense if we're not on darwin or linux
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
t.Skip("Test only applicable on unsupported OS")
}
flusher := NewDNSFlusher(FlushMethodAuto)
err := flusher.Flush()
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported operating system")
}
// Matrix test for flush methods
func TestFlushMethod_Matrix(t *testing.T) {
methods := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
FlushMethodSystemd,
FlushMethodNscd,
}
platforms := []string{"darwin", "linux"}
for _, method := range methods {
for _, platform := range platforms {
t.Run(string(method)+"_"+platform, func(t *testing.T) {
flusher := NewDNSFlusher(method)
assert.NotNil(t, flusher)
// Just verify no panic when checking method
_ = flusher.method
})
}
}
}
func BenchmarkDNSFlusher_DetectMethod(b *testing.B) {
flusher := NewDNSFlusher(FlushMethodAuto)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = flusher.detectMethod()
}
}
+319
View File
@@ -0,0 +1,319 @@
// Package daemon implements the privileged daemon that manages /etc/hosts.
package daemon
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
)
const (
// HostsPath is the path to the system hosts file.
HostsPath = "/etc/hosts"
// BackupDir is the directory for hosts file backups.
BackupDir = "/var/backups/lolcathost"
// MaxBackups is the maximum number of backups to keep.
MaxBackups = 10
// Markers for the managed section.
markerStart = "# ========== LOLCATHOST MANAGED - DO NOT EDIT =========="
markerEnd = "# ========== END LOLCATHOST =========="
)
// HostEntry represents a single entry in the hosts file.
type HostEntry struct {
IP string
Domain string
Alias string
Enabled bool
}
// HostsManager handles reading and writing the hosts file.
type HostsManager struct {
hostsPath string
backupDir string
}
// NewHostsManager creates a new hosts manager.
func NewHostsManager() *HostsManager {
return &HostsManager{
hostsPath: HostsPath,
backupDir: BackupDir,
}
}
// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
return &HostsManager{
hostsPath: hostsPath,
backupDir: backupDir,
}
}
// ReadManagedEntries reads the lolcathost-managed entries from the hosts file.
func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) {
file, err := os.Open(m.hostsPath)
if err != nil {
return nil, fmt.Errorf("failed to open hosts file: %w", err)
}
defer file.Close()
var entries []HostEntry
inManagedSection := false
scanner := bufio.NewScanner(file)
entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == markerStart {
inManagedSection = true
continue
}
if line == markerEnd {
inManagedSection = false
continue
}
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
matches := entryRegex.FindStringSubmatch(line)
if len(matches) == 4 {
entries = append(entries, HostEntry{
IP: matches[1],
Domain: matches[2],
Alias: matches[3],
Enabled: true,
})
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read hosts file: %w", err)
}
return entries, nil
}
// WriteManagedEntries writes the managed entries to the hosts file.
func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error {
// Create backup first
if err := m.CreateBackup(); err != nil {
return fmt.Errorf("failed to create backup: %w", err)
}
// Read existing content
content, err := os.ReadFile(m.hostsPath)
if err != nil {
return fmt.Errorf("failed to read hosts file: %w", err)
}
// Remove existing managed section
newContent := m.removeManagedSection(string(content))
// Build new managed section
managedSection := m.buildManagedSection(entries)
// Append managed section
newContent = strings.TrimRight(newContent, "\n") + "\n\n" + managedSection
// Write atomically
if err := m.writeAtomic(newContent); err != nil {
return fmt.Errorf("failed to write hosts file: %w", err)
}
return nil
}
func (m *HostsManager) removeManagedSection(content string) string {
lines := strings.Split(content, "\n")
var result []string
inManagedSection := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == markerStart {
inManagedSection = true
continue
}
if trimmed == markerEnd {
inManagedSection = false
continue
}
if !inManagedSection {
result = append(result, line)
}
}
// Remove trailing empty lines
for len(result) > 0 && strings.TrimSpace(result[len(result)-1]) == "" {
result = result[:len(result)-1]
}
return strings.Join(result, "\n")
}
func (m *HostsManager) buildManagedSection(entries []HostEntry) string {
var sb strings.Builder
sb.WriteString(markerStart)
sb.WriteString("\n")
for _, entry := range entries {
if entry.Enabled {
sb.WriteString(fmt.Sprintf("%s\t%s\t# lolcathost:%s\n", entry.IP, entry.Domain, entry.Alias))
}
}
sb.WriteString(markerEnd)
sb.WriteString("\n")
return sb.String()
}
func (m *HostsManager) writeAtomic(content string) error {
// Write to temp file first
tmpFile := m.hostsPath + ".tmp"
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
return err
}
// Rename atomically
if err := os.Rename(tmpFile, m.hostsPath); err != nil {
os.Remove(tmpFile)
return err
}
return nil
}
// CreateBackup creates a backup of the current hosts file.
func (m *HostsManager) CreateBackup() error {
if err := os.MkdirAll(m.backupDir, 0755); err != nil {
return fmt.Errorf("failed to create backup directory: %w", err)
}
content, err := os.ReadFile(m.hostsPath)
if err != nil {
return fmt.Errorf("failed to read hosts file: %w", err)
}
timestamp := time.Now().Format("20060102-150405")
backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp))
if err := os.WriteFile(backupPath, content, 0644); err != nil {
return fmt.Errorf("failed to write backup: %w", err)
}
// Cleanup old backups
if err := m.cleanupBackups(); err != nil {
// Log but don't fail
fmt.Fprintf(os.Stderr, "warning: failed to cleanup backups: %v\n", err)
}
return nil
}
func (m *HostsManager) cleanupBackups() error {
entries, err := os.ReadDir(m.backupDir)
if err != nil {
return err
}
var backups []os.DirEntry
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "hosts.") && strings.HasSuffix(entry.Name(), ".bak") {
backups = append(backups, entry)
}
}
if len(backups) <= MaxBackups {
return nil
}
// Sort by name (timestamp) descending
sort.Slice(backups, func(i, j int) bool {
return backups[i].Name() > backups[j].Name()
})
// Remove oldest backups
for i := MaxBackups; i < len(backups); i++ {
path := filepath.Join(m.backupDir, backups[i].Name())
os.Remove(path)
}
return nil
}
// ListBackups returns a list of available backups.
func (m *HostsManager) ListBackups() ([]BackupInfo, error) {
entries, err := os.ReadDir(m.backupDir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var backups []BackupInfo
for _, entry := range entries {
if entry.IsDir() || !strings.HasPrefix(entry.Name(), "hosts.") || !strings.HasSuffix(entry.Name(), ".bak") {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
backups = append(backups, BackupInfo{
Name: entry.Name(),
Timestamp: info.ModTime().Unix(),
Size: info.Size(),
})
}
// Sort by timestamp descending
sort.Slice(backups, func(i, j int) bool {
return backups[i].Timestamp > backups[j].Timestamp
})
return backups, nil
}
// BackupInfo holds information about a backup file.
type BackupInfo struct {
Name string
Timestamp int64
Size int64
}
// RestoreBackup restores a backup by name.
func (m *HostsManager) RestoreBackup(name string) error {
backupPath := filepath.Join(m.backupDir, name)
// Validate backup name to prevent path traversal
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
return fmt.Errorf("invalid backup name")
}
content, err := os.ReadFile(backupPath)
if err != nil {
return fmt.Errorf("failed to read backup: %w", err)
}
// Create a backup of current state before restoring
if err := m.CreateBackup(); err != nil {
return fmt.Errorf("failed to create backup before restore: %w", err)
}
if err := m.writeAtomic(string(content)); err != nil {
return fmt.Errorf("failed to restore backup: %w", err)
}
return nil
}
+422
View File
@@ -0,0 +1,422 @@
package daemon
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostsManager_ReadManagedEntries(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
hostsContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
::1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 example.com # lolcathost:example-local
192.168.1.1 api.example.com # lolcathost:api-local
# ========== END LOLCATHOST ==========
`
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
assert.Len(t, entries, 2)
assert.Equal(t, "127.0.0.1", entries[0].IP)
assert.Equal(t, "example.com", entries[0].Domain)
assert.Equal(t, "example-local", entries[0].Alias)
assert.Equal(t, "192.168.1.1", entries[1].IP)
assert.Equal(t, "api.example.com", entries[1].Domain)
assert.Equal(t, "api-local", entries[1].Alias)
}
func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
hostsContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
`
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
assert.Empty(t, entries)
}
func TestHostsManager_WriteManagedEntries(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create initial hosts file
initialContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
`
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true},
{IP: "127.0.0.1", Domain: "api.myapp.com", Alias: "api-local", Enabled: true},
{IP: "192.168.1.1", Domain: "staging.myapp.com", Alias: "staging", Enabled: false},
}
err = manager.WriteManagedEntries(entries)
require.NoError(t, err)
// Read back
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
assert.Contains(t, contentStr, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
assert.Contains(t, contentStr, "127.0.0.1\tmyapp.com\t# lolcathost:myapp-local")
assert.Contains(t, contentStr, "127.0.0.1\tapi.myapp.com\t# lolcathost:api-local")
assert.NotContains(t, contentStr, "staging.myapp.com") // disabled
assert.Contains(t, contentStr, "# ========== END LOLCATHOST ==========")
}
func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create hosts file with existing managed section
initialContent := `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 old.com # lolcathost:old
# ========== END LOLCATHOST ==========
`
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true},
}
err = manager.WriteManagedEntries(entries)
require.NoError(t, err)
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
assert.Contains(t, contentStr, "new.com")
assert.NotContains(t, contentStr, "old.com")
}
func TestHostsManager_CreateBackup(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
hostsContent := "127.0.0.1\tlocalhost\n"
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
err = manager.CreateBackup()
require.NoError(t, err)
// Verify backup exists
entries, err := os.ReadDir(backupDir)
require.NoError(t, err)
assert.Len(t, entries, 1)
assert.True(t, strings.HasPrefix(entries[0].Name(), "hosts."))
assert.True(t, strings.HasSuffix(entries[0].Name(), ".bak"))
// Verify backup content
backupContent, err := os.ReadFile(filepath.Join(backupDir, entries[0].Name()))
require.NoError(t, err)
assert.Equal(t, hostsContent, string(backupContent))
}
func TestHostsManager_ListBackups(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create hosts file
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
require.NoError(t, err)
// Manually create backup files with different timestamps
err = os.MkdirAll(backupDir, 0755)
require.NoError(t, err)
backupNames := []string{
"hosts.20231201-120000.bak",
"hosts.20231201-120001.bak",
"hosts.20231201-120002.bak",
}
for _, name := range backupNames {
err = os.WriteFile(filepath.Join(backupDir, name), []byte("backup"), 0644)
require.NoError(t, err)
}
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.Len(t, backups, 3)
}
func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "nonexistent")
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.Empty(t, backups)
}
func TestHostsManager_RestoreBackup(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create initial hosts file
initialContent := "initial content"
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
// Create backup
err = manager.CreateBackup()
require.NoError(t, err)
// Modify hosts file
err = os.WriteFile(hostsPath, []byte("modified content"), 0644)
require.NoError(t, err)
// Get backup name
backups, err := manager.ListBackups()
require.NoError(t, err)
require.Len(t, backups, 1)
// Restore
err = manager.RestoreBackup(backups[0].Name)
require.NoError(t, err)
// Verify content restored
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
assert.Equal(t, initialContent, string(content))
}
func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) {
tmpDir := t.TempDir()
manager := NewHostsManagerWithPaths(
filepath.Join(tmpDir, "hosts"),
filepath.Join(tmpDir, "backups"),
)
tests := []string{
"../../../etc/passwd",
"hosts.bak", // Missing timestamp
"notahosts.backup", // Wrong format
"",
}
for _, name := range tests {
t.Run(name, func(t *testing.T) {
err := manager.RestoreBackup(name)
assert.Error(t, err)
})
}
}
func TestHostsManager_CleanupBackups(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
// Create more than MaxBackups
for i := 0; i < MaxBackups+5; i++ {
err = manager.CreateBackup()
require.NoError(t, err)
}
// Verify only MaxBackups remain
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.LessOrEqual(t, len(backups), MaxBackups)
}
func TestHostsManager_RemoveManagedSection(t *testing.T) {
manager := &HostsManager{}
tests := []struct {
name string
input string
expected string
}{
{
name: "with managed section",
input: `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 example.com # lolcathost:test
# ========== END LOLCATHOST ==========
`,
expected: "127.0.0.1\tlocalhost",
},
{
name: "without managed section",
input: "127.0.0.1\tlocalhost\n",
expected: "127.0.0.1\tlocalhost",
},
{
name: "multiple managed sections",
input: `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
entry1
# ========== END LOLCATHOST ==========
more content
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
entry2
# ========== END LOLCATHOST ==========
`,
expected: "127.0.0.1\tlocalhost\nmore content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.removeManagedSection(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHostsManager_BuildManagedSection(t *testing.T) {
manager := &HostsManager{}
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "a.com", Alias: "a", Enabled: true},
{IP: "192.168.1.1", Domain: "b.com", Alias: "b", Enabled: true},
{IP: "10.0.0.1", Domain: "c.com", Alias: "c", Enabled: false},
}
result := manager.buildManagedSection(entries)
assert.Contains(t, result, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
assert.Contains(t, result, "127.0.0.1\ta.com\t# lolcathost:a")
assert.Contains(t, result, "192.168.1.1\tb.com\t# lolcathost:b")
assert.NotContains(t, result, "c.com") // disabled
assert.Contains(t, result, "# ========== END LOLCATHOST ==========")
}
// Matrix tests for hosts file parsing
func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
ips := []string{"127.0.0.1", "192.168.1.1", "::1"}
domains := []string{"example.com", "sub.example.com", "my-app.test"}
aliases := []string{"test", "my-alias", "app-1"}
for _, ip := range ips {
for _, domain := range domains {
for _, alias := range aliases {
t.Run(ip+"/"+domain+"/"+alias, func(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
content := "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n"
content += ip + "\t" + domain + "\t# lolcathost:" + alias + "\n"
content += "# ========== END LOLCATHOST ==========\n"
err := os.WriteFile(hostsPath, []byte(content), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, ip, entries[0].IP)
assert.Equal(t, domain, entries[0].Domain)
assert.Equal(t, alias, entries[0].Alias)
})
}
}
}
}
func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
tmpDir := b.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
// Create a hosts file with many entries
var content strings.Builder
content.WriteString("127.0.0.1\tlocalhost\n")
content.WriteString("# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n")
for i := 0; i < 100; i++ {
content.WriteString("127.0.0.1\texample" + string(rune('a'+i%26)) + ".com\t# lolcathost:alias" + string(rune('a'+i%26)) + "\n")
}
content.WriteString("# ========== END LOLCATHOST ==========\n")
err := os.WriteFile(hostsPath, []byte(content.String()), 0644)
require.NoError(b, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = manager.ReadManagedEntries()
}
}
func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) {
tmpDir := b.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644)
require.NoError(b, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := make([]HostEntry, 50)
for i := range entries {
entries[i] = HostEntry{
IP: "127.0.0.1",
Domain: "example" + string(rune('a'+i%26)) + ".com",
Alias: "alias" + string(rune('a'+i%26)),
Enabled: i%2 == 0,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = manager.WriteManagedEntries(entries)
}
}
+57
View File
@@ -0,0 +1,57 @@
//go:build darwin
package daemon
import (
"net"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// getPeerCredentials extracts peer credentials from a Unix socket connection on macOS.
// Note: macOS Xucred doesn't include PID, so we use LOCAL_PEERPID separately.
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return nil
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return nil
}
var creds *PeerCredentials
rawConn.Control(func(fd uintptr) {
xucred, err := unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
if err != nil {
return
}
// Get PID separately using LOCAL_PEERPID
var pid int32
pidLen := uint32(unsafe.Sizeof(pid))
_, _, errno := syscall.Syscall6(
syscall.SYS_GETSOCKOPT,
fd,
unix.SOL_LOCAL,
0x002, // LOCAL_PEERPID
uintptr(unsafe.Pointer(&pid)),
uintptr(unsafe.Pointer(&pidLen)),
0,
)
if errno != 0 {
pid = 0
}
creds = &PeerCredentials{
UID: xucred.Uid,
GID: xucred.Groups[0],
PID: pid,
}
})
return creds
}
+37
View File
@@ -0,0 +1,37 @@
//go:build linux
package daemon
import (
"net"
"golang.org/x/sys/unix"
)
// getPeerCredentials extracts peer credentials from a Unix socket connection on Linux.
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return nil
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return nil
}
var creds *PeerCredentials
rawConn.Control(func(fd uintptr) {
ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err != nil {
return
}
creds = &PeerCredentials{
UID: ucred.Uid,
GID: ucred.Gid,
PID: ucred.Pid,
}
})
return creds
}
+196
View File
@@ -0,0 +1,196 @@
// Package daemon provides security functions including rate limiting and audit logging.
package daemon
import (
"encoding/json"
"fmt"
"os"
"os/user"
"sync"
"time"
)
const (
// AuditLogPath is the path to the audit log file.
AuditLogPath = "/var/log/lolcathost/audit.log"
// RateLimit is the maximum requests per minute per PID.
RateLimit = 100
// RateLimitWindow is the time window for rate limiting.
RateLimitWindow = time.Minute
)
// RateLimiter implements per-PID rate limiting.
type RateLimiter struct {
mu sync.Mutex
requests map[int32][]time.Time
limit int
window time.Duration
}
// NewRateLimiter creates a new rate limiter.
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
return &RateLimiter{
requests: make(map[int32][]time.Time),
limit: limit,
window: window,
}
}
// Allow checks if a request from the given PID should be allowed.
func (r *RateLimiter) Allow(pid int32) bool {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
cutoff := now.Add(-r.window)
// Get existing requests for this PID
reqs := r.requests[pid]
// Filter out old requests
var validReqs []time.Time
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
// Check if under limit
if len(validReqs) >= r.limit {
r.requests[pid] = validReqs
return false
}
// Add new request
validReqs = append(validReqs, now)
r.requests[pid] = validReqs
return true
}
// Cleanup removes old entries from the rate limiter.
func (r *RateLimiter) Cleanup() {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
cutoff := now.Add(-r.window)
for pid, reqs := range r.requests {
var validReqs []time.Time
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
if len(validReqs) == 0 {
delete(r.requests, pid)
} else {
r.requests[pid] = validReqs
}
}
}
// AuditLogger handles audit logging.
type AuditLogger struct {
mu sync.Mutex
file *os.File
path string
encoder *json.Encoder
}
// AuditEntry represents a single audit log entry.
type AuditEntry struct {
Timestamp string `json:"timestamp"`
UID uint32 `json:"uid"`
PID int32 `json:"pid"`
Action string `json:"action"`
Details any `json:"details,omitempty"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// NewAuditLogger creates a new audit logger.
func NewAuditLogger(path string) (*AuditLogger, error) {
// Ensure directory exists
dir := path[:len(path)-len("/audit.log")]
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open audit log: %w", err)
}
return &AuditLogger{
file: file,
path: path,
encoder: json.NewEncoder(file),
}, nil
}
// Log writes an audit entry.
func (a *AuditLogger) Log(uid uint32, pid int32, action string, details any, success bool, errMsg string) {
a.mu.Lock()
defer a.mu.Unlock()
entry := AuditEntry{
Timestamp: time.Now().UTC().Format(time.RFC3339),
UID: uid,
PID: pid,
Action: action,
Details: details,
Success: success,
Error: errMsg,
}
// Ignore encoding errors - audit logging should not fail the operation
_ = a.encoder.Encode(entry)
}
// Close closes the audit logger.
func (a *AuditLogger) Close() error {
a.mu.Lock()
defer a.mu.Unlock()
if a.file != nil {
err := a.file.Close()
a.file = nil // Prevent double close
return err
}
return nil
}
// PeerCredentials holds the credentials of a connected peer.
type PeerCredentials struct {
UID uint32
GID uint32
PID int32
}
// isUserInGroup checks if a user (by UID) is a member of a group (by GID).
// This checks supplementary groups, not just the primary GID.
func isUserInGroup(uid uint32, targetGID uint32) bool {
// Look up user by UID
u, err := user.LookupId(fmt.Sprintf("%d", uid))
if err != nil {
return false
}
// Get user's group IDs
groupIDs, err := u.GroupIds()
if err != nil {
return false
}
// Check if target GID is in the list
targetGIDStr := fmt.Sprintf("%d", targetGID)
for _, gid := range groupIDs {
if gid == targetGIDStr {
return true
}
}
return false
}
+206
View File
@@ -0,0 +1,206 @@
package daemon
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimiter_Allow(t *testing.T) {
t.Run("under limit", func(t *testing.T) {
rl := NewRateLimiter(5, time.Minute)
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow(123), "request %d should be allowed", i)
}
})
t.Run("over limit", func(t *testing.T) {
rl := NewRateLimiter(3, time.Minute)
for i := 0; i < 3; i++ {
assert.True(t, rl.Allow(123))
}
// 4th request should be blocked
assert.False(t, rl.Allow(123))
})
t.Run("different PIDs", func(t *testing.T) {
rl := NewRateLimiter(2, time.Minute)
// PID 1
assert.True(t, rl.Allow(1))
assert.True(t, rl.Allow(1))
assert.False(t, rl.Allow(1))
// PID 2 should have its own limit
assert.True(t, rl.Allow(2))
assert.True(t, rl.Allow(2))
assert.False(t, rl.Allow(2))
})
t.Run("window expiration", func(t *testing.T) {
rl := NewRateLimiter(2, 10*time.Millisecond)
assert.True(t, rl.Allow(123))
assert.True(t, rl.Allow(123))
assert.False(t, rl.Allow(123))
// Wait for window to expire
time.Sleep(15 * time.Millisecond)
// Should be allowed again
assert.True(t, rl.Allow(123))
})
}
func TestRateLimiter_Cleanup(t *testing.T) {
rl := NewRateLimiter(10, 10*time.Millisecond)
// Add requests from multiple PIDs
for pid := int32(1); pid <= 5; pid++ {
rl.Allow(pid)
}
assert.Len(t, rl.requests, 5)
// Wait for expiration
time.Sleep(15 * time.Millisecond)
// Cleanup
rl.Cleanup()
assert.Empty(t, rl.requests)
}
func TestAuditLogger_Log(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
defer logger.Close()
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
logger.Log(1000, 12345, "sync", nil, false, "sync failed")
// Read log file
content, err := os.ReadFile(logPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, `"action":"set"`)
assert.Contains(t, contentStr, `"uid":1000`)
assert.Contains(t, contentStr, `"pid":12345`)
assert.Contains(t, contentStr, `"success":true`)
assert.Contains(t, contentStr, `"action":"sync"`)
assert.Contains(t, contentStr, `"success":false`)
assert.Contains(t, contentStr, `"error":"sync failed"`)
}
func TestAuditLogger_CreatesDirectory(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "subdir", "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
defer logger.Close()
// Verify directory was created
_, err = os.Stat(filepath.Dir(logPath))
assert.NoError(t, err)
}
func TestAuditLogger_Close(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
err = logger.Close()
assert.NoError(t, err)
// Closing again should not error
err = logger.Close()
assert.NoError(t, err)
}
func TestPeerCredentials(t *testing.T) {
creds := &PeerCredentials{
UID: 501,
GID: 20,
PID: 12345,
}
assert.Equal(t, uint32(501), creds.UID)
assert.Equal(t, uint32(20), creds.GID)
assert.Equal(t, int32(12345), creds.PID)
}
// Matrix test for rate limiting
func TestRateLimiter_Matrix(t *testing.T) {
limits := []int{1, 5, 10, 100}
windows := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, time.Second}
for _, limit := range limits {
for _, window := range windows {
t.Run(
"limit="+string(rune('0'+limit))+"_window="+window.String(),
func(t *testing.T) {
rl := NewRateLimiter(limit, window)
// Should allow exactly 'limit' requests
for i := 0; i < limit; i++ {
assert.True(t, rl.Allow(1))
}
// Next should be blocked
assert.False(t, rl.Allow(1))
},
)
}
}
}
func BenchmarkRateLimiter_Allow(b *testing.B) {
rl := NewRateLimiter(RateLimit, RateLimitWindow)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rl.Allow(int32(i % 100))
}
}
func BenchmarkRateLimiter_Cleanup(b *testing.B) {
rl := NewRateLimiter(RateLimit, RateLimitWindow)
// Pre-populate with requests
for i := 0; i < 1000; i++ {
rl.Allow(int32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rl.Cleanup()
}
}
func BenchmarkAuditLogger_Log(b *testing.B) {
tmpDir := b.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(b, err)
defer logger.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
}
}
+803
View File
@@ -0,0 +1,803 @@
// Package daemon provides the Unix socket server for the daemon.
package daemon
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
"github.com/lukaszraczylo/lolcathost/internal/config"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// Version is set by the main package at startup
var Version = "dev"
// Server is the daemon's Unix socket server.
type Server struct {
socketPath string
listener net.Listener
config *config.Manager
hosts *HostsManager
flusher *DNSFlusher
rateLimiter *RateLimiter
auditLogger *AuditLogger
mu sync.RWMutex
running bool
stopCh chan struct{}
requestCount int64
startTime int64
}
// NewServer creates a new daemon server.
func NewServer(socketPath string, cfgManager *config.Manager) *Server {
return &Server{
socketPath: socketPath,
config: cfgManager,
hosts: NewHostsManager(),
flusher: NewDNSFlusher(FlushMethodAuto),
rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow),
stopCh: make(chan struct{}),
}
}
// Start starts the server.
func (s *Server) Start() error {
// Remove existing socket
os.Remove(s.socketPath)
listener, err := net.Listen("unix", s.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on socket: %w", err)
}
// Set socket permissions: 0660 root:lolcathost
if err := os.Chmod(s.socketPath, 0660); err != nil {
listener.Close()
return fmt.Errorf("failed to set socket permissions: %w", err)
}
// Set socket group to lolcathost (GID 850)
if err := os.Chown(s.socketPath, 0, 850); err != nil {
listener.Close()
return fmt.Errorf("failed to set socket ownership: %w", err)
}
s.listener = listener
s.running = true
s.startTime = currentTimeUnix()
// Try to create audit logger, but don't fail if it doesn't work
if logger, err := NewAuditLogger(AuditLogPath); err == nil {
s.auditLogger = logger
}
go s.acceptLoop()
return nil
}
func currentTimeUnix() int64 {
return time.Now().Unix()
}
// Stop stops the server.
func (s *Server) Stop() error {
s.mu.Lock()
s.running = false
s.mu.Unlock()
close(s.stopCh)
if s.listener != nil {
s.listener.Close()
}
os.Remove(s.socketPath)
if s.auditLogger != nil {
s.auditLogger.Close()
}
return nil
}
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.stopCh:
return
default:
continue
}
}
go s.handleConnection(conn)
}
}
// LolcathostGID is the group ID for the lolcathost group.
const LolcathostGID = 850
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
// Get peer credentials
creds := s.getPeerCredentials(conn)
// Authorization check: verify peer is authorized
if !s.isAuthorized(creds) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
if s.auditLogger != nil {
var uid uint32
var pid int32
if creds != nil {
uid = creds.UID
pid = creds.PID
}
s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt")
}
return
}
reader := bufio.NewReader(conn)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
return
}
var req protocol.Request
if err := json.Unmarshal(line, &req); err != nil {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
continue
}
// Rate limiting
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
continue
}
s.mu.Lock()
s.requestCount++
s.mu.Unlock()
resp := s.handleRequest(&req, creds)
s.writeResponse(conn, resp)
}
}
// isAuthorized checks if the peer is authorized to access the daemon.
// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850).
func (s *Server) isAuthorized(creds *PeerCredentials) bool {
if creds == nil {
// Can't verify credentials - deny by default
return false
}
// Root is always authorized
if creds.UID == 0 {
return true
}
// Check if user's primary GID is lolcathost
if creds.GID == LolcathostGID {
return true
}
// Check supplementary groups (user might be in lolcathost as secondary group)
// This requires looking up the user's groups from the system
return isUserInGroup(creds.UID, LolcathostGID)
}
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
data, _ := json.Marshal(resp)
data = append(data, '\n')
conn.Write(data)
}
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
var uid uint32
var pid int32
if creds != nil {
uid = creds.UID
pid = creds.PID
}
switch req.Type {
case protocol.RequestPing:
return s.handlePing()
case protocol.RequestStatus:
return s.handleStatus()
case protocol.RequestList:
return s.handleList()
case protocol.RequestSet:
resp := s.handleSet(req)
if s.auditLogger != nil {
var payload protocol.SetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestSync:
resp := s.handleSync()
if s.auditLogger != nil {
s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestPreset:
resp := s.handlePreset(req)
if s.auditLogger != nil {
var payload protocol.PresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestRollback:
resp := s.handleRollback(req)
if s.auditLogger != nil {
var payload protocol.RollbackPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestBackups:
return s.handleBackups()
case protocol.RequestAdd:
resp := s.handleAdd(req)
if s.auditLogger != nil {
var payload protocol.AddPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDelete:
resp := s.handleDelete(req)
if s.auditLogger != nil {
var payload protocol.DeletePayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestAddGroup:
resp := s.handleAddGroup(req)
if s.auditLogger != nil {
var payload protocol.GroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDeleteGroup:
resp := s.handleDeleteGroup(req)
if s.auditLogger != nil {
var payload protocol.GroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestListGroups:
return s.handleListGroups()
case protocol.RequestRenameGroup:
resp := s.handleRenameGroup(req)
if s.auditLogger != nil {
var payload protocol.RenameGroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestAddPreset:
resp := s.handleAddPreset(req)
if s.auditLogger != nil {
var payload protocol.AddPresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDeletePreset:
resp := s.handleDeletePreset(req)
if s.auditLogger != nil {
var payload protocol.PresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestListPresets:
return s.handleListPresets()
default:
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type))
}
}
func (s *Server) handlePing() *protocol.Response {
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
return resp
}
func (s *Server) handleStatus() *protocol.Response {
s.mu.RLock()
reqCount := s.requestCount
startTime := s.startTime
s.mu.RUnlock()
cfg := s.config.Get()
var activeCount int
if cfg != nil {
for _, h := range cfg.GetAllHosts() {
if h.Enabled {
activeCount++
}
}
}
data := protocol.StatusData{
Running: true,
Version: Version,
Uptime: nowUnix() - startTime,
ActiveCount: activeCount,
RequestCount: reqCount,
}
resp, _ := protocol.NewOKResponse(data)
return resp
}
func nowUnix() int64 {
return time.Now().Unix()
}
func (s *Server) handleList() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
var entries []protocol.HostEntry
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
entries = append(entries, protocol.HostEntry{
Domain: h.Domain,
IP: h.IP,
Alias: h.Alias,
Enabled: h.Enabled,
Group: g.Name,
})
}
}
resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries})
return resp
}
func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
var payload protocol.SetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
host, _ := cfg.FindHostByAlias(payload.Alias)
if host == nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
}
// Check for conflicts if enabling
if payload.Enabled && !payload.Force {
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled {
return protocol.NewErrorResponse(protocol.ErrCodeConflict,
fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias))
}
}
}
}
// Update config
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
Domain: host.Domain,
Applied: true,
})
return resp
}
func (s *Server) handleSync() *protocol.Response {
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
return resp
}
func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
var payload protocol.PresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.ApplyPreset(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
return resp
}
func (s *Server) handleRollback(req *protocol.Request) *protocol.Response {
var payload protocol.RollbackPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if err := s.hosts.RestoreBackup(payload.BackupName); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err))
}
// Flush DNS after restore
s.flusher.Flush()
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
return resp
}
func (s *Server) handleBackups() *protocol.Response {
backups, err := s.hosts.ListBackups()
if err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err))
}
var infos []protocol.BackupInfo
for _, b := range backups {
infos = append(infos, protocol.BackupInfo{
Name: b.Name,
Timestamp: b.Timestamp,
Size: b.Size,
})
}
resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos})
return resp
}
func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
var payload protocol.AddPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
// Validate domain
if payload.Domain == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required")
}
// Validate IP
if payload.IP == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required")
}
// Validate group
if payload.Group == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required")
}
// Check blocked domains
if config.IsBlockedDomain(payload.Domain) {
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain))
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
// Add to config (alias will be auto-generated if empty)
if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
Domain: payload.Domain,
Applied: true,
})
return resp
}
func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
var payload protocol.DeletePayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Alias == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
// Delete from config
if !cfg.DeleteHost(payload.Alias) {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
return resp
}
func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response {
var payload protocol.GroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.AddGroup(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
return resp
}
func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
var payload protocol.GroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.DeleteGroup(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
return resp
}
func (s *Server) handleListGroups() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()})
return resp
}
func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response {
var payload protocol.RenameGroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.OldName == "" || payload.NewName == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName})
return resp
}
func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response {
var payload protocol.AddPresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
return resp
}
func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response {
var payload protocol.PresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.DeletePreset(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
return resp
}
func (s *Server) handleListPresets() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
presets := cfg.GetPresets()
infos := make([]protocol.PresetInfo, len(presets))
for i, p := range presets {
infos[i] = protocol.PresetInfo{
Name: p.Name,
Enable: p.Enable,
Disable: p.Disable,
}
}
resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos})
return resp
}
func (s *Server) syncHostsFile() error {
cfg := s.config.Get()
if cfg == nil {
return fmt.Errorf("no configuration loaded")
}
var entries []HostEntry
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
entries = append(entries, HostEntry{
IP: h.IP,
Domain: h.Domain,
Alias: h.Alias,
Enabled: h.Enabled,
})
}
}
if err := s.hosts.WriteManagedEntries(entries); err != nil {
return err
}
// Flush DNS cache
return s.flusher.Flush()
}
+474
View File
@@ -0,0 +1,474 @@
// Package installer handles installation and uninstallation of the lolcathost daemon.
package installer
import (
"fmt"
"os"
"os/exec"
"os/user"
"path/filepath"
"runtime"
"strconv"
"strings"
"github.com/lukaszraczylo/lolcathost/internal/config"
)
const (
// GroupName is the name of the lolcathost group.
GroupName = "lolcathost"
// GroupGID is the GID for the lolcathost group (macOS).
GroupGID = 850
// Paths
LogDir = "/var/log/lolcathost"
BackupDir = "/var/backups/lolcathost"
SocketPath = "/var/run/lolcathost.sock"
LaunchDaemonDir = "/Library/LaunchDaemons"
SystemdDir = "/etc/systemd/system"
)
// LaunchDaemonPlist is the macOS LaunchDaemon plist template.
const LaunchDaemonPlist = `<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>com.lolcathost.daemon</string>
<key>ProgramArguments</key>
<array>
<string>%s</string>
<string>--daemon</string>
<string>--config</string>
<string>/etc/lolcathost/config.yaml</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/lolcathost/daemon.log</string>
<key>StandardErrorPath</key>
<string>/var/log/lolcathost/daemon.err</string>
</dict>
</plist>
`
// SystemdUnit is the Linux systemd unit template.
const SystemdUnit = `[Unit]
Description=lolcathost - Dynamic Host Management Daemon
After=network.target
[Service]
Type=simple
ExecStart=%s --daemon --config /etc/lolcathost/config.yaml
Restart=always
RestartSec=5
User=root
Group=root
[Install]
WantedBy=multi-user.target
`
// Installer handles installation and uninstallation.
type Installer struct {
binaryPath string
verbose bool
}
// New creates a new installer.
func New() (*Installer, error) {
binaryPath, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("failed to get executable path: %w", err)
}
// Resolve symlinks
binaryPath, err = filepath.EvalSymlinks(binaryPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve executable path: %w", err)
}
return &Installer{
binaryPath: binaryPath,
verbose: true,
}, nil
}
// Install performs the full installation.
func (i *Installer) Install() error {
if os.Geteuid() != 0 {
return fmt.Errorf("--install requires sudo")
}
i.log("Installing lolcathost...")
// Create group
if err := i.createGroup(); err != nil {
return fmt.Errorf("failed to create group: %w", err)
}
// Add current user to group
if err := i.addCurrentUserToGroup(); err != nil {
return fmt.Errorf("failed to add user to group: %w", err)
}
// Create directories
if err := i.createDirectories(); err != nil {
return fmt.Errorf("failed to create directories: %w", err)
}
// Create system config for daemon
if err := i.createSystemConfig(); err != nil {
return fmt.Errorf("failed to create system config: %w", err)
}
// Install service
if runtime.GOOS == "darwin" {
if err := i.installLaunchDaemon(); err != nil {
return fmt.Errorf("failed to install LaunchDaemon: %w", err)
}
} else if runtime.GOOS == "linux" {
if err := i.installSystemdService(); err != nil {
return fmt.Errorf("failed to install systemd service: %w", err)
}
}
// Create default config for the invoking user
if err := i.createDefaultConfig(); err != nil {
i.log("Warning: failed to create default config: %v", err)
}
i.log("")
i.log("✓ Installed successfully!")
i.log("")
i.log("Next steps:")
i.log(" 1. Open a NEW terminal (for group membership to take effect)")
i.log(" 2. Run 'lolcathost' to start the TUI")
i.log("")
return nil
}
// Uninstall removes the installation.
func (i *Installer) Uninstall() error {
if os.Geteuid() != 0 {
return fmt.Errorf("--uninstall requires sudo")
}
i.log("Uninstalling lolcathost...")
// Stop and remove service
if runtime.GOOS == "darwin" {
i.uninstallLaunchDaemon()
} else if runtime.GOOS == "linux" {
i.uninstallSystemdService()
}
// Remove socket
os.Remove(SocketPath)
// Note: We don't remove the group, logs, or backups
// The user may want to keep these
i.log("")
i.log("✓ Uninstalled successfully!")
i.log("")
i.log("Note: Log files, backups, and the group were preserved.")
i.log("To fully remove, manually delete:")
i.log(" - /var/log/lolcathost/")
i.log(" - /var/backups/lolcathost/")
i.log(" - ~/.config/lolcathost/")
if runtime.GOOS == "darwin" {
i.log(" - Remove group: sudo dscl . -delete /Groups/%s", GroupName)
} else {
i.log(" - Remove group: sudo groupdel %s", GroupName)
}
i.log("")
return nil
}
func (i *Installer) log(format string, args ...any) {
if i.verbose {
fmt.Printf(format+"\n", args...)
}
}
func (i *Installer) createGroup() error {
switch runtime.GOOS {
case "darwin":
return i.createGroupDarwin()
case "linux":
return i.createGroupLinux()
default:
return fmt.Errorf("unsupported OS: %s", runtime.GOOS)
}
}
func (i *Installer) createGroupDarwin() error {
// Check if group exists
if _, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName).Output(); err == nil {
i.log(" Group '%s' already exists", GroupName)
return nil
}
i.log(" Creating group '%s' (GID %d)...", GroupName, GroupGID)
// Create group
cmds := [][]string{
{"dscl", ".", "-create", "/Groups/" + GroupName},
{"dscl", ".", "-create", "/Groups/" + GroupName, "PrimaryGroupID", strconv.Itoa(GroupGID)},
{"dscl", ".", "-create", "/Groups/" + GroupName, "RealName", "lolcathost users"},
}
for _, args := range cmds {
if err := exec.Command(args[0], args[1:]...).Run(); err != nil {
return fmt.Errorf("command %v failed: %w", args, err)
}
}
return nil
}
func (i *Installer) createGroupLinux() error {
// Check if group exists
if _, err := exec.Command("getent", "group", GroupName).Output(); err == nil {
i.log(" Group '%s' already exists", GroupName)
return nil
}
i.log(" Creating group '%s'...", GroupName)
if err := exec.Command("groupadd", "-r", GroupName).Run(); err != nil {
return fmt.Errorf("groupadd failed: %w", err)
}
return nil
}
func (i *Installer) addCurrentUserToGroup() error {
// Get the real user (not root)
username := os.Getenv("SUDO_USER")
if username == "" {
// Fall back to current user
u, err := user.Current()
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}
username = u.Username
}
if username == "root" {
i.log(" Skipping adding root to group")
return nil
}
switch runtime.GOOS {
case "darwin":
return i.addUserToGroupDarwin(username)
case "linux":
return i.addUserToGroupLinux(username)
default:
return fmt.Errorf("unsupported OS: %s", runtime.GOOS)
}
}
func (i *Installer) addUserToGroupDarwin(username string) error {
// Check if user is already in group
output, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName, "GroupMembership").Output()
if err == nil && strings.Contains(string(output), username) {
i.log(" User '%s' already in group '%s'", username, GroupName)
return nil
}
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
if err := exec.Command("dscl", ".", "-append", "/Groups/"+GroupName, "GroupMembership", username).Run(); err != nil {
return fmt.Errorf("failed to add user to group: %w", err)
}
return nil
}
func (i *Installer) addUserToGroupLinux(username string) error {
// Check if user is already in group
output, err := exec.Command("id", "-nG", username).Output()
if err == nil && strings.Contains(string(output), GroupName) {
i.log(" User '%s' already in group '%s'", username, GroupName)
return nil
}
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
if err := exec.Command("usermod", "-aG", GroupName, username).Run(); err != nil {
return fmt.Errorf("failed to add user to group: %w", err)
}
return nil
}
func (i *Installer) createDirectories() error {
dirs := []string{LogDir, BackupDir, config.SystemConfigDir}
for _, dir := range dirs {
i.log(" Creating directory '%s'...", dir)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create %s: %w", dir, err)
}
}
return nil
}
func (i *Installer) createSystemConfig() error {
// Check if system config already exists
if _, err := os.Stat(config.SystemConfigPath); err == nil {
i.log(" System config already exists at %s", config.SystemConfigPath)
return nil
}
i.log(" Creating system config at %s...", config.SystemConfigPath)
return config.CreateDefault(config.SystemConfigPath)
}
func (i *Installer) installLaunchDaemon() error {
plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist")
plistContent := fmt.Sprintf(LaunchDaemonPlist, i.binaryPath)
i.log(" Writing LaunchDaemon plist...")
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
return fmt.Errorf("failed to write plist: %w", err)
}
// Unload if already loaded
exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
// Bootstrap the daemon
i.log(" Starting daemon...")
if err := exec.Command("launchctl", "bootstrap", "system", plistPath).Run(); err != nil {
return fmt.Errorf("failed to bootstrap daemon: %w", err)
}
return nil
}
func (i *Installer) uninstallLaunchDaemon() {
plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist")
i.log(" Stopping daemon...")
exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
i.log(" Removing LaunchDaemon plist...")
os.Remove(plistPath)
}
func (i *Installer) installSystemdService() error {
unitPath := filepath.Join(SystemdDir, "lolcathost.service")
unitContent := fmt.Sprintf(SystemdUnit, i.binaryPath)
i.log(" Writing systemd unit...")
if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil {
return fmt.Errorf("failed to write unit file: %w", err)
}
// Reload systemd
i.log(" Reloading systemd...")
if err := exec.Command("systemctl", "daemon-reload").Run(); err != nil {
return fmt.Errorf("failed to reload systemd: %w", err)
}
// Enable and start the service
i.log(" Enabling and starting service...")
if err := exec.Command("systemctl", "enable", "--now", "lolcathost.service").Run(); err != nil {
return fmt.Errorf("failed to enable service: %w", err)
}
return nil
}
func (i *Installer) uninstallSystemdService() {
i.log(" Stopping and disabling service...")
exec.Command("systemctl", "disable", "--now", "lolcathost.service").Run()
i.log(" Removing systemd unit...")
os.Remove(filepath.Join(SystemdDir, "lolcathost.service"))
exec.Command("systemctl", "daemon-reload").Run()
}
func (i *Installer) createDefaultConfig() error {
// Get the real user's home directory
username := os.Getenv("SUDO_USER")
if username == "" {
return nil // Can't determine user
}
u, err := user.Lookup(username)
if err != nil {
return fmt.Errorf("failed to lookup user: %w", err)
}
configPath := filepath.Join(u.HomeDir, ".config", "lolcathost", "config.yaml")
// Check if config already exists
if _, err := os.Stat(configPath); err == nil {
i.log(" Config already exists at %s", configPath)
return nil
}
i.log(" Creating default config at %s...", configPath)
if err := config.CreateDefault(configPath); err != nil {
return err
}
// Change ownership to the real user
uid, _ := strconv.Atoi(u.Uid)
gid, _ := strconv.Atoi(u.Gid)
configDir := filepath.Dir(configPath)
os.Chown(configDir, uid, gid)
os.Chown(filepath.Dir(configDir), uid, gid)
os.Chown(configPath, uid, gid)
return nil
}
// CheckInstallation checks if the daemon is properly installed.
func CheckInstallation() error {
// Check if socket exists
if _, err := os.Stat(SocketPath); os.IsNotExist(err) {
return fmt.Errorf("daemon not running (socket not found)")
}
// Check if user is in group
u, err := user.Current()
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}
groups, err := u.GroupIds()
if err != nil {
return fmt.Errorf("failed to get user groups: %w", err)
}
inGroup := false
for _, gid := range groups {
g, err := user.LookupGroupId(gid)
if err != nil {
continue
}
if g.Name == GroupName {
inGroup = true
break
}
}
if !inGroup {
return fmt.Errorf("user '%s' is not in group '%s'. Run 'sudo lolcathost --install' and open a new terminal", u.Username, GroupName)
}
return nil
}
+226
View File
@@ -0,0 +1,226 @@
// Package protocol defines shared message types for client-daemon communication.
package protocol
import (
"encoding/json"
"fmt"
)
// SocketPath is the Unix socket path for daemon communication.
const SocketPath = "/var/run/lolcathost.sock"
// RequestType defines the type of request.
type RequestType string
const (
RequestPing RequestType = "ping"
RequestStatus RequestType = "status"
RequestList RequestType = "list"
RequestSet RequestType = "set"
RequestAdd RequestType = "add"
RequestDelete RequestType = "delete"
RequestSync RequestType = "sync"
RequestPreset RequestType = "preset"
RequestRollback RequestType = "rollback"
RequestBackups RequestType = "backups"
RequestAddGroup RequestType = "add_group"
RequestDeleteGroup RequestType = "delete_group"
RequestRenameGroup RequestType = "rename_group"
RequestListGroups RequestType = "list_groups"
RequestAddPreset RequestType = "add_preset"
RequestDeletePreset RequestType = "delete_preset"
RequestListPresets RequestType = "list_presets"
)
// ErrorCode defines standard error codes.
type ErrorCode string
const (
ErrCodeInvalidRequest ErrorCode = "INVALID_REQUEST"
ErrCodeInvalidDomain ErrorCode = "INVALID_DOMAIN"
ErrCodeInvalidIP ErrorCode = "INVALID_IP"
ErrCodeBlockedDomain ErrorCode = "BLOCKED_DOMAIN"
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
ErrCodeUnauthorized ErrorCode = "UNAUTHORIZED"
ErrCodeNotFound ErrorCode = "NOT_FOUND"
ErrCodeConflict ErrorCode = "CONFLICT"
ErrCodeInternalError ErrorCode = "INTERNAL_ERROR"
ErrCodePermissionError ErrorCode = "PERMISSION_ERROR"
)
// Request represents a client request to the daemon.
type Request struct {
Type RequestType `json:"type"`
Payload json.RawMessage `json:"payload,omitempty"`
}
// SetPayload is the payload for set requests.
type SetPayload struct {
Alias string `json:"alias"`
Enabled bool `json:"enabled"`
Force bool `json:"force,omitempty"`
}
// PresetPayload is the payload for preset requests.
type PresetPayload struct {
Name string `json:"name"`
}
// RollbackPayload is the payload for rollback requests.
type RollbackPayload struct {
BackupName string `json:"backup_name"`
}
// AddPayload is the payload for add requests.
type AddPayload struct {
Domain string `json:"domain"`
IP string `json:"ip"`
Alias string `json:"alias"`
Group string `json:"group"`
Enabled bool `json:"enabled"`
}
// DeletePayload is the payload for delete requests.
type DeletePayload struct {
Alias string `json:"alias"`
}
// GroupPayload is the payload for group add/delete requests.
type GroupPayload struct {
Name string `json:"name"`
}
// RenameGroupPayload is the payload for rename_group requests.
type RenameGroupPayload struct {
OldName string `json:"old_name"`
NewName string `json:"new_name"`
}
// GroupsData is the data for list_groups responses.
type GroupsData struct {
Groups []string `json:"groups"`
}
// AddPresetPayload is the payload for add_preset requests.
type AddPresetPayload struct {
Name string `json:"name"`
Enable []string `json:"enable"`
Disable []string `json:"disable"`
}
// PresetInfo represents a preset with its configuration.
type PresetInfo struct {
Name string `json:"name"`
Enable []string `json:"enable"`
Disable []string `json:"disable"`
}
// PresetsData is the data for list_presets responses.
type PresetsData struct {
Presets []PresetInfo `json:"presets"`
}
// Response represents a daemon response.
type Response struct {
Status string `json:"status"`
Data json.RawMessage `json:"data,omitempty"`
Message string `json:"message,omitempty"`
Code ErrorCode `json:"code,omitempty"`
}
// StatusData is the data for status responses.
type StatusData struct {
Running bool `json:"running"`
Version string `json:"version"`
Uptime int64 `json:"uptime_seconds"`
ActiveCount int `json:"active_count"`
RequestCount int64 `json:"request_count"`
}
// HostEntry represents a single host entry.
type HostEntry struct {
Domain string `json:"domain"`
IP string `json:"ip"`
Alias string `json:"alias"`
Enabled bool `json:"enabled"`
Group string `json:"group"`
}
// ListData is the data for list responses.
type ListData struct {
Entries []HostEntry `json:"entries"`
}
// SetData is the data for set responses.
type SetData struct {
Domain string `json:"domain"`
Applied bool `json:"applied"`
}
// BackupsData is the data for backups responses.
type BackupsData struct {
Backups []BackupInfo `json:"backups"`
}
// BackupInfo represents a backup file.
type BackupInfo struct {
Name string `json:"name"`
Timestamp int64 `json:"timestamp"`
Size int64 `json:"size"`
}
// NewRequest creates a new request with the given type and payload.
func NewRequest(reqType RequestType, payload interface{}) (*Request, error) {
req := &Request{Type: reqType}
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %w", err)
}
req.Payload = data
}
return req, nil
}
// NewOKResponse creates a success response with optional data.
func NewOKResponse(data interface{}) (*Response, error) {
resp := &Response{Status: "ok"}
if data != nil {
dataBytes, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("failed to marshal data: %w", err)
}
resp.Data = dataBytes
}
return resp, nil
}
// NewErrorResponse creates an error response.
func NewErrorResponse(code ErrorCode, message string) *Response {
return &Response{
Status: "error",
Code: code,
Message: message,
}
}
// ParsePayload unmarshals the request payload into the given target.
func (r *Request) ParsePayload(target interface{}) error {
if r.Payload == nil {
return fmt.Errorf("no payload in request")
}
return json.Unmarshal(r.Payload, target)
}
// ParseData unmarshals the response data into the given target.
func (r *Response) ParseData(target interface{}) error {
if r.Data == nil {
return fmt.Errorf("no data in response")
}
return json.Unmarshal(r.Data, target)
}
// IsOK returns true if the response indicates success.
func (r *Response) IsOK() bool {
return r.Status == "ok"
}
+227
View File
@@ -0,0 +1,227 @@
package protocol
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRequest(t *testing.T) {
tests := []struct {
name string
reqType RequestType
payload interface{}
wantErr bool
}{
{
name: "ping request without payload",
reqType: RequestPing,
payload: nil,
wantErr: false,
},
{
name: "set request with payload",
reqType: RequestSet,
payload: SetPayload{Alias: "test", Enabled: true},
wantErr: false,
},
{
name: "preset request with payload",
reqType: RequestPreset,
payload: PresetPayload{Name: "local"},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, err := NewRequest(tt.reqType, tt.payload)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.reqType, req.Type)
if tt.payload != nil {
assert.NotNil(t, req.Payload)
}
})
}
}
func TestRequest_ParsePayload(t *testing.T) {
t.Run("valid payload", func(t *testing.T) {
payload := SetPayload{Alias: "test-alias", Enabled: true, Force: false}
req, err := NewRequest(RequestSet, payload)
require.NoError(t, err)
var parsed SetPayload
err = req.ParsePayload(&parsed)
require.NoError(t, err)
assert.Equal(t, "test-alias", parsed.Alias)
assert.True(t, parsed.Enabled)
assert.False(t, parsed.Force)
})
t.Run("nil payload", func(t *testing.T) {
req := &Request{Type: RequestPing}
var parsed SetPayload
err := req.ParsePayload(&parsed)
assert.Error(t, err)
})
}
func TestNewOKResponse(t *testing.T) {
t.Run("with data", func(t *testing.T) {
data := StatusData{
Running: true,
Version: "1.0.0",
Uptime: 3600,
ActiveCount: 5,
RequestCount: 100,
}
resp, err := NewOKResponse(data)
require.NoError(t, err)
assert.Equal(t, "ok", resp.Status)
assert.NotNil(t, resp.Data)
assert.True(t, resp.IsOK())
})
t.Run("without data", func(t *testing.T) {
resp, err := NewOKResponse(nil)
require.NoError(t, err)
assert.Equal(t, "ok", resp.Status)
assert.Nil(t, resp.Data)
})
}
func TestNewErrorResponse(t *testing.T) {
resp := NewErrorResponse(ErrCodeBlockedDomain, "domain is blocked")
assert.Equal(t, "error", resp.Status)
assert.Equal(t, ErrCodeBlockedDomain, resp.Code)
assert.Equal(t, "domain is blocked", resp.Message)
assert.False(t, resp.IsOK())
}
func TestResponse_ParseData(t *testing.T) {
t.Run("valid data", func(t *testing.T) {
data := ListData{
Entries: []HostEntry{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"},
},
}
resp, err := NewOKResponse(data)
require.NoError(t, err)
var parsed ListData
err = resp.ParseData(&parsed)
require.NoError(t, err)
assert.Len(t, parsed.Entries, 1)
assert.Equal(t, "example.com", parsed.Entries[0].Domain)
})
t.Run("nil data", func(t *testing.T) {
resp := &Response{Status: "ok"}
var parsed ListData
err := resp.ParseData(&parsed)
assert.Error(t, err)
})
}
func TestRequestTypes(t *testing.T) {
types := []RequestType{
RequestPing,
RequestStatus,
RequestList,
RequestSet,
RequestSync,
RequestPreset,
RequestRollback,
RequestBackups,
}
for _, rt := range types {
t.Run(string(rt), func(t *testing.T) {
req, err := NewRequest(rt, nil)
require.NoError(t, err)
assert.Equal(t, rt, req.Type)
// Verify JSON marshaling works
data, err := json.Marshal(req)
require.NoError(t, err)
assert.Contains(t, string(data), string(rt))
})
}
}
func TestErrorCodes(t *testing.T) {
codes := []ErrorCode{
ErrCodeInvalidRequest,
ErrCodeInvalidDomain,
ErrCodeInvalidIP,
ErrCodeBlockedDomain,
ErrCodeRateLimited,
ErrCodeNotFound,
ErrCodeConflict,
ErrCodeInternalError,
ErrCodePermissionError,
}
for _, code := range codes {
t.Run(string(code), func(t *testing.T) {
resp := NewErrorResponse(code, "test error")
assert.Equal(t, code, resp.Code)
// Verify JSON marshaling works
data, err := json.Marshal(resp)
require.NoError(t, err)
assert.Contains(t, string(data), string(code))
})
}
}
func TestHostEntry(t *testing.T) {
entry := HostEntry{
Domain: "example.com",
IP: "127.0.0.1",
Alias: "example-local",
Enabled: true,
Group: "development",
}
data, err := json.Marshal(entry)
require.NoError(t, err)
var parsed HostEntry
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, entry.Domain, parsed.Domain)
assert.Equal(t, entry.IP, parsed.IP)
assert.Equal(t, entry.Alias, parsed.Alias)
assert.Equal(t, entry.Enabled, parsed.Enabled)
assert.Equal(t, entry.Group, parsed.Group)
}
func TestBackupInfo(t *testing.T) {
info := BackupInfo{
Name: "hosts.20231201-120000.bak",
Timestamp: 1701432000,
Size: 1024,
}
data, err := json.Marshal(info)
require.NoError(t, err)
var parsed BackupInfo
err = json.Unmarshal(data, &parsed)
require.NoError(t, err)
assert.Equal(t, info.Name, parsed.Name)
assert.Equal(t, info.Timestamp, parsed.Timestamp)
assert.Equal(t, info.Size, parsed.Size)
}
+904
View File
@@ -0,0 +1,904 @@
// Package tui provides the main Bubble Tea application.
package tui
import (
"context"
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/lukaszraczylo/lolcathost/internal/client"
"github.com/lukaszraczylo/lolcathost/internal/config"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
"github.com/lukaszraczylo/lolcathost/internal/version"
)
// ViewMode represents the current view mode.
type ViewMode int
const (
ViewList ViewMode = iota
ViewForm
ViewPresets
ViewGroups
ViewHelp
ViewSearch
)
// Model is the main Bubble Tea model.
type Model struct {
// Client
client *client.Client
connected bool
// Config
configPath string
config *config.Manager
// Views
mode ViewMode
list *ListView
form *Form
presetPicker *PresetPicker
groupPicker *GroupPicker
searchInput textinput.Model
// State
width int
height int
message string
messageStyle string // "error" or "success"
messageTime time.Time
searchTerm string
allGroups []string // All groups including empty ones
// Update notification
updateAvailable bool
updateVersion string
updateURL string
// Version info for update checking
version string
githubOwner string
githubRepo string
}
// Message types
type (
connectMsg struct{ err error }
refreshMsg struct {
entries []protocol.HostEntry
err error
}
toggleMsg struct {
alias string
err error
}
presetMsg struct {
name string
err error
}
addMsg struct {
domain string
err error
}
deleteMsg struct {
alias string
err error
}
addPresetMsg struct {
name string
err error
}
deletePresetMsg struct {
name string
err error
}
refreshPresetsMsg struct {
presets []protocol.PresetInfo
err error
}
addGroupMsg struct {
name string
err error
}
renameGroupMsg struct {
name string
err error
}
deleteGroupMsg struct {
name string
err error
}
refreshGroupsMsg struct {
groups []string
err error
}
clearMsgMsg struct{}
tickMsg struct{}
updateMsg struct {
version string
url string
}
)
// NewModel creates a new TUI model.
func NewModel(socketPath, configPath string) *Model {
searchInput := textinput.New()
searchInput.Placeholder = "Search..."
searchInput.CharLimit = 100
searchInput.Width = 50
return &Model{
client: client.New(socketPath),
configPath: configPath,
config: config.NewManager(configPath),
list: NewListView(),
form: NewForm(),
presetPicker: NewPresetPicker(),
groupPicker: NewGroupPicker(),
searchInput: searchInput,
mode: ViewList,
}
}
// Init initializes the model.
func (m *Model) Init() tea.Cmd {
return tea.Batch(
m.connect(),
tea.SetWindowTitle("lolcathost"),
m.tick(),
m.checkForUpdate(),
)
}
func (m *Model) connect() tea.Cmd {
return func() tea.Msg {
if err := m.client.Connect(); err != nil {
return connectMsg{err: err}
}
return connectMsg{err: nil}
}
}
func (m *Model) refresh() tea.Cmd {
return func() tea.Msg {
entries, err := m.client.List()
if err != nil {
return refreshMsg{entries: nil, err: err}
}
return refreshMsg{entries: entries, err: nil}
}
}
func (m *Model) toggle(alias string, enabled bool) tea.Cmd {
return func() tea.Msg {
_, err := m.client.Set(alias, enabled, false)
return toggleMsg{alias: alias, err: err}
}
}
func (m *Model) applyPreset(name string) tea.Cmd {
return func() tea.Msg {
err := m.client.ApplyPreset(name)
return presetMsg{name: name, err: err}
}
}
func (m *Model) addHost(domain, ip, alias, group string) tea.Cmd {
return func() tea.Msg {
_, err := m.client.Add(domain, ip, alias, group, false)
return addMsg{domain: domain, err: err}
}
}
func (m *Model) deleteHost(alias string) tea.Cmd {
return func() tea.Msg {
err := m.client.Delete(alias)
return deleteMsg{alias: alias, err: err}
}
}
func (m *Model) addPreset(name string, enable, disable []string) tea.Cmd {
return func() tea.Msg {
err := m.client.AddPreset(name, enable, disable)
return addPresetMsg{name: name, err: err}
}
}
func (m *Model) deletePreset(name string) tea.Cmd {
return func() tea.Msg {
err := m.client.DeletePreset(name)
return deletePresetMsg{name: name, err: err}
}
}
func (m *Model) refreshPresets() tea.Cmd {
return func() tea.Msg {
presets, err := m.client.ListPresets()
return refreshPresetsMsg{presets: presets, err: err}
}
}
func (m *Model) addGroup(name string) tea.Cmd {
return func() tea.Msg {
err := m.client.AddGroup(name)
return addGroupMsg{name: name, err: err}
}
}
func (m *Model) renameGroup(oldName, newName string) tea.Cmd {
return func() tea.Msg {
err := m.client.RenameGroup(oldName, newName)
return renameGroupMsg{name: newName, err: err}
}
}
func (m *Model) deleteGroup(name string) tea.Cmd {
return func() tea.Msg {
err := m.client.DeleteGroup(name)
return deleteGroupMsg{name: name, err: err}
}
}
func (m *Model) refreshGroups() tea.Cmd {
return func() tea.Msg {
groups, err := m.client.ListGroups()
return refreshGroupsMsg{groups: groups, err: err}
}
}
func (m *Model) tick() tea.Cmd {
return tea.Tick(time.Second*3, func(t time.Time) tea.Msg {
return tickMsg{}
})
}
func (m *Model) clearMsg() tea.Cmd {
return tea.Tick(time.Second*3, func(t time.Time) tea.Msg {
return clearMsgMsg{}
})
}
func (m *Model) checkForUpdate() tea.Cmd {
if m.githubOwner == "" || m.githubRepo == "" {
return nil
}
return func() tea.Msg {
checker := version.NewChecker(m.githubOwner, m.githubRepo, m.version)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if update := checker.CheckForUpdate(ctx); update != nil {
return updateMsg{version: update.LatestVersion, url: update.ReleaseURL}
}
return nil
}
}
// Update handles messages.
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
switch msg := msg.(type) {
case tea.WindowSizeMsg:
m.width = msg.Width
m.height = msg.Height
m.list.SetSize(msg.Width, msg.Height-10)
m.form.SetSize(msg.Width, msg.Height)
m.presetPicker.SetSize(msg.Width, msg.Height)
m.groupPicker.SetSize(msg.Width, msg.Height)
// Set search input width
searchWidth := msg.Width - 20
if searchWidth > 60 {
searchWidth = 60
}
m.searchInput.Width = searchWidth
case tea.KeyMsg:
cmd := m.handleKey(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
case connectMsg:
if msg.err != nil {
m.connected = false
m.setError(fmt.Sprintf("Failed to connect: %v", msg.err))
} else {
m.connected = true
cmds = append(cmds, m.refresh())
cmds = append(cmds, m.refreshPresets())
cmds = append(cmds, m.refreshGroups())
m.loadConfig()
}
case refreshMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Refresh failed: %v", msg.err))
// Mark as disconnected to trigger reconnect
m.connected = false
m.client.Close()
} else if msg.entries != nil {
m.list.SetItems(msg.entries)
}
case toggleMsg:
if msg.err != nil {
m.list.SetError(msg.alias, true)
m.setError(fmt.Sprintf("Toggle failed: %v", msg.err))
} else {
m.list.SetPending(msg.alias, false)
cmds = append(cmds, m.refresh())
m.setSuccess("Entry toggled")
}
case presetMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Preset failed: %v", msg.err))
} else {
cmds = append(cmds, m.refresh())
m.setSuccess(fmt.Sprintf("Applied preset: %s", msg.name))
}
m.mode = ViewList
case addMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Add failed: %v", msg.err))
} else {
cmds = append(cmds, m.refresh())
m.setSuccess(fmt.Sprintf("Added host: %s", msg.domain))
}
m.mode = ViewList
case deleteMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Delete failed: %v", msg.err))
} else {
cmds = append(cmds, m.refresh())
m.setSuccess(fmt.Sprintf("Deleted: %s", msg.alias))
}
case addPresetMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Add preset failed: %v", msg.err))
} else {
cmds = append(cmds, m.refreshPresets())
m.setSuccess(fmt.Sprintf("Added preset: %s", msg.name))
}
m.presetPicker.CancelForm()
case deletePresetMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Delete preset failed: %v", msg.err))
} else {
cmds = append(cmds, m.refreshPresets())
m.setSuccess(fmt.Sprintf("Deleted preset: %s", msg.name))
}
m.presetPicker.CancelForm()
case refreshPresetsMsg:
if msg.err == nil && msg.presets != nil {
m.presetPicker.SetPresetsWithInfo(msg.presets)
}
case addGroupMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Add group failed: %v", msg.err))
} else {
cmds = append(cmds, m.refreshGroups())
cmds = append(cmds, m.refresh()) // Refresh list to show new group
m.setSuccess(fmt.Sprintf("Added group: %s", msg.name))
}
m.groupPicker.CancelForm()
case renameGroupMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Rename group failed: %v", msg.err))
} else {
cmds = append(cmds, m.refreshGroups())
cmds = append(cmds, m.refresh())
m.setSuccess(fmt.Sprintf("Renamed group to: %s", msg.name))
}
m.groupPicker.CancelForm()
case deleteGroupMsg:
if msg.err != nil {
m.setError(fmt.Sprintf("Delete group failed: %v", msg.err))
} else {
cmds = append(cmds, m.refreshGroups())
cmds = append(cmds, m.refresh())
m.setSuccess(fmt.Sprintf("Deleted group: %s", msg.name))
}
m.groupPicker.CancelForm()
case refreshGroupsMsg:
if msg.err == nil && msg.groups != nil {
m.allGroups = msg.groups
m.groupPicker.SetGroups(msg.groups)
}
case clearMsgMsg:
if time.Since(m.messageTime) >= time.Second*3 {
m.message = ""
}
case tickMsg:
// Reconnect if disconnected
if !m.connected {
cmds = append(cmds, m.connect())
}
cmds = append(cmds, m.tick())
case updateMsg:
if msg.version != "" {
m.updateAvailable = true
m.updateVersion = msg.version
m.updateURL = msg.url
}
}
return m, tea.Batch(cmds...)
}
func (m *Model) handleKey(msg tea.KeyMsg) tea.Cmd {
// Global keys
switch msg.String() {
case "ctrl+c":
return tea.Quit
}
// Mode-specific keys
switch m.mode {
case ViewList:
return m.handleListKey(msg)
case ViewForm:
return m.handleFormKey(msg)
case ViewPresets:
return m.handlePresetKey(msg)
case ViewGroups:
return m.handleGroupKey(msg)
case ViewHelp:
return m.handleHelpKey(msg)
case ViewSearch:
return m.handleSearchKey(msg)
}
return nil
}
func (m *Model) handleListKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "q":
return tea.Quit
case "esc":
// Clear search if active
if m.searchTerm != "" {
m.searchTerm = ""
m.searchInput.Reset()
}
case "up", "k":
m.list.MoveUp()
case "down", "j":
m.list.MoveDown()
case " ", "enter":
return m.toggleSelected()
case "n":
m.mode = ViewForm
m.form.SetGroups(m.allGroups)
m.form.Init()
case "e":
if item := m.list.Selected(); item != nil {
m.mode = ViewForm
m.form.SetGroups(m.allGroups)
m.form.InitEdit(item.Entry.Domain, item.Entry.IP, item.Entry.Alias, item.Entry.Group)
}
case "d":
if item := m.list.Selected(); item != nil {
return m.deleteHost(item.Entry.Alias)
}
case "p":
m.mode = ViewPresets
case "g":
m.mode = ViewGroups
return m.refreshGroups()
case "/":
m.mode = ViewSearch
m.searchInput.Focus()
case "?":
m.mode = ViewHelp
case "r":
return m.refresh()
}
return nil
}
func (m *Model) handleFormKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.mode = ViewList
return nil
case "enter":
if errMsg := m.form.Validate(); errMsg != "" {
m.setError(errMsg)
return m.clearMsg()
}
domain, ip, group := m.form.Values()
if m.form.IsEdit() {
// For edit, delete old and add new (simple approach)
oldAlias := m.form.EditAlias()
return tea.Sequence(
func() tea.Msg {
m.client.Delete(oldAlias)
return nil
},
m.addHost(domain, ip, "", group), // Empty alias = auto-generate
)
}
return m.addHost(domain, ip, "", group) // Empty alias = auto-generate
}
return m.form.Update(msg)
}
func (m *Model) handlePresetKey(msg tea.KeyMsg) tea.Cmd {
// Handle based on preset picker mode
switch m.presetPicker.Mode() {
case PresetModeSelect:
return m.handlePresetSelectKey(msg)
case PresetModeAdd, PresetModeEdit:
return m.handlePresetFormKey(msg)
case PresetModeConfirmDelete:
return m.handlePresetDeleteKey(msg)
}
return nil
}
func (m *Model) handlePresetSelectKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc", "q":
m.mode = ViewList
case "up", "k":
m.presetPicker.MoveUp()
case "down", "j":
m.presetPicker.MoveDown()
case "enter":
if preset := m.presetPicker.Selected(); preset != "" {
return m.applyPreset(preset)
}
case "n":
m.presetPicker.InitAdd()
case "e":
m.presetPicker.InitEdit()
case "d":
m.presetPicker.InitDelete()
}
return nil
}
func (m *Model) handlePresetFormKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.presetPicker.CancelForm()
return nil
case "enter":
if errMsg := m.presetPicker.ValidateForm(); errMsg != "" {
m.setError(errMsg)
return m.clearMsg()
}
name, enable, disable := m.presetPicker.FormValues()
if m.presetPicker.IsEdit() {
// For edit, delete old and add new
oldName := m.presetPicker.EditName()
return tea.Sequence(
func() tea.Msg {
m.client.DeletePreset(oldName)
return nil
},
m.addPreset(name, enable, disable),
)
}
return m.addPreset(name, enable, disable)
}
return m.presetPicker.Update(msg)
}
func (m *Model) handlePresetDeleteKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "y", "Y":
if preset := m.presetPicker.Selected(); preset != "" {
return m.deletePreset(preset)
}
m.presetPicker.CancelForm()
case "n", "N", "esc":
m.presetPicker.CancelForm()
}
return nil
}
func (m *Model) handleGroupKey(msg tea.KeyMsg) tea.Cmd {
// Handle based on group picker mode
switch m.groupPicker.Mode() {
case GroupModeSelect:
return m.handleGroupSelectKey(msg)
case GroupModeAdd, GroupModeRename:
return m.handleGroupFormKey(msg)
case GroupModeConfirmDelete:
return m.handleGroupDeleteKey(msg)
}
return nil
}
func (m *Model) handleGroupSelectKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc", "q":
m.mode = ViewList
case "up", "k":
m.groupPicker.MoveUp()
case "down", "j":
m.groupPicker.MoveDown()
case "n":
m.groupPicker.InitAdd()
case "r":
m.groupPicker.InitRename()
case "d":
m.groupPicker.InitDelete()
}
return nil
}
func (m *Model) handleGroupFormKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.groupPicker.CancelForm()
return nil
case "enter":
if errMsg := m.groupPicker.ValidateForm(); errMsg != "" {
m.setError(errMsg)
return m.clearMsg()
}
name := m.groupPicker.FormValue()
if m.groupPicker.IsRename() {
oldName := m.groupPicker.EditName()
return m.renameGroup(oldName, name)
}
return m.addGroup(name)
}
return m.groupPicker.Update(msg)
}
func (m *Model) handleGroupDeleteKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "y", "Y":
if group := m.groupPicker.Selected(); group != "" {
return m.deleteGroup(group)
}
m.groupPicker.CancelForm()
case "n", "N", "esc":
m.groupPicker.CancelForm()
}
return nil
}
func (m *Model) handleHelpKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc", "q", "?":
m.mode = ViewList
}
return nil
}
func (m *Model) handleSearchKey(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "esc":
m.mode = ViewList
m.searchTerm = ""
m.searchInput.Reset()
return nil
case "enter":
m.searchTerm = m.searchInput.Value()
m.mode = ViewList
return nil
}
var cmd tea.Cmd
m.searchInput, cmd = m.searchInput.Update(msg)
return cmd
}
func (m *Model) toggleSelected() tea.Cmd {
item := m.list.Selected()
if item == nil {
return nil
}
m.list.SetPending(item.Entry.Alias, true)
return m.toggle(item.Entry.Alias, !item.Entry.Enabled)
}
func (m *Model) loadConfig() {
if err := m.config.Load(); err != nil {
return
}
cfg := m.config.Get()
if cfg == nil {
return
}
var presetNames []string
for _, p := range cfg.Presets {
presetNames = append(presetNames, p.Name)
}
m.presetPicker.SetPresets(presetNames)
}
func (m *Model) setError(msg string) {
m.message = msg
m.messageStyle = "error"
m.messageTime = time.Now()
}
func (m *Model) setSuccess(msg string) {
m.message = msg
m.messageStyle = "success"
m.messageTime = time.Now()
}
// View renders the UI.
func (m *Model) View() string {
var sb strings.Builder
// Title with version
title := titleStyle.Render("lolcathost - Host Management")
sb.WriteString(title)
// Update notification
if m.updateAvailable {
sb.WriteString(" ")
sb.WriteString(updateStyle.Render(fmt.Sprintf("Update available: v%s", m.updateVersion)))
}
sb.WriteString("\n\n")
// Main content based on mode
switch m.mode {
case ViewList:
sb.WriteString(m.list.ViewFiltered(m.searchTerm))
case ViewForm:
sb.WriteString(m.form.View())
case ViewPresets:
sb.WriteString(m.presetPicker.View())
case ViewGroups:
sb.WriteString(m.groupPicker.View())
case ViewHelp:
sb.WriteString(m.helpView())
case ViewSearch:
sb.WriteString(m.searchView())
}
// Message
if m.message != "" {
sb.WriteString("\n")
if m.messageStyle == "error" {
sb.WriteString(errorMsgStyle.Render(m.message))
} else {
sb.WriteString(successMsgStyle.Render(m.message))
}
}
// Calculate remaining space for footer positioning
currentContent := sb.String()
currentLines := strings.Count(currentContent, "\n") + 1
// Fill space to push footer to bottom (reserve 3 lines for footer)
footerHeight := 3
remainingLines := m.height - currentLines - footerHeight
if remainingLines > 0 {
sb.WriteString(strings.Repeat("\n", remainingLines))
}
// Footer (help bar + status bar)
if m.mode == ViewList {
sb.WriteString("\n")
sb.WriteString(m.helpBar())
}
sb.WriteString("\n")
sb.WriteString(m.statusBar())
return sb.String()
}
func (m *Model) helpBar() string {
return helpBarStyle.Render(fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: New %s: Edit %s: Delete %s: Presets %s: Groups %s: Search %s: Help %s: Quit",
helpKeyStyle.Render("↑↓"),
helpKeyStyle.Render("jk"),
helpKeyStyle.Render("Space"),
helpKeyStyle.Render("n"),
helpKeyStyle.Render("e"),
helpKeyStyle.Render("d"),
helpKeyStyle.Render("p"),
helpKeyStyle.Render("g"),
helpKeyStyle.Render("/"),
helpKeyStyle.Render("?"),
helpKeyStyle.Render("q")))
}
func (m *Model) statusBar() string {
var status string
if m.connected {
status = connectedStyle.String()
} else {
status = disconnectedStyle.String()
}
active := fmt.Sprintf("%d active", m.list.ActiveCount())
total := fmt.Sprintf("%d total", m.list.Len())
return statusBarStyle.Render(fmt.Sprintf("%s | %s | %s", status, active, total))
}
func (m *Model) helpView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render("Help"))
sb.WriteString("\n\n")
help := []struct{ key, desc string }{
{"↑/↓ or j/k", "Navigate up/down"},
{"Space/Enter", "Toggle entry on/off"},
{"n", "Add new entry"},
{"e", "Edit selected entry"},
{"d", "Delete selected entry"},
{"p", "Open preset manager"},
{"g", "Open group manager"},
{"/", "Search"},
{"r", "Refresh list"},
{"?", "Toggle this help"},
{"q", "Quit"},
}
for _, h := range help {
sb.WriteString(fmt.Sprintf(" %s %s\n",
helpKeyStyle.Width(15).Render(h.key),
helpDescStyle.Render(h.desc)))
}
sb.WriteString("\n")
sb.WriteString(helpDescStyle.Render("Press ? or Esc to close"))
return dialogStyle.Render(sb.String())
}
func (m *Model) searchView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render("Search"))
sb.WriteString("\n\n")
sb.WriteString(inputFocusStyle.Render(m.searchInput.View()))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("Enter to search • Esc to cancel"))
return dialogStyle.Render(sb.String())
}
// Run starts the TUI application.
func Run(socketPath, configPath string) error {
return RunWithVersion(socketPath, configPath, "dev", "", "")
}
// RunWithVersion starts the TUI application with version info for update checking.
func RunWithVersion(socketPath, configPath, version, githubOwner, githubRepo string) error {
m := NewModel(socketPath, configPath)
m.version = version
m.githubOwner = githubOwner
m.githubRepo = githubRepo
p := tea.NewProgram(m, tea.WithAltScreen())
_, err := p.Run()
return err
}
+336
View File
@@ -0,0 +1,336 @@
// Package tui provides the form component for adding/editing entries.
package tui
import (
"fmt"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
// FormMode represents the form mode.
type FormMode int
const (
FormModeAdd FormMode = iota
FormModeEdit
)
// FormField represents a form field index.
type FormField int
const (
FieldDomain FormField = iota
FieldIP
FieldGroup
FieldCount
)
// Form handles the add/edit entry form.
type Form struct {
mode FormMode
fields []textinput.Model
focus FormField
width int
height int
editAlias string // Original alias when editing
// Group dropdown
groups []string
groupCursor int
groupFocused bool
}
// NewForm creates a new form.
func NewForm() *Form {
fields := make([]textinput.Model, FieldCount)
// Domain field
fields[FieldDomain] = textinput.New()
fields[FieldDomain].Placeholder = "example.com"
fields[FieldDomain].CharLimit = 253
// IP field
fields[FieldIP] = textinput.New()
fields[FieldIP].Placeholder = "127.0.0.1"
fields[FieldIP].CharLimit = 45 // IPv6 max
// Group field (not used as text input, but kept for compatibility)
fields[FieldGroup] = textinput.New()
fields[FieldGroup].Placeholder = "development"
fields[FieldGroup].CharLimit = 63
return &Form{
fields: fields,
focus: FieldDomain,
groups: []string{"default"},
}
}
// SetGroups sets the available groups for the dropdown.
func (f *Form) SetGroups(groups []string) {
if len(groups) == 0 {
f.groups = []string{"default"}
} else {
f.groups = groups
}
// Reset cursor if out of bounds
if f.groupCursor >= len(f.groups) {
f.groupCursor = 0
}
}
// Init initializes the form for adding a new entry.
func (f *Form) Init() {
f.mode = FormModeAdd
f.editAlias = ""
for i := range f.fields {
f.fields[i].Reset()
}
f.fields[FieldIP].SetValue("127.0.0.1")
f.groupCursor = 0
f.groupFocused = false
f.focus = FieldDomain
f.fields[FieldDomain].Focus()
}
// InitEdit initializes the form for editing an existing entry.
func (f *Form) InitEdit(domain, ip, alias, group string) {
f.mode = FormModeEdit
f.editAlias = alias
f.fields[FieldDomain].SetValue(domain)
f.fields[FieldIP].SetValue(ip)
// Find the group in the list
f.groupCursor = 0
for i, g := range f.groups {
if g == group {
f.groupCursor = i
break
}
}
f.groupFocused = false
f.focus = FieldDomain
f.fields[FieldDomain].Focus()
}
// SetSize sets the form dimensions.
func (f *Form) SetSize(width, height int) {
f.width = width
f.height = height
inputWidth := min(50, width-10)
for i := range f.fields {
f.fields[i].Width = inputWidth
}
}
// Update handles input events.
func (f *Form) Update(msg tea.Msg) tea.Cmd {
switch msg := msg.(type) {
case tea.KeyMsg:
// Handle group dropdown navigation
if f.focus == FieldGroup {
switch msg.String() {
case "tab":
f.nextField()
return nil
case "shift+tab":
f.prevField()
return nil
case "up", "k":
if f.groupCursor > 0 {
f.groupCursor--
}
return nil
case "down", "j":
if f.groupCursor < len(f.groups)-1 {
f.groupCursor++
}
return nil
case "left":
if f.groupCursor > 0 {
f.groupCursor--
}
return nil
case "right":
if f.groupCursor < len(f.groups)-1 {
f.groupCursor++
}
return nil
}
return nil
}
// Handle text input fields
switch msg.String() {
case "tab", "down":
f.nextField()
return nil
case "shift+tab", "up":
f.prevField()
return nil
}
}
// Update the focused text field (only for Domain and IP)
if f.focus != FieldGroup {
var cmd tea.Cmd
f.fields[f.focus], cmd = f.fields[f.focus].Update(msg)
return cmd
}
return nil
}
func (f *Form) nextField() {
if f.focus != FieldGroup {
f.fields[f.focus].Blur()
}
f.focus = (f.focus + 1) % FieldCount
if f.focus != FieldGroup {
f.fields[f.focus].Focus()
}
}
func (f *Form) prevField() {
if f.focus != FieldGroup {
f.fields[f.focus].Blur()
}
f.focus = (f.focus - 1 + FieldCount) % FieldCount
if f.focus != FieldGroup {
f.fields[f.focus].Focus()
}
}
// Values returns the form values (domain, ip, group).
func (f *Form) Values() (domain, ip, group string) {
group = ""
if f.groupCursor < len(f.groups) {
group = f.groups[f.groupCursor]
}
return strings.TrimSpace(f.fields[FieldDomain].Value()),
strings.TrimSpace(f.fields[FieldIP].Value()),
group
}
// EditAlias returns the original alias when editing.
func (f *Form) EditAlias() string {
return f.editAlias
}
// IsEdit returns true if in edit mode.
func (f *Form) IsEdit() bool {
return f.mode == FormModeEdit
}
// Validate validates the form values.
func (f *Form) Validate() string {
domain, ip, group := f.Values()
if domain == "" {
return "Domain is required"
}
if ip == "" {
return "IP address is required"
}
if group == "" {
return "Group is required"
}
return ""
}
// View renders the form.
func (f *Form) View() string {
var sb strings.Builder
title := "Add New Entry"
if f.mode == FormModeEdit {
title = "Edit Entry"
}
sb.WriteString(titleStyle.Render(title))
sb.WriteString("\n\n")
// Domain field
sb.WriteString(inputLabelStyle.Render("Domain:"))
sb.WriteString("\n")
style := inputStyle
if f.focus == FieldDomain {
style = inputFocusStyle
}
sb.WriteString(style.Render(f.fields[FieldDomain].View()))
sb.WriteString("\n\n")
// IP field
sb.WriteString(inputLabelStyle.Render("IP Address:"))
sb.WriteString("\n")
style = inputStyle
if f.focus == FieldIP {
style = inputFocusStyle
}
sb.WriteString(style.Render(f.fields[FieldIP].View()))
sb.WriteString("\n\n")
// Group dropdown
sb.WriteString(inputLabelStyle.Render("Group:"))
sb.WriteString("\n")
sb.WriteString(f.renderGroupDropdown())
sb.WriteString("\n\n")
sb.WriteString("\n")
sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • ←→ select group • Enter save • Esc cancel"))
return dialogStyle.Render(sb.String())
}
func (f *Form) renderGroupDropdown() string {
isFocused := f.focus == FieldGroup
// Get current group name
currentGroup := "default"
if f.groupCursor < len(f.groups) {
currentGroup = f.groups[f.groupCursor]
}
// Build the selector content: ◀ group_name ▶
var content string
if isFocused {
// Show arrows when focused
leftArrow := "◀"
rightArrow := "▶"
if f.groupCursor == 0 {
leftArrow = " " // dim or hide left arrow at start
}
if f.groupCursor >= len(f.groups)-1 {
rightArrow = " " // dim or hide right arrow at end
}
content = leftArrow + " " + currentGroup + " " + rightArrow
} else {
content = " " + currentGroup + " "
}
// Show position indicator if multiple groups
if len(f.groups) > 1 {
content += fmt.Sprintf(" (%d/%d)", f.groupCursor+1, len(f.groups))
}
// Apply border style
if isFocused {
return inputFocusStyle.Render(content)
}
return inputStyle.Render(content)
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
+232
View File
@@ -0,0 +1,232 @@
// Package tui provides the group management component.
package tui
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
)
// GroupMode represents the group view mode.
type GroupMode int
const (
GroupModeSelect GroupMode = iota
GroupModeAdd
GroupModeRename
GroupModeConfirmDelete
)
// GroupPicker handles the group selection and management UI.
type GroupPicker struct {
groups []string
cursor int
width int
height int
mode GroupMode
input textinput.Model
editName string // Original name when renaming
}
// NewGroupPicker creates a new group picker.
func NewGroupPicker() *GroupPicker {
input := textinput.New()
input.Placeholder = "group-name"
input.CharLimit = 63
return &GroupPicker{
input: input,
mode: GroupModeSelect,
}
}
// SetGroups updates the available groups.
func (g *GroupPicker) SetGroups(groups []string) {
g.groups = groups
if g.cursor >= len(groups) {
g.cursor = max(0, len(groups)-1)
}
}
// SetSize sets the picker dimensions.
func (g *GroupPicker) SetSize(width, height int) {
g.width = width
g.height = height
g.input.Width = min(50, width-10)
}
// MoveUp moves the cursor up.
func (g *GroupPicker) MoveUp() {
if g.cursor > 0 {
g.cursor--
}
}
// MoveDown moves the cursor down.
func (g *GroupPicker) MoveDown() {
if g.cursor < len(g.groups)-1 {
g.cursor++
}
}
// Selected returns the currently selected group.
func (g *GroupPicker) Selected() string {
if g.cursor >= 0 && g.cursor < len(g.groups) {
return g.groups[g.cursor]
}
return ""
}
// Len returns the number of groups.
func (g *GroupPicker) Len() int {
return len(g.groups)
}
// Mode returns the current mode.
func (g *GroupPicker) Mode() GroupMode {
return g.mode
}
// InitAdd initializes the form for adding a new group.
func (g *GroupPicker) InitAdd() {
g.mode = GroupModeAdd
g.editName = ""
g.input.Reset()
g.input.Focus()
}
// InitRename initializes the form for renaming an existing group.
func (g *GroupPicker) InitRename() {
selected := g.Selected()
if selected == "" {
return
}
g.mode = GroupModeRename
g.editName = selected
g.input.SetValue(selected)
g.input.Focus()
}
// InitDelete starts delete confirmation.
func (g *GroupPicker) InitDelete() {
if g.Selected() == "" {
return
}
g.mode = GroupModeConfirmDelete
}
// CancelForm cancels the current form operation.
func (g *GroupPicker) CancelForm() {
g.mode = GroupModeSelect
g.editName = ""
g.input.Reset()
g.input.Blur()
}
// Update handles input events for form mode.
func (g *GroupPicker) Update(msg tea.KeyMsg) tea.Cmd {
var cmd tea.Cmd
g.input, cmd = g.input.Update(msg)
return cmd
}
// FormValue returns the form input value.
func (g *GroupPicker) FormValue() string {
return strings.TrimSpace(g.input.Value())
}
// EditName returns the original name when renaming.
func (g *GroupPicker) EditName() string {
return g.editName
}
// IsRename returns true if in rename mode.
func (g *GroupPicker) IsRename() bool {
return g.mode == GroupModeRename
}
// ValidateForm validates the form value.
func (g *GroupPicker) ValidateForm() string {
value := g.FormValue()
if value == "" {
return "Group name is required"
}
return ""
}
// View renders the group picker.
func (g *GroupPicker) View() string {
switch g.mode {
case GroupModeAdd, GroupModeRename:
return g.formView()
case GroupModeConfirmDelete:
return g.deleteView()
default:
return g.selectView()
}
}
func (g *GroupPicker) selectView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render("Groups"))
sb.WriteString("\n\n")
if len(g.groups) == 0 {
sb.WriteString(helpDescStyle.Render("No groups configured."))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("Press 'n' to create one"))
} else {
for i, group := range g.groups {
if i == g.cursor {
sb.WriteString(presetSelectedStyle.Render("▸ " + group))
} else {
sb.WriteString(presetItemStyle.Render(" " + group))
}
sb.WriteString("\n")
}
}
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("↑↓ navigate • n new • r rename • d delete • Esc back"))
return dialogStyle.Render(sb.String())
}
func (g *GroupPicker) formView() string {
var sb strings.Builder
title := "Add New Group"
if g.mode == GroupModeRename {
title = "Rename Group"
}
sb.WriteString(titleStyle.Render(title))
sb.WriteString("\n\n")
sb.WriteString(inputLabelStyle.Render("Name:"))
sb.WriteString("\n")
sb.WriteString(inputFocusStyle.Render(g.input.View()))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("Enter save • Esc cancel"))
return dialogStyle.Render(sb.String())
}
func (g *GroupPicker) deleteView() string {
var sb strings.Builder
groupName := g.Selected()
sb.WriteString(titleStyle.Render("Delete Group"))
sb.WriteString("\n\n")
sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete group '" + groupName + "'?"))
sb.WriteString("\n")
sb.WriteString(helpDescStyle.Render("This will remove all hosts in this group!"))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel"))
return dialogStyle.Render(sb.String())
}
+429
View File
@@ -0,0 +1,429 @@
// Package tui provides the list view component.
package tui
import (
"fmt"
"strings"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/lipgloss/table"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// EntryItem represents a displayable host entry.
type EntryItem struct {
Entry protocol.HostEntry
Pending bool
HasError bool
}
// ListView handles the list of host entries.
type ListView struct {
items []EntryItem
groups map[string][]int // group name -> indices in items
groupOrder []string // ordered group names
cursor int
width int
height int
}
// NewListView creates a new list view.
func NewListView() *ListView {
return &ListView{
groups: make(map[string][]int),
}
}
// SetItems updates the list items.
func (l *ListView) SetItems(entries []protocol.HostEntry) {
l.items = make([]EntryItem, len(entries))
l.groups = make(map[string][]int)
l.groupOrder = nil
groupSeen := make(map[string]bool)
for i, e := range entries {
l.items[i] = EntryItem{Entry: e}
if !groupSeen[e.Group] {
groupSeen[e.Group] = true
l.groupOrder = append(l.groupOrder, e.Group)
}
l.groups[e.Group] = append(l.groups[e.Group], i)
}
// Reset cursor if out of bounds
if l.cursor >= len(l.items) {
l.cursor = max(0, len(l.items)-1)
}
}
// SetSize sets the view dimensions.
func (l *ListView) SetSize(width, height int) {
l.width = width
l.height = height
}
// MoveUp moves the cursor up.
func (l *ListView) MoveUp() {
if l.cursor > 0 {
l.cursor--
}
}
// MoveDown moves the cursor down.
func (l *ListView) MoveDown() {
if l.cursor < len(l.items)-1 {
l.cursor++
}
}
// Selected returns the currently selected item.
func (l *ListView) Selected() *EntryItem {
if l.cursor >= 0 && l.cursor < len(l.items) {
return &l.items[l.cursor]
}
return nil
}
// SelectedAlias returns the alias of the selected item.
func (l *ListView) SelectedAlias() string {
if item := l.Selected(); item != nil {
return item.Entry.Alias
}
return ""
}
// SetPending marks an item as pending.
func (l *ListView) SetPending(alias string, pending bool) {
for i := range l.items {
if l.items[i].Entry.Alias == alias {
l.items[i].Pending = pending
break
}
}
}
// SetError marks an item as having an error.
func (l *ListView) SetError(alias string, hasError bool) {
for i := range l.items {
if l.items[i].Entry.Alias == alias {
l.items[i].HasError = hasError
break
}
}
}
// UpdateEntry updates an entry's enabled state.
func (l *ListView) UpdateEntry(alias string, enabled bool) {
for i := range l.items {
if l.items[i].Entry.Alias == alias {
l.items[i].Entry.Enabled = enabled
l.items[i].Pending = false
l.items[i].HasError = false
break
}
}
}
// Len returns the number of items.
func (l *ListView) Len() int {
return len(l.items)
}
// ActiveCount returns the number of enabled entries.
func (l *ListView) ActiveCount() int {
count := 0
for _, item := range l.items {
if item.Entry.Enabled {
count++
}
}
return count
}
// FindByAlias finds an item by alias.
func (l *ListView) FindByAlias(alias string) *EntryItem {
for i := range l.items {
if l.items[i].Entry.Alias == alias {
return &l.items[i]
}
}
return nil
}
// Filter filters items by search term.
func (l *ListView) Filter(term string) []EntryItem {
if term == "" {
return l.items
}
term = strings.ToLower(term)
var filtered []EntryItem
for _, item := range l.items {
if strings.Contains(strings.ToLower(item.Entry.Domain), term) ||
strings.Contains(strings.ToLower(item.Entry.Alias), term) ||
strings.Contains(strings.ToLower(item.Entry.IP), term) ||
strings.Contains(strings.ToLower(item.Entry.Group), term) {
filtered = append(filtered, item)
}
}
return filtered
}
// ViewFiltered renders the list filtered by search term.
func (l *ListView) ViewFiltered(searchTerm string) string {
if searchTerm == "" {
return l.View()
}
filtered := l.Filter(searchTerm)
if len(filtered) == 0 {
emptyStyle := lipgloss.NewStyle().Foreground(colorMuted)
return "\n" + emptyStyle.Render(fmt.Sprintf(" No results for '%s'. Press Esc to clear search.", searchTerm)) + "\n"
}
var sb strings.Builder
// Show search indicator
searchIndicator := lipgloss.NewStyle().
Foreground(colorWarning).
Bold(true).
Render(fmt.Sprintf(" Search: %s (%d results)", searchTerm, len(filtered)))
sb.WriteString(searchIndicator)
sb.WriteString("\n")
// Group header style - bright colors for dark terminals
groupHeaderStyle := lipgloss.NewStyle().
Bold(true).
Foreground(colorGroupHeader).
Background(lipgloss.Color("238")).
Padding(0, 1).
MarginTop(1)
// Organize filtered items by group
groupItems := make(map[string][]EntryItem)
var groupOrder []string
groupSeen := make(map[string]bool)
for _, item := range filtered {
group := item.Entry.Group
if !groupSeen[group] {
groupSeen[group] = true
groupOrder = append(groupOrder, group)
}
groupItems[group] = append(groupItems[group], item)
}
for _, groupName := range groupOrder {
items := groupItems[groupName]
if len(items) == 0 {
continue
}
// Group header
headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(items))
sb.WriteString(groupHeaderStyle.Render(headerText))
sb.WriteString("\n")
// Build rows for this group's table
var rows [][]string
for _, item := range items {
status := l.getStatusString(item)
rows = append(rows, []string{
truncate(item.Entry.Domain, 30),
truncate(item.Entry.IP, 15),
status,
})
}
// Create table for this group
t := table.New().
Border(lipgloss.HiddenBorder()).
Headers("DOMAIN", "IP ADDRESS", "STATUS").
Rows(rows...).
StyleFunc(func(row, col int) lipgloss.Style {
// Header row
if row == table.HeaderRow {
return lipgloss.NewStyle().
Bold(true).
Foreground(colorHeader).
Padding(0, 1)
}
baseStyle := lipgloss.NewStyle().Padding(0, 1)
if row >= 0 && row < len(items) {
item := items[row]
// Disabled rows are muted
if !item.Entry.Enabled && !item.Pending && !item.HasError {
return baseStyle.Foreground(colorMuted)
}
// Status column gets colored based on status
if col == 2 { // STATUS column
if item.HasError {
return baseStyle.Foreground(colorError)
}
if item.Pending {
return baseStyle.Foreground(colorWarning)
}
if item.Entry.Enabled {
return baseStyle.Foreground(colorSuccess)
}
}
}
return baseStyle
})
sb.WriteString(t.Render())
sb.WriteString("\n")
}
return sb.String()
}
// GroupCount returns the number of groups.
func (l *ListView) GroupCount() int {
return len(l.groupOrder)
}
// GetGroups returns all group names.
func (l *ListView) GetGroups() []string {
return l.groupOrder
}
// View renders the list with groups as headers.
func (l *ListView) View() string {
if len(l.items) == 0 {
emptyStyle := lipgloss.NewStyle().Foreground(colorMuted)
return "\n" + emptyStyle.Render(" No host entries configured. Press 'n' to add a new entry.") + "\n"
}
var sb strings.Builder
// Group header style - bright colors for dark terminals
groupHeaderStyle := lipgloss.NewStyle().
Bold(true).
Foreground(colorGroupHeader).
Background(lipgloss.Color("238")).
Padding(0, 1).
MarginTop(1)
for _, groupName := range l.groupOrder {
indices := l.groups[groupName]
if len(indices) == 0 {
continue
}
// Group header
headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(indices))
sb.WriteString(groupHeaderStyle.Render(headerText))
sb.WriteString("\n")
// Build rows for this group's table
var rows [][]string
// Store actual item indices for cursor matching
itemIndices := make([]int, len(indices))
copy(itemIndices, indices)
for _, idx := range indices {
item := l.items[idx]
status := l.getStatusString(item)
rows = append(rows, []string{
truncate(item.Entry.Domain, 30),
truncate(item.Entry.IP, 15),
status,
})
}
// Create table for this group
t := table.New().
Border(lipgloss.HiddenBorder()).
Headers("DOMAIN", "IP ADDRESS", "STATUS").
Rows(rows...).
StyleFunc(func(row, col int) lipgloss.Style {
// Header row
if row == table.HeaderRow {
return lipgloss.NewStyle().
Bold(true).
Foreground(colorHeader).
Padding(0, 1)
}
baseStyle := lipgloss.NewStyle().Padding(0, 1)
// Check if this row is selected
if row >= 0 && row < len(itemIndices) {
actualItemIdx := itemIndices[row]
isSelected := actualItemIdx == l.cursor
item := l.items[actualItemIdx]
// Selected row gets background highlight
if isSelected {
return baseStyle.
Background(colorSelectedBg).
Foreground(colorSelectedFg)
}
// Disabled rows are muted
if !item.Entry.Enabled && !item.Pending && !item.HasError {
return baseStyle.Foreground(colorMuted)
}
// Status column gets colored based on status
if col == 2 { // STATUS column
if item.HasError {
return baseStyle.Foreground(colorError)
}
if item.Pending {
return baseStyle.Foreground(colorWarning)
}
if item.Entry.Enabled {
return baseStyle.Foreground(colorSuccess)
}
}
}
return baseStyle
})
sb.WriteString(t.Render())
sb.WriteString("\n")
}
return sb.String()
}
func (l *ListView) getStatusString(item EntryItem) string {
if item.HasError {
return "✗ Error"
}
if item.Pending {
return "◐ Pending"
}
if item.Entry.Enabled {
return "● Active"
}
return "○ Disabled"
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
if maxLen <= 3 {
return s[:maxLen]
}
return s[:maxLen-3] + "..."
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
+409
View File
@@ -0,0 +1,409 @@
package tui
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
func TestListView_SetItems(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
}
lv.SetItems(entries)
assert.Equal(t, 3, lv.Len())
assert.Len(t, lv.groups, 2)
assert.Contains(t, lv.groupOrder, "dev")
assert.Contains(t, lv.groupOrder, "staging")
}
func TestListView_Navigation(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
}
lv.SetItems(entries)
// Initial position
assert.Equal(t, 0, lv.cursor)
// Move down
lv.MoveDown()
assert.Equal(t, 1, lv.cursor)
lv.MoveDown()
assert.Equal(t, 2, lv.cursor)
// Can't move past end
lv.MoveDown()
assert.Equal(t, 2, lv.cursor)
// Move up
lv.MoveUp()
assert.Equal(t, 1, lv.cursor)
lv.MoveUp()
assert.Equal(t, 0, lv.cursor)
// Can't move before start
lv.MoveUp()
assert.Equal(t, 0, lv.cursor)
}
func TestListView_Selected(t *testing.T) {
lv := NewListView()
t.Run("empty list", func(t *testing.T) {
item := lv.Selected()
assert.Nil(t, item)
})
t.Run("with items", func(t *testing.T) {
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
}
lv.SetItems(entries)
item := lv.Selected()
require.NotNil(t, item)
assert.Equal(t, "a.com", item.Entry.Domain)
lv.MoveDown()
item = lv.Selected()
require.NotNil(t, item)
assert.Equal(t, "b.com", item.Entry.Domain)
})
}
func TestListView_SelectedAlias(t *testing.T) {
lv := NewListView()
t.Run("empty list", func(t *testing.T) {
alias := lv.SelectedAlias()
assert.Empty(t, alias)
})
t.Run("with items", func(t *testing.T) {
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "my-alias", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
alias := lv.SelectedAlias()
assert.Equal(t, "my-alias", alias)
})
}
func TestListView_SetPending(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
assert.False(t, lv.items[0].Pending)
lv.SetPending("a", true)
assert.True(t, lv.items[0].Pending)
lv.SetPending("a", false)
assert.False(t, lv.items[0].Pending)
// Non-existent alias should not panic
lv.SetPending("nonexistent", true)
}
func TestListView_SetError(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
assert.False(t, lv.items[0].HasError)
lv.SetError("a", true)
assert.True(t, lv.items[0].HasError)
lv.SetError("a", false)
assert.False(t, lv.items[0].HasError)
}
func TestListView_UpdateEntry(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false, Group: "dev"},
}
lv.SetItems(entries)
lv.items[0].Pending = true
lv.items[0].HasError = true
lv.UpdateEntry("a", true)
assert.True(t, lv.items[0].Entry.Enabled)
assert.False(t, lv.items[0].Pending)
assert.False(t, lv.items[0].HasError)
}
func TestListView_ActiveCount(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
}
lv.SetItems(entries)
assert.Equal(t, 2, lv.ActiveCount())
}
func TestListView_FindByAlias(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
}
lv.SetItems(entries)
t.Run("found", func(t *testing.T) {
item := lv.FindByAlias("b")
require.NotNil(t, item)
assert.Equal(t, "b.com", item.Entry.Domain)
})
t.Run("not found", func(t *testing.T) {
item := lv.FindByAlias("nonexistent")
assert.Nil(t, item)
})
}
func TestListView_Filter(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "myapp.com", IP: "127.0.0.1", Alias: "myapp", Enabled: true, Group: "dev"},
{Domain: "api.myapp.com", IP: "127.0.0.1", Alias: "api", Enabled: false, Group: "dev"},
{Domain: "other.com", IP: "192.168.1.1", Alias: "other", Enabled: true, Group: "staging"},
}
lv.SetItems(entries)
t.Run("empty term", func(t *testing.T) {
filtered := lv.Filter("")
assert.Len(t, filtered, 3)
})
t.Run("by domain", func(t *testing.T) {
filtered := lv.Filter("myapp")
assert.Len(t, filtered, 2)
})
t.Run("by alias", func(t *testing.T) {
filtered := lv.Filter("api")
assert.Len(t, filtered, 1)
assert.Equal(t, "api.myapp.com", filtered[0].Entry.Domain)
})
t.Run("by IP", func(t *testing.T) {
filtered := lv.Filter("192.168")
assert.Len(t, filtered, 1)
assert.Equal(t, "other.com", filtered[0].Entry.Domain)
})
t.Run("case insensitive", func(t *testing.T) {
filtered := lv.Filter("MYAPP")
assert.Len(t, filtered, 2)
})
t.Run("no match", func(t *testing.T) {
filtered := lv.Filter("nonexistent")
assert.Empty(t, filtered)
})
}
func TestListView_View(t *testing.T) {
t.Run("empty list", func(t *testing.T) {
lv := NewListView()
view := lv.View()
assert.Contains(t, view, "No host entries")
})
t.Run("with items", func(t *testing.T) {
lv := NewListView()
entries := []protocol.HostEntry{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
view := lv.View()
// Group header is shown as section title (uppercase)
assert.Contains(t, view, "DEV")
// Table headers
assert.Contains(t, view, "DOMAIN")
assert.Contains(t, view, "IP ADDRESS")
assert.Contains(t, view, "STATUS")
// Data is in the view
assert.Contains(t, view, "example.com")
assert.Contains(t, view, "127.0.0.1")
assert.Contains(t, view, "Active")
})
}
func TestListView_SetSize(t *testing.T) {
lv := NewListView()
lv.SetSize(80, 24)
assert.Equal(t, 80, lv.width)
assert.Equal(t, 24, lv.height)
}
func TestListView_CursorBounds(t *testing.T) {
lv := NewListView()
// Set items
entries := []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
lv.cursor = 1
// Set fewer items - cursor should be adjusted
entries = []protocol.HostEntry{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
}
lv.SetItems(entries)
assert.Equal(t, 0, lv.cursor)
}
func TestTruncate(t *testing.T) {
tests := []struct {
input string
maxLen int
expected string
}{
{"short", 10, "short"},
{"exactly10!", 10, "exactly10!"},
{"this is too long", 10, "this is..."},
{"", 5, ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := truncate(tt.input, tt.maxLen)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMax(t *testing.T) {
assert.Equal(t, 5, max(3, 5))
assert.Equal(t, 5, max(5, 3))
assert.Equal(t, 5, max(5, 5))
assert.Equal(t, 0, max(0, -1))
}
// Matrix test for navigation
func TestListView_Navigation_Matrix(t *testing.T) {
sizes := []int{1, 5, 10, 100}
for _, size := range sizes {
t.Run("size="+string(rune('0'+size)), func(t *testing.T) {
lv := NewListView()
entries := make([]protocol.HostEntry, size)
for i := range entries {
entries[i] = protocol.HostEntry{
Domain: "domain" + string(rune('a'+i%26)) + ".com",
IP: "127.0.0.1",
Alias: "alias" + string(rune('a'+i%26)),
Enabled: true,
Group: "dev",
}
}
lv.SetItems(entries)
// Move to end
for i := 0; i < size*2; i++ {
lv.MoveDown()
}
assert.Equal(t, size-1, lv.cursor)
// Move to start
for i := 0; i < size*2; i++ {
lv.MoveUp()
}
assert.Equal(t, 0, lv.cursor)
})
}
}
func BenchmarkListView_SetItems(b *testing.B) {
entries := make([]protocol.HostEntry, 100)
for i := range entries {
entries[i] = protocol.HostEntry{
Domain: "domain.com",
IP: "127.0.0.1",
Alias: "alias",
Enabled: true,
Group: "dev",
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
lv := NewListView()
lv.SetItems(entries)
}
}
func BenchmarkListView_Filter(b *testing.B) {
lv := NewListView()
entries := make([]protocol.HostEntry, 100)
for i := range entries {
entries[i] = protocol.HostEntry{
Domain: "domain" + string(rune('a'+i%26)) + ".com",
IP: "127.0.0.1",
Alias: "alias" + string(rune('a'+i%26)),
Enabled: true,
Group: "dev",
}
}
lv.SetItems(entries)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = lv.Filter("domain")
}
}
func BenchmarkListView_View(b *testing.B) {
lv := NewListView()
entries := make([]protocol.HostEntry, 50)
for i := range entries {
entries[i] = protocol.HostEntry{
Domain: "domain.com",
IP: "127.0.0.1",
Alias: "alias",
Enabled: i%2 == 0,
Group: "group" + string(rune('a'+i%5)),
}
}
lv.SetItems(entries)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = lv.View()
}
}
+356
View File
@@ -0,0 +1,356 @@
// Package tui provides the preset picker component.
package tui
import (
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// PresetMode represents the preset view mode.
type PresetMode int
const (
PresetModeSelect PresetMode = iota
PresetModeAdd
PresetModeEdit
PresetModeConfirmDelete
)
// PresetFormField represents a form field index.
type PresetFormField int
const (
PresetFieldName PresetFormField = iota
PresetFieldEnable
PresetFieldDisable
PresetFieldCount
)
// PresetPicker handles the preset selection and management UI.
type PresetPicker struct {
presets []protocol.PresetInfo
cursor int
width int
height int
mode PresetMode
fields []textinput.Model
focus PresetFormField
editName string // Original name when editing
}
// NewPresetPicker creates a new preset picker.
func NewPresetPicker() *PresetPicker {
fields := make([]textinput.Model, PresetFieldCount)
// Name field
fields[PresetFieldName] = textinput.New()
fields[PresetFieldName].Placeholder = "preset-name"
fields[PresetFieldName].CharLimit = 63
// Enable field
fields[PresetFieldEnable] = textinput.New()
fields[PresetFieldEnable].Placeholder = "alias1,alias2,alias3"
fields[PresetFieldEnable].CharLimit = 500
// Disable field
fields[PresetFieldDisable] = textinput.New()
fields[PresetFieldDisable].Placeholder = "alias1,alias2,alias3"
fields[PresetFieldDisable].CharLimit = 500
return &PresetPicker{
fields: fields,
mode: PresetModeSelect,
}
}
// SetPresets updates the available presets (legacy method for compatibility).
func (p *PresetPicker) SetPresets(presets []string) {
p.presets = make([]protocol.PresetInfo, len(presets))
for i, name := range presets {
p.presets[i] = protocol.PresetInfo{Name: name}
}
if p.cursor >= len(presets) {
p.cursor = max(0, len(presets)-1)
}
}
// SetPresetsWithInfo updates the available presets with full info.
func (p *PresetPicker) SetPresetsWithInfo(presets []protocol.PresetInfo) {
p.presets = presets
if p.cursor >= len(presets) {
p.cursor = max(0, len(presets)-1)
}
}
// SetSize sets the picker dimensions.
func (p *PresetPicker) SetSize(width, height int) {
p.width = width
p.height = height
inputWidth := min(60, width-10)
for i := range p.fields {
p.fields[i].Width = inputWidth
}
}
// MoveUp moves the cursor up.
func (p *PresetPicker) MoveUp() {
if p.cursor > 0 {
p.cursor--
}
}
// MoveDown moves the cursor down.
func (p *PresetPicker) MoveDown() {
if p.cursor < len(p.presets)-1 {
p.cursor++
}
}
// Selected returns the currently selected preset name.
func (p *PresetPicker) Selected() string {
if p.cursor >= 0 && p.cursor < len(p.presets) {
return p.presets[p.cursor].Name
}
return ""
}
// SelectedInfo returns the currently selected preset info.
func (p *PresetPicker) SelectedInfo() *protocol.PresetInfo {
if p.cursor >= 0 && p.cursor < len(p.presets) {
return &p.presets[p.cursor]
}
return nil
}
// Len returns the number of presets.
func (p *PresetPicker) Len() int {
return len(p.presets)
}
// Mode returns the current mode.
func (p *PresetPicker) Mode() PresetMode {
return p.mode
}
// SetMode sets the mode.
func (p *PresetPicker) SetMode(mode PresetMode) {
p.mode = mode
}
// InitAdd initializes the form for adding a new preset.
func (p *PresetPicker) InitAdd() {
p.mode = PresetModeAdd
p.editName = ""
for i := range p.fields {
p.fields[i].Reset()
}
p.focus = PresetFieldName
p.fields[PresetFieldName].Focus()
}
// InitEdit initializes the form for editing an existing preset.
func (p *PresetPicker) InitEdit() {
preset := p.SelectedInfo()
if preset == nil {
return
}
p.mode = PresetModeEdit
p.editName = preset.Name
p.fields[PresetFieldName].SetValue(preset.Name)
p.fields[PresetFieldEnable].SetValue(strings.Join(preset.Enable, ","))
p.fields[PresetFieldDisable].SetValue(strings.Join(preset.Disable, ","))
p.focus = PresetFieldName
p.fields[PresetFieldName].Focus()
}
// InitDelete starts delete confirmation.
func (p *PresetPicker) InitDelete() {
if p.SelectedInfo() == nil {
return
}
p.mode = PresetModeConfirmDelete
}
// CancelForm cancels the current form operation.
func (p *PresetPicker) CancelForm() {
p.mode = PresetModeSelect
p.editName = ""
for i := range p.fields {
p.fields[i].Reset()
p.fields[i].Blur()
}
}
// Update handles input events for form mode.
func (p *PresetPicker) Update(msg tea.KeyMsg) tea.Cmd {
switch msg.String() {
case "tab", "down":
p.nextField()
return nil
case "shift+tab", "up":
p.prevField()
return nil
}
// Update the focused field
var cmd tea.Cmd
p.fields[p.focus], cmd = p.fields[p.focus].Update(msg)
return cmd
}
func (p *PresetPicker) nextField() {
p.fields[p.focus].Blur()
p.focus = (p.focus + 1) % PresetFieldCount
p.fields[p.focus].Focus()
}
func (p *PresetPicker) prevField() {
p.fields[p.focus].Blur()
p.focus = (p.focus - 1 + PresetFieldCount) % PresetFieldCount
p.fields[p.focus].Focus()
}
// FormValues returns the form values (name, enable list, disable list).
func (p *PresetPicker) FormValues() (name string, enable, disable []string) {
name = strings.TrimSpace(p.fields[PresetFieldName].Value())
enableStr := strings.TrimSpace(p.fields[PresetFieldEnable].Value())
if enableStr != "" {
for _, s := range strings.Split(enableStr, ",") {
if trimmed := strings.TrimSpace(s); trimmed != "" {
enable = append(enable, trimmed)
}
}
}
disableStr := strings.TrimSpace(p.fields[PresetFieldDisable].Value())
if disableStr != "" {
for _, s := range strings.Split(disableStr, ",") {
if trimmed := strings.TrimSpace(s); trimmed != "" {
disable = append(disable, trimmed)
}
}
}
return name, enable, disable
}
// EditName returns the original name when editing.
func (p *PresetPicker) EditName() string {
return p.editName
}
// IsEdit returns true if in edit mode.
func (p *PresetPicker) IsEdit() bool {
return p.mode == PresetModeEdit
}
// ValidateForm validates the form values.
func (p *PresetPicker) ValidateForm() string {
name, enable, disable := p.FormValues()
if name == "" {
return "Preset name is required"
}
if len(enable) == 0 && len(disable) == 0 {
return "At least one alias to enable or disable is required"
}
return ""
}
// View renders the preset picker.
func (p *PresetPicker) View() string {
switch p.mode {
case PresetModeAdd, PresetModeEdit:
return p.formView()
case PresetModeConfirmDelete:
return p.deleteView()
default:
return p.selectView()
}
}
func (p *PresetPicker) selectView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render("Presets"))
sb.WriteString("\n\n")
if len(p.presets) == 0 {
sb.WriteString(helpDescStyle.Render("No presets configured."))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("Press 'n' to create one"))
} else {
for i, preset := range p.presets {
if i == p.cursor {
sb.WriteString(presetSelectedStyle.Render("▸ " + preset.Name))
} else {
sb.WriteString(presetItemStyle.Render(" " + preset.Name))
}
sb.WriteString("\n")
}
}
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("↑↓ navigate • Enter apply • n new • e edit • d delete • Esc cancel"))
return dialogStyle.Render(sb.String())
}
func (p *PresetPicker) formView() string {
var sb strings.Builder
title := "Add New Preset"
if p.mode == PresetModeEdit {
title = "Edit Preset"
}
sb.WriteString(titleStyle.Render(title))
sb.WriteString("\n\n")
labels := []string{"Name:", "Enable aliases (comma-separated):", "Disable aliases (comma-separated):"}
for i, label := range labels {
sb.WriteString(inputLabelStyle.Render(label))
sb.WriteString("\n")
style := inputStyle
if PresetFormField(i) == p.focus {
style = inputFocusStyle
}
sb.WriteString(style.Render(p.fields[i].View()))
sb.WriteString("\n\n")
}
sb.WriteString("\n")
sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • Enter save • Esc cancel"))
return dialogStyle.Render(sb.String())
}
func (p *PresetPicker) deleteView() string {
var sb strings.Builder
preset := p.SelectedInfo()
presetName := ""
if preset != nil {
presetName = preset.Name
}
sb.WriteString(titleStyle.Render("Delete Preset"))
sb.WriteString("\n\n")
sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete preset '" + presetName + "'?"))
sb.WriteString("\n\n")
sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel"))
return dialogStyle.Render(sb.String())
}
+150
View File
@@ -0,0 +1,150 @@
// Package tui provides the terminal user interface.
package tui
import (
"github.com/charmbracelet/lipgloss"
)
// Colors - matching kportal style, optimized for dark terminals
var (
colorPrimary = lipgloss.Color("205") // Pink/Magenta
colorSuccess = lipgloss.Color("42") // Green
colorWarning = lipgloss.Color("220") // Yellow
colorError = lipgloss.Color("196") // Red
colorMuted = lipgloss.Color("245") // Gray (brighter for dark terminals)
colorAccent = lipgloss.Color("141") // Light purple (brighter for dark terminals)
colorHeader = lipgloss.Color("220") // Yellow for headers
colorSelectedBg = lipgloss.Color("236") // Gray background for selection
colorSelectedFg = lipgloss.Color("255") // White foreground for selection
colorGroupHeader = lipgloss.Color("213") // Light pink for group headers
)
// Title and header styles
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHeader).
Padding(0, 1)
)
// Status indicators
var (
enabledStyle = lipgloss.NewStyle().
Foreground(colorSuccess).
Bold(true)
disabledStyle = lipgloss.NewStyle().
Foreground(colorMuted)
pendingStyle = lipgloss.NewStyle().
Foreground(colorWarning)
errorIndicatorStyle = lipgloss.NewStyle().
Foreground(colorError)
)
// Status bar and help
var (
statusBarStyle = lipgloss.NewStyle().
Foreground(colorMuted)
connectedStyle = lipgloss.NewStyle().
Foreground(colorSuccess).
SetString("Connected")
disconnectedStyle = lipgloss.NewStyle().
Foreground(colorError).
SetString("Disconnected")
helpBarStyle = lipgloss.NewStyle().
Foreground(colorMuted)
helpKeyStyle = lipgloss.NewStyle().
Foreground(colorHeader).
Bold(true)
helpDescStyle = lipgloss.NewStyle().
Foreground(colorMuted)
)
// Message styles
var (
errorMsgStyle = lipgloss.NewStyle().
Foreground(colorError).
Bold(true).
MarginTop(1)
successMsgStyle = lipgloss.NewStyle().
Foreground(colorSuccess).
MarginTop(1)
updateStyle = lipgloss.NewStyle().
Foreground(colorSuccess).
Bold(true)
)
// Form styles
var (
inputLabelStyle = lipgloss.NewStyle().
Foreground(colorPrimary).
Bold(true)
inputStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorMuted).
Padding(0, 1)
inputFocusStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorPrimary).
Padding(0, 1)
)
// Dialog/modal styles
var (
dialogStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorAccent).
Padding(1, 2)
presetItemStyle = lipgloss.NewStyle().
Padding(0, 1)
presetSelectedStyle = lipgloss.NewStyle().
Background(colorSelectedBg).
Foreground(colorSelectedFg).
Padding(0, 1)
)
// Indicator returns the appropriate status indicator string.
func Indicator(enabled bool, pending bool, hasError bool) string {
if hasError {
return errorIndicatorStyle.Render("✗")
}
if pending {
return pendingStyle.Render("◐")
}
if enabled {
return enabledStyle.Render("●")
}
return disabledStyle.Render("○")
}
// StatusText returns the status text with appropriate styling
func StatusText(enabled bool, pending bool, hasError bool) string {
if hasError {
return errorIndicatorStyle.Render("✗ Error")
}
if pending {
return pendingStyle.Render("◐ Pending")
}
if enabled {
return enabledStyle.Render("● Active")
}
return disabledStyle.Render("○ Disabled")
}
// HelpItem formats a help item.
func HelpItem(key, desc string) string {
return helpKeyStyle.Render(key) + " " + helpDescStyle.Render(desc)
}
+159
View File
@@ -0,0 +1,159 @@
// Package version provides version checking against GitHub releases.
package version
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)
const (
// githubReleasesURL is the GitHub API endpoint for latest release
githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest"
// requestTimeout is the timeout for HTTP requests
requestTimeout = 5 * time.Second
)
// ReleaseInfo contains information about a GitHub release
type ReleaseInfo struct {
TagName string `json:"tag_name"`
HTMLURL string `json:"html_url"`
Name string `json:"name"`
}
// UpdateInfo contains information about an available update
type UpdateInfo struct {
CurrentVersion string
LatestVersion string
ReleaseURL string
ReleaseName string
}
// Checker checks for new versions on GitHub
type Checker struct {
owner string
repo string
current string
client *http.Client
}
// NewChecker creates a new version checker
func NewChecker(owner, repo, currentVersion string) *Checker {
return &Checker{
owner: owner,
repo: repo,
current: normalizeVersion(currentVersion),
client: &http.Client{
Timeout: requestTimeout,
},
}
}
// CheckForUpdate checks if a newer version is available.
// Returns nil if current version is up to date or if check fails.
// This is designed to fail silently - network errors should not impact the user.
func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo {
release, err := c.fetchLatestRelease(ctx)
if err != nil {
return nil
}
latestVersion := normalizeVersion(release.TagName)
if isNewerVersion(latestVersion, c.current) {
return &UpdateInfo{
CurrentVersion: c.current,
LatestVersion: latestVersion,
ReleaseURL: release.HTMLURL,
ReleaseName: release.Name,
}
}
return nil
}
// fetchLatestRelease fetches the latest release info from GitHub API
func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) {
url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/vnd.github.v3+json")
req.Header.Set("User-Agent", "lolcathost-version-checker")
resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
}
var release ReleaseInfo
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, err
}
return &release, nil
}
// normalizeVersion removes 'v' or 'V' prefix and trims whitespace
func normalizeVersion(v string) string {
v = strings.TrimSpace(v)
v = strings.TrimPrefix(v, "v")
v = strings.TrimPrefix(v, "V")
return v
}
// isNewerVersion compares two semver-like versions.
// Returns true if latest is newer than current.
func isNewerVersion(latest, current string) bool {
latestParts := parseVersion(latest)
currentParts := parseVersion(current)
// Compare each part
for i := 0; i < len(latestParts) && i < len(currentParts); i++ {
if latestParts[i] > currentParts[i] {
return true
}
if latestParts[i] < currentParts[i] {
return false
}
}
// If all compared parts are equal, longer version is newer
// e.g., 1.0.1 > 1.0
return len(latestParts) > len(currentParts)
}
// parseVersion splits a version string into numeric parts
func parseVersion(v string) []int {
// Remove any suffix like -beta, -rc1, etc.
if idx := strings.IndexAny(v, "-+"); idx != -1 {
v = v[:idx]
}
parts := strings.Split(v, ".")
result := make([]int, 0, len(parts))
for _, p := range parts {
var num int
fmt.Sscanf(p, "%d", &num)
result = append(result, num)
}
return result
}
// FormatUpdateMessage formats a user-friendly update notification
func (u *UpdateInfo) FormatUpdateMessage() string {
return fmt.Sprintf("New version available: %s (current: %s) - %s",
u.LatestVersion, u.CurrentVersion, u.ReleaseURL)
}
+99
View File
@@ -0,0 +1,99 @@
package version
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNormalizeVersion(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"v1.0.0", "1.0.0"},
{"1.0.0", "1.0.0"},
{" v2.1.3 ", "2.1.3"},
{"V1.0.0", "1.0.0"},
{"v0.1.0", "0.1.0"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := normalizeVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestParseVersion(t *testing.T) {
tests := []struct {
input string
expected []int
}{
{"1.0.0", []int{1, 0, 0}},
{"2.1.3", []int{2, 1, 3}},
{"1.0", []int{1, 0}},
{"10.20.30", []int{10, 20, 30}},
{"1.0.0-beta", []int{1, 0, 0}},
{"1.0.0-rc1", []int{1, 0, 0}},
{"1.0.0+build123", []int{1, 0, 0}},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := parseVersion(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestIsNewerVersion(t *testing.T) {
tests := []struct {
name string
latest string
current string
expected bool
}{
{"major version bump", "2.0.0", "1.0.0", true},
{"minor version bump", "1.1.0", "1.0.0", true},
{"patch version bump", "1.0.1", "1.0.0", true},
{"same version", "1.0.0", "1.0.0", false},
{"current is newer major", "1.0.0", "2.0.0", false},
{"current is newer minor", "1.0.0", "1.1.0", false},
{"current is newer patch", "1.0.0", "1.0.1", false},
{"longer version is newer", "1.0.1", "1.0", true},
{"shorter version is older", "1.0", "1.0.1", false},
{"double digit versions", "10.0.0", "9.0.0", true},
{"with prerelease suffix", "1.1.0", "1.0.0-beta", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNewerVersion(tt.latest, tt.current)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUpdateInfo_FormatUpdateMessage(t *testing.T) {
info := &UpdateInfo{
CurrentVersion: "1.0.0",
LatestVersion: "1.1.0",
ReleaseURL: "https://github.com/lukaszraczylo/lolcathost/releases/tag/v1.1.0",
}
msg := info.FormatUpdateMessage()
assert.Contains(t, msg, "1.0.0")
assert.Contains(t, msg, "1.1.0")
assert.Contains(t, msg, "https://github.com")
}
func TestNewChecker(t *testing.T) {
checker := NewChecker("lukaszraczylo", "lolcathost", "v1.0.0")
assert.Equal(t, "lukaszraczylo", checker.owner)
assert.Equal(t, "lolcathost", checker.repo)
assert.Equal(t, "1.0.0", checker.current) // Should be normalized
assert.NotNil(t, checker.client)
}