mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-05 23:29:18 +00:00
Initial commit.
This commit is contained in:
@@ -0,0 +1,427 @@
|
||||
// Package client provides a client library for communicating with the lolcathost daemon.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Client is a client for the lolcathost daemon.
|
||||
type Client struct {
|
||||
socketPath string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
timeout time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new client.
|
||||
func New(socketPath string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithTimeout creates a new client with a custom timeout.
|
||||
func NewWithTimeout(socketPath string, timeout time.Duration) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the daemon.
|
||||
func (c *Client) Connect() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
c.reader = nil
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unix", c.socketPath, c.timeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.reader = bufio.NewReader(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
c.reader = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// send sends a request and receives a response.
|
||||
func (c *Client) send(req *protocol.Request) (*protocol.Response, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
// Set deadline
|
||||
c.conn.SetDeadline(time.Now().Add(c.timeout))
|
||||
|
||||
// Send request
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
if _, err := c.conn.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
line, err := c.reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var resp protocol.Response
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Ping checks if the daemon is responsive.
|
||||
func (c *Client) Ping() error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestPing, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("ping failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status returns the daemon's status.
|
||||
func (c *Client) Status() (*protocol.StatusData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestStatus, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("status failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.StatusData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// List returns all host entries.
|
||||
func (c *Client) List() ([]protocol.HostEntry, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestList, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("list failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.ListData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Entries, nil
|
||||
}
|
||||
|
||||
// Set enables or disables a host entry by alias.
|
||||
func (c *Client) Set(alias string, enabled bool, force bool) (*protocol.SetData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{
|
||||
Alias: alias,
|
||||
Enabled: enabled,
|
||||
Force: force,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.SetData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// Enable enables a host entry by alias.
|
||||
func (c *Client) Enable(alias string) (*protocol.SetData, error) {
|
||||
return c.Set(alias, true, false)
|
||||
}
|
||||
|
||||
// Disable disables a host entry by alias.
|
||||
func (c *Client) Disable(alias string) (*protocol.SetData, error) {
|
||||
return c.Set(alias, false, false)
|
||||
}
|
||||
|
||||
// Add adds a new host entry.
|
||||
func (c *Client) Add(domain, ip, alias, group string, enabled bool) (*protocol.SetData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: alias,
|
||||
Group: group,
|
||||
Enabled: enabled,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.SetData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// Delete removes a host entry by alias.
|
||||
func (c *Client) Delete(alias string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{
|
||||
Alias: alias,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddGroup adds a new group.
|
||||
func (c *Client) AddGroup(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup removes a group and all its hosts.
|
||||
func (c *Client) DeleteGroup(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups returns all group names.
|
||||
func (c *Client) ListGroups() ([]string, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestListGroups, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.GroupsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Groups, nil
|
||||
}
|
||||
|
||||
// Sync synchronizes the config to the hosts file.
|
||||
func (c *Client) Sync() error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestSync, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("sync failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPreset applies a named preset.
|
||||
func (c *Client) ApplyPreset(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("preset failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback restores a backup by name.
|
||||
func (c *Client) Rollback(backupName string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{
|
||||
BackupName: backupName,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("rollback failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBackups returns available backups.
|
||||
func (c *Client) ListBackups() ([]protocol.BackupInfo, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestBackups, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("backups failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.BackupsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Backups, nil
|
||||
}
|
||||
|
||||
// RenameGroup renames a group.
|
||||
func (c *Client) RenameGroup(oldName, newName string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{
|
||||
OldName: oldName,
|
||||
NewName: newName,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPreset adds a new preset.
|
||||
func (c *Client) AddPreset(name string, enable, disable []string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{
|
||||
Name: name,
|
||||
Enable: enable,
|
||||
Disable: disable,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePreset removes a preset by name.
|
||||
func (c *Client) DeletePreset(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPresets returns all presets.
|
||||
func (c *Client) ListPresets() ([]protocol.PresetInfo, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestListPresets, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.PresetsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Presets, nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the daemon is reachable.
|
||||
func IsConnected(socketPath string) bool {
|
||||
client := New(socketPath)
|
||||
if err := client.Connect(); err != nil {
|
||||
return false
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
return client.Ping() == nil
|
||||
}
|
||||
@@ -0,0 +1,516 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// mockServer creates a mock Unix socket server for testing
|
||||
type mockServer struct {
|
||||
listener net.Listener
|
||||
path string
|
||||
handler func(req *protocol.Request) *protocol.Response
|
||||
}
|
||||
|
||||
func newMockServer(t *testing.T) *mockServer {
|
||||
// Use /tmp directly to avoid long paths (Unix socket paths have ~104 char limit on macOS)
|
||||
tmpDir, err := os.MkdirTemp("/tmp", "lolcat")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
socketPath := filepath.Join(tmpDir, "s.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ms := &mockServer{
|
||||
listener: listener,
|
||||
path: socketPath,
|
||||
}
|
||||
|
||||
go ms.serve()
|
||||
|
||||
return ms
|
||||
}
|
||||
|
||||
func (ms *mockServer) serve() {
|
||||
for {
|
||||
conn, err := ms.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go ms.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mockServer) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var resp *protocol.Response
|
||||
if ms.handler != nil {
|
||||
resp = ms.handler(&req)
|
||||
} else {
|
||||
resp, _ = protocol.NewOKResponse(nil)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(resp)
|
||||
conn.Write(append(data, '\n'))
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mockServer) close() {
|
||||
ms.listener.Close()
|
||||
os.Remove(ms.path)
|
||||
}
|
||||
|
||||
func TestClient_Connect(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
assert.NotNil(t, client.conn)
|
||||
assert.NotNil(t, client.reader)
|
||||
})
|
||||
|
||||
t.Run("failure - socket not found", func(t *testing.T) {
|
||||
client := New("/nonexistent/socket.sock")
|
||||
err := client.Connect()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_Ping(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestPing {
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected request")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Ping()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Status(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestStatus {
|
||||
resp, _ := protocol.NewOKResponse(protocol.StatusData{
|
||||
Running: true,
|
||||
Version: "1.0.0",
|
||||
Uptime: 3600,
|
||||
ActiveCount: 5,
|
||||
RequestCount: 100,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, status.Running)
|
||||
assert.Equal(t, "1.0.0", status.Version)
|
||||
assert.Equal(t, int64(3600), status.Uptime)
|
||||
assert.Equal(t, 5, status.ActiveCount)
|
||||
assert.Equal(t, int64(100), status.RequestCount)
|
||||
}
|
||||
|
||||
func TestClient_List(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestList {
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{
|
||||
Entries: []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
},
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
entries, err := client.List()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, entries, 2)
|
||||
assert.Equal(t, "a.com", entries[0].Domain)
|
||||
assert.True(t, entries[0].Enabled)
|
||||
assert.Equal(t, "b.com", entries[1].Domain)
|
||||
assert.False(t, entries[1].Enabled)
|
||||
}
|
||||
|
||||
func TestClient_Set(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: "example.com",
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
data, err := client.Set("test", true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "example.com", data.Domain)
|
||||
assert.True(t, data.Applied)
|
||||
}
|
||||
|
||||
func TestClient_Enable(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.True(t, payload.Enabled)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Enable("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Disable(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.False(t, payload.Enabled)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Disable("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Sync(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSync {
|
||||
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Sync()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_ApplyPreset(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestPreset {
|
||||
var payload protocol.PresetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.Equal(t, "local", payload.Name)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": "local"})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.ApplyPreset("local")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Rollback(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestRollback {
|
||||
var payload protocol.RollbackPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.Equal(t, "hosts.backup.bak", payload.BackupName)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Rollback("hosts.backup.bak")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_ListBackups(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestBackups {
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{
|
||||
Backups: []protocol.BackupInfo{
|
||||
{Name: "hosts.20231201.bak", Timestamp: 1701432000, Size: 1024},
|
||||
{Name: "hosts.20231130.bak", Timestamp: 1701345600, Size: 1000},
|
||||
},
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
backups, err := client.ListBackups()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, backups, 2)
|
||||
assert.Equal(t, "hosts.20231201.bak", backups[0].Name)
|
||||
}
|
||||
|
||||
func TestClient_ErrorResponse(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, "domain is blocked")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Set("test", true, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "domain is blocked")
|
||||
}
|
||||
|
||||
func TestClient_NotConnected(t *testing.T) {
|
||||
client := New("/nonexistent/socket.sock")
|
||||
|
||||
_, err := client.Status()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not connected")
|
||||
}
|
||||
|
||||
func TestClient_Timeout(t *testing.T) {
|
||||
client := NewWithTimeout("/nonexistent.sock", 100*time.Millisecond)
|
||||
assert.Equal(t, 100*time.Millisecond, client.timeout)
|
||||
}
|
||||
|
||||
func TestIsConnected(t *testing.T) {
|
||||
t.Run("connected", func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
return resp
|
||||
}
|
||||
|
||||
connected := IsConnected(server.path)
|
||||
assert.True(t, connected)
|
||||
})
|
||||
|
||||
t.Run("not connected", func(t *testing.T) {
|
||||
connected := IsConnected("/nonexistent/socket.sock")
|
||||
assert.False(t, connected)
|
||||
})
|
||||
}
|
||||
|
||||
// Matrix test for request types
|
||||
func TestClient_RequestTypes_Matrix(t *testing.T) {
|
||||
types := []struct {
|
||||
name string
|
||||
reqType protocol.RequestType
|
||||
call func(*Client) error
|
||||
}{
|
||||
{"ping", protocol.RequestPing, func(c *Client) error { return c.Ping() }},
|
||||
{"status", protocol.RequestStatus, func(c *Client) error { _, err := c.Status(); return err }},
|
||||
{"list", protocol.RequestList, func(c *Client) error { _, err := c.List(); return err }},
|
||||
{"sync", protocol.RequestSync, func(c *Client) error { return c.Sync() }},
|
||||
{"preset", protocol.RequestPreset, func(c *Client) error { return c.ApplyPreset("test") }},
|
||||
{"backups", protocol.RequestBackups, func(c *Client) error { _, err := c.ListBackups(); return err }},
|
||||
}
|
||||
|
||||
for _, tt := range types {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
receivedType := protocol.RequestType("")
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
receivedType = req.Type
|
||||
|
||||
switch req.Type {
|
||||
case protocol.RequestStatus:
|
||||
resp, _ := protocol.NewOKResponse(protocol.StatusData{})
|
||||
return resp
|
||||
case protocol.RequestList:
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{})
|
||||
return resp
|
||||
case protocol.RequestBackups:
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{})
|
||||
return resp
|
||||
default:
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_ = tt.call(client)
|
||||
assert.Equal(t, tt.reqType, receivedType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClient_Ping(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(b, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
reader := bufio.NewReader(c)
|
||||
for {
|
||||
_, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Write(append(data, '\n'))
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
client := New(socketPath)
|
||||
err = client.Connect()
|
||||
require.NoError(b, err)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = client.Ping()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,541 @@
|
||||
// Package config handles YAML configuration parsing and hot-reload.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SystemConfigDir is the system-wide config directory for the daemon.
|
||||
const SystemConfigDir = "/etc/lolcathost"
|
||||
|
||||
// SystemConfigPath is the system-wide config file path for the daemon.
|
||||
const SystemConfigPath = "/etc/lolcathost/config.yaml"
|
||||
|
||||
// DefaultConfigDir returns the default config directory path for users.
|
||||
func DefaultConfigDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".config", "lolcathost")
|
||||
}
|
||||
|
||||
// DefaultConfigPath returns the default config file path for users.
|
||||
func DefaultConfigPath() string {
|
||||
return filepath.Join(DefaultConfigDir(), "config.yaml")
|
||||
}
|
||||
|
||||
// FlushMethod defines DNS cache flush methods.
|
||||
type FlushMethod string
|
||||
|
||||
const (
|
||||
FlushMethodAuto FlushMethod = "auto"
|
||||
FlushMethodDscacheutil FlushMethod = "dscacheutil"
|
||||
FlushMethodKillall FlushMethod = "killall"
|
||||
FlushMethodBoth FlushMethod = "both"
|
||||
)
|
||||
|
||||
// Settings holds global configuration settings.
|
||||
type Settings struct {
|
||||
AutoApply bool `yaml:"autoApply"`
|
||||
FlushMethod FlushMethod `yaml:"flushMethod"`
|
||||
}
|
||||
|
||||
// Host represents a single host entry in configuration.
|
||||
type Host struct {
|
||||
Domain string `yaml:"domain"`
|
||||
IP string `yaml:"ip"`
|
||||
Alias string `yaml:"alias"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// Group represents a group of host entries.
|
||||
type Group struct {
|
||||
Name string `yaml:"name"`
|
||||
Hosts []Host `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// Preset defines a named preset that enables/disables specific aliases.
|
||||
type Preset struct {
|
||||
Name string `yaml:"name"`
|
||||
Enable []string `yaml:"enable,omitempty"`
|
||||
Disable []string `yaml:"disable,omitempty"`
|
||||
}
|
||||
|
||||
// Config represents the complete configuration.
|
||||
type Config struct {
|
||||
Settings Settings `yaml:"settings"`
|
||||
Groups []Group `yaml:"groups"`
|
||||
Presets []Preset `yaml:"presets"`
|
||||
}
|
||||
|
||||
// Manager handles configuration loading and watching.
|
||||
type Manager struct {
|
||||
path string
|
||||
config *Config
|
||||
mu sync.RWMutex
|
||||
watcher *fsnotify.Watcher
|
||||
onChange func(*Config)
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewManager creates a new config manager.
|
||||
func NewManager(path string) *Manager {
|
||||
return &Manager{
|
||||
path: path,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads and parses the configuration file.
|
||||
func (m *Manager) Load() error {
|
||||
data, err := os.ReadFile(m.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if err := ValidateConfig(&cfg); err != nil {
|
||||
return fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.config = &cfg
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the current configuration.
|
||||
func (m *Manager) Get() *Config {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Watch starts watching the config file for changes.
|
||||
func (m *Manager) Watch(onChange func(*Config)) error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create watcher: %w", err)
|
||||
}
|
||||
|
||||
m.watcher = watcher
|
||||
m.onChange = onChange
|
||||
|
||||
go m.watchLoop()
|
||||
|
||||
if err := watcher.Add(m.path); err != nil {
|
||||
return fmt.Errorf("failed to watch config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) watchLoop() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-m.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
if err := m.Load(); err == nil && m.onChange != nil {
|
||||
m.onChange(m.Get())
|
||||
}
|
||||
}
|
||||
case <-m.watcher.Errors:
|
||||
// Ignore watcher errors
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops watching the config file.
|
||||
func (m *Manager) Stop() {
|
||||
close(m.stopCh)
|
||||
if m.watcher != nil {
|
||||
m.watcher.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllHosts returns all hosts from all groups.
|
||||
func (c *Config) GetAllHosts() []Host {
|
||||
var hosts []Host
|
||||
for _, g := range c.Groups {
|
||||
hosts = append(hosts, g.Hosts...)
|
||||
}
|
||||
return hosts
|
||||
}
|
||||
|
||||
// FindHostByAlias finds a host by its alias.
|
||||
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
return &c.Groups[i].Hosts[j], &c.Groups[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// FindPreset finds a preset by name.
|
||||
func (c *Config) FindPreset(name string) *Preset {
|
||||
for i := range c.Presets {
|
||||
if c.Presets[i].Name == name {
|
||||
return &c.Presets[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHostEnabled sets the enabled state of a host by alias.
|
||||
func (c *Config) SetHostEnabled(alias string, enabled bool) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts[j].Enabled = enabled
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateAlias creates a unique alias from a domain name.
|
||||
func (c *Config) GenerateAlias(domain string) string {
|
||||
// Convert domain to alias format: example.com -> example-com
|
||||
alias := strings.ReplaceAll(domain, ".", "-")
|
||||
alias = strings.ReplaceAll(alias, "_", "-")
|
||||
alias = strings.ToLower(alias)
|
||||
|
||||
// Check if alias exists, if so append a number
|
||||
baseAlias := alias
|
||||
counter := 1
|
||||
for {
|
||||
if existing, _ := c.FindHostByAlias(alias); existing == nil {
|
||||
break
|
||||
}
|
||||
counter++
|
||||
alias = fmt.Sprintf("%s-%d", baseAlias, counter)
|
||||
}
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
// AddHost adds a new host to the configuration.
|
||||
func (c *Config) AddHost(domain, ip, alias, groupName string, enabled bool) error {
|
||||
// Auto-generate alias if empty
|
||||
if alias == "" {
|
||||
alias = c.GenerateAlias(domain)
|
||||
} else {
|
||||
// Check for duplicate alias
|
||||
if existing, _ := c.FindHostByAlias(alias); existing != nil {
|
||||
return fmt.Errorf("alias already exists: %s", alias)
|
||||
}
|
||||
}
|
||||
|
||||
host := Host{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: alias,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
// Find or create group
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == groupName {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: groupName,
|
||||
Hosts: []Host{host},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddGroup adds a new empty group.
|
||||
func (c *Config) AddGroup(name string) error {
|
||||
// Check if group already exists
|
||||
for _, g := range c.Groups {
|
||||
if g.Name == name {
|
||||
return fmt.Errorf("group already exists: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: name,
|
||||
Hosts: []Host{},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup removes a group and all its hosts.
|
||||
func (c *Config) DeleteGroup(name string) error {
|
||||
for i, g := range c.Groups {
|
||||
if g.Name == name {
|
||||
c.Groups = append(c.Groups[:i], c.Groups[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("group not found: %s", name)
|
||||
}
|
||||
|
||||
// RenameGroup renames an existing group.
|
||||
func (c *Config) RenameGroup(oldName, newName string) error {
|
||||
// Check if new name already exists
|
||||
for _, g := range c.Groups {
|
||||
if g.Name == newName {
|
||||
return fmt.Errorf("group already exists: %s", newName)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == oldName {
|
||||
c.Groups[i].Name = newName
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("group not found: %s", oldName)
|
||||
}
|
||||
|
||||
// GetGroups returns all group names.
|
||||
func (c *Config) GetGroups() []string {
|
||||
names := make([]string, len(c.Groups))
|
||||
for i, g := range c.Groups {
|
||||
names[i] = g.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// DeleteHost removes a host by alias.
|
||||
func (c *Config) DeleteHost(alias string) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateHost updates an existing host by alias.
|
||||
func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error {
|
||||
// Find the host
|
||||
var foundGroup int = -1
|
||||
var foundHost int = -1
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == oldAlias {
|
||||
foundGroup = i
|
||||
foundHost = j
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundHost >= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundHost < 0 {
|
||||
return fmt.Errorf("alias not found: %s", oldAlias)
|
||||
}
|
||||
|
||||
// Check for duplicate alias if alias is changing
|
||||
if oldAlias != newAlias {
|
||||
if existing, _ := c.FindHostByAlias(newAlias); existing != nil {
|
||||
return fmt.Errorf("alias already exists: %s", newAlias)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current enabled state
|
||||
enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled
|
||||
|
||||
// If group is changing, move to new group
|
||||
if c.Groups[foundGroup].Name != groupName {
|
||||
// Remove from old group
|
||||
c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...)
|
||||
|
||||
// Add to new group
|
||||
host := Host{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: newAlias,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
// Find or create target group
|
||||
found := false
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == groupName {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: groupName,
|
||||
Hosts: []Host{host},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Update in place
|
||||
c.Groups[foundGroup].Hosts[foundHost].Domain = domain
|
||||
c.Groups[foundGroup].Hosts[foundHost].IP = ip
|
||||
c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPreset applies a preset to the configuration.
|
||||
func (c *Config) ApplyPreset(name string) error {
|
||||
preset := c.FindPreset(name)
|
||||
if preset == nil {
|
||||
return fmt.Errorf("preset not found: %s", name)
|
||||
}
|
||||
|
||||
for _, alias := range preset.Enable {
|
||||
c.SetHostEnabled(alias, true)
|
||||
}
|
||||
for _, alias := range preset.Disable {
|
||||
c.SetHostEnabled(alias, false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPreset adds a new preset.
|
||||
func (c *Config) AddPreset(name string, enable, disable []string) error {
|
||||
// Check if preset already exists
|
||||
for _, p := range c.Presets {
|
||||
if p.Name == name {
|
||||
return fmt.Errorf("preset already exists: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
c.Presets = append(c.Presets, Preset{
|
||||
Name: name,
|
||||
Enable: enable,
|
||||
Disable: disable,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePreset removes a preset by name.
|
||||
func (c *Config) DeletePreset(name string) error {
|
||||
for i, p := range c.Presets {
|
||||
if p.Name == name {
|
||||
c.Presets = append(c.Presets[:i], c.Presets[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("preset not found: %s", name)
|
||||
}
|
||||
|
||||
// GetPresets returns all presets.
|
||||
func (c *Config) GetPresets() []Preset {
|
||||
return c.Presets
|
||||
}
|
||||
|
||||
// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed.
|
||||
func (c *Config) EnsureDefaultGroup() {
|
||||
if len(c.Groups) == 0 {
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: "default",
|
||||
Hosts: []Host{},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Save writes the configuration to the file.
|
||||
func (m *Manager) Save() error {
|
||||
m.mu.RLock()
|
||||
cfg := m.config
|
||||
m.mu.RUnlock()
|
||||
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("no config loaded")
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDefault creates a default configuration file.
|
||||
func CreateDefault(path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Settings: Settings{
|
||||
AutoApply: true,
|
||||
FlushMethod: FlushMethodAuto,
|
||||
},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{
|
||||
Domain: "example.local",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "example-local",
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{
|
||||
Name: "local",
|
||||
Enable: []string{"example-local"},
|
||||
Disable: []string{},
|
||||
},
|
||||
{
|
||||
Name: "clear",
|
||||
Enable: []string{},
|
||||
Disable: []string{"example-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal default config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write default config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_GetAllHosts(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "staging",
|
||||
Hosts: []Host{
|
||||
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
hosts := cfg.GetAllHosts()
|
||||
assert.Len(t, hosts, 3)
|
||||
assert.Equal(t, "a.com", hosts[0].Domain)
|
||||
assert.Equal(t, "b.com", hosts[1].Domain)
|
||||
assert.Equal(t, "c.com", hosts[2].Domain)
|
||||
}
|
||||
|
||||
func TestConfig_FindHostByAlias(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
host, group := cfg.FindHostByAlias("example")
|
||||
require.NotNil(t, host)
|
||||
require.NotNil(t, group)
|
||||
assert.Equal(t, "example.com", host.Domain)
|
||||
assert.Equal(t, "dev", group.Name)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
host, group := cfg.FindHostByAlias("nonexistent")
|
||||
assert.Nil(t, host)
|
||||
assert.Nil(t, group)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_FindPreset(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"a"}, Disable: []string{"b"}},
|
||||
{Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
preset := cfg.FindPreset("local")
|
||||
require.NotNil(t, preset)
|
||||
assert.Equal(t, "local", preset.Name)
|
||||
assert.Equal(t, []string{"a"}, preset.Enable)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
preset := cfg.FindPreset("nonexistent")
|
||||
assert.Nil(t, preset)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_SetHostEnabled(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("enable existing", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("example", true)
|
||||
assert.True(t, result)
|
||||
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
})
|
||||
|
||||
t.Run("disable existing", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("example", false)
|
||||
assert.True(t, result)
|
||||
assert.False(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
})
|
||||
|
||||
t.Run("nonexistent alias", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("nonexistent", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_ApplyPreset(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "swap", Enable: []string{"a"}, Disable: []string{"b"}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("valid preset", func(t *testing.T) {
|
||||
err := cfg.ApplyPreset("swap")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
assert.False(t, cfg.Groups[0].Hosts[1].Enabled)
|
||||
})
|
||||
|
||||
t.Run("nonexistent preset", func(t *testing.T) {
|
||||
err := cfg.ApplyPreset("nonexistent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_LoadAndGet(t *testing.T) {
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
configContent := `
|
||||
settings:
|
||||
autoApply: true
|
||||
flushMethod: auto
|
||||
groups:
|
||||
- name: development
|
||||
hosts:
|
||||
- domain: example.com
|
||||
ip: 127.0.0.1
|
||||
alias: example-local
|
||||
enabled: true
|
||||
presets:
|
||||
- name: local
|
||||
enable: [example-local]
|
||||
disable: []
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Settings.AutoApply)
|
||||
assert.Equal(t, FlushMethodAuto, cfg.Settings.FlushMethod)
|
||||
assert.Len(t, cfg.Groups, 1)
|
||||
assert.Equal(t, "development", cfg.Groups[0].Name)
|
||||
assert.Len(t, cfg.Groups[0].Hosts, 1)
|
||||
assert.Equal(t, "example.com", cfg.Groups[0].Hosts[0].Domain)
|
||||
}
|
||||
|
||||
func TestManager_Save(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Create initial config
|
||||
err := CreateDefault(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load and modify
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
cfg.Groups[0].Hosts[0].Enabled = true
|
||||
|
||||
// Save
|
||||
err = manager.Save()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reload and verify
|
||||
manager2 := NewManager(configPath)
|
||||
err = manager2.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2 := manager2.Get()
|
||||
assert.True(t, cfg2.Groups[0].Hosts[0].Enabled)
|
||||
}
|
||||
|
||||
func TestCreateDefault(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "subdir", "config.yaml")
|
||||
|
||||
err := CreateDefault(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file exists
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify content is valid
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
require.NotNil(t, cfg)
|
||||
assert.True(t, cfg.Settings.AutoApply)
|
||||
assert.Len(t, cfg.Groups, 1)
|
||||
assert.Len(t, cfg.Presets, 2)
|
||||
}
|
||||
|
||||
func TestManager_Load_InvalidYAML(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Load_FileNotFound(t *testing.T) {
|
||||
manager := NewManager("/nonexistent/path/config.yaml")
|
||||
err := manager.Load()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFlushMethod(t *testing.T) {
|
||||
methods := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
}
|
||||
|
||||
for _, m := range methods {
|
||||
t.Run(string(m), func(t *testing.T) {
|
||||
assert.NotEmpty(t, string(m))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
// Package config provides validation functions for configuration.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// domainRegex validates domain names.
|
||||
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$|^localhost$`)
|
||||
|
||||
// aliasRegex validates alias names.
|
||||
var aliasRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`)
|
||||
|
||||
// blockedDomains contains domains that cannot be modified.
|
||||
var blockedDomains = map[string]bool{
|
||||
"apple.com": true,
|
||||
"icloud.com": true,
|
||||
"icloud-content.com": true,
|
||||
"apple-dns.cn": true,
|
||||
"apple-dns.net": true,
|
||||
"mzstatic.com": true,
|
||||
"itunes.apple.com": true,
|
||||
"updates.apple.com": true,
|
||||
}
|
||||
|
||||
// ValidationError represents a configuration validation error.
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidateConfig validates the entire configuration.
|
||||
func ValidateConfig(cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return &ValidationError{Field: "config", Message: "config is nil"}
|
||||
}
|
||||
|
||||
if err := validateSettings(&cfg.Settings); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Track aliases for uniqueness
|
||||
aliases := make(map[string]bool)
|
||||
|
||||
for i, g := range cfg.Groups {
|
||||
if err := validateGroup(&g, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i, p := range cfg.Presets {
|
||||
if err := validatePreset(&p, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSettings(s *Settings) error {
|
||||
switch s.FlushMethod {
|
||||
case FlushMethodAuto, FlushMethodDscacheutil, FlushMethodKillall, FlushMethodBoth, "":
|
||||
// Valid
|
||||
default:
|
||||
return &ValidationError{
|
||||
Field: "settings.flushMethod",
|
||||
Message: fmt.Sprintf("invalid flush method: %s", s.FlushMethod),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroup(g *Group, index int, aliases map[string]bool) error {
|
||||
if strings.TrimSpace(g.Name) == "" {
|
||||
return &ValidationError{
|
||||
Field: fmt.Sprintf("groups[%d].name", index),
|
||||
Message: "group name is required",
|
||||
}
|
||||
}
|
||||
|
||||
for i, h := range g.Hosts {
|
||||
if err := validateHost(&h, index, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHost(h *Host, groupIndex, hostIndex int, aliases map[string]bool) error {
|
||||
fieldPrefix := fmt.Sprintf("groups[%d].hosts[%d]", groupIndex, hostIndex)
|
||||
|
||||
// Validate domain
|
||||
if !ValidateDomain(h.Domain) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".domain",
|
||||
Message: fmt.Sprintf("invalid domain: %s", h.Domain),
|
||||
}
|
||||
}
|
||||
|
||||
// Check blocked domains
|
||||
if IsBlockedDomain(h.Domain) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".domain",
|
||||
Message: fmt.Sprintf("domain is blocked: %s", h.Domain),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate IP
|
||||
if !ValidateIP(h.IP) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".ip",
|
||||
Message: fmt.Sprintf("invalid IP address: %s", h.IP),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate alias
|
||||
if !ValidateAlias(h.Alias) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".alias",
|
||||
Message: fmt.Sprintf("invalid alias: %s", h.Alias),
|
||||
}
|
||||
}
|
||||
|
||||
// Check alias uniqueness
|
||||
if aliases[h.Alias] {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".alias",
|
||||
Message: fmt.Sprintf("duplicate alias: %s", h.Alias),
|
||||
}
|
||||
}
|
||||
aliases[h.Alias] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePreset(p *Preset, index int, aliases map[string]bool) error {
|
||||
fieldPrefix := fmt.Sprintf("presets[%d]", index)
|
||||
|
||||
if strings.TrimSpace(p.Name) == "" {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".name",
|
||||
Message: "preset name is required",
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We don't validate preset aliases strictly anymore.
|
||||
// Unknown aliases in presets will simply be skipped when applying the preset.
|
||||
// This allows presets to survive when hosts are removed from the config.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDomain checks if a domain name is valid.
|
||||
func ValidateDomain(domain string) bool {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
return domainRegex.MatchString(domain)
|
||||
}
|
||||
|
||||
// ValidateIP checks if an IP address is valid (IPv4 or IPv6).
|
||||
func ValidateIP(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
return net.ParseIP(ip) != nil
|
||||
}
|
||||
|
||||
// ValidateAlias checks if an alias is valid.
|
||||
func ValidateAlias(alias string) bool {
|
||||
if alias == "" {
|
||||
return false
|
||||
}
|
||||
return aliasRegex.MatchString(alias)
|
||||
}
|
||||
|
||||
// IsBlockedDomain checks if a domain is in the blocklist.
|
||||
func IsBlockedDomain(domain string) bool {
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
// Check exact match
|
||||
if blockedDomains[domain] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's a subdomain of a blocked domain
|
||||
for blocked := range blockedDomains {
|
||||
if strings.HasSuffix(domain, "."+blocked) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetBlockedDomains returns a copy of the blocked domains list.
|
||||
func GetBlockedDomains() []string {
|
||||
domains := make([]string, 0, len(blockedDomains))
|
||||
for d := range blockedDomains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
@@ -0,0 +1,436 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
valid bool
|
||||
}{
|
||||
{"example.com", true},
|
||||
{"sub.example.com", true},
|
||||
{"my-app.example.com", true},
|
||||
{"localhost", true},
|
||||
{"a.b.c.d.example.com", true},
|
||||
{"example123.com", true},
|
||||
|
||||
{"", false},
|
||||
{"-example.com", false},
|
||||
{"example-.com", false},
|
||||
{"example.c", false}, // TLD too short
|
||||
{"example", false}, // No TLD
|
||||
{".example.com", false},
|
||||
{"example..com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
result := ValidateDomain(tt.domain)
|
||||
assert.Equal(t, tt.valid, result, "domain: %s", tt.domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
valid bool
|
||||
}{
|
||||
// Valid IPv4
|
||||
{"127.0.0.1", true},
|
||||
{"192.168.1.1", true},
|
||||
{"0.0.0.0", true},
|
||||
{"255.255.255.255", true},
|
||||
|
||||
// Valid IPv6
|
||||
{"::1", true},
|
||||
{"2001:db8::1", true},
|
||||
{"fe80::1", true},
|
||||
{"::ffff:192.168.1.1", true},
|
||||
|
||||
// Invalid
|
||||
{"", false},
|
||||
{"256.0.0.1", false},
|
||||
{"192.168.1", false},
|
||||
{"not-an-ip", false},
|
||||
{"192.168.1.1.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := ValidateIP(tt.ip)
|
||||
assert.Equal(t, tt.valid, result, "ip: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
alias string
|
||||
valid bool
|
||||
}{
|
||||
{"my-alias", true},
|
||||
{"myalias", true},
|
||||
{"my_alias", true},
|
||||
{"alias123", true},
|
||||
{"a", true},
|
||||
{"a-b_c-d", true},
|
||||
|
||||
{"", false},
|
||||
{"-startswithdash", false},
|
||||
{"_startswithunderscore", false},
|
||||
{"has spaces", false},
|
||||
{"has.dot", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.alias, func(t *testing.T) {
|
||||
result := ValidateAlias(tt.alias)
|
||||
assert.Equal(t, tt.valid, result, "alias: %s", tt.alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
blocked bool
|
||||
}{
|
||||
// Blocked domains
|
||||
{"apple.com", true},
|
||||
{"icloud.com", true},
|
||||
{"sub.apple.com", true},
|
||||
{"deep.sub.icloud.com", true},
|
||||
{"APPLE.COM", true}, // Case insensitive
|
||||
|
||||
// Allowed domains
|
||||
{"example.com", false},
|
||||
{"myapp.com", false},
|
||||
{"applestore.com", false}, // Not a subdomain
|
||||
{"notapple.com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
result := IsBlockedDomain(tt.domain)
|
||||
assert.Equal(t, tt.blocked, result, "domain: %s", tt.domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBlockedDomains(t *testing.T) {
|
||||
domains := GetBlockedDomains()
|
||||
assert.NotEmpty(t, domains)
|
||||
assert.Contains(t, domains, "apple.com")
|
||||
assert.Contains(t, domains, "icloud.com")
|
||||
}
|
||||
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{
|
||||
AutoApply: true,
|
||||
FlushMethod: FlushMethodAuto,
|
||||
},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"example"}, Disable: []string{}},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateConfig(cfg)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil config", func(t *testing.T) {
|
||||
err := ValidateConfig(nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid flush method", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{FlushMethod: "invalid"},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty group name", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{{Name: "", Hosts: []Host{}}},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid domain", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "invalid", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("blocked domain", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "apple.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid IP", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "invalid", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid alias", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "-invalid", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("duplicate alias", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty preset name", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "", Enable: []string{}},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("preset with unknown alias is allowed", func(t *testing.T) {
|
||||
// Unknown aliases in presets are now allowed (they're simply skipped when applied)
|
||||
// This allows presets to survive when hosts are removed from the config
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"unknown"}},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidationError(t *testing.T) {
|
||||
err := &ValidationError{Field: "test.field", Message: "test message"}
|
||||
assert.Equal(t, "test.field: test message", err.Error())
|
||||
}
|
||||
|
||||
func TestValidateSettings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method FlushMethod
|
||||
wantErr bool
|
||||
}{
|
||||
{"auto", FlushMethodAuto, false},
|
||||
{"dscacheutil", FlushMethodDscacheutil, false},
|
||||
{"killall", FlushMethodKillall, false},
|
||||
{"both", FlushMethodBoth, false},
|
||||
{"empty", "", false},
|
||||
{"invalid", "invalid", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
settings := &Settings{FlushMethod: tt.method}
|
||||
err := validateSettings(settings)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix testing for domain validation
|
||||
func TestValidateDomain_Matrix(t *testing.T) {
|
||||
prefixes := []string{"", "sub.", "a.b."}
|
||||
domains := []string{"example", "my-app", "test123"}
|
||||
tlds := []string{".com", ".io", ".co.uk", ".dev"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
for _, domain := range domains {
|
||||
for _, tld := range tlds {
|
||||
fullDomain := prefix + domain + tld
|
||||
t.Run(fullDomain, func(t *testing.T) {
|
||||
result := ValidateDomain(fullDomain)
|
||||
assert.True(t, result, "expected %s to be valid", fullDomain)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix testing for IP validation
|
||||
func TestValidateIP_Matrix(t *testing.T) {
|
||||
octets := []string{"0", "127", "192", "255"}
|
||||
|
||||
for _, o1 := range octets {
|
||||
for _, o2 := range octets {
|
||||
for _, o3 := range octets {
|
||||
for _, o4 := range octets {
|
||||
ip := o1 + "." + o2 + "." + o3 + "." + o4
|
||||
t.Run(ip, func(t *testing.T) {
|
||||
result := ValidateIP(ip)
|
||||
assert.True(t, result, "expected %s to be valid", ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkValidateDomain(b *testing.B) {
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"sub.example.com",
|
||||
"very.long.subdomain.chain.example.com",
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
b.Run(domain, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateDomain(domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateIP(b *testing.B) {
|
||||
ips := []string{
|
||||
"127.0.0.1",
|
||||
"192.168.1.1",
|
||||
"::1",
|
||||
"2001:db8::1",
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
b.Run(ip, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateIP(ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsBlockedDomain(b *testing.B) {
|
||||
domains := []string{
|
||||
"example.com", // not blocked
|
||||
"apple.com", // blocked
|
||||
"sub.icloud.com", // blocked subdomain
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
b.Run(domain, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsBlockedDomain(domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateConfig(b *testing.B) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{AutoApply: true, FlushMethod: FlushMethodAuto},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.example.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
|
||||
{Domain: "b.example.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
|
||||
{Domain: "c.example.com", IP: "127.0.0.1", Alias: "c", Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"a", "b"}, Disable: []string{"c"}},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := ValidateConfig(cfg)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
// Package daemon provides the main daemon loop and lifecycle management.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Daemon represents the lolcathost daemon.
|
||||
type Daemon struct {
|
||||
server *Server
|
||||
config *config.Manager
|
||||
stopCh chan struct{}
|
||||
cleanupCh chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new daemon instance.
|
||||
func New(configPath string) (*Daemon, error) {
|
||||
cfgManager := config.NewManager(configPath)
|
||||
|
||||
// Try to load config, create default if it doesn't exist
|
||||
if err := cfgManager.Load(); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err := config.CreateDefault(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
if err := cfgManager.Load(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load default config: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure at least one group exists
|
||||
cfg := cfgManager.Get()
|
||||
if cfg != nil {
|
||||
cfg.EnsureDefaultGroup()
|
||||
// Save if we added a default group
|
||||
if len(cfg.Groups) == 1 && cfg.Groups[0].Name == "default" && len(cfg.Groups[0].Hosts) == 0 {
|
||||
cfgManager.Save()
|
||||
}
|
||||
}
|
||||
|
||||
server := NewServer(protocol.SocketPath, cfgManager)
|
||||
|
||||
return &Daemon{
|
||||
server: server,
|
||||
config: cfgManager,
|
||||
stopCh: make(chan struct{}),
|
||||
cleanupCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the daemon and blocks until stopped.
|
||||
func (d *Daemon) Run() error {
|
||||
// Verify we're running as root
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("daemon must run as root")
|
||||
}
|
||||
|
||||
// Start the server
|
||||
if err := d.server.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Watch config for changes
|
||||
if err := d.config.Watch(d.onConfigChange); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to watch config: %v\n", err)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go d.cleanupLoop()
|
||||
|
||||
// Wait for shutdown signal
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
fmt.Println("Received shutdown signal")
|
||||
case <-d.stopCh:
|
||||
fmt.Println("Shutdown requested")
|
||||
}
|
||||
|
||||
return d.shutdown()
|
||||
}
|
||||
|
||||
// Stop signals the daemon to stop.
|
||||
func (d *Daemon) Stop() {
|
||||
close(d.stopCh)
|
||||
}
|
||||
|
||||
func (d *Daemon) shutdown() error {
|
||||
close(d.cleanupCh)
|
||||
d.config.Stop()
|
||||
|
||||
if err := d.server.Stop(); err != nil {
|
||||
return fmt.Errorf("failed to stop server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) onConfigChange(cfg *config.Config) {
|
||||
fmt.Println("Config changed, syncing hosts file...")
|
||||
// The server will use the updated config on next request
|
||||
// We could trigger a sync here if autoApply is enabled
|
||||
if cfg != nil && cfg.Settings.AutoApply {
|
||||
// Sync hosts file with new config
|
||||
// This is handled by the server internally
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
d.server.rateLimiter.Cleanup()
|
||||
case <-d.cleanupCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
// Package daemon provides DNS cache flushing functionality.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// DNSFlusher handles DNS cache flushing.
|
||||
type DNSFlusher struct {
|
||||
method FlushMethod
|
||||
}
|
||||
|
||||
// FlushMethod defines the DNS flush method to use.
|
||||
type FlushMethod string
|
||||
|
||||
const (
|
||||
FlushMethodAuto FlushMethod = "auto"
|
||||
FlushMethodDscacheutil FlushMethod = "dscacheutil"
|
||||
FlushMethodKillall FlushMethod = "killall"
|
||||
FlushMethodBoth FlushMethod = "both"
|
||||
FlushMethodSystemd FlushMethod = "systemd"
|
||||
FlushMethodNscd FlushMethod = "nscd"
|
||||
)
|
||||
|
||||
// NewDNSFlusher creates a new DNS flusher.
|
||||
func NewDNSFlusher(method FlushMethod) *DNSFlusher {
|
||||
return &DNSFlusher{method: method}
|
||||
}
|
||||
|
||||
// Flush flushes the DNS cache using the configured method.
|
||||
func (f *DNSFlusher) Flush() error {
|
||||
method := f.method
|
||||
if method == FlushMethodAuto || method == "" {
|
||||
method = f.detectMethod()
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return f.flushDarwin(method)
|
||||
case "linux":
|
||||
return f.flushLinux(method)
|
||||
default:
|
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) detectMethod() FlushMethod {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return FlushMethodBoth
|
||||
case "linux":
|
||||
// Check for systemd-resolve first
|
||||
if _, err := exec.LookPath("systemd-resolve"); err == nil {
|
||||
return FlushMethodSystemd
|
||||
}
|
||||
if _, err := exec.LookPath("resolvectl"); err == nil {
|
||||
return FlushMethodSystemd
|
||||
}
|
||||
// Fall back to nscd
|
||||
if _, err := exec.LookPath("nscd"); err == nil {
|
||||
return FlushMethodNscd
|
||||
}
|
||||
return FlushMethodAuto
|
||||
default:
|
||||
return FlushMethodAuto
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) flushDarwin(method FlushMethod) error {
|
||||
var errs []error
|
||||
|
||||
switch method {
|
||||
case FlushMethodDscacheutil:
|
||||
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
|
||||
return fmt.Errorf("dscacheutil failed: %w", err)
|
||||
}
|
||||
case FlushMethodKillall:
|
||||
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
|
||||
return fmt.Errorf("killall mDNSResponder failed: %w", err)
|
||||
}
|
||||
case FlushMethodBoth:
|
||||
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
|
||||
errs = append(errs, fmt.Errorf("dscacheutil failed: %w", err))
|
||||
}
|
||||
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
|
||||
errs = append(errs, fmt.Errorf("killall mDNSResponder failed: %w", err))
|
||||
}
|
||||
if len(errs) == 2 {
|
||||
return fmt.Errorf("all DNS flush methods failed: %v, %v", errs[0], errs[1])
|
||||
}
|
||||
default:
|
||||
// Auto - try both
|
||||
_ = runCommand("dscacheutil", "-flushcache")
|
||||
_ = runCommand("killall", "-HUP", "mDNSResponder")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) flushLinux(method FlushMethod) error {
|
||||
switch method {
|
||||
case FlushMethodSystemd:
|
||||
// Try resolvectl first (newer), then systemd-resolve (older)
|
||||
if err := runCommand("resolvectl", "flush-caches"); err != nil {
|
||||
if err := runCommand("systemd-resolve", "--flush-caches"); err != nil {
|
||||
return fmt.Errorf("systemd DNS flush failed: %w", err)
|
||||
}
|
||||
}
|
||||
case FlushMethodNscd:
|
||||
// Try to restart nscd
|
||||
if err := runCommand("nscd", "-i", "hosts"); err != nil {
|
||||
// Try service restart as fallback
|
||||
if err := runCommand("service", "nscd", "restart"); err != nil {
|
||||
return fmt.Errorf("nscd flush failed: %w", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Auto - try all methods
|
||||
// Try systemd first
|
||||
if err := runCommand("resolvectl", "flush-caches"); err == nil {
|
||||
return nil
|
||||
}
|
||||
if err := runCommand("systemd-resolve", "--flush-caches"); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Try nscd
|
||||
if err := runCommand("nscd", "-i", "hosts"); err == nil {
|
||||
return nil
|
||||
}
|
||||
// On many Linux systems, no explicit flush is needed as /etc/hosts is read directly
|
||||
// So we return nil here
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommand(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewDNSFlusher(t *testing.T) {
|
||||
tests := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
FlushMethodSystemd,
|
||||
FlushMethodNscd,
|
||||
}
|
||||
|
||||
for _, method := range tests {
|
||||
t.Run(string(method), func(t *testing.T) {
|
||||
flusher := NewDNSFlusher(method)
|
||||
assert.NotNil(t, flusher)
|
||||
assert.Equal(t, method, flusher.method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFlusher_DetectMethod(t *testing.T) {
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
|
||||
method := flusher.detectMethod()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
assert.Equal(t, FlushMethodBoth, method)
|
||||
case "linux":
|
||||
// Could be systemd, nscd, or auto depending on system
|
||||
assert.Contains(t, []FlushMethod{FlushMethodSystemd, FlushMethodNscd, FlushMethodAuto}, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushMethod_String(t *testing.T) {
|
||||
methods := map[FlushMethod]string{
|
||||
FlushMethodAuto: "auto",
|
||||
FlushMethodDscacheutil: "dscacheutil",
|
||||
FlushMethodKillall: "killall",
|
||||
FlushMethodBoth: "both",
|
||||
FlushMethodSystemd: "systemd",
|
||||
FlushMethodNscd: "nscd",
|
||||
}
|
||||
|
||||
for method, expected := range methods {
|
||||
t.Run(expected, func(t *testing.T) {
|
||||
assert.Equal(t, expected, string(method))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Actually testing DNS flush requires root and modifies system state,
|
||||
// so we skip those tests in unit tests. They would be integration tests.
|
||||
|
||||
func TestDNSFlusher_Flush_UnsupportedOS(t *testing.T) {
|
||||
// This test only makes sense if we're not on darwin or linux
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
|
||||
t.Skip("Test only applicable on unsupported OS")
|
||||
}
|
||||
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
err := flusher.Flush()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported operating system")
|
||||
}
|
||||
|
||||
// Matrix test for flush methods
|
||||
func TestFlushMethod_Matrix(t *testing.T) {
|
||||
methods := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
FlushMethodSystemd,
|
||||
FlushMethodNscd,
|
||||
}
|
||||
|
||||
platforms := []string{"darwin", "linux"}
|
||||
|
||||
for _, method := range methods {
|
||||
for _, platform := range platforms {
|
||||
t.Run(string(method)+"_"+platform, func(t *testing.T) {
|
||||
flusher := NewDNSFlusher(method)
|
||||
assert.NotNil(t, flusher)
|
||||
|
||||
// Just verify no panic when checking method
|
||||
_ = flusher.method
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDNSFlusher_DetectMethod(b *testing.B) {
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = flusher.detectMethod()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
// Package daemon implements the privileged daemon that manages /etc/hosts.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// HostsPath is the path to the system hosts file.
|
||||
HostsPath = "/etc/hosts"
|
||||
// BackupDir is the directory for hosts file backups.
|
||||
BackupDir = "/var/backups/lolcathost"
|
||||
// MaxBackups is the maximum number of backups to keep.
|
||||
MaxBackups = 10
|
||||
|
||||
// Markers for the managed section.
|
||||
markerStart = "# ========== LOLCATHOST MANAGED - DO NOT EDIT =========="
|
||||
markerEnd = "# ========== END LOLCATHOST =========="
|
||||
)
|
||||
|
||||
// HostEntry represents a single entry in the hosts file.
|
||||
type HostEntry struct {
|
||||
IP string
|
||||
Domain string
|
||||
Alias string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// HostsManager handles reading and writing the hosts file.
|
||||
type HostsManager struct {
|
||||
hostsPath string
|
||||
backupDir string
|
||||
}
|
||||
|
||||
// NewHostsManager creates a new hosts manager.
|
||||
func NewHostsManager() *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: HostsPath,
|
||||
backupDir: BackupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
|
||||
func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: hostsPath,
|
||||
backupDir: backupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadManagedEntries reads the lolcathost-managed entries from the hosts file.
|
||||
func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) {
|
||||
file, err := os.Open(m.hostsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open hosts file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var entries []HostEntry
|
||||
inManagedSection := false
|
||||
scanner := bufio.NewScanner(file)
|
||||
entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if line == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
|
||||
matches := entryRegex.FindStringSubmatch(line)
|
||||
if len(matches) == 4 {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: matches[1],
|
||||
Domain: matches[2],
|
||||
Alias: matches[3],
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// WriteManagedEntries writes the managed entries to the hosts file.
|
||||
func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error {
|
||||
// Create backup first
|
||||
if err := m.CreateBackup(); err != nil {
|
||||
return fmt.Errorf("failed to create backup: %w", err)
|
||||
}
|
||||
|
||||
// Read existing content
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
// Remove existing managed section
|
||||
newContent := m.removeManagedSection(string(content))
|
||||
|
||||
// Build new managed section
|
||||
managedSection := m.buildManagedSection(entries)
|
||||
|
||||
// Append managed section
|
||||
newContent = strings.TrimRight(newContent, "\n") + "\n\n" + managedSection
|
||||
|
||||
// Write atomically
|
||||
if err := m.writeAtomic(newContent); err != nil {
|
||||
return fmt.Errorf("failed to write hosts file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *HostsManager) removeManagedSection(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var result []string
|
||||
inManagedSection := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if trimmed == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
if !inManagedSection {
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing empty lines
|
||||
for len(result) > 0 && strings.TrimSpace(result[len(result)-1]) == "" {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func (m *HostsManager) buildManagedSection(entries []HostEntry) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(markerStart)
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.Enabled {
|
||||
sb.WriteString(fmt.Sprintf("%s\t%s\t# lolcathost:%s\n", entry.IP, entry.Domain, entry.Alias))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(markerEnd)
|
||||
sb.WriteString("\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m *HostsManager) writeAtomic(content string) error {
|
||||
// Write to temp file first
|
||||
tmpFile := m.hostsPath + ".tmp"
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename atomically
|
||||
if err := os.Rename(tmpFile, m.hostsPath); err != nil {
|
||||
os.Remove(tmpFile)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBackup creates a backup of the current hosts file.
|
||||
func (m *HostsManager) CreateBackup() error {
|
||||
if err := os.MkdirAll(m.backupDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp))
|
||||
|
||||
if err := os.WriteFile(backupPath, content, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write backup: %w", err)
|
||||
}
|
||||
|
||||
// Cleanup old backups
|
||||
if err := m.cleanupBackups(); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to cleanup backups: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *HostsManager) cleanupBackups() error {
|
||||
entries, err := os.ReadDir(m.backupDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var backups []os.DirEntry
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "hosts.") && strings.HasSuffix(entry.Name(), ".bak") {
|
||||
backups = append(backups, entry)
|
||||
}
|
||||
}
|
||||
|
||||
if len(backups) <= MaxBackups {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by name (timestamp) descending
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
return backups[i].Name() > backups[j].Name()
|
||||
})
|
||||
|
||||
// Remove oldest backups
|
||||
for i := MaxBackups; i < len(backups); i++ {
|
||||
path := filepath.Join(m.backupDir, backups[i].Name())
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBackups returns a list of available backups.
|
||||
func (m *HostsManager) ListBackups() ([]BackupInfo, error) {
|
||||
entries, err := os.ReadDir(m.backupDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backups []BackupInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasPrefix(entry.Name(), "hosts.") || !strings.HasSuffix(entry.Name(), ".bak") {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
backups = append(backups, BackupInfo{
|
||||
Name: entry.Name(),
|
||||
Timestamp: info.ModTime().Unix(),
|
||||
Size: info.Size(),
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by timestamp descending
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
return backups[i].Timestamp > backups[j].Timestamp
|
||||
})
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
// BackupInfo holds information about a backup file.
|
||||
type BackupInfo struct {
|
||||
Name string
|
||||
Timestamp int64
|
||||
Size int64
|
||||
}
|
||||
|
||||
// RestoreBackup restores a backup by name.
|
||||
func (m *HostsManager) RestoreBackup(name string) error {
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
// Validate backup name to prevent path traversal
|
||||
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
|
||||
return fmt.Errorf("invalid backup name")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read backup: %w", err)
|
||||
}
|
||||
|
||||
// Create a backup of current state before restoring
|
||||
if err := m.CreateBackup(); err != nil {
|
||||
return fmt.Errorf("failed to create backup before restore: %w", err)
|
||||
}
|
||||
|
||||
if err := m.writeAtomic(string(content)); err != nil {
|
||||
return fmt.Errorf("failed to restore backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostsManager_ReadManagedEntries(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
hostsContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
::1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 example.com # lolcathost:example-local
|
||||
192.168.1.1 api.example.com # lolcathost:api-local
|
||||
# ========== END LOLCATHOST ==========
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, entries, 2)
|
||||
assert.Equal(t, "127.0.0.1", entries[0].IP)
|
||||
assert.Equal(t, "example.com", entries[0].Domain)
|
||||
assert.Equal(t, "example-local", entries[0].Alias)
|
||||
assert.Equal(t, "192.168.1.1", entries[1].IP)
|
||||
assert.Equal(t, "api.example.com", entries[1].Domain)
|
||||
assert.Equal(t, "api-local", entries[1].Alias)
|
||||
}
|
||||
|
||||
func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
hostsContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, entries)
|
||||
}
|
||||
|
||||
func TestHostsManager_WriteManagedEntries(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create initial hosts file
|
||||
initialContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true},
|
||||
{IP: "127.0.0.1", Domain: "api.myapp.com", Alias: "api-local", Enabled: true},
|
||||
{IP: "192.168.1.1", Domain: "staging.myapp.com", Alias: "staging", Enabled: false},
|
||||
}
|
||||
|
||||
err = manager.WriteManagedEntries(entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
|
||||
assert.Contains(t, contentStr, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tmyapp.com\t# lolcathost:myapp-local")
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tapi.myapp.com\t# lolcathost:api-local")
|
||||
assert.NotContains(t, contentStr, "staging.myapp.com") // disabled
|
||||
assert.Contains(t, contentStr, "# ========== END LOLCATHOST ==========")
|
||||
}
|
||||
|
||||
func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create hosts file with existing managed section
|
||||
initialContent := `127.0.0.1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 old.com # lolcathost:old
|
||||
# ========== END LOLCATHOST ==========
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true},
|
||||
}
|
||||
|
||||
err = manager.WriteManagedEntries(entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
|
||||
assert.Contains(t, contentStr, "new.com")
|
||||
assert.NotContains(t, contentStr, "old.com")
|
||||
}
|
||||
|
||||
func TestHostsManager_CreateBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
hostsContent := "127.0.0.1\tlocalhost\n"
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify backup exists
|
||||
entries, err := os.ReadDir(backupDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, entries, 1)
|
||||
assert.True(t, strings.HasPrefix(entries[0].Name(), "hosts."))
|
||||
assert.True(t, strings.HasSuffix(entries[0].Name(), ".bak"))
|
||||
|
||||
// Verify backup content
|
||||
backupContent, err := os.ReadFile(filepath.Join(backupDir, entries[0].Name()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, hostsContent, string(backupContent))
|
||||
}
|
||||
|
||||
func TestHostsManager_ListBackups(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create hosts file
|
||||
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create backup files with different timestamps
|
||||
err = os.MkdirAll(backupDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
backupNames := []string{
|
||||
"hosts.20231201-120000.bak",
|
||||
"hosts.20231201-120001.bak",
|
||||
"hosts.20231201-120002.bak",
|
||||
}
|
||||
for _, name := range backupNames {
|
||||
err = os.WriteFile(filepath.Join(backupDir, name), []byte("backup"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, backups, 3)
|
||||
}
|
||||
|
||||
func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "nonexistent")
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, backups)
|
||||
}
|
||||
|
||||
func TestHostsManager_RestoreBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create initial hosts file
|
||||
initialContent := "initial content"
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create backup
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify hosts file
|
||||
err = os.WriteFile(hostsPath, []byte("modified content"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get backup name
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, backups, 1)
|
||||
|
||||
// Restore
|
||||
err = manager.RestoreBackup(backups[0].Name)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify content restored
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, initialContent, string(content))
|
||||
}
|
||||
|
||||
func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
manager := NewHostsManagerWithPaths(
|
||||
filepath.Join(tmpDir, "hosts"),
|
||||
filepath.Join(tmpDir, "backups"),
|
||||
)
|
||||
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
"hosts.bak", // Missing timestamp
|
||||
"notahosts.backup", // Wrong format
|
||||
"",
|
||||
}
|
||||
|
||||
for _, name := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := manager.RestoreBackup(name)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsManager_CleanupBackups(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create more than MaxBackups
|
||||
for i := 0; i < MaxBackups+5; i++ {
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify only MaxBackups remain
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(backups), MaxBackups)
|
||||
}
|
||||
|
||||
func TestHostsManager_RemoveManagedSection(t *testing.T) {
|
||||
manager := &HostsManager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with managed section",
|
||||
input: `127.0.0.1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 example.com # lolcathost:test
|
||||
# ========== END LOLCATHOST ==========
|
||||
`,
|
||||
expected: "127.0.0.1\tlocalhost",
|
||||
},
|
||||
{
|
||||
name: "without managed section",
|
||||
input: "127.0.0.1\tlocalhost\n",
|
||||
expected: "127.0.0.1\tlocalhost",
|
||||
},
|
||||
{
|
||||
name: "multiple managed sections",
|
||||
input: `127.0.0.1 localhost
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
entry1
|
||||
# ========== END LOLCATHOST ==========
|
||||
more content
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
entry2
|
||||
# ========== END LOLCATHOST ==========
|
||||
`,
|
||||
expected: "127.0.0.1\tlocalhost\nmore content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.removeManagedSection(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsManager_BuildManagedSection(t *testing.T) {
|
||||
manager := &HostsManager{}
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "a.com", Alias: "a", Enabled: true},
|
||||
{IP: "192.168.1.1", Domain: "b.com", Alias: "b", Enabled: true},
|
||||
{IP: "10.0.0.1", Domain: "c.com", Alias: "c", Enabled: false},
|
||||
}
|
||||
|
||||
result := manager.buildManagedSection(entries)
|
||||
|
||||
assert.Contains(t, result, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
|
||||
assert.Contains(t, result, "127.0.0.1\ta.com\t# lolcathost:a")
|
||||
assert.Contains(t, result, "192.168.1.1\tb.com\t# lolcathost:b")
|
||||
assert.NotContains(t, result, "c.com") // disabled
|
||||
assert.Contains(t, result, "# ========== END LOLCATHOST ==========")
|
||||
}
|
||||
|
||||
// Matrix tests for hosts file parsing
|
||||
func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
|
||||
ips := []string{"127.0.0.1", "192.168.1.1", "::1"}
|
||||
domains := []string{"example.com", "sub.example.com", "my-app.test"}
|
||||
aliases := []string{"test", "my-alias", "app-1"}
|
||||
|
||||
for _, ip := range ips {
|
||||
for _, domain := range domains {
|
||||
for _, alias := range aliases {
|
||||
t.Run(ip+"/"+domain+"/"+alias, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
content := "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n"
|
||||
content += ip + "\t" + domain + "\t# lolcathost:" + alias + "\n"
|
||||
content += "# ========== END LOLCATHOST ==========\n"
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
assert.Equal(t, ip, entries[0].IP)
|
||||
assert.Equal(t, domain, entries[0].Domain)
|
||||
assert.Equal(t, alias, entries[0].Alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
// Create a hosts file with many entries
|
||||
var content strings.Builder
|
||||
content.WriteString("127.0.0.1\tlocalhost\n")
|
||||
content.WriteString("# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n")
|
||||
for i := 0; i < 100; i++ {
|
||||
content.WriteString("127.0.0.1\texample" + string(rune('a'+i%26)) + ".com\t# lolcathost:alias" + string(rune('a'+i%26)) + "\n")
|
||||
}
|
||||
content.WriteString("# ========== END LOLCATHOST ==========\n")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte(content.String()), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = manager.ReadManagedEntries()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := make([]HostEntry, 50)
|
||||
for i := range entries {
|
||||
entries[i] = HostEntry{
|
||||
IP: "127.0.0.1",
|
||||
Domain: "example" + string(rune('a'+i%26)) + ".com",
|
||||
Alias: "alias" + string(rune('a'+i%26)),
|
||||
Enabled: i%2 == 0,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = manager.WriteManagedEntries(entries)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
//go:build darwin
|
||||
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getPeerCredentials extracts peer credentials from a Unix socket connection on macOS.
|
||||
// Note: macOS Xucred doesn't include PID, so we use LOCAL_PEERPID separately.
|
||||
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
|
||||
unixConn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawConn, err := unixConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var creds *PeerCredentials
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
xucred, err := unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get PID separately using LOCAL_PEERPID
|
||||
var pid int32
|
||||
pidLen := uint32(unsafe.Sizeof(pid))
|
||||
_, _, errno := syscall.Syscall6(
|
||||
syscall.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_LOCAL,
|
||||
0x002, // LOCAL_PEERPID
|
||||
uintptr(unsafe.Pointer(&pid)),
|
||||
uintptr(unsafe.Pointer(&pidLen)),
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
pid = 0
|
||||
}
|
||||
|
||||
creds = &PeerCredentials{
|
||||
UID: xucred.Uid,
|
||||
GID: xucred.Groups[0],
|
||||
PID: pid,
|
||||
}
|
||||
})
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
//go:build linux
|
||||
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getPeerCredentials extracts peer credentials from a Unix socket connection on Linux.
|
||||
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
|
||||
unixConn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawConn, err := unixConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var creds *PeerCredentials
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
creds = &PeerCredentials{
|
||||
UID: ucred.Uid,
|
||||
GID: ucred.Gid,
|
||||
PID: ucred.Pid,
|
||||
}
|
||||
})
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
// Package daemon provides security functions including rate limiting and audit logging.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuditLogPath is the path to the audit log file.
|
||||
AuditLogPath = "/var/log/lolcathost/audit.log"
|
||||
// RateLimit is the maximum requests per minute per PID.
|
||||
RateLimit = 100
|
||||
// RateLimitWindow is the time window for rate limiting.
|
||||
RateLimitWindow = time.Minute
|
||||
)
|
||||
|
||||
// RateLimiter implements per-PID rate limiting.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
requests map[int32][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
requests: make(map[int32][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given PID should be allowed.
|
||||
func (r *RateLimiter) Allow(pid int32) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
// Get existing requests for this PID
|
||||
reqs := r.requests[pid]
|
||||
|
||||
// Filter out old requests
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validReqs) >= r.limit {
|
||||
r.requests[pid] = validReqs
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new request
|
||||
validReqs = append(validReqs, now)
|
||||
r.requests[pid] = validReqs
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes old entries from the rate limiter.
|
||||
func (r *RateLimiter) Cleanup() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
for pid, reqs := range r.requests {
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
}
|
||||
}
|
||||
if len(validReqs) == 0 {
|
||||
delete(r.requests, pid)
|
||||
} else {
|
||||
r.requests[pid] = validReqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AuditLogger handles audit logging.
|
||||
type AuditLogger struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
path string
|
||||
encoder *json.Encoder
|
||||
}
|
||||
|
||||
// AuditEntry represents a single audit log entry.
|
||||
type AuditEntry struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
UID uint32 `json:"uid"`
|
||||
PID int32 `json:"pid"`
|
||||
Action string `json:"action"`
|
||||
Details any `json:"details,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger.
|
||||
func NewAuditLogger(path string) (*AuditLogger, error) {
|
||||
// Ensure directory exists
|
||||
dir := path[:len(path)-len("/audit.log")]
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audit log: %w", err)
|
||||
}
|
||||
|
||||
return &AuditLogger{
|
||||
file: file,
|
||||
path: path,
|
||||
encoder: json.NewEncoder(file),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Log writes an audit entry.
|
||||
func (a *AuditLogger) Log(uid uint32, pid int32, action string, details any, success bool, errMsg string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
entry := AuditEntry{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
UID: uid,
|
||||
PID: pid,
|
||||
Action: action,
|
||||
Details: details,
|
||||
Success: success,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// Ignore encoding errors - audit logging should not fail the operation
|
||||
_ = a.encoder.Encode(entry)
|
||||
}
|
||||
|
||||
// Close closes the audit logger.
|
||||
func (a *AuditLogger) Close() error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.file != nil {
|
||||
err := a.file.Close()
|
||||
a.file = nil // Prevent double close
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerCredentials holds the credentials of a connected peer.
|
||||
type PeerCredentials struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
PID int32
|
||||
}
|
||||
|
||||
// isUserInGroup checks if a user (by UID) is a member of a group (by GID).
|
||||
// This checks supplementary groups, not just the primary GID.
|
||||
func isUserInGroup(uid uint32, targetGID uint32) bool {
|
||||
// Look up user by UID
|
||||
u, err := user.LookupId(fmt.Sprintf("%d", uid))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get user's group IDs
|
||||
groupIDs, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if target GID is in the list
|
||||
targetGIDStr := fmt.Sprintf("%d", targetGID)
|
||||
for _, gid := range groupIDs {
|
||||
if gid == targetGIDStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRateLimiter_Allow(t *testing.T) {
|
||||
t.Run("under limit", func(t *testing.T) {
|
||||
rl := NewRateLimiter(5, time.Minute)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow(123), "request %d should be allowed", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("over limit", func(t *testing.T) {
|
||||
rl := NewRateLimiter(3, time.Minute)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, rl.Allow(123))
|
||||
}
|
||||
|
||||
// 4th request should be blocked
|
||||
assert.False(t, rl.Allow(123))
|
||||
})
|
||||
|
||||
t.Run("different PIDs", func(t *testing.T) {
|
||||
rl := NewRateLimiter(2, time.Minute)
|
||||
|
||||
// PID 1
|
||||
assert.True(t, rl.Allow(1))
|
||||
assert.True(t, rl.Allow(1))
|
||||
assert.False(t, rl.Allow(1))
|
||||
|
||||
// PID 2 should have its own limit
|
||||
assert.True(t, rl.Allow(2))
|
||||
assert.True(t, rl.Allow(2))
|
||||
assert.False(t, rl.Allow(2))
|
||||
})
|
||||
|
||||
t.Run("window expiration", func(t *testing.T) {
|
||||
rl := NewRateLimiter(2, 10*time.Millisecond)
|
||||
|
||||
assert.True(t, rl.Allow(123))
|
||||
assert.True(t, rl.Allow(123))
|
||||
assert.False(t, rl.Allow(123))
|
||||
|
||||
// Wait for window to expire
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow(123))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiter_Cleanup(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10*time.Millisecond)
|
||||
|
||||
// Add requests from multiple PIDs
|
||||
for pid := int32(1); pid <= 5; pid++ {
|
||||
rl.Allow(pid)
|
||||
}
|
||||
|
||||
assert.Len(t, rl.requests, 5)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Cleanup
|
||||
rl.Cleanup()
|
||||
|
||||
assert.Empty(t, rl.requests)
|
||||
}
|
||||
|
||||
func TestAuditLogger_Log(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
|
||||
logger.Log(1000, 12345, "sync", nil, false, "sync failed")
|
||||
|
||||
// Read log file
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, `"action":"set"`)
|
||||
assert.Contains(t, contentStr, `"uid":1000`)
|
||||
assert.Contains(t, contentStr, `"pid":12345`)
|
||||
assert.Contains(t, contentStr, `"success":true`)
|
||||
assert.Contains(t, contentStr, `"action":"sync"`)
|
||||
assert.Contains(t, contentStr, `"success":false`)
|
||||
assert.Contains(t, contentStr, `"error":"sync failed"`)
|
||||
}
|
||||
|
||||
func TestAuditLogger_CreatesDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "subdir", "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
// Verify directory was created
|
||||
_, err = os.Stat(filepath.Dir(logPath))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuditLogger_Close(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = logger.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Closing again should not error
|
||||
err = logger.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPeerCredentials(t *testing.T) {
|
||||
creds := &PeerCredentials{
|
||||
UID: 501,
|
||||
GID: 20,
|
||||
PID: 12345,
|
||||
}
|
||||
|
||||
assert.Equal(t, uint32(501), creds.UID)
|
||||
assert.Equal(t, uint32(20), creds.GID)
|
||||
assert.Equal(t, int32(12345), creds.PID)
|
||||
}
|
||||
|
||||
// Matrix test for rate limiting
|
||||
func TestRateLimiter_Matrix(t *testing.T) {
|
||||
limits := []int{1, 5, 10, 100}
|
||||
windows := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, time.Second}
|
||||
|
||||
for _, limit := range limits {
|
||||
for _, window := range windows {
|
||||
t.Run(
|
||||
"limit="+string(rune('0'+limit))+"_window="+window.String(),
|
||||
func(t *testing.T) {
|
||||
rl := NewRateLimiter(limit, window)
|
||||
|
||||
// Should allow exactly 'limit' requests
|
||||
for i := 0; i < limit; i++ {
|
||||
assert.True(t, rl.Allow(1))
|
||||
}
|
||||
|
||||
// Next should be blocked
|
||||
assert.False(t, rl.Allow(1))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRateLimiter_Allow(b *testing.B) {
|
||||
rl := NewRateLimiter(RateLimit, RateLimitWindow)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Allow(int32(i % 100))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRateLimiter_Cleanup(b *testing.B) {
|
||||
rl := NewRateLimiter(RateLimit, RateLimitWindow)
|
||||
|
||||
// Pre-populate with requests
|
||||
for i := 0; i < 1000; i++ {
|
||||
rl.Allow(int32(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAuditLogger_Log(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(b, err)
|
||||
defer logger.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,803 @@
|
||||
// Package daemon provides the Unix socket server for the daemon.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Version is set by the main package at startup
|
||||
var Version = "dev"
|
||||
|
||||
// Server is the daemon's Unix socket server.
|
||||
type Server struct {
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
config *config.Manager
|
||||
hosts *HostsManager
|
||||
flusher *DNSFlusher
|
||||
rateLimiter *RateLimiter
|
||||
auditLogger *AuditLogger
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
requestCount int64
|
||||
startTime int64
|
||||
}
|
||||
|
||||
// NewServer creates a new daemon server.
|
||||
func NewServer(socketPath string, cfgManager *config.Manager) *Server {
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
config: cfgManager,
|
||||
hosts: NewHostsManager(),
|
||||
flusher: NewDNSFlusher(FlushMethodAuto),
|
||||
rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the server.
|
||||
func (s *Server) Start() error {
|
||||
// Remove existing socket
|
||||
os.Remove(s.socketPath)
|
||||
|
||||
listener, err := net.Listen("unix", s.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions: 0660 root:lolcathost
|
||||
if err := os.Chmod(s.socketPath, 0660); err != nil {
|
||||
listener.Close()
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
// Set socket group to lolcathost (GID 850)
|
||||
if err := os.Chown(s.socketPath, 0, 850); err != nil {
|
||||
listener.Close()
|
||||
return fmt.Errorf("failed to set socket ownership: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.running = true
|
||||
s.startTime = currentTimeUnix()
|
||||
|
||||
// Try to create audit logger, but don't fail if it doesn't work
|
||||
if logger, err := NewAuditLogger(AuditLogPath); err == nil {
|
||||
s.auditLogger = logger
|
||||
}
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func currentTimeUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// Stop stops the server.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
|
||||
close(s.stopCh)
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
os.Remove(s.socketPath)
|
||||
|
||||
if s.auditLogger != nil {
|
||||
s.auditLogger.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// LolcathostGID is the group ID for the lolcathost group.
|
||||
const LolcathostGID = 850
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Get peer credentials
|
||||
creds := s.getPeerCredentials(conn)
|
||||
|
||||
// Authorization check: verify peer is authorized
|
||||
if !s.isAuthorized(creds) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
if s.auditLogger != nil {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
if creds != nil {
|
||||
uid = creds.UID
|
||||
pid = creds.PID
|
||||
}
|
||||
s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
|
||||
continue
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.requestCount++
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handleRequest(&req, creds)
|
||||
s.writeResponse(conn, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// isAuthorized checks if the peer is authorized to access the daemon.
|
||||
// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850).
|
||||
func (s *Server) isAuthorized(creds *PeerCredentials) bool {
|
||||
if creds == nil {
|
||||
// Can't verify credentials - deny by default
|
||||
return false
|
||||
}
|
||||
|
||||
// Root is always authorized
|
||||
if creds.UID == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if user's primary GID is lolcathost
|
||||
if creds.GID == LolcathostGID {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check supplementary groups (user might be in lolcathost as secondary group)
|
||||
// This requires looking up the user's groups from the system
|
||||
return isUserInGroup(creds.UID, LolcathostGID)
|
||||
}
|
||||
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
|
||||
data, _ := json.Marshal(resp)
|
||||
data = append(data, '\n')
|
||||
conn.Write(data)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
if creds != nil {
|
||||
uid = creds.UID
|
||||
pid = creds.PID
|
||||
}
|
||||
|
||||
switch req.Type {
|
||||
case protocol.RequestPing:
|
||||
return s.handlePing()
|
||||
|
||||
case protocol.RequestStatus:
|
||||
return s.handleStatus()
|
||||
|
||||
case protocol.RequestList:
|
||||
return s.handleList()
|
||||
|
||||
case protocol.RequestSet:
|
||||
resp := s.handleSet(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.SetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestSync:
|
||||
resp := s.handleSync()
|
||||
if s.auditLogger != nil {
|
||||
s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestPreset:
|
||||
resp := s.handlePreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.PresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestRollback:
|
||||
resp := s.handleRollback(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.RollbackPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestBackups:
|
||||
return s.handleBackups()
|
||||
|
||||
case protocol.RequestAdd:
|
||||
resp := s.handleAdd(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.AddPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDelete:
|
||||
resp := s.handleDelete(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.DeletePayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestAddGroup:
|
||||
resp := s.handleAddGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.GroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDeleteGroup:
|
||||
resp := s.handleDeleteGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.GroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestListGroups:
|
||||
return s.handleListGroups()
|
||||
|
||||
case protocol.RequestRenameGroup:
|
||||
resp := s.handleRenameGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.RenameGroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestAddPreset:
|
||||
resp := s.handleAddPreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.AddPresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDeletePreset:
|
||||
resp := s.handleDeletePreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.PresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestListPresets:
|
||||
return s.handleListPresets()
|
||||
|
||||
default:
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePing() *protocol.Response {
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleStatus() *protocol.Response {
|
||||
s.mu.RLock()
|
||||
reqCount := s.requestCount
|
||||
startTime := s.startTime
|
||||
s.mu.RUnlock()
|
||||
|
||||
cfg := s.config.Get()
|
||||
var activeCount int
|
||||
if cfg != nil {
|
||||
for _, h := range cfg.GetAllHosts() {
|
||||
if h.Enabled {
|
||||
activeCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data := protocol.StatusData{
|
||||
Running: true,
|
||||
Version: Version,
|
||||
Uptime: nowUnix() - startTime,
|
||||
ActiveCount: activeCount,
|
||||
RequestCount: reqCount,
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(data)
|
||||
return resp
|
||||
}
|
||||
|
||||
func nowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
func (s *Server) handleList() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
var entries []protocol.HostEntry
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
entries = append(entries, protocol.HostEntry{
|
||||
Domain: h.Domain,
|
||||
IP: h.IP,
|
||||
Alias: h.Alias,
|
||||
Enabled: h.Enabled,
|
||||
Group: g.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.SetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
host, _ := cfg.FindHostByAlias(payload.Alias)
|
||||
if host == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Check for conflicts if enabling
|
||||
if payload.Enabled && !payload.Force {
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict,
|
||||
fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update config
|
||||
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: host.Domain,
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleSync() *protocol.Response {
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.PresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.ApplyPreset(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleRollback(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.RollbackPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if err := s.hosts.RestoreBackup(payload.BackupName); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err))
|
||||
}
|
||||
|
||||
// Flush DNS after restore
|
||||
s.flusher.Flush()
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleBackups() *protocol.Response {
|
||||
backups, err := s.hosts.ListBackups()
|
||||
if err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err))
|
||||
}
|
||||
|
||||
var infos []protocol.BackupInfo
|
||||
for _, b := range backups {
|
||||
infos = append(infos, protocol.BackupInfo{
|
||||
Name: b.Name,
|
||||
Timestamp: b.Timestamp,
|
||||
Size: b.Size,
|
||||
})
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.AddPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
// Validate domain
|
||||
if payload.Domain == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required")
|
||||
}
|
||||
|
||||
// Validate IP
|
||||
if payload.IP == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required")
|
||||
}
|
||||
|
||||
// Validate group
|
||||
if payload.Group == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required")
|
||||
}
|
||||
|
||||
// Check blocked domains
|
||||
if config.IsBlockedDomain(payload.Domain) {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain))
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
// Add to config (alias will be auto-generated if empty)
|
||||
if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: payload.Domain,
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.DeletePayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Alias == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
// Delete from config
|
||||
if !cfg.DeleteHost(payload.Alias) {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.GroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.AddGroup(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.GroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.DeleteGroup(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleListGroups() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.RenameGroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.OldName == "" || payload.NewName == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.AddPresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.PresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.DeletePreset(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleListPresets() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
presets := cfg.GetPresets()
|
||||
infos := make([]protocol.PresetInfo, len(presets))
|
||||
for i, p := range presets {
|
||||
infos[i] = protocol.PresetInfo{
|
||||
Name: p.Name,
|
||||
Enable: p.Enable,
|
||||
Disable: p.Disable,
|
||||
}
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) syncHostsFile() error {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("no configuration loaded")
|
||||
}
|
||||
|
||||
var entries []HostEntry
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: h.IP,
|
||||
Domain: h.Domain,
|
||||
Alias: h.Alias,
|
||||
Enabled: h.Enabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hosts.WriteManagedEntries(entries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
return s.flusher.Flush()
|
||||
}
|
||||
@@ -0,0 +1,474 @@
|
||||
// Package installer handles installation and uninstallation of the lolcathost daemon.
|
||||
package installer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
// GroupName is the name of the lolcathost group.
|
||||
GroupName = "lolcathost"
|
||||
// GroupGID is the GID for the lolcathost group (macOS).
|
||||
GroupGID = 850
|
||||
|
||||
// Paths
|
||||
LogDir = "/var/log/lolcathost"
|
||||
BackupDir = "/var/backups/lolcathost"
|
||||
SocketPath = "/var/run/lolcathost.sock"
|
||||
LaunchDaemonDir = "/Library/LaunchDaemons"
|
||||
SystemdDir = "/etc/systemd/system"
|
||||
)
|
||||
|
||||
// LaunchDaemonPlist is the macOS LaunchDaemon plist template.
|
||||
const LaunchDaemonPlist = `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.lolcathost.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>%s</string>
|
||||
<string>--daemon</string>
|
||||
<string>--config</string>
|
||||
<string>/etc/lolcathost/config.yaml</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/lolcathost/daemon.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/lolcathost/daemon.err</string>
|
||||
</dict>
|
||||
</plist>
|
||||
`
|
||||
|
||||
// SystemdUnit is the Linux systemd unit template.
|
||||
const SystemdUnit = `[Unit]
|
||||
Description=lolcathost - Dynamic Host Management Daemon
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart=%s --daemon --config /etc/lolcathost/config.yaml
|
||||
Restart=always
|
||||
RestartSec=5
|
||||
User=root
|
||||
Group=root
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
|
||||
// Installer handles installation and uninstallation.
|
||||
type Installer struct {
|
||||
binaryPath string
|
||||
verbose bool
|
||||
}
|
||||
|
||||
// New creates a new installer.
|
||||
func New() (*Installer, error) {
|
||||
binaryPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get executable path: %w", err)
|
||||
}
|
||||
|
||||
// Resolve symlinks
|
||||
binaryPath, err = filepath.EvalSymlinks(binaryPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve executable path: %w", err)
|
||||
}
|
||||
|
||||
return &Installer{
|
||||
binaryPath: binaryPath,
|
||||
verbose: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Install performs the full installation.
|
||||
func (i *Installer) Install() error {
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("--install requires sudo")
|
||||
}
|
||||
|
||||
i.log("Installing lolcathost...")
|
||||
|
||||
// Create group
|
||||
if err := i.createGroup(); err != nil {
|
||||
return fmt.Errorf("failed to create group: %w", err)
|
||||
}
|
||||
|
||||
// Add current user to group
|
||||
if err := i.addCurrentUserToGroup(); err != nil {
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
|
||||
// Create directories
|
||||
if err := i.createDirectories(); err != nil {
|
||||
return fmt.Errorf("failed to create directories: %w", err)
|
||||
}
|
||||
|
||||
// Create system config for daemon
|
||||
if err := i.createSystemConfig(); err != nil {
|
||||
return fmt.Errorf("failed to create system config: %w", err)
|
||||
}
|
||||
|
||||
// Install service
|
||||
if runtime.GOOS == "darwin" {
|
||||
if err := i.installLaunchDaemon(); err != nil {
|
||||
return fmt.Errorf("failed to install LaunchDaemon: %w", err)
|
||||
}
|
||||
} else if runtime.GOOS == "linux" {
|
||||
if err := i.installSystemdService(); err != nil {
|
||||
return fmt.Errorf("failed to install systemd service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create default config for the invoking user
|
||||
if err := i.createDefaultConfig(); err != nil {
|
||||
i.log("Warning: failed to create default config: %v", err)
|
||||
}
|
||||
|
||||
i.log("")
|
||||
i.log("✓ Installed successfully!")
|
||||
i.log("")
|
||||
i.log("Next steps:")
|
||||
i.log(" 1. Open a NEW terminal (for group membership to take effect)")
|
||||
i.log(" 2. Run 'lolcathost' to start the TUI")
|
||||
i.log("")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Uninstall removes the installation.
|
||||
func (i *Installer) Uninstall() error {
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("--uninstall requires sudo")
|
||||
}
|
||||
|
||||
i.log("Uninstalling lolcathost...")
|
||||
|
||||
// Stop and remove service
|
||||
if runtime.GOOS == "darwin" {
|
||||
i.uninstallLaunchDaemon()
|
||||
} else if runtime.GOOS == "linux" {
|
||||
i.uninstallSystemdService()
|
||||
}
|
||||
|
||||
// Remove socket
|
||||
os.Remove(SocketPath)
|
||||
|
||||
// Note: We don't remove the group, logs, or backups
|
||||
// The user may want to keep these
|
||||
|
||||
i.log("")
|
||||
i.log("✓ Uninstalled successfully!")
|
||||
i.log("")
|
||||
i.log("Note: Log files, backups, and the group were preserved.")
|
||||
i.log("To fully remove, manually delete:")
|
||||
i.log(" - /var/log/lolcathost/")
|
||||
i.log(" - /var/backups/lolcathost/")
|
||||
i.log(" - ~/.config/lolcathost/")
|
||||
if runtime.GOOS == "darwin" {
|
||||
i.log(" - Remove group: sudo dscl . -delete /Groups/%s", GroupName)
|
||||
} else {
|
||||
i.log(" - Remove group: sudo groupdel %s", GroupName)
|
||||
}
|
||||
i.log("")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) log(format string, args ...any) {
|
||||
if i.verbose {
|
||||
fmt.Printf(format+"\n", args...)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Installer) createGroup() error {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return i.createGroupDarwin()
|
||||
case "linux":
|
||||
return i.createGroupLinux()
|
||||
default:
|
||||
return fmt.Errorf("unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Installer) createGroupDarwin() error {
|
||||
// Check if group exists
|
||||
if _, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName).Output(); err == nil {
|
||||
i.log(" Group '%s' already exists", GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating group '%s' (GID %d)...", GroupName, GroupGID)
|
||||
|
||||
// Create group
|
||||
cmds := [][]string{
|
||||
{"dscl", ".", "-create", "/Groups/" + GroupName},
|
||||
{"dscl", ".", "-create", "/Groups/" + GroupName, "PrimaryGroupID", strconv.Itoa(GroupGID)},
|
||||
{"dscl", ".", "-create", "/Groups/" + GroupName, "RealName", "lolcathost users"},
|
||||
}
|
||||
|
||||
for _, args := range cmds {
|
||||
if err := exec.Command(args[0], args[1:]...).Run(); err != nil {
|
||||
return fmt.Errorf("command %v failed: %w", args, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) createGroupLinux() error {
|
||||
// Check if group exists
|
||||
if _, err := exec.Command("getent", "group", GroupName).Output(); err == nil {
|
||||
i.log(" Group '%s' already exists", GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating group '%s'...", GroupName)
|
||||
|
||||
if err := exec.Command("groupadd", "-r", GroupName).Run(); err != nil {
|
||||
return fmt.Errorf("groupadd failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) addCurrentUserToGroup() error {
|
||||
// Get the real user (not root)
|
||||
username := os.Getenv("SUDO_USER")
|
||||
if username == "" {
|
||||
// Fall back to current user
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current user: %w", err)
|
||||
}
|
||||
username = u.Username
|
||||
}
|
||||
|
||||
if username == "root" {
|
||||
i.log(" Skipping adding root to group")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return i.addUserToGroupDarwin(username)
|
||||
case "linux":
|
||||
return i.addUserToGroupLinux(username)
|
||||
default:
|
||||
return fmt.Errorf("unsupported OS: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Installer) addUserToGroupDarwin(username string) error {
|
||||
// Check if user is already in group
|
||||
output, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName, "GroupMembership").Output()
|
||||
if err == nil && strings.Contains(string(output), username) {
|
||||
i.log(" User '%s' already in group '%s'", username, GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
|
||||
|
||||
if err := exec.Command("dscl", ".", "-append", "/Groups/"+GroupName, "GroupMembership", username).Run(); err != nil {
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) addUserToGroupLinux(username string) error {
|
||||
// Check if user is already in group
|
||||
output, err := exec.Command("id", "-nG", username).Output()
|
||||
if err == nil && strings.Contains(string(output), GroupName) {
|
||||
i.log(" User '%s' already in group '%s'", username, GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
|
||||
|
||||
if err := exec.Command("usermod", "-aG", GroupName, username).Run(); err != nil {
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) createDirectories() error {
|
||||
dirs := []string{LogDir, BackupDir, config.SystemConfigDir}
|
||||
|
||||
for _, dir := range dirs {
|
||||
i.log(" Creating directory '%s'...", dir)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create %s: %w", dir, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) createSystemConfig() error {
|
||||
// Check if system config already exists
|
||||
if _, err := os.Stat(config.SystemConfigPath); err == nil {
|
||||
i.log(" System config already exists at %s", config.SystemConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating system config at %s...", config.SystemConfigPath)
|
||||
return config.CreateDefault(config.SystemConfigPath)
|
||||
}
|
||||
|
||||
func (i *Installer) installLaunchDaemon() error {
|
||||
plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist")
|
||||
plistContent := fmt.Sprintf(LaunchDaemonPlist, i.binaryPath)
|
||||
|
||||
i.log(" Writing LaunchDaemon plist...")
|
||||
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write plist: %w", err)
|
||||
}
|
||||
|
||||
// Unload if already loaded
|
||||
exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
|
||||
|
||||
// Bootstrap the daemon
|
||||
i.log(" Starting daemon...")
|
||||
if err := exec.Command("launchctl", "bootstrap", "system", plistPath).Run(); err != nil {
|
||||
return fmt.Errorf("failed to bootstrap daemon: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) uninstallLaunchDaemon() {
|
||||
plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist")
|
||||
|
||||
i.log(" Stopping daemon...")
|
||||
exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
|
||||
|
||||
i.log(" Removing LaunchDaemon plist...")
|
||||
os.Remove(plistPath)
|
||||
}
|
||||
|
||||
func (i *Installer) installSystemdService() error {
|
||||
unitPath := filepath.Join(SystemdDir, "lolcathost.service")
|
||||
unitContent := fmt.Sprintf(SystemdUnit, i.binaryPath)
|
||||
|
||||
i.log(" Writing systemd unit...")
|
||||
if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write unit file: %w", err)
|
||||
}
|
||||
|
||||
// Reload systemd
|
||||
i.log(" Reloading systemd...")
|
||||
if err := exec.Command("systemctl", "daemon-reload").Run(); err != nil {
|
||||
return fmt.Errorf("failed to reload systemd: %w", err)
|
||||
}
|
||||
|
||||
// Enable and start the service
|
||||
i.log(" Enabling and starting service...")
|
||||
if err := exec.Command("systemctl", "enable", "--now", "lolcathost.service").Run(); err != nil {
|
||||
return fmt.Errorf("failed to enable service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *Installer) uninstallSystemdService() {
|
||||
i.log(" Stopping and disabling service...")
|
||||
exec.Command("systemctl", "disable", "--now", "lolcathost.service").Run()
|
||||
|
||||
i.log(" Removing systemd unit...")
|
||||
os.Remove(filepath.Join(SystemdDir, "lolcathost.service"))
|
||||
|
||||
exec.Command("systemctl", "daemon-reload").Run()
|
||||
}
|
||||
|
||||
func (i *Installer) createDefaultConfig() error {
|
||||
// Get the real user's home directory
|
||||
username := os.Getenv("SUDO_USER")
|
||||
if username == "" {
|
||||
return nil // Can't determine user
|
||||
}
|
||||
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to lookup user: %w", err)
|
||||
}
|
||||
|
||||
configPath := filepath.Join(u.HomeDir, ".config", "lolcathost", "config.yaml")
|
||||
|
||||
// Check if config already exists
|
||||
if _, err := os.Stat(configPath); err == nil {
|
||||
i.log(" Config already exists at %s", configPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating default config at %s...", configPath)
|
||||
|
||||
if err := config.CreateDefault(configPath); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Change ownership to the real user
|
||||
uid, _ := strconv.Atoi(u.Uid)
|
||||
gid, _ := strconv.Atoi(u.Gid)
|
||||
|
||||
configDir := filepath.Dir(configPath)
|
||||
os.Chown(configDir, uid, gid)
|
||||
os.Chown(filepath.Dir(configDir), uid, gid)
|
||||
os.Chown(configPath, uid, gid)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckInstallation checks if the daemon is properly installed.
|
||||
func CheckInstallation() error {
|
||||
// Check if socket exists
|
||||
if _, err := os.Stat(SocketPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("daemon not running (socket not found)")
|
||||
}
|
||||
|
||||
// Check if user is in group
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current user: %w", err)
|
||||
}
|
||||
|
||||
groups, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user groups: %w", err)
|
||||
}
|
||||
|
||||
inGroup := false
|
||||
for _, gid := range groups {
|
||||
g, err := user.LookupGroupId(gid)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if g.Name == GroupName {
|
||||
inGroup = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !inGroup {
|
||||
return fmt.Errorf("user '%s' is not in group '%s'. Run 'sudo lolcathost --install' and open a new terminal", u.Username, GroupName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
// Package protocol defines shared message types for client-daemon communication.
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SocketPath is the Unix socket path for daemon communication.
|
||||
const SocketPath = "/var/run/lolcathost.sock"
|
||||
|
||||
// RequestType defines the type of request.
|
||||
type RequestType string
|
||||
|
||||
const (
|
||||
RequestPing RequestType = "ping"
|
||||
RequestStatus RequestType = "status"
|
||||
RequestList RequestType = "list"
|
||||
RequestSet RequestType = "set"
|
||||
RequestAdd RequestType = "add"
|
||||
RequestDelete RequestType = "delete"
|
||||
RequestSync RequestType = "sync"
|
||||
RequestPreset RequestType = "preset"
|
||||
RequestRollback RequestType = "rollback"
|
||||
RequestBackups RequestType = "backups"
|
||||
RequestAddGroup RequestType = "add_group"
|
||||
RequestDeleteGroup RequestType = "delete_group"
|
||||
RequestRenameGroup RequestType = "rename_group"
|
||||
RequestListGroups RequestType = "list_groups"
|
||||
RequestAddPreset RequestType = "add_preset"
|
||||
RequestDeletePreset RequestType = "delete_preset"
|
||||
RequestListPresets RequestType = "list_presets"
|
||||
)
|
||||
|
||||
// ErrorCode defines standard error codes.
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrCodeInvalidRequest ErrorCode = "INVALID_REQUEST"
|
||||
ErrCodeInvalidDomain ErrorCode = "INVALID_DOMAIN"
|
||||
ErrCodeInvalidIP ErrorCode = "INVALID_IP"
|
||||
ErrCodeBlockedDomain ErrorCode = "BLOCKED_DOMAIN"
|
||||
ErrCodeRateLimited ErrorCode = "RATE_LIMITED"
|
||||
ErrCodeUnauthorized ErrorCode = "UNAUTHORIZED"
|
||||
ErrCodeNotFound ErrorCode = "NOT_FOUND"
|
||||
ErrCodeConflict ErrorCode = "CONFLICT"
|
||||
ErrCodeInternalError ErrorCode = "INTERNAL_ERROR"
|
||||
ErrCodePermissionError ErrorCode = "PERMISSION_ERROR"
|
||||
)
|
||||
|
||||
// Request represents a client request to the daemon.
|
||||
type Request struct {
|
||||
Type RequestType `json:"type"`
|
||||
Payload json.RawMessage `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
// SetPayload is the payload for set requests.
|
||||
type SetPayload struct {
|
||||
Alias string `json:"alias"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Force bool `json:"force,omitempty"`
|
||||
}
|
||||
|
||||
// PresetPayload is the payload for preset requests.
|
||||
type PresetPayload struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// RollbackPayload is the payload for rollback requests.
|
||||
type RollbackPayload struct {
|
||||
BackupName string `json:"backup_name"`
|
||||
}
|
||||
|
||||
// AddPayload is the payload for add requests.
|
||||
type AddPayload struct {
|
||||
Domain string `json:"domain"`
|
||||
IP string `json:"ip"`
|
||||
Alias string `json:"alias"`
|
||||
Group string `json:"group"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// DeletePayload is the payload for delete requests.
|
||||
type DeletePayload struct {
|
||||
Alias string `json:"alias"`
|
||||
}
|
||||
|
||||
// GroupPayload is the payload for group add/delete requests.
|
||||
type GroupPayload struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// RenameGroupPayload is the payload for rename_group requests.
|
||||
type RenameGroupPayload struct {
|
||||
OldName string `json:"old_name"`
|
||||
NewName string `json:"new_name"`
|
||||
}
|
||||
|
||||
// GroupsData is the data for list_groups responses.
|
||||
type GroupsData struct {
|
||||
Groups []string `json:"groups"`
|
||||
}
|
||||
|
||||
// AddPresetPayload is the payload for add_preset requests.
|
||||
type AddPresetPayload struct {
|
||||
Name string `json:"name"`
|
||||
Enable []string `json:"enable"`
|
||||
Disable []string `json:"disable"`
|
||||
}
|
||||
|
||||
// PresetInfo represents a preset with its configuration.
|
||||
type PresetInfo struct {
|
||||
Name string `json:"name"`
|
||||
Enable []string `json:"enable"`
|
||||
Disable []string `json:"disable"`
|
||||
}
|
||||
|
||||
// PresetsData is the data for list_presets responses.
|
||||
type PresetsData struct {
|
||||
Presets []PresetInfo `json:"presets"`
|
||||
}
|
||||
|
||||
// Response represents a daemon response.
|
||||
type Response struct {
|
||||
Status string `json:"status"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Code ErrorCode `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// StatusData is the data for status responses.
|
||||
type StatusData struct {
|
||||
Running bool `json:"running"`
|
||||
Version string `json:"version"`
|
||||
Uptime int64 `json:"uptime_seconds"`
|
||||
ActiveCount int `json:"active_count"`
|
||||
RequestCount int64 `json:"request_count"`
|
||||
}
|
||||
|
||||
// HostEntry represents a single host entry.
|
||||
type HostEntry struct {
|
||||
Domain string `json:"domain"`
|
||||
IP string `json:"ip"`
|
||||
Alias string `json:"alias"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Group string `json:"group"`
|
||||
}
|
||||
|
||||
// ListData is the data for list responses.
|
||||
type ListData struct {
|
||||
Entries []HostEntry `json:"entries"`
|
||||
}
|
||||
|
||||
// SetData is the data for set responses.
|
||||
type SetData struct {
|
||||
Domain string `json:"domain"`
|
||||
Applied bool `json:"applied"`
|
||||
}
|
||||
|
||||
// BackupsData is the data for backups responses.
|
||||
type BackupsData struct {
|
||||
Backups []BackupInfo `json:"backups"`
|
||||
}
|
||||
|
||||
// BackupInfo represents a backup file.
|
||||
type BackupInfo struct {
|
||||
Name string `json:"name"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// NewRequest creates a new request with the given type and payload.
|
||||
func NewRequest(reqType RequestType, payload interface{}) (*Request, error) {
|
||||
req := &Request{Type: reqType}
|
||||
if payload != nil {
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal payload: %w", err)
|
||||
}
|
||||
req.Payload = data
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// NewOKResponse creates a success response with optional data.
|
||||
func NewOKResponse(data interface{}) (*Response, error) {
|
||||
resp := &Response{Status: "ok"}
|
||||
if data != nil {
|
||||
dataBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
resp.Data = dataBytes
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewErrorResponse creates an error response.
|
||||
func NewErrorResponse(code ErrorCode, message string) *Response {
|
||||
return &Response{
|
||||
Status: "error",
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// ParsePayload unmarshals the request payload into the given target.
|
||||
func (r *Request) ParsePayload(target interface{}) error {
|
||||
if r.Payload == nil {
|
||||
return fmt.Errorf("no payload in request")
|
||||
}
|
||||
return json.Unmarshal(r.Payload, target)
|
||||
}
|
||||
|
||||
// ParseData unmarshals the response data into the given target.
|
||||
func (r *Response) ParseData(target interface{}) error {
|
||||
if r.Data == nil {
|
||||
return fmt.Errorf("no data in response")
|
||||
}
|
||||
return json.Unmarshal(r.Data, target)
|
||||
}
|
||||
|
||||
// IsOK returns true if the response indicates success.
|
||||
func (r *Response) IsOK() bool {
|
||||
return r.Status == "ok"
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
reqType RequestType
|
||||
payload interface{}
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "ping request without payload",
|
||||
reqType: RequestPing,
|
||||
payload: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "set request with payload",
|
||||
reqType: RequestSet,
|
||||
payload: SetPayload{Alias: "test", Enabled: true},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "preset request with payload",
|
||||
reqType: RequestPreset,
|
||||
payload: PresetPayload{Name: "local"},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req, err := NewRequest(tt.reqType, tt.payload)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.reqType, req.Type)
|
||||
if tt.payload != nil {
|
||||
assert.NotNil(t, req.Payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_ParsePayload(t *testing.T) {
|
||||
t.Run("valid payload", func(t *testing.T) {
|
||||
payload := SetPayload{Alias: "test-alias", Enabled: true, Force: false}
|
||||
req, err := NewRequest(RequestSet, payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed SetPayload
|
||||
err = req.ParsePayload(&parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test-alias", parsed.Alias)
|
||||
assert.True(t, parsed.Enabled)
|
||||
assert.False(t, parsed.Force)
|
||||
})
|
||||
|
||||
t.Run("nil payload", func(t *testing.T) {
|
||||
req := &Request{Type: RequestPing}
|
||||
var parsed SetPayload
|
||||
err := req.ParsePayload(&parsed)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewOKResponse(t *testing.T) {
|
||||
t.Run("with data", func(t *testing.T) {
|
||||
data := StatusData{
|
||||
Running: true,
|
||||
Version: "1.0.0",
|
||||
Uptime: 3600,
|
||||
ActiveCount: 5,
|
||||
RequestCount: 100,
|
||||
}
|
||||
|
||||
resp, err := NewOKResponse(data)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.NotNil(t, resp.Data)
|
||||
assert.True(t, resp.IsOK())
|
||||
})
|
||||
|
||||
t.Run("without data", func(t *testing.T) {
|
||||
resp, err := NewOKResponse(nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp.Status)
|
||||
assert.Nil(t, resp.Data)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewErrorResponse(t *testing.T) {
|
||||
resp := NewErrorResponse(ErrCodeBlockedDomain, "domain is blocked")
|
||||
|
||||
assert.Equal(t, "error", resp.Status)
|
||||
assert.Equal(t, ErrCodeBlockedDomain, resp.Code)
|
||||
assert.Equal(t, "domain is blocked", resp.Message)
|
||||
assert.False(t, resp.IsOK())
|
||||
}
|
||||
|
||||
func TestResponse_ParseData(t *testing.T) {
|
||||
t.Run("valid data", func(t *testing.T) {
|
||||
data := ListData{
|
||||
Entries: []HostEntry{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"},
|
||||
},
|
||||
}
|
||||
resp, err := NewOKResponse(data)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed ListData
|
||||
err = resp.ParseData(&parsed)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, parsed.Entries, 1)
|
||||
assert.Equal(t, "example.com", parsed.Entries[0].Domain)
|
||||
})
|
||||
|
||||
t.Run("nil data", func(t *testing.T) {
|
||||
resp := &Response{Status: "ok"}
|
||||
var parsed ListData
|
||||
err := resp.ParseData(&parsed)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestTypes(t *testing.T) {
|
||||
types := []RequestType{
|
||||
RequestPing,
|
||||
RequestStatus,
|
||||
RequestList,
|
||||
RequestSet,
|
||||
RequestSync,
|
||||
RequestPreset,
|
||||
RequestRollback,
|
||||
RequestBackups,
|
||||
}
|
||||
|
||||
for _, rt := range types {
|
||||
t.Run(string(rt), func(t *testing.T) {
|
||||
req, err := NewRequest(rt, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, rt, req.Type)
|
||||
|
||||
// Verify JSON marshaling works
|
||||
data, err := json.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), string(rt))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorCodes(t *testing.T) {
|
||||
codes := []ErrorCode{
|
||||
ErrCodeInvalidRequest,
|
||||
ErrCodeInvalidDomain,
|
||||
ErrCodeInvalidIP,
|
||||
ErrCodeBlockedDomain,
|
||||
ErrCodeRateLimited,
|
||||
ErrCodeNotFound,
|
||||
ErrCodeConflict,
|
||||
ErrCodeInternalError,
|
||||
ErrCodePermissionError,
|
||||
}
|
||||
|
||||
for _, code := range codes {
|
||||
t.Run(string(code), func(t *testing.T) {
|
||||
resp := NewErrorResponse(code, "test error")
|
||||
assert.Equal(t, code, resp.Code)
|
||||
|
||||
// Verify JSON marshaling works
|
||||
data, err := json.Marshal(resp)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(data), string(code))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostEntry(t *testing.T) {
|
||||
entry := HostEntry{
|
||||
Domain: "example.com",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "example-local",
|
||||
Enabled: true,
|
||||
Group: "development",
|
||||
}
|
||||
|
||||
data, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed HostEntry
|
||||
err = json.Unmarshal(data, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, entry.Domain, parsed.Domain)
|
||||
assert.Equal(t, entry.IP, parsed.IP)
|
||||
assert.Equal(t, entry.Alias, parsed.Alias)
|
||||
assert.Equal(t, entry.Enabled, parsed.Enabled)
|
||||
assert.Equal(t, entry.Group, parsed.Group)
|
||||
}
|
||||
|
||||
func TestBackupInfo(t *testing.T) {
|
||||
info := BackupInfo{
|
||||
Name: "hosts.20231201-120000.bak",
|
||||
Timestamp: 1701432000,
|
||||
Size: 1024,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(info)
|
||||
require.NoError(t, err)
|
||||
|
||||
var parsed BackupInfo
|
||||
err = json.Unmarshal(data, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, info.Name, parsed.Name)
|
||||
assert.Equal(t, info.Timestamp, parsed.Timestamp)
|
||||
assert.Equal(t, info.Size, parsed.Size)
|
||||
}
|
||||
@@ -0,0 +1,904 @@
|
||||
// Package tui provides the main Bubble Tea application.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/client"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/version"
|
||||
)
|
||||
|
||||
// ViewMode represents the current view mode.
|
||||
type ViewMode int
|
||||
|
||||
const (
|
||||
ViewList ViewMode = iota
|
||||
ViewForm
|
||||
ViewPresets
|
||||
ViewGroups
|
||||
ViewHelp
|
||||
ViewSearch
|
||||
)
|
||||
|
||||
// Model is the main Bubble Tea model.
|
||||
type Model struct {
|
||||
// Client
|
||||
client *client.Client
|
||||
connected bool
|
||||
|
||||
// Config
|
||||
configPath string
|
||||
config *config.Manager
|
||||
|
||||
// Views
|
||||
mode ViewMode
|
||||
list *ListView
|
||||
form *Form
|
||||
presetPicker *PresetPicker
|
||||
groupPicker *GroupPicker
|
||||
searchInput textinput.Model
|
||||
|
||||
// State
|
||||
width int
|
||||
height int
|
||||
message string
|
||||
messageStyle string // "error" or "success"
|
||||
messageTime time.Time
|
||||
searchTerm string
|
||||
allGroups []string // All groups including empty ones
|
||||
|
||||
// Update notification
|
||||
updateAvailable bool
|
||||
updateVersion string
|
||||
updateURL string
|
||||
|
||||
// Version info for update checking
|
||||
version string
|
||||
githubOwner string
|
||||
githubRepo string
|
||||
}
|
||||
|
||||
// Message types
|
||||
type (
|
||||
connectMsg struct{ err error }
|
||||
refreshMsg struct {
|
||||
entries []protocol.HostEntry
|
||||
err error
|
||||
}
|
||||
toggleMsg struct {
|
||||
alias string
|
||||
err error
|
||||
}
|
||||
presetMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
addMsg struct {
|
||||
domain string
|
||||
err error
|
||||
}
|
||||
deleteMsg struct {
|
||||
alias string
|
||||
err error
|
||||
}
|
||||
addPresetMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
deletePresetMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
refreshPresetsMsg struct {
|
||||
presets []protocol.PresetInfo
|
||||
err error
|
||||
}
|
||||
addGroupMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
renameGroupMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
deleteGroupMsg struct {
|
||||
name string
|
||||
err error
|
||||
}
|
||||
refreshGroupsMsg struct {
|
||||
groups []string
|
||||
err error
|
||||
}
|
||||
clearMsgMsg struct{}
|
||||
tickMsg struct{}
|
||||
updateMsg struct {
|
||||
version string
|
||||
url string
|
||||
}
|
||||
)
|
||||
|
||||
// NewModel creates a new TUI model.
|
||||
func NewModel(socketPath, configPath string) *Model {
|
||||
searchInput := textinput.New()
|
||||
searchInput.Placeholder = "Search..."
|
||||
searchInput.CharLimit = 100
|
||||
searchInput.Width = 50
|
||||
|
||||
return &Model{
|
||||
client: client.New(socketPath),
|
||||
configPath: configPath,
|
||||
config: config.NewManager(configPath),
|
||||
list: NewListView(),
|
||||
form: NewForm(),
|
||||
presetPicker: NewPresetPicker(),
|
||||
groupPicker: NewGroupPicker(),
|
||||
searchInput: searchInput,
|
||||
mode: ViewList,
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the model.
|
||||
func (m *Model) Init() tea.Cmd {
|
||||
return tea.Batch(
|
||||
m.connect(),
|
||||
tea.SetWindowTitle("lolcathost"),
|
||||
m.tick(),
|
||||
m.checkForUpdate(),
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Model) connect() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
if err := m.client.Connect(); err != nil {
|
||||
return connectMsg{err: err}
|
||||
}
|
||||
return connectMsg{err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) refresh() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
entries, err := m.client.List()
|
||||
if err != nil {
|
||||
return refreshMsg{entries: nil, err: err}
|
||||
}
|
||||
return refreshMsg{entries: entries, err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) toggle(alias string, enabled bool) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
_, err := m.client.Set(alias, enabled, false)
|
||||
return toggleMsg{alias: alias, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) applyPreset(name string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.ApplyPreset(name)
|
||||
return presetMsg{name: name, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) addHost(domain, ip, alias, group string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
_, err := m.client.Add(domain, ip, alias, group, false)
|
||||
return addMsg{domain: domain, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) deleteHost(alias string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.Delete(alias)
|
||||
return deleteMsg{alias: alias, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) addPreset(name string, enable, disable []string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.AddPreset(name, enable, disable)
|
||||
return addPresetMsg{name: name, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) deletePreset(name string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.DeletePreset(name)
|
||||
return deletePresetMsg{name: name, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) refreshPresets() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
presets, err := m.client.ListPresets()
|
||||
return refreshPresetsMsg{presets: presets, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) addGroup(name string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.AddGroup(name)
|
||||
return addGroupMsg{name: name, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) renameGroup(oldName, newName string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.RenameGroup(oldName, newName)
|
||||
return renameGroupMsg{name: newName, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) deleteGroup(name string) tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
err := m.client.DeleteGroup(name)
|
||||
return deleteGroupMsg{name: name, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) refreshGroups() tea.Cmd {
|
||||
return func() tea.Msg {
|
||||
groups, err := m.client.ListGroups()
|
||||
return refreshGroupsMsg{groups: groups, err: err}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) tick() tea.Cmd {
|
||||
return tea.Tick(time.Second*3, func(t time.Time) tea.Msg {
|
||||
return tickMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Model) clearMsg() tea.Cmd {
|
||||
return tea.Tick(time.Second*3, func(t time.Time) tea.Msg {
|
||||
return clearMsgMsg{}
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Model) checkForUpdate() tea.Cmd {
|
||||
if m.githubOwner == "" || m.githubRepo == "" {
|
||||
return nil
|
||||
}
|
||||
return func() tea.Msg {
|
||||
checker := version.NewChecker(m.githubOwner, m.githubRepo, m.version)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if update := checker.CheckForUpdate(ctx); update != nil {
|
||||
return updateMsg{version: update.LatestVersion, url: update.ReleaseURL}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update handles messages.
|
||||
func (m *Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
var cmds []tea.Cmd
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case tea.WindowSizeMsg:
|
||||
m.width = msg.Width
|
||||
m.height = msg.Height
|
||||
m.list.SetSize(msg.Width, msg.Height-10)
|
||||
m.form.SetSize(msg.Width, msg.Height)
|
||||
m.presetPicker.SetSize(msg.Width, msg.Height)
|
||||
m.groupPicker.SetSize(msg.Width, msg.Height)
|
||||
// Set search input width
|
||||
searchWidth := msg.Width - 20
|
||||
if searchWidth > 60 {
|
||||
searchWidth = 60
|
||||
}
|
||||
m.searchInput.Width = searchWidth
|
||||
|
||||
case tea.KeyMsg:
|
||||
cmd := m.handleKey(msg)
|
||||
if cmd != nil {
|
||||
cmds = append(cmds, cmd)
|
||||
}
|
||||
|
||||
case connectMsg:
|
||||
if msg.err != nil {
|
||||
m.connected = false
|
||||
m.setError(fmt.Sprintf("Failed to connect: %v", msg.err))
|
||||
} else {
|
||||
m.connected = true
|
||||
cmds = append(cmds, m.refresh())
|
||||
cmds = append(cmds, m.refreshPresets())
|
||||
cmds = append(cmds, m.refreshGroups())
|
||||
m.loadConfig()
|
||||
}
|
||||
|
||||
case refreshMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Refresh failed: %v", msg.err))
|
||||
// Mark as disconnected to trigger reconnect
|
||||
m.connected = false
|
||||
m.client.Close()
|
||||
} else if msg.entries != nil {
|
||||
m.list.SetItems(msg.entries)
|
||||
}
|
||||
|
||||
case toggleMsg:
|
||||
if msg.err != nil {
|
||||
m.list.SetError(msg.alias, true)
|
||||
m.setError(fmt.Sprintf("Toggle failed: %v", msg.err))
|
||||
} else {
|
||||
m.list.SetPending(msg.alias, false)
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess("Entry toggled")
|
||||
}
|
||||
|
||||
case presetMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Preset failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess(fmt.Sprintf("Applied preset: %s", msg.name))
|
||||
}
|
||||
m.mode = ViewList
|
||||
|
||||
case addMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Add failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess(fmt.Sprintf("Added host: %s", msg.domain))
|
||||
}
|
||||
m.mode = ViewList
|
||||
|
||||
case deleteMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Delete failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess(fmt.Sprintf("Deleted: %s", msg.alias))
|
||||
}
|
||||
|
||||
case addPresetMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Add preset failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refreshPresets())
|
||||
m.setSuccess(fmt.Sprintf("Added preset: %s", msg.name))
|
||||
}
|
||||
m.presetPicker.CancelForm()
|
||||
|
||||
case deletePresetMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Delete preset failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refreshPresets())
|
||||
m.setSuccess(fmt.Sprintf("Deleted preset: %s", msg.name))
|
||||
}
|
||||
m.presetPicker.CancelForm()
|
||||
|
||||
case refreshPresetsMsg:
|
||||
if msg.err == nil && msg.presets != nil {
|
||||
m.presetPicker.SetPresetsWithInfo(msg.presets)
|
||||
}
|
||||
|
||||
case addGroupMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Add group failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refreshGroups())
|
||||
cmds = append(cmds, m.refresh()) // Refresh list to show new group
|
||||
m.setSuccess(fmt.Sprintf("Added group: %s", msg.name))
|
||||
}
|
||||
m.groupPicker.CancelForm()
|
||||
|
||||
case renameGroupMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Rename group failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refreshGroups())
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess(fmt.Sprintf("Renamed group to: %s", msg.name))
|
||||
}
|
||||
m.groupPicker.CancelForm()
|
||||
|
||||
case deleteGroupMsg:
|
||||
if msg.err != nil {
|
||||
m.setError(fmt.Sprintf("Delete group failed: %v", msg.err))
|
||||
} else {
|
||||
cmds = append(cmds, m.refreshGroups())
|
||||
cmds = append(cmds, m.refresh())
|
||||
m.setSuccess(fmt.Sprintf("Deleted group: %s", msg.name))
|
||||
}
|
||||
m.groupPicker.CancelForm()
|
||||
|
||||
case refreshGroupsMsg:
|
||||
if msg.err == nil && msg.groups != nil {
|
||||
m.allGroups = msg.groups
|
||||
m.groupPicker.SetGroups(msg.groups)
|
||||
}
|
||||
|
||||
case clearMsgMsg:
|
||||
if time.Since(m.messageTime) >= time.Second*3 {
|
||||
m.message = ""
|
||||
}
|
||||
|
||||
case tickMsg:
|
||||
// Reconnect if disconnected
|
||||
if !m.connected {
|
||||
cmds = append(cmds, m.connect())
|
||||
}
|
||||
cmds = append(cmds, m.tick())
|
||||
|
||||
case updateMsg:
|
||||
if msg.version != "" {
|
||||
m.updateAvailable = true
|
||||
m.updateVersion = msg.version
|
||||
m.updateURL = msg.url
|
||||
}
|
||||
}
|
||||
|
||||
return m, tea.Batch(cmds...)
|
||||
}
|
||||
|
||||
func (m *Model) handleKey(msg tea.KeyMsg) tea.Cmd {
|
||||
// Global keys
|
||||
switch msg.String() {
|
||||
case "ctrl+c":
|
||||
return tea.Quit
|
||||
}
|
||||
|
||||
// Mode-specific keys
|
||||
switch m.mode {
|
||||
case ViewList:
|
||||
return m.handleListKey(msg)
|
||||
case ViewForm:
|
||||
return m.handleFormKey(msg)
|
||||
case ViewPresets:
|
||||
return m.handlePresetKey(msg)
|
||||
case ViewGroups:
|
||||
return m.handleGroupKey(msg)
|
||||
case ViewHelp:
|
||||
return m.handleHelpKey(msg)
|
||||
case ViewSearch:
|
||||
return m.handleSearchKey(msg)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleListKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "q":
|
||||
return tea.Quit
|
||||
case "esc":
|
||||
// Clear search if active
|
||||
if m.searchTerm != "" {
|
||||
m.searchTerm = ""
|
||||
m.searchInput.Reset()
|
||||
}
|
||||
case "up", "k":
|
||||
m.list.MoveUp()
|
||||
case "down", "j":
|
||||
m.list.MoveDown()
|
||||
case " ", "enter":
|
||||
return m.toggleSelected()
|
||||
case "n":
|
||||
m.mode = ViewForm
|
||||
m.form.SetGroups(m.allGroups)
|
||||
m.form.Init()
|
||||
case "e":
|
||||
if item := m.list.Selected(); item != nil {
|
||||
m.mode = ViewForm
|
||||
m.form.SetGroups(m.allGroups)
|
||||
m.form.InitEdit(item.Entry.Domain, item.Entry.IP, item.Entry.Alias, item.Entry.Group)
|
||||
}
|
||||
case "d":
|
||||
if item := m.list.Selected(); item != nil {
|
||||
return m.deleteHost(item.Entry.Alias)
|
||||
}
|
||||
case "p":
|
||||
m.mode = ViewPresets
|
||||
case "g":
|
||||
m.mode = ViewGroups
|
||||
return m.refreshGroups()
|
||||
case "/":
|
||||
m.mode = ViewSearch
|
||||
m.searchInput.Focus()
|
||||
case "?":
|
||||
m.mode = ViewHelp
|
||||
case "r":
|
||||
return m.refresh()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleFormKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.mode = ViewList
|
||||
return nil
|
||||
case "enter":
|
||||
if errMsg := m.form.Validate(); errMsg != "" {
|
||||
m.setError(errMsg)
|
||||
return m.clearMsg()
|
||||
}
|
||||
domain, ip, group := m.form.Values()
|
||||
if m.form.IsEdit() {
|
||||
// For edit, delete old and add new (simple approach)
|
||||
oldAlias := m.form.EditAlias()
|
||||
return tea.Sequence(
|
||||
func() tea.Msg {
|
||||
m.client.Delete(oldAlias)
|
||||
return nil
|
||||
},
|
||||
m.addHost(domain, ip, "", group), // Empty alias = auto-generate
|
||||
)
|
||||
}
|
||||
return m.addHost(domain, ip, "", group) // Empty alias = auto-generate
|
||||
}
|
||||
|
||||
return m.form.Update(msg)
|
||||
}
|
||||
|
||||
func (m *Model) handlePresetKey(msg tea.KeyMsg) tea.Cmd {
|
||||
// Handle based on preset picker mode
|
||||
switch m.presetPicker.Mode() {
|
||||
case PresetModeSelect:
|
||||
return m.handlePresetSelectKey(msg)
|
||||
case PresetModeAdd, PresetModeEdit:
|
||||
return m.handlePresetFormKey(msg)
|
||||
case PresetModeConfirmDelete:
|
||||
return m.handlePresetDeleteKey(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handlePresetSelectKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc", "q":
|
||||
m.mode = ViewList
|
||||
case "up", "k":
|
||||
m.presetPicker.MoveUp()
|
||||
case "down", "j":
|
||||
m.presetPicker.MoveDown()
|
||||
case "enter":
|
||||
if preset := m.presetPicker.Selected(); preset != "" {
|
||||
return m.applyPreset(preset)
|
||||
}
|
||||
case "n":
|
||||
m.presetPicker.InitAdd()
|
||||
case "e":
|
||||
m.presetPicker.InitEdit()
|
||||
case "d":
|
||||
m.presetPicker.InitDelete()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handlePresetFormKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.presetPicker.CancelForm()
|
||||
return nil
|
||||
case "enter":
|
||||
if errMsg := m.presetPicker.ValidateForm(); errMsg != "" {
|
||||
m.setError(errMsg)
|
||||
return m.clearMsg()
|
||||
}
|
||||
name, enable, disable := m.presetPicker.FormValues()
|
||||
if m.presetPicker.IsEdit() {
|
||||
// For edit, delete old and add new
|
||||
oldName := m.presetPicker.EditName()
|
||||
return tea.Sequence(
|
||||
func() tea.Msg {
|
||||
m.client.DeletePreset(oldName)
|
||||
return nil
|
||||
},
|
||||
m.addPreset(name, enable, disable),
|
||||
)
|
||||
}
|
||||
return m.addPreset(name, enable, disable)
|
||||
}
|
||||
return m.presetPicker.Update(msg)
|
||||
}
|
||||
|
||||
func (m *Model) handlePresetDeleteKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
if preset := m.presetPicker.Selected(); preset != "" {
|
||||
return m.deletePreset(preset)
|
||||
}
|
||||
m.presetPicker.CancelForm()
|
||||
case "n", "N", "esc":
|
||||
m.presetPicker.CancelForm()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleGroupKey(msg tea.KeyMsg) tea.Cmd {
|
||||
// Handle based on group picker mode
|
||||
switch m.groupPicker.Mode() {
|
||||
case GroupModeSelect:
|
||||
return m.handleGroupSelectKey(msg)
|
||||
case GroupModeAdd, GroupModeRename:
|
||||
return m.handleGroupFormKey(msg)
|
||||
case GroupModeConfirmDelete:
|
||||
return m.handleGroupDeleteKey(msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleGroupSelectKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc", "q":
|
||||
m.mode = ViewList
|
||||
case "up", "k":
|
||||
m.groupPicker.MoveUp()
|
||||
case "down", "j":
|
||||
m.groupPicker.MoveDown()
|
||||
case "n":
|
||||
m.groupPicker.InitAdd()
|
||||
case "r":
|
||||
m.groupPicker.InitRename()
|
||||
case "d":
|
||||
m.groupPicker.InitDelete()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleGroupFormKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.groupPicker.CancelForm()
|
||||
return nil
|
||||
case "enter":
|
||||
if errMsg := m.groupPicker.ValidateForm(); errMsg != "" {
|
||||
m.setError(errMsg)
|
||||
return m.clearMsg()
|
||||
}
|
||||
name := m.groupPicker.FormValue()
|
||||
if m.groupPicker.IsRename() {
|
||||
oldName := m.groupPicker.EditName()
|
||||
return m.renameGroup(oldName, name)
|
||||
}
|
||||
return m.addGroup(name)
|
||||
}
|
||||
return m.groupPicker.Update(msg)
|
||||
}
|
||||
|
||||
func (m *Model) handleGroupDeleteKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "y", "Y":
|
||||
if group := m.groupPicker.Selected(); group != "" {
|
||||
return m.deleteGroup(group)
|
||||
}
|
||||
m.groupPicker.CancelForm()
|
||||
case "n", "N", "esc":
|
||||
m.groupPicker.CancelForm()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleHelpKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc", "q", "?":
|
||||
m.mode = ViewList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Model) handleSearchKey(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "esc":
|
||||
m.mode = ViewList
|
||||
m.searchTerm = ""
|
||||
m.searchInput.Reset()
|
||||
return nil
|
||||
case "enter":
|
||||
m.searchTerm = m.searchInput.Value()
|
||||
m.mode = ViewList
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd tea.Cmd
|
||||
m.searchInput, cmd = m.searchInput.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (m *Model) toggleSelected() tea.Cmd {
|
||||
item := m.list.Selected()
|
||||
if item == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.list.SetPending(item.Entry.Alias, true)
|
||||
return m.toggle(item.Entry.Alias, !item.Entry.Enabled)
|
||||
}
|
||||
|
||||
func (m *Model) loadConfig() {
|
||||
if err := m.config.Load(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cfg := m.config.Get()
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var presetNames []string
|
||||
for _, p := range cfg.Presets {
|
||||
presetNames = append(presetNames, p.Name)
|
||||
}
|
||||
m.presetPicker.SetPresets(presetNames)
|
||||
}
|
||||
|
||||
func (m *Model) setError(msg string) {
|
||||
m.message = msg
|
||||
m.messageStyle = "error"
|
||||
m.messageTime = time.Now()
|
||||
}
|
||||
|
||||
func (m *Model) setSuccess(msg string) {
|
||||
m.message = msg
|
||||
m.messageStyle = "success"
|
||||
m.messageTime = time.Now()
|
||||
}
|
||||
|
||||
// View renders the UI.
|
||||
func (m *Model) View() string {
|
||||
var sb strings.Builder
|
||||
|
||||
// Title with version
|
||||
title := titleStyle.Render("lolcathost - Host Management")
|
||||
sb.WriteString(title)
|
||||
|
||||
// Update notification
|
||||
if m.updateAvailable {
|
||||
sb.WriteString(" ")
|
||||
sb.WriteString(updateStyle.Render(fmt.Sprintf("Update available: v%s", m.updateVersion)))
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// Main content based on mode
|
||||
switch m.mode {
|
||||
case ViewList:
|
||||
sb.WriteString(m.list.ViewFiltered(m.searchTerm))
|
||||
case ViewForm:
|
||||
sb.WriteString(m.form.View())
|
||||
case ViewPresets:
|
||||
sb.WriteString(m.presetPicker.View())
|
||||
case ViewGroups:
|
||||
sb.WriteString(m.groupPicker.View())
|
||||
case ViewHelp:
|
||||
sb.WriteString(m.helpView())
|
||||
case ViewSearch:
|
||||
sb.WriteString(m.searchView())
|
||||
}
|
||||
|
||||
// Message
|
||||
if m.message != "" {
|
||||
sb.WriteString("\n")
|
||||
if m.messageStyle == "error" {
|
||||
sb.WriteString(errorMsgStyle.Render(m.message))
|
||||
} else {
|
||||
sb.WriteString(successMsgStyle.Render(m.message))
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate remaining space for footer positioning
|
||||
currentContent := sb.String()
|
||||
currentLines := strings.Count(currentContent, "\n") + 1
|
||||
|
||||
// Fill space to push footer to bottom (reserve 3 lines for footer)
|
||||
footerHeight := 3
|
||||
remainingLines := m.height - currentLines - footerHeight
|
||||
if remainingLines > 0 {
|
||||
sb.WriteString(strings.Repeat("\n", remainingLines))
|
||||
}
|
||||
|
||||
// Footer (help bar + status bar)
|
||||
if m.mode == ViewList {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(m.helpBar())
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(m.statusBar())
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m *Model) helpBar() string {
|
||||
return helpBarStyle.Render(fmt.Sprintf("%s/%s: Navigate %s: Toggle %s: New %s: Edit %s: Delete %s: Presets %s: Groups %s: Search %s: Help %s: Quit",
|
||||
helpKeyStyle.Render("↑↓"),
|
||||
helpKeyStyle.Render("jk"),
|
||||
helpKeyStyle.Render("Space"),
|
||||
helpKeyStyle.Render("n"),
|
||||
helpKeyStyle.Render("e"),
|
||||
helpKeyStyle.Render("d"),
|
||||
helpKeyStyle.Render("p"),
|
||||
helpKeyStyle.Render("g"),
|
||||
helpKeyStyle.Render("/"),
|
||||
helpKeyStyle.Render("?"),
|
||||
helpKeyStyle.Render("q")))
|
||||
}
|
||||
|
||||
func (m *Model) statusBar() string {
|
||||
var status string
|
||||
if m.connected {
|
||||
status = connectedStyle.String()
|
||||
} else {
|
||||
status = disconnectedStyle.String()
|
||||
}
|
||||
|
||||
active := fmt.Sprintf("%d active", m.list.ActiveCount())
|
||||
total := fmt.Sprintf("%d total", m.list.Len())
|
||||
|
||||
return statusBarStyle.Render(fmt.Sprintf("%s | %s | %s", status, active, total))
|
||||
}
|
||||
|
||||
func (m *Model) helpView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render("Help"))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
help := []struct{ key, desc string }{
|
||||
{"↑/↓ or j/k", "Navigate up/down"},
|
||||
{"Space/Enter", "Toggle entry on/off"},
|
||||
{"n", "Add new entry"},
|
||||
{"e", "Edit selected entry"},
|
||||
{"d", "Delete selected entry"},
|
||||
{"p", "Open preset manager"},
|
||||
{"g", "Open group manager"},
|
||||
{"/", "Search"},
|
||||
{"r", "Refresh list"},
|
||||
{"?", "Toggle this help"},
|
||||
{"q", "Quit"},
|
||||
}
|
||||
|
||||
for _, h := range help {
|
||||
sb.WriteString(fmt.Sprintf(" %s %s\n",
|
||||
helpKeyStyle.Width(15).Render(h.key),
|
||||
helpDescStyle.Render(h.desc)))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpDescStyle.Render("Press ? or Esc to close"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (m *Model) searchView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render("Search"))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
sb.WriteString(inputFocusStyle.Render(m.searchInput.View()))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("Enter to search • Esc to cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
// Run starts the TUI application.
|
||||
func Run(socketPath, configPath string) error {
|
||||
return RunWithVersion(socketPath, configPath, "dev", "", "")
|
||||
}
|
||||
|
||||
// RunWithVersion starts the TUI application with version info for update checking.
|
||||
func RunWithVersion(socketPath, configPath, version, githubOwner, githubRepo string) error {
|
||||
m := NewModel(socketPath, configPath)
|
||||
m.version = version
|
||||
m.githubOwner = githubOwner
|
||||
m.githubRepo = githubRepo
|
||||
p := tea.NewProgram(m, tea.WithAltScreen())
|
||||
|
||||
_, err := p.Run()
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
// Package tui provides the form component for adding/editing entries.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// FormMode represents the form mode.
|
||||
type FormMode int
|
||||
|
||||
const (
|
||||
FormModeAdd FormMode = iota
|
||||
FormModeEdit
|
||||
)
|
||||
|
||||
// FormField represents a form field index.
|
||||
type FormField int
|
||||
|
||||
const (
|
||||
FieldDomain FormField = iota
|
||||
FieldIP
|
||||
FieldGroup
|
||||
FieldCount
|
||||
)
|
||||
|
||||
// Form handles the add/edit entry form.
|
||||
type Form struct {
|
||||
mode FormMode
|
||||
fields []textinput.Model
|
||||
focus FormField
|
||||
width int
|
||||
height int
|
||||
editAlias string // Original alias when editing
|
||||
|
||||
// Group dropdown
|
||||
groups []string
|
||||
groupCursor int
|
||||
groupFocused bool
|
||||
}
|
||||
|
||||
// NewForm creates a new form.
|
||||
func NewForm() *Form {
|
||||
fields := make([]textinput.Model, FieldCount)
|
||||
|
||||
// Domain field
|
||||
fields[FieldDomain] = textinput.New()
|
||||
fields[FieldDomain].Placeholder = "example.com"
|
||||
fields[FieldDomain].CharLimit = 253
|
||||
|
||||
// IP field
|
||||
fields[FieldIP] = textinput.New()
|
||||
fields[FieldIP].Placeholder = "127.0.0.1"
|
||||
fields[FieldIP].CharLimit = 45 // IPv6 max
|
||||
|
||||
// Group field (not used as text input, but kept for compatibility)
|
||||
fields[FieldGroup] = textinput.New()
|
||||
fields[FieldGroup].Placeholder = "development"
|
||||
fields[FieldGroup].CharLimit = 63
|
||||
|
||||
return &Form{
|
||||
fields: fields,
|
||||
focus: FieldDomain,
|
||||
groups: []string{"default"},
|
||||
}
|
||||
}
|
||||
|
||||
// SetGroups sets the available groups for the dropdown.
|
||||
func (f *Form) SetGroups(groups []string) {
|
||||
if len(groups) == 0 {
|
||||
f.groups = []string{"default"}
|
||||
} else {
|
||||
f.groups = groups
|
||||
}
|
||||
// Reset cursor if out of bounds
|
||||
if f.groupCursor >= len(f.groups) {
|
||||
f.groupCursor = 0
|
||||
}
|
||||
}
|
||||
|
||||
// Init initializes the form for adding a new entry.
|
||||
func (f *Form) Init() {
|
||||
f.mode = FormModeAdd
|
||||
f.editAlias = ""
|
||||
|
||||
for i := range f.fields {
|
||||
f.fields[i].Reset()
|
||||
}
|
||||
|
||||
f.fields[FieldIP].SetValue("127.0.0.1")
|
||||
f.groupCursor = 0
|
||||
f.groupFocused = false
|
||||
f.focus = FieldDomain
|
||||
f.fields[FieldDomain].Focus()
|
||||
}
|
||||
|
||||
// InitEdit initializes the form for editing an existing entry.
|
||||
func (f *Form) InitEdit(domain, ip, alias, group string) {
|
||||
f.mode = FormModeEdit
|
||||
f.editAlias = alias
|
||||
|
||||
f.fields[FieldDomain].SetValue(domain)
|
||||
f.fields[FieldIP].SetValue(ip)
|
||||
|
||||
// Find the group in the list
|
||||
f.groupCursor = 0
|
||||
for i, g := range f.groups {
|
||||
if g == group {
|
||||
f.groupCursor = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
f.groupFocused = false
|
||||
f.focus = FieldDomain
|
||||
f.fields[FieldDomain].Focus()
|
||||
}
|
||||
|
||||
// SetSize sets the form dimensions.
|
||||
func (f *Form) SetSize(width, height int) {
|
||||
f.width = width
|
||||
f.height = height
|
||||
|
||||
inputWidth := min(50, width-10)
|
||||
for i := range f.fields {
|
||||
f.fields[i].Width = inputWidth
|
||||
}
|
||||
}
|
||||
|
||||
// Update handles input events.
|
||||
func (f *Form) Update(msg tea.Msg) tea.Cmd {
|
||||
switch msg := msg.(type) {
|
||||
case tea.KeyMsg:
|
||||
// Handle group dropdown navigation
|
||||
if f.focus == FieldGroup {
|
||||
switch msg.String() {
|
||||
case "tab":
|
||||
f.nextField()
|
||||
return nil
|
||||
case "shift+tab":
|
||||
f.prevField()
|
||||
return nil
|
||||
case "up", "k":
|
||||
if f.groupCursor > 0 {
|
||||
f.groupCursor--
|
||||
}
|
||||
return nil
|
||||
case "down", "j":
|
||||
if f.groupCursor < len(f.groups)-1 {
|
||||
f.groupCursor++
|
||||
}
|
||||
return nil
|
||||
case "left":
|
||||
if f.groupCursor > 0 {
|
||||
f.groupCursor--
|
||||
}
|
||||
return nil
|
||||
case "right":
|
||||
if f.groupCursor < len(f.groups)-1 {
|
||||
f.groupCursor++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle text input fields
|
||||
switch msg.String() {
|
||||
case "tab", "down":
|
||||
f.nextField()
|
||||
return nil
|
||||
case "shift+tab", "up":
|
||||
f.prevField()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Update the focused text field (only for Domain and IP)
|
||||
if f.focus != FieldGroup {
|
||||
var cmd tea.Cmd
|
||||
f.fields[f.focus], cmd = f.fields[f.focus].Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Form) nextField() {
|
||||
if f.focus != FieldGroup {
|
||||
f.fields[f.focus].Blur()
|
||||
}
|
||||
f.focus = (f.focus + 1) % FieldCount
|
||||
if f.focus != FieldGroup {
|
||||
f.fields[f.focus].Focus()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Form) prevField() {
|
||||
if f.focus != FieldGroup {
|
||||
f.fields[f.focus].Blur()
|
||||
}
|
||||
f.focus = (f.focus - 1 + FieldCount) % FieldCount
|
||||
if f.focus != FieldGroup {
|
||||
f.fields[f.focus].Focus()
|
||||
}
|
||||
}
|
||||
|
||||
// Values returns the form values (domain, ip, group).
|
||||
func (f *Form) Values() (domain, ip, group string) {
|
||||
group = ""
|
||||
if f.groupCursor < len(f.groups) {
|
||||
group = f.groups[f.groupCursor]
|
||||
}
|
||||
return strings.TrimSpace(f.fields[FieldDomain].Value()),
|
||||
strings.TrimSpace(f.fields[FieldIP].Value()),
|
||||
group
|
||||
}
|
||||
|
||||
// EditAlias returns the original alias when editing.
|
||||
func (f *Form) EditAlias() string {
|
||||
return f.editAlias
|
||||
}
|
||||
|
||||
// IsEdit returns true if in edit mode.
|
||||
func (f *Form) IsEdit() bool {
|
||||
return f.mode == FormModeEdit
|
||||
}
|
||||
|
||||
// Validate validates the form values.
|
||||
func (f *Form) Validate() string {
|
||||
domain, ip, group := f.Values()
|
||||
|
||||
if domain == "" {
|
||||
return "Domain is required"
|
||||
}
|
||||
if ip == "" {
|
||||
return "IP address is required"
|
||||
}
|
||||
if group == "" {
|
||||
return "Group is required"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// View renders the form.
|
||||
func (f *Form) View() string {
|
||||
var sb strings.Builder
|
||||
|
||||
title := "Add New Entry"
|
||||
if f.mode == FormModeEdit {
|
||||
title = "Edit Entry"
|
||||
}
|
||||
|
||||
sb.WriteString(titleStyle.Render(title))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// Domain field
|
||||
sb.WriteString(inputLabelStyle.Render("Domain:"))
|
||||
sb.WriteString("\n")
|
||||
style := inputStyle
|
||||
if f.focus == FieldDomain {
|
||||
style = inputFocusStyle
|
||||
}
|
||||
sb.WriteString(style.Render(f.fields[FieldDomain].View()))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// IP field
|
||||
sb.WriteString(inputLabelStyle.Render("IP Address:"))
|
||||
sb.WriteString("\n")
|
||||
style = inputStyle
|
||||
if f.focus == FieldIP {
|
||||
style = inputFocusStyle
|
||||
}
|
||||
sb.WriteString(style.Render(f.fields[FieldIP].View()))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
// Group dropdown
|
||||
sb.WriteString(inputLabelStyle.Render("Group:"))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(f.renderGroupDropdown())
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • ←→ select group • Enter save • Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (f *Form) renderGroupDropdown() string {
|
||||
isFocused := f.focus == FieldGroup
|
||||
|
||||
// Get current group name
|
||||
currentGroup := "default"
|
||||
if f.groupCursor < len(f.groups) {
|
||||
currentGroup = f.groups[f.groupCursor]
|
||||
}
|
||||
|
||||
// Build the selector content: ◀ group_name ▶
|
||||
var content string
|
||||
if isFocused {
|
||||
// Show arrows when focused
|
||||
leftArrow := "◀"
|
||||
rightArrow := "▶"
|
||||
if f.groupCursor == 0 {
|
||||
leftArrow = " " // dim or hide left arrow at start
|
||||
}
|
||||
if f.groupCursor >= len(f.groups)-1 {
|
||||
rightArrow = " " // dim or hide right arrow at end
|
||||
}
|
||||
content = leftArrow + " " + currentGroup + " " + rightArrow
|
||||
} else {
|
||||
content = " " + currentGroup + " "
|
||||
}
|
||||
|
||||
// Show position indicator if multiple groups
|
||||
if len(f.groups) > 1 {
|
||||
content += fmt.Sprintf(" (%d/%d)", f.groupCursor+1, len(f.groups))
|
||||
}
|
||||
|
||||
// Apply border style
|
||||
if isFocused {
|
||||
return inputFocusStyle.Render(content)
|
||||
}
|
||||
return inputStyle.Render(content)
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
// Package tui provides the group management component.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
)
|
||||
|
||||
// GroupMode represents the group view mode.
|
||||
type GroupMode int
|
||||
|
||||
const (
|
||||
GroupModeSelect GroupMode = iota
|
||||
GroupModeAdd
|
||||
GroupModeRename
|
||||
GroupModeConfirmDelete
|
||||
)
|
||||
|
||||
// GroupPicker handles the group selection and management UI.
|
||||
type GroupPicker struct {
|
||||
groups []string
|
||||
cursor int
|
||||
width int
|
||||
height int
|
||||
mode GroupMode
|
||||
input textinput.Model
|
||||
editName string // Original name when renaming
|
||||
}
|
||||
|
||||
// NewGroupPicker creates a new group picker.
|
||||
func NewGroupPicker() *GroupPicker {
|
||||
input := textinput.New()
|
||||
input.Placeholder = "group-name"
|
||||
input.CharLimit = 63
|
||||
|
||||
return &GroupPicker{
|
||||
input: input,
|
||||
mode: GroupModeSelect,
|
||||
}
|
||||
}
|
||||
|
||||
// SetGroups updates the available groups.
|
||||
func (g *GroupPicker) SetGroups(groups []string) {
|
||||
g.groups = groups
|
||||
if g.cursor >= len(groups) {
|
||||
g.cursor = max(0, len(groups)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSize sets the picker dimensions.
|
||||
func (g *GroupPicker) SetSize(width, height int) {
|
||||
g.width = width
|
||||
g.height = height
|
||||
g.input.Width = min(50, width-10)
|
||||
}
|
||||
|
||||
// MoveUp moves the cursor up.
|
||||
func (g *GroupPicker) MoveUp() {
|
||||
if g.cursor > 0 {
|
||||
g.cursor--
|
||||
}
|
||||
}
|
||||
|
||||
// MoveDown moves the cursor down.
|
||||
func (g *GroupPicker) MoveDown() {
|
||||
if g.cursor < len(g.groups)-1 {
|
||||
g.cursor++
|
||||
}
|
||||
}
|
||||
|
||||
// Selected returns the currently selected group.
|
||||
func (g *GroupPicker) Selected() string {
|
||||
if g.cursor >= 0 && g.cursor < len(g.groups) {
|
||||
return g.groups[g.cursor]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Len returns the number of groups.
|
||||
func (g *GroupPicker) Len() int {
|
||||
return len(g.groups)
|
||||
}
|
||||
|
||||
// Mode returns the current mode.
|
||||
func (g *GroupPicker) Mode() GroupMode {
|
||||
return g.mode
|
||||
}
|
||||
|
||||
// InitAdd initializes the form for adding a new group.
|
||||
func (g *GroupPicker) InitAdd() {
|
||||
g.mode = GroupModeAdd
|
||||
g.editName = ""
|
||||
g.input.Reset()
|
||||
g.input.Focus()
|
||||
}
|
||||
|
||||
// InitRename initializes the form for renaming an existing group.
|
||||
func (g *GroupPicker) InitRename() {
|
||||
selected := g.Selected()
|
||||
if selected == "" {
|
||||
return
|
||||
}
|
||||
|
||||
g.mode = GroupModeRename
|
||||
g.editName = selected
|
||||
g.input.SetValue(selected)
|
||||
g.input.Focus()
|
||||
}
|
||||
|
||||
// InitDelete starts delete confirmation.
|
||||
func (g *GroupPicker) InitDelete() {
|
||||
if g.Selected() == "" {
|
||||
return
|
||||
}
|
||||
g.mode = GroupModeConfirmDelete
|
||||
}
|
||||
|
||||
// CancelForm cancels the current form operation.
|
||||
func (g *GroupPicker) CancelForm() {
|
||||
g.mode = GroupModeSelect
|
||||
g.editName = ""
|
||||
g.input.Reset()
|
||||
g.input.Blur()
|
||||
}
|
||||
|
||||
// Update handles input events for form mode.
|
||||
func (g *GroupPicker) Update(msg tea.KeyMsg) tea.Cmd {
|
||||
var cmd tea.Cmd
|
||||
g.input, cmd = g.input.Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// FormValue returns the form input value.
|
||||
func (g *GroupPicker) FormValue() string {
|
||||
return strings.TrimSpace(g.input.Value())
|
||||
}
|
||||
|
||||
// EditName returns the original name when renaming.
|
||||
func (g *GroupPicker) EditName() string {
|
||||
return g.editName
|
||||
}
|
||||
|
||||
// IsRename returns true if in rename mode.
|
||||
func (g *GroupPicker) IsRename() bool {
|
||||
return g.mode == GroupModeRename
|
||||
}
|
||||
|
||||
// ValidateForm validates the form value.
|
||||
func (g *GroupPicker) ValidateForm() string {
|
||||
value := g.FormValue()
|
||||
if value == "" {
|
||||
return "Group name is required"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// View renders the group picker.
|
||||
func (g *GroupPicker) View() string {
|
||||
switch g.mode {
|
||||
case GroupModeAdd, GroupModeRename:
|
||||
return g.formView()
|
||||
case GroupModeConfirmDelete:
|
||||
return g.deleteView()
|
||||
default:
|
||||
return g.selectView()
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GroupPicker) selectView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render("Groups"))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if len(g.groups) == 0 {
|
||||
sb.WriteString(helpDescStyle.Render("No groups configured."))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("Press 'n' to create one"))
|
||||
} else {
|
||||
for i, group := range g.groups {
|
||||
if i == g.cursor {
|
||||
sb.WriteString(presetSelectedStyle.Render("▸ " + group))
|
||||
} else {
|
||||
sb.WriteString(presetItemStyle.Render(" " + group))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("↑↓ navigate • n new • r rename • d delete • Esc back"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (g *GroupPicker) formView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
title := "Add New Group"
|
||||
if g.mode == GroupModeRename {
|
||||
title = "Rename Group"
|
||||
}
|
||||
|
||||
sb.WriteString(titleStyle.Render(title))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
sb.WriteString(inputLabelStyle.Render("Name:"))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(inputFocusStyle.Render(g.input.View()))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("Enter save • Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (g *GroupPicker) deleteView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
groupName := g.Selected()
|
||||
|
||||
sb.WriteString(titleStyle.Render("Delete Group"))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete group '" + groupName + "'?"))
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpDescStyle.Render("This will remove all hosts in this group!"))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
@@ -0,0 +1,429 @@
|
||||
// Package tui provides the list view component.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
"github.com/charmbracelet/lipgloss/table"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// EntryItem represents a displayable host entry.
|
||||
type EntryItem struct {
|
||||
Entry protocol.HostEntry
|
||||
Pending bool
|
||||
HasError bool
|
||||
}
|
||||
|
||||
// ListView handles the list of host entries.
|
||||
type ListView struct {
|
||||
items []EntryItem
|
||||
groups map[string][]int // group name -> indices in items
|
||||
groupOrder []string // ordered group names
|
||||
cursor int
|
||||
width int
|
||||
height int
|
||||
}
|
||||
|
||||
// NewListView creates a new list view.
|
||||
func NewListView() *ListView {
|
||||
return &ListView{
|
||||
groups: make(map[string][]int),
|
||||
}
|
||||
}
|
||||
|
||||
// SetItems updates the list items.
|
||||
func (l *ListView) SetItems(entries []protocol.HostEntry) {
|
||||
l.items = make([]EntryItem, len(entries))
|
||||
l.groups = make(map[string][]int)
|
||||
l.groupOrder = nil
|
||||
|
||||
groupSeen := make(map[string]bool)
|
||||
|
||||
for i, e := range entries {
|
||||
l.items[i] = EntryItem{Entry: e}
|
||||
|
||||
if !groupSeen[e.Group] {
|
||||
groupSeen[e.Group] = true
|
||||
l.groupOrder = append(l.groupOrder, e.Group)
|
||||
}
|
||||
|
||||
l.groups[e.Group] = append(l.groups[e.Group], i)
|
||||
}
|
||||
|
||||
// Reset cursor if out of bounds
|
||||
if l.cursor >= len(l.items) {
|
||||
l.cursor = max(0, len(l.items)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSize sets the view dimensions.
|
||||
func (l *ListView) SetSize(width, height int) {
|
||||
l.width = width
|
||||
l.height = height
|
||||
}
|
||||
|
||||
// MoveUp moves the cursor up.
|
||||
func (l *ListView) MoveUp() {
|
||||
if l.cursor > 0 {
|
||||
l.cursor--
|
||||
}
|
||||
}
|
||||
|
||||
// MoveDown moves the cursor down.
|
||||
func (l *ListView) MoveDown() {
|
||||
if l.cursor < len(l.items)-1 {
|
||||
l.cursor++
|
||||
}
|
||||
}
|
||||
|
||||
// Selected returns the currently selected item.
|
||||
func (l *ListView) Selected() *EntryItem {
|
||||
if l.cursor >= 0 && l.cursor < len(l.items) {
|
||||
return &l.items[l.cursor]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SelectedAlias returns the alias of the selected item.
|
||||
func (l *ListView) SelectedAlias() string {
|
||||
if item := l.Selected(); item != nil {
|
||||
return item.Entry.Alias
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetPending marks an item as pending.
|
||||
func (l *ListView) SetPending(alias string, pending bool) {
|
||||
for i := range l.items {
|
||||
if l.items[i].Entry.Alias == alias {
|
||||
l.items[i].Pending = pending
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetError marks an item as having an error.
|
||||
func (l *ListView) SetError(alias string, hasError bool) {
|
||||
for i := range l.items {
|
||||
if l.items[i].Entry.Alias == alias {
|
||||
l.items[i].HasError = hasError
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateEntry updates an entry's enabled state.
|
||||
func (l *ListView) UpdateEntry(alias string, enabled bool) {
|
||||
for i := range l.items {
|
||||
if l.items[i].Entry.Alias == alias {
|
||||
l.items[i].Entry.Enabled = enabled
|
||||
l.items[i].Pending = false
|
||||
l.items[i].HasError = false
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of items.
|
||||
func (l *ListView) Len() int {
|
||||
return len(l.items)
|
||||
}
|
||||
|
||||
// ActiveCount returns the number of enabled entries.
|
||||
func (l *ListView) ActiveCount() int {
|
||||
count := 0
|
||||
for _, item := range l.items {
|
||||
if item.Entry.Enabled {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// FindByAlias finds an item by alias.
|
||||
func (l *ListView) FindByAlias(alias string) *EntryItem {
|
||||
for i := range l.items {
|
||||
if l.items[i].Entry.Alias == alias {
|
||||
return &l.items[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Filter filters items by search term.
|
||||
func (l *ListView) Filter(term string) []EntryItem {
|
||||
if term == "" {
|
||||
return l.items
|
||||
}
|
||||
|
||||
term = strings.ToLower(term)
|
||||
var filtered []EntryItem
|
||||
for _, item := range l.items {
|
||||
if strings.Contains(strings.ToLower(item.Entry.Domain), term) ||
|
||||
strings.Contains(strings.ToLower(item.Entry.Alias), term) ||
|
||||
strings.Contains(strings.ToLower(item.Entry.IP), term) ||
|
||||
strings.Contains(strings.ToLower(item.Entry.Group), term) {
|
||||
filtered = append(filtered, item)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// ViewFiltered renders the list filtered by search term.
|
||||
func (l *ListView) ViewFiltered(searchTerm string) string {
|
||||
if searchTerm == "" {
|
||||
return l.View()
|
||||
}
|
||||
|
||||
filtered := l.Filter(searchTerm)
|
||||
if len(filtered) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(colorMuted)
|
||||
return "\n" + emptyStyle.Render(fmt.Sprintf(" No results for '%s'. Press Esc to clear search.", searchTerm)) + "\n"
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Show search indicator
|
||||
searchIndicator := lipgloss.NewStyle().
|
||||
Foreground(colorWarning).
|
||||
Bold(true).
|
||||
Render(fmt.Sprintf(" Search: %s (%d results)", searchTerm, len(filtered)))
|
||||
sb.WriteString(searchIndicator)
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Group header style - bright colors for dark terminals
|
||||
groupHeaderStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorGroupHeader).
|
||||
Background(lipgloss.Color("238")).
|
||||
Padding(0, 1).
|
||||
MarginTop(1)
|
||||
|
||||
// Organize filtered items by group
|
||||
groupItems := make(map[string][]EntryItem)
|
||||
var groupOrder []string
|
||||
groupSeen := make(map[string]bool)
|
||||
|
||||
for _, item := range filtered {
|
||||
group := item.Entry.Group
|
||||
if !groupSeen[group] {
|
||||
groupSeen[group] = true
|
||||
groupOrder = append(groupOrder, group)
|
||||
}
|
||||
groupItems[group] = append(groupItems[group], item)
|
||||
}
|
||||
|
||||
for _, groupName := range groupOrder {
|
||||
items := groupItems[groupName]
|
||||
if len(items) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Group header
|
||||
headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(items))
|
||||
sb.WriteString(groupHeaderStyle.Render(headerText))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Build rows for this group's table
|
||||
var rows [][]string
|
||||
for _, item := range items {
|
||||
status := l.getStatusString(item)
|
||||
rows = append(rows, []string{
|
||||
truncate(item.Entry.Domain, 30),
|
||||
truncate(item.Entry.IP, 15),
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
// Create table for this group
|
||||
t := table.New().
|
||||
Border(lipgloss.HiddenBorder()).
|
||||
Headers("DOMAIN", "IP ADDRESS", "STATUS").
|
||||
Rows(rows...).
|
||||
StyleFunc(func(row, col int) lipgloss.Style {
|
||||
// Header row
|
||||
if row == table.HeaderRow {
|
||||
return lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorHeader).
|
||||
Padding(0, 1)
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle().Padding(0, 1)
|
||||
|
||||
if row >= 0 && row < len(items) {
|
||||
item := items[row]
|
||||
|
||||
// Disabled rows are muted
|
||||
if !item.Entry.Enabled && !item.Pending && !item.HasError {
|
||||
return baseStyle.Foreground(colorMuted)
|
||||
}
|
||||
|
||||
// Status column gets colored based on status
|
||||
if col == 2 { // STATUS column
|
||||
if item.HasError {
|
||||
return baseStyle.Foreground(colorError)
|
||||
}
|
||||
if item.Pending {
|
||||
return baseStyle.Foreground(colorWarning)
|
||||
}
|
||||
if item.Entry.Enabled {
|
||||
return baseStyle.Foreground(colorSuccess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return baseStyle
|
||||
})
|
||||
|
||||
sb.WriteString(t.Render())
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// GroupCount returns the number of groups.
|
||||
func (l *ListView) GroupCount() int {
|
||||
return len(l.groupOrder)
|
||||
}
|
||||
|
||||
// GetGroups returns all group names.
|
||||
func (l *ListView) GetGroups() []string {
|
||||
return l.groupOrder
|
||||
}
|
||||
|
||||
// View renders the list with groups as headers.
|
||||
func (l *ListView) View() string {
|
||||
if len(l.items) == 0 {
|
||||
emptyStyle := lipgloss.NewStyle().Foreground(colorMuted)
|
||||
return "\n" + emptyStyle.Render(" No host entries configured. Press 'n' to add a new entry.") + "\n"
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
// Group header style - bright colors for dark terminals
|
||||
groupHeaderStyle := lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorGroupHeader).
|
||||
Background(lipgloss.Color("238")).
|
||||
Padding(0, 1).
|
||||
MarginTop(1)
|
||||
|
||||
for _, groupName := range l.groupOrder {
|
||||
indices := l.groups[groupName]
|
||||
if len(indices) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Group header
|
||||
headerText := fmt.Sprintf(" %s (%d)", strings.ToUpper(groupName), len(indices))
|
||||
sb.WriteString(groupHeaderStyle.Render(headerText))
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Build rows for this group's table
|
||||
var rows [][]string
|
||||
// Store actual item indices for cursor matching
|
||||
itemIndices := make([]int, len(indices))
|
||||
copy(itemIndices, indices)
|
||||
|
||||
for _, idx := range indices {
|
||||
item := l.items[idx]
|
||||
status := l.getStatusString(item)
|
||||
rows = append(rows, []string{
|
||||
truncate(item.Entry.Domain, 30),
|
||||
truncate(item.Entry.IP, 15),
|
||||
status,
|
||||
})
|
||||
}
|
||||
|
||||
// Create table for this group
|
||||
t := table.New().
|
||||
Border(lipgloss.HiddenBorder()).
|
||||
Headers("DOMAIN", "IP ADDRESS", "STATUS").
|
||||
Rows(rows...).
|
||||
StyleFunc(func(row, col int) lipgloss.Style {
|
||||
// Header row
|
||||
if row == table.HeaderRow {
|
||||
return lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorHeader).
|
||||
Padding(0, 1)
|
||||
}
|
||||
|
||||
baseStyle := lipgloss.NewStyle().Padding(0, 1)
|
||||
|
||||
// Check if this row is selected
|
||||
if row >= 0 && row < len(itemIndices) {
|
||||
actualItemIdx := itemIndices[row]
|
||||
isSelected := actualItemIdx == l.cursor
|
||||
item := l.items[actualItemIdx]
|
||||
|
||||
// Selected row gets background highlight
|
||||
if isSelected {
|
||||
return baseStyle.
|
||||
Background(colorSelectedBg).
|
||||
Foreground(colorSelectedFg)
|
||||
}
|
||||
|
||||
// Disabled rows are muted
|
||||
if !item.Entry.Enabled && !item.Pending && !item.HasError {
|
||||
return baseStyle.Foreground(colorMuted)
|
||||
}
|
||||
|
||||
// Status column gets colored based on status
|
||||
if col == 2 { // STATUS column
|
||||
if item.HasError {
|
||||
return baseStyle.Foreground(colorError)
|
||||
}
|
||||
if item.Pending {
|
||||
return baseStyle.Foreground(colorWarning)
|
||||
}
|
||||
if item.Entry.Enabled {
|
||||
return baseStyle.Foreground(colorSuccess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return baseStyle
|
||||
})
|
||||
|
||||
sb.WriteString(t.Render())
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (l *ListView) getStatusString(item EntryItem) string {
|
||||
if item.HasError {
|
||||
return "✗ Error"
|
||||
}
|
||||
if item.Pending {
|
||||
return "◐ Pending"
|
||||
}
|
||||
if item.Entry.Enabled {
|
||||
return "● Active"
|
||||
}
|
||||
return "○ Disabled"
|
||||
}
|
||||
|
||||
func truncate(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
if maxLen <= 3 {
|
||||
return s[:maxLen]
|
||||
}
|
||||
return s[:maxLen-3] + "..."
|
||||
}
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
@@ -0,0 +1,409 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
func TestListView_SetItems(t *testing.T) {
|
||||
lv := NewListView()
|
||||
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
|
||||
}
|
||||
|
||||
lv.SetItems(entries)
|
||||
|
||||
assert.Equal(t, 3, lv.Len())
|
||||
assert.Len(t, lv.groups, 2)
|
||||
assert.Contains(t, lv.groupOrder, "dev")
|
||||
assert.Contains(t, lv.groupOrder, "staging")
|
||||
}
|
||||
|
||||
func TestListView_Navigation(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
// Initial position
|
||||
assert.Equal(t, 0, lv.cursor)
|
||||
|
||||
// Move down
|
||||
lv.MoveDown()
|
||||
assert.Equal(t, 1, lv.cursor)
|
||||
|
||||
lv.MoveDown()
|
||||
assert.Equal(t, 2, lv.cursor)
|
||||
|
||||
// Can't move past end
|
||||
lv.MoveDown()
|
||||
assert.Equal(t, 2, lv.cursor)
|
||||
|
||||
// Move up
|
||||
lv.MoveUp()
|
||||
assert.Equal(t, 1, lv.cursor)
|
||||
|
||||
lv.MoveUp()
|
||||
assert.Equal(t, 0, lv.cursor)
|
||||
|
||||
// Can't move before start
|
||||
lv.MoveUp()
|
||||
assert.Equal(t, 0, lv.cursor)
|
||||
}
|
||||
|
||||
func TestListView_Selected(t *testing.T) {
|
||||
lv := NewListView()
|
||||
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
item := lv.Selected()
|
||||
assert.Nil(t, item)
|
||||
})
|
||||
|
||||
t.Run("with items", func(t *testing.T) {
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
item := lv.Selected()
|
||||
require.NotNil(t, item)
|
||||
assert.Equal(t, "a.com", item.Entry.Domain)
|
||||
|
||||
lv.MoveDown()
|
||||
item = lv.Selected()
|
||||
require.NotNil(t, item)
|
||||
assert.Equal(t, "b.com", item.Entry.Domain)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListView_SelectedAlias(t *testing.T) {
|
||||
lv := NewListView()
|
||||
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
alias := lv.SelectedAlias()
|
||||
assert.Empty(t, alias)
|
||||
})
|
||||
|
||||
t.Run("with items", func(t *testing.T) {
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "my-alias", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
alias := lv.SelectedAlias()
|
||||
assert.Equal(t, "my-alias", alias)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListView_SetPending(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
assert.False(t, lv.items[0].Pending)
|
||||
|
||||
lv.SetPending("a", true)
|
||||
assert.True(t, lv.items[0].Pending)
|
||||
|
||||
lv.SetPending("a", false)
|
||||
assert.False(t, lv.items[0].Pending)
|
||||
|
||||
// Non-existent alias should not panic
|
||||
lv.SetPending("nonexistent", true)
|
||||
}
|
||||
|
||||
func TestListView_SetError(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
assert.False(t, lv.items[0].HasError)
|
||||
|
||||
lv.SetError("a", true)
|
||||
assert.True(t, lv.items[0].HasError)
|
||||
|
||||
lv.SetError("a", false)
|
||||
assert.False(t, lv.items[0].HasError)
|
||||
}
|
||||
|
||||
func TestListView_UpdateEntry(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
lv.items[0].Pending = true
|
||||
lv.items[0].HasError = true
|
||||
|
||||
lv.UpdateEntry("a", true)
|
||||
|
||||
assert.True(t, lv.items[0].Entry.Enabled)
|
||||
assert.False(t, lv.items[0].Pending)
|
||||
assert.False(t, lv.items[0].HasError)
|
||||
}
|
||||
|
||||
func TestListView_ActiveCount(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true, Group: "staging"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
assert.Equal(t, 2, lv.ActiveCount())
|
||||
}
|
||||
|
||||
func TestListView_FindByAlias(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
item := lv.FindByAlias("b")
|
||||
require.NotNil(t, item)
|
||||
assert.Equal(t, "b.com", item.Entry.Domain)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
item := lv.FindByAlias("nonexistent")
|
||||
assert.Nil(t, item)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListView_Filter(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "myapp.com", IP: "127.0.0.1", Alias: "myapp", Enabled: true, Group: "dev"},
|
||||
{Domain: "api.myapp.com", IP: "127.0.0.1", Alias: "api", Enabled: false, Group: "dev"},
|
||||
{Domain: "other.com", IP: "192.168.1.1", Alias: "other", Enabled: true, Group: "staging"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
t.Run("empty term", func(t *testing.T) {
|
||||
filtered := lv.Filter("")
|
||||
assert.Len(t, filtered, 3)
|
||||
})
|
||||
|
||||
t.Run("by domain", func(t *testing.T) {
|
||||
filtered := lv.Filter("myapp")
|
||||
assert.Len(t, filtered, 2)
|
||||
})
|
||||
|
||||
t.Run("by alias", func(t *testing.T) {
|
||||
filtered := lv.Filter("api")
|
||||
assert.Len(t, filtered, 1)
|
||||
assert.Equal(t, "api.myapp.com", filtered[0].Entry.Domain)
|
||||
})
|
||||
|
||||
t.Run("by IP", func(t *testing.T) {
|
||||
filtered := lv.Filter("192.168")
|
||||
assert.Len(t, filtered, 1)
|
||||
assert.Equal(t, "other.com", filtered[0].Entry.Domain)
|
||||
})
|
||||
|
||||
t.Run("case insensitive", func(t *testing.T) {
|
||||
filtered := lv.Filter("MYAPP")
|
||||
assert.Len(t, filtered, 2)
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
filtered := lv.Filter("nonexistent")
|
||||
assert.Empty(t, filtered)
|
||||
})
|
||||
}
|
||||
|
||||
func TestListView_View(t *testing.T) {
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
lv := NewListView()
|
||||
view := lv.View()
|
||||
assert.Contains(t, view, "No host entries")
|
||||
})
|
||||
|
||||
t.Run("with items", func(t *testing.T) {
|
||||
lv := NewListView()
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
view := lv.View()
|
||||
// Group header is shown as section title (uppercase)
|
||||
assert.Contains(t, view, "DEV")
|
||||
// Table headers
|
||||
assert.Contains(t, view, "DOMAIN")
|
||||
assert.Contains(t, view, "IP ADDRESS")
|
||||
assert.Contains(t, view, "STATUS")
|
||||
// Data is in the view
|
||||
assert.Contains(t, view, "example.com")
|
||||
assert.Contains(t, view, "127.0.0.1")
|
||||
assert.Contains(t, view, "Active")
|
||||
})
|
||||
}
|
||||
|
||||
func TestListView_SetSize(t *testing.T) {
|
||||
lv := NewListView()
|
||||
lv.SetSize(80, 24)
|
||||
|
||||
assert.Equal(t, 80, lv.width)
|
||||
assert.Equal(t, 24, lv.height)
|
||||
}
|
||||
|
||||
func TestListView_CursorBounds(t *testing.T) {
|
||||
lv := NewListView()
|
||||
|
||||
// Set items
|
||||
entries := []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
lv.cursor = 1
|
||||
|
||||
// Set fewer items - cursor should be adjusted
|
||||
entries = []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
assert.Equal(t, 0, lv.cursor)
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
maxLen int
|
||||
expected string
|
||||
}{
|
||||
{"short", 10, "short"},
|
||||
{"exactly10!", 10, "exactly10!"},
|
||||
{"this is too long", 10, "this is..."},
|
||||
{"", 5, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := truncate(tt.input, tt.maxLen)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMax(t *testing.T) {
|
||||
assert.Equal(t, 5, max(3, 5))
|
||||
assert.Equal(t, 5, max(5, 3))
|
||||
assert.Equal(t, 5, max(5, 5))
|
||||
assert.Equal(t, 0, max(0, -1))
|
||||
}
|
||||
|
||||
// Matrix test for navigation
|
||||
func TestListView_Navigation_Matrix(t *testing.T) {
|
||||
sizes := []int{1, 5, 10, 100}
|
||||
|
||||
for _, size := range sizes {
|
||||
t.Run("size="+string(rune('0'+size)), func(t *testing.T) {
|
||||
lv := NewListView()
|
||||
|
||||
entries := make([]protocol.HostEntry, size)
|
||||
for i := range entries {
|
||||
entries[i] = protocol.HostEntry{
|
||||
Domain: "domain" + string(rune('a'+i%26)) + ".com",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "alias" + string(rune('a'+i%26)),
|
||||
Enabled: true,
|
||||
Group: "dev",
|
||||
}
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
// Move to end
|
||||
for i := 0; i < size*2; i++ {
|
||||
lv.MoveDown()
|
||||
}
|
||||
assert.Equal(t, size-1, lv.cursor)
|
||||
|
||||
// Move to start
|
||||
for i := 0; i < size*2; i++ {
|
||||
lv.MoveUp()
|
||||
}
|
||||
assert.Equal(t, 0, lv.cursor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkListView_SetItems(b *testing.B) {
|
||||
entries := make([]protocol.HostEntry, 100)
|
||||
for i := range entries {
|
||||
entries[i] = protocol.HostEntry{
|
||||
Domain: "domain.com",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "alias",
|
||||
Enabled: true,
|
||||
Group: "dev",
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
lv := NewListView()
|
||||
lv.SetItems(entries)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkListView_Filter(b *testing.B) {
|
||||
lv := NewListView()
|
||||
entries := make([]protocol.HostEntry, 100)
|
||||
for i := range entries {
|
||||
entries[i] = protocol.HostEntry{
|
||||
Domain: "domain" + string(rune('a'+i%26)) + ".com",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "alias" + string(rune('a'+i%26)),
|
||||
Enabled: true,
|
||||
Group: "dev",
|
||||
}
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = lv.Filter("domain")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkListView_View(b *testing.B) {
|
||||
lv := NewListView()
|
||||
entries := make([]protocol.HostEntry, 50)
|
||||
for i := range entries {
|
||||
entries[i] = protocol.HostEntry{
|
||||
Domain: "domain.com",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "alias",
|
||||
Enabled: i%2 == 0,
|
||||
Group: "group" + string(rune('a'+i%5)),
|
||||
}
|
||||
}
|
||||
lv.SetItems(entries)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = lv.View()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,356 @@
|
||||
// Package tui provides the preset picker component.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/charmbracelet/bubbles/textinput"
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// PresetMode represents the preset view mode.
|
||||
type PresetMode int
|
||||
|
||||
const (
|
||||
PresetModeSelect PresetMode = iota
|
||||
PresetModeAdd
|
||||
PresetModeEdit
|
||||
PresetModeConfirmDelete
|
||||
)
|
||||
|
||||
// PresetFormField represents a form field index.
|
||||
type PresetFormField int
|
||||
|
||||
const (
|
||||
PresetFieldName PresetFormField = iota
|
||||
PresetFieldEnable
|
||||
PresetFieldDisable
|
||||
PresetFieldCount
|
||||
)
|
||||
|
||||
// PresetPicker handles the preset selection and management UI.
|
||||
type PresetPicker struct {
|
||||
presets []protocol.PresetInfo
|
||||
cursor int
|
||||
width int
|
||||
height int
|
||||
mode PresetMode
|
||||
fields []textinput.Model
|
||||
focus PresetFormField
|
||||
editName string // Original name when editing
|
||||
}
|
||||
|
||||
// NewPresetPicker creates a new preset picker.
|
||||
func NewPresetPicker() *PresetPicker {
|
||||
fields := make([]textinput.Model, PresetFieldCount)
|
||||
|
||||
// Name field
|
||||
fields[PresetFieldName] = textinput.New()
|
||||
fields[PresetFieldName].Placeholder = "preset-name"
|
||||
fields[PresetFieldName].CharLimit = 63
|
||||
|
||||
// Enable field
|
||||
fields[PresetFieldEnable] = textinput.New()
|
||||
fields[PresetFieldEnable].Placeholder = "alias1,alias2,alias3"
|
||||
fields[PresetFieldEnable].CharLimit = 500
|
||||
|
||||
// Disable field
|
||||
fields[PresetFieldDisable] = textinput.New()
|
||||
fields[PresetFieldDisable].Placeholder = "alias1,alias2,alias3"
|
||||
fields[PresetFieldDisable].CharLimit = 500
|
||||
|
||||
return &PresetPicker{
|
||||
fields: fields,
|
||||
mode: PresetModeSelect,
|
||||
}
|
||||
}
|
||||
|
||||
// SetPresets updates the available presets (legacy method for compatibility).
|
||||
func (p *PresetPicker) SetPresets(presets []string) {
|
||||
p.presets = make([]protocol.PresetInfo, len(presets))
|
||||
for i, name := range presets {
|
||||
p.presets[i] = protocol.PresetInfo{Name: name}
|
||||
}
|
||||
if p.cursor >= len(presets) {
|
||||
p.cursor = max(0, len(presets)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// SetPresetsWithInfo updates the available presets with full info.
|
||||
func (p *PresetPicker) SetPresetsWithInfo(presets []protocol.PresetInfo) {
|
||||
p.presets = presets
|
||||
if p.cursor >= len(presets) {
|
||||
p.cursor = max(0, len(presets)-1)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSize sets the picker dimensions.
|
||||
func (p *PresetPicker) SetSize(width, height int) {
|
||||
p.width = width
|
||||
p.height = height
|
||||
|
||||
inputWidth := min(60, width-10)
|
||||
for i := range p.fields {
|
||||
p.fields[i].Width = inputWidth
|
||||
}
|
||||
}
|
||||
|
||||
// MoveUp moves the cursor up.
|
||||
func (p *PresetPicker) MoveUp() {
|
||||
if p.cursor > 0 {
|
||||
p.cursor--
|
||||
}
|
||||
}
|
||||
|
||||
// MoveDown moves the cursor down.
|
||||
func (p *PresetPicker) MoveDown() {
|
||||
if p.cursor < len(p.presets)-1 {
|
||||
p.cursor++
|
||||
}
|
||||
}
|
||||
|
||||
// Selected returns the currently selected preset name.
|
||||
func (p *PresetPicker) Selected() string {
|
||||
if p.cursor >= 0 && p.cursor < len(p.presets) {
|
||||
return p.presets[p.cursor].Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SelectedInfo returns the currently selected preset info.
|
||||
func (p *PresetPicker) SelectedInfo() *protocol.PresetInfo {
|
||||
if p.cursor >= 0 && p.cursor < len(p.presets) {
|
||||
return &p.presets[p.cursor]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Len returns the number of presets.
|
||||
func (p *PresetPicker) Len() int {
|
||||
return len(p.presets)
|
||||
}
|
||||
|
||||
// Mode returns the current mode.
|
||||
func (p *PresetPicker) Mode() PresetMode {
|
||||
return p.mode
|
||||
}
|
||||
|
||||
// SetMode sets the mode.
|
||||
func (p *PresetPicker) SetMode(mode PresetMode) {
|
||||
p.mode = mode
|
||||
}
|
||||
|
||||
// InitAdd initializes the form for adding a new preset.
|
||||
func (p *PresetPicker) InitAdd() {
|
||||
p.mode = PresetModeAdd
|
||||
p.editName = ""
|
||||
for i := range p.fields {
|
||||
p.fields[i].Reset()
|
||||
}
|
||||
p.focus = PresetFieldName
|
||||
p.fields[PresetFieldName].Focus()
|
||||
}
|
||||
|
||||
// InitEdit initializes the form for editing an existing preset.
|
||||
func (p *PresetPicker) InitEdit() {
|
||||
preset := p.SelectedInfo()
|
||||
if preset == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.mode = PresetModeEdit
|
||||
p.editName = preset.Name
|
||||
|
||||
p.fields[PresetFieldName].SetValue(preset.Name)
|
||||
p.fields[PresetFieldEnable].SetValue(strings.Join(preset.Enable, ","))
|
||||
p.fields[PresetFieldDisable].SetValue(strings.Join(preset.Disable, ","))
|
||||
|
||||
p.focus = PresetFieldName
|
||||
p.fields[PresetFieldName].Focus()
|
||||
}
|
||||
|
||||
// InitDelete starts delete confirmation.
|
||||
func (p *PresetPicker) InitDelete() {
|
||||
if p.SelectedInfo() == nil {
|
||||
return
|
||||
}
|
||||
p.mode = PresetModeConfirmDelete
|
||||
}
|
||||
|
||||
// CancelForm cancels the current form operation.
|
||||
func (p *PresetPicker) CancelForm() {
|
||||
p.mode = PresetModeSelect
|
||||
p.editName = ""
|
||||
for i := range p.fields {
|
||||
p.fields[i].Reset()
|
||||
p.fields[i].Blur()
|
||||
}
|
||||
}
|
||||
|
||||
// Update handles input events for form mode.
|
||||
func (p *PresetPicker) Update(msg tea.KeyMsg) tea.Cmd {
|
||||
switch msg.String() {
|
||||
case "tab", "down":
|
||||
p.nextField()
|
||||
return nil
|
||||
case "shift+tab", "up":
|
||||
p.prevField()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update the focused field
|
||||
var cmd tea.Cmd
|
||||
p.fields[p.focus], cmd = p.fields[p.focus].Update(msg)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (p *PresetPicker) nextField() {
|
||||
p.fields[p.focus].Blur()
|
||||
p.focus = (p.focus + 1) % PresetFieldCount
|
||||
p.fields[p.focus].Focus()
|
||||
}
|
||||
|
||||
func (p *PresetPicker) prevField() {
|
||||
p.fields[p.focus].Blur()
|
||||
p.focus = (p.focus - 1 + PresetFieldCount) % PresetFieldCount
|
||||
p.fields[p.focus].Focus()
|
||||
}
|
||||
|
||||
// FormValues returns the form values (name, enable list, disable list).
|
||||
func (p *PresetPicker) FormValues() (name string, enable, disable []string) {
|
||||
name = strings.TrimSpace(p.fields[PresetFieldName].Value())
|
||||
|
||||
enableStr := strings.TrimSpace(p.fields[PresetFieldEnable].Value())
|
||||
if enableStr != "" {
|
||||
for _, s := range strings.Split(enableStr, ",") {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
enable = append(enable, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
disableStr := strings.TrimSpace(p.fields[PresetFieldDisable].Value())
|
||||
if disableStr != "" {
|
||||
for _, s := range strings.Split(disableStr, ",") {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
disable = append(disable, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return name, enable, disable
|
||||
}
|
||||
|
||||
// EditName returns the original name when editing.
|
||||
func (p *PresetPicker) EditName() string {
|
||||
return p.editName
|
||||
}
|
||||
|
||||
// IsEdit returns true if in edit mode.
|
||||
func (p *PresetPicker) IsEdit() bool {
|
||||
return p.mode == PresetModeEdit
|
||||
}
|
||||
|
||||
// ValidateForm validates the form values.
|
||||
func (p *PresetPicker) ValidateForm() string {
|
||||
name, enable, disable := p.FormValues()
|
||||
|
||||
if name == "" {
|
||||
return "Preset name is required"
|
||||
}
|
||||
if len(enable) == 0 && len(disable) == 0 {
|
||||
return "At least one alias to enable or disable is required"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// View renders the preset picker.
|
||||
func (p *PresetPicker) View() string {
|
||||
switch p.mode {
|
||||
case PresetModeAdd, PresetModeEdit:
|
||||
return p.formView()
|
||||
case PresetModeConfirmDelete:
|
||||
return p.deleteView()
|
||||
default:
|
||||
return p.selectView()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PresetPicker) selectView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString(titleStyle.Render("Presets"))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if len(p.presets) == 0 {
|
||||
sb.WriteString(helpDescStyle.Render("No presets configured."))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("Press 'n' to create one"))
|
||||
} else {
|
||||
for i, preset := range p.presets {
|
||||
if i == p.cursor {
|
||||
sb.WriteString(presetSelectedStyle.Render("▸ " + preset.Name))
|
||||
} else {
|
||||
sb.WriteString(presetItemStyle.Render(" " + preset.Name))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("↑↓ navigate • Enter apply • n new • e edit • d delete • Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (p *PresetPicker) formView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
title := "Add New Preset"
|
||||
if p.mode == PresetModeEdit {
|
||||
title = "Edit Preset"
|
||||
}
|
||||
|
||||
sb.WriteString(titleStyle.Render(title))
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
labels := []string{"Name:", "Enable aliases (comma-separated):", "Disable aliases (comma-separated):"}
|
||||
|
||||
for i, label := range labels {
|
||||
sb.WriteString(inputLabelStyle.Render(label))
|
||||
sb.WriteString("\n")
|
||||
|
||||
style := inputStyle
|
||||
if PresetFormField(i) == p.focus {
|
||||
style = inputFocusStyle
|
||||
}
|
||||
|
||||
sb.WriteString(style.Render(p.fields[i].View()))
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString(helpDescStyle.Render("Tab/↓ next • Shift+Tab/↑ prev • Enter save • Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
|
||||
func (p *PresetPicker) deleteView() string {
|
||||
var sb strings.Builder
|
||||
|
||||
preset := p.SelectedInfo()
|
||||
presetName := ""
|
||||
if preset != nil {
|
||||
presetName = preset.Name
|
||||
}
|
||||
|
||||
sb.WriteString(titleStyle.Render("Delete Preset"))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(errorMsgStyle.Render("Are you sure you want to delete preset '" + presetName + "'?"))
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(helpDescStyle.Render("y confirm • n/Esc cancel"))
|
||||
|
||||
return dialogStyle.Render(sb.String())
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
// Package tui provides the terminal user interface.
|
||||
package tui
|
||||
|
||||
import (
|
||||
"github.com/charmbracelet/lipgloss"
|
||||
)
|
||||
|
||||
// Colors - matching kportal style, optimized for dark terminals
|
||||
var (
|
||||
colorPrimary = lipgloss.Color("205") // Pink/Magenta
|
||||
colorSuccess = lipgloss.Color("42") // Green
|
||||
colorWarning = lipgloss.Color("220") // Yellow
|
||||
colorError = lipgloss.Color("196") // Red
|
||||
colorMuted = lipgloss.Color("245") // Gray (brighter for dark terminals)
|
||||
colorAccent = lipgloss.Color("141") // Light purple (brighter for dark terminals)
|
||||
colorHeader = lipgloss.Color("220") // Yellow for headers
|
||||
colorSelectedBg = lipgloss.Color("236") // Gray background for selection
|
||||
colorSelectedFg = lipgloss.Color("255") // White foreground for selection
|
||||
colorGroupHeader = lipgloss.Color("213") // Light pink for group headers
|
||||
)
|
||||
|
||||
// Title and header styles
|
||||
var (
|
||||
titleStyle = lipgloss.NewStyle().
|
||||
Bold(true).
|
||||
Foreground(colorHeader).
|
||||
Padding(0, 1)
|
||||
)
|
||||
|
||||
// Status indicators
|
||||
var (
|
||||
enabledStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSuccess).
|
||||
Bold(true)
|
||||
|
||||
disabledStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
|
||||
pendingStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWarning)
|
||||
|
||||
errorIndicatorStyle = lipgloss.NewStyle().
|
||||
Foreground(colorError)
|
||||
)
|
||||
|
||||
// Status bar and help
|
||||
var (
|
||||
statusBarStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
|
||||
connectedStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSuccess).
|
||||
SetString("Connected")
|
||||
|
||||
disconnectedStyle = lipgloss.NewStyle().
|
||||
Foreground(colorError).
|
||||
SetString("Disconnected")
|
||||
|
||||
helpBarStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
|
||||
helpKeyStyle = lipgloss.NewStyle().
|
||||
Foreground(colorHeader).
|
||||
Bold(true)
|
||||
|
||||
helpDescStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
)
|
||||
|
||||
// Message styles
|
||||
var (
|
||||
errorMsgStyle = lipgloss.NewStyle().
|
||||
Foreground(colorError).
|
||||
Bold(true).
|
||||
MarginTop(1)
|
||||
|
||||
successMsgStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSuccess).
|
||||
MarginTop(1)
|
||||
|
||||
updateStyle = lipgloss.NewStyle().
|
||||
Foreground(colorSuccess).
|
||||
Bold(true)
|
||||
)
|
||||
|
||||
// Form styles
|
||||
var (
|
||||
inputLabelStyle = lipgloss.NewStyle().
|
||||
Foreground(colorPrimary).
|
||||
Bold(true)
|
||||
|
||||
inputStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(colorMuted).
|
||||
Padding(0, 1)
|
||||
|
||||
inputFocusStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(colorPrimary).
|
||||
Padding(0, 1)
|
||||
)
|
||||
|
||||
// Dialog/modal styles
|
||||
var (
|
||||
dialogStyle = lipgloss.NewStyle().
|
||||
Border(lipgloss.RoundedBorder()).
|
||||
BorderForeground(colorAccent).
|
||||
Padding(1, 2)
|
||||
|
||||
presetItemStyle = lipgloss.NewStyle().
|
||||
Padding(0, 1)
|
||||
|
||||
presetSelectedStyle = lipgloss.NewStyle().
|
||||
Background(colorSelectedBg).
|
||||
Foreground(colorSelectedFg).
|
||||
Padding(0, 1)
|
||||
)
|
||||
|
||||
// Indicator returns the appropriate status indicator string.
|
||||
func Indicator(enabled bool, pending bool, hasError bool) string {
|
||||
if hasError {
|
||||
return errorIndicatorStyle.Render("✗")
|
||||
}
|
||||
if pending {
|
||||
return pendingStyle.Render("◐")
|
||||
}
|
||||
if enabled {
|
||||
return enabledStyle.Render("●")
|
||||
}
|
||||
return disabledStyle.Render("○")
|
||||
}
|
||||
|
||||
// StatusText returns the status text with appropriate styling
|
||||
func StatusText(enabled bool, pending bool, hasError bool) string {
|
||||
if hasError {
|
||||
return errorIndicatorStyle.Render("✗ Error")
|
||||
}
|
||||
if pending {
|
||||
return pendingStyle.Render("◐ Pending")
|
||||
}
|
||||
if enabled {
|
||||
return enabledStyle.Render("● Active")
|
||||
}
|
||||
return disabledStyle.Render("○ Disabled")
|
||||
}
|
||||
|
||||
// HelpItem formats a help item.
|
||||
func HelpItem(key, desc string) string {
|
||||
return helpKeyStyle.Render(key) + " " + helpDescStyle.Render(desc)
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
// Package version provides version checking against GitHub releases.
|
||||
package version
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// githubReleasesURL is the GitHub API endpoint for latest release
|
||||
githubReleasesURL = "https://api.github.com/repos/%s/%s/releases/latest"
|
||||
// requestTimeout is the timeout for HTTP requests
|
||||
requestTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// ReleaseInfo contains information about a GitHub release
|
||||
type ReleaseInfo struct {
|
||||
TagName string `json:"tag_name"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// UpdateInfo contains information about an available update
|
||||
type UpdateInfo struct {
|
||||
CurrentVersion string
|
||||
LatestVersion string
|
||||
ReleaseURL string
|
||||
ReleaseName string
|
||||
}
|
||||
|
||||
// Checker checks for new versions on GitHub
|
||||
type Checker struct {
|
||||
owner string
|
||||
repo string
|
||||
current string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewChecker creates a new version checker
|
||||
func NewChecker(owner, repo, currentVersion string) *Checker {
|
||||
return &Checker{
|
||||
owner: owner,
|
||||
repo: repo,
|
||||
current: normalizeVersion(currentVersion),
|
||||
client: &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CheckForUpdate checks if a newer version is available.
|
||||
// Returns nil if current version is up to date or if check fails.
|
||||
// This is designed to fail silently - network errors should not impact the user.
|
||||
func (c *Checker) CheckForUpdate(ctx context.Context) *UpdateInfo {
|
||||
release, err := c.fetchLatestRelease(ctx)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
latestVersion := normalizeVersion(release.TagName)
|
||||
if isNewerVersion(latestVersion, c.current) {
|
||||
return &UpdateInfo{
|
||||
CurrentVersion: c.current,
|
||||
LatestVersion: latestVersion,
|
||||
ReleaseURL: release.HTMLURL,
|
||||
ReleaseName: release.Name,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fetchLatestRelease fetches the latest release info from GitHub API
|
||||
func (c *Checker) fetchLatestRelease(ctx context.Context) (*ReleaseInfo, error) {
|
||||
url := fmt.Sprintf(githubReleasesURL, c.owner, c.repo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "lolcathost-version-checker")
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release ReleaseInfo
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
// normalizeVersion removes 'v' or 'V' prefix and trims whitespace
|
||||
func normalizeVersion(v string) string {
|
||||
v = strings.TrimSpace(v)
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
v = strings.TrimPrefix(v, "V")
|
||||
return v
|
||||
}
|
||||
|
||||
// isNewerVersion compares two semver-like versions.
|
||||
// Returns true if latest is newer than current.
|
||||
func isNewerVersion(latest, current string) bool {
|
||||
latestParts := parseVersion(latest)
|
||||
currentParts := parseVersion(current)
|
||||
|
||||
// Compare each part
|
||||
for i := 0; i < len(latestParts) && i < len(currentParts); i++ {
|
||||
if latestParts[i] > currentParts[i] {
|
||||
return true
|
||||
}
|
||||
if latestParts[i] < currentParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// If all compared parts are equal, longer version is newer
|
||||
// e.g., 1.0.1 > 1.0
|
||||
return len(latestParts) > len(currentParts)
|
||||
}
|
||||
|
||||
// parseVersion splits a version string into numeric parts
|
||||
func parseVersion(v string) []int {
|
||||
// Remove any suffix like -beta, -rc1, etc.
|
||||
if idx := strings.IndexAny(v, "-+"); idx != -1 {
|
||||
v = v[:idx]
|
||||
}
|
||||
|
||||
parts := strings.Split(v, ".")
|
||||
result := make([]int, 0, len(parts))
|
||||
|
||||
for _, p := range parts {
|
||||
var num int
|
||||
fmt.Sscanf(p, "%d", &num)
|
||||
result = append(result, num)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// FormatUpdateMessage formats a user-friendly update notification
|
||||
func (u *UpdateInfo) FormatUpdateMessage() string {
|
||||
return fmt.Sprintf("New version available: %s (current: %s) - %s",
|
||||
u.LatestVersion, u.CurrentVersion, u.ReleaseURL)
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNormalizeVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"v1.0.0", "1.0.0"},
|
||||
{"1.0.0", "1.0.0"},
|
||||
{" v2.1.3 ", "2.1.3"},
|
||||
{"V1.0.0", "1.0.0"},
|
||||
{"v0.1.0", "0.1.0"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := normalizeVersion(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []int
|
||||
}{
|
||||
{"1.0.0", []int{1, 0, 0}},
|
||||
{"2.1.3", []int{2, 1, 3}},
|
||||
{"1.0", []int{1, 0}},
|
||||
{"10.20.30", []int{10, 20, 30}},
|
||||
{"1.0.0-beta", []int{1, 0, 0}},
|
||||
{"1.0.0-rc1", []int{1, 0, 0}},
|
||||
{"1.0.0+build123", []int{1, 0, 0}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := parseVersion(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNewerVersion(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
latest string
|
||||
current string
|
||||
expected bool
|
||||
}{
|
||||
{"major version bump", "2.0.0", "1.0.0", true},
|
||||
{"minor version bump", "1.1.0", "1.0.0", true},
|
||||
{"patch version bump", "1.0.1", "1.0.0", true},
|
||||
{"same version", "1.0.0", "1.0.0", false},
|
||||
{"current is newer major", "1.0.0", "2.0.0", false},
|
||||
{"current is newer minor", "1.0.0", "1.1.0", false},
|
||||
{"current is newer patch", "1.0.0", "1.0.1", false},
|
||||
{"longer version is newer", "1.0.1", "1.0", true},
|
||||
{"shorter version is older", "1.0", "1.0.1", false},
|
||||
{"double digit versions", "10.0.0", "9.0.0", true},
|
||||
{"with prerelease suffix", "1.1.0", "1.0.0-beta", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isNewerVersion(tt.latest, tt.current)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInfo_FormatUpdateMessage(t *testing.T) {
|
||||
info := &UpdateInfo{
|
||||
CurrentVersion: "1.0.0",
|
||||
LatestVersion: "1.1.0",
|
||||
ReleaseURL: "https://github.com/lukaszraczylo/lolcathost/releases/tag/v1.1.0",
|
||||
}
|
||||
|
||||
msg := info.FormatUpdateMessage()
|
||||
assert.Contains(t, msg, "1.0.0")
|
||||
assert.Contains(t, msg, "1.1.0")
|
||||
assert.Contains(t, msg, "https://github.com")
|
||||
}
|
||||
|
||||
func TestNewChecker(t *testing.T) {
|
||||
checker := NewChecker("lukaszraczylo", "lolcathost", "v1.0.0")
|
||||
|
||||
assert.Equal(t, "lukaszraczylo", checker.owner)
|
||||
assert.Equal(t, "lolcathost", checker.repo)
|
||||
assert.Equal(t, "1.0.0", checker.current) // Should be normalized
|
||||
assert.NotNil(t, checker.client)
|
||||
}
|
||||
Reference in New Issue
Block a user