mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-05 23:29:18 +00:00
Initial commit.
This commit is contained in:
@@ -0,0 +1,427 @@
|
||||
// Package client provides a client library for communicating with the lolcathost daemon.
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Client is a client for the lolcathost daemon.
|
||||
type Client struct {
|
||||
socketPath string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
timeout time.Duration
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new client.
|
||||
func New(socketPath string) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithTimeout creates a new client with a custom timeout.
|
||||
func NewWithTimeout(socketPath string, timeout time.Duration) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the daemon.
|
||||
func (c *Client) Connect() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Close existing connection if any
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
c.reader = nil
|
||||
}
|
||||
|
||||
conn, err := net.DialTimeout("unix", c.socketPath, c.timeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.reader = bufio.NewReader(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn != nil {
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
c.reader = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// send sends a request and receives a response.
|
||||
func (c *Client) send(req *protocol.Request) (*protocol.Response, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.conn == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
// Set deadline
|
||||
c.conn.SetDeadline(time.Now().Add(c.timeout))
|
||||
|
||||
// Send request
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
|
||||
if _, err := c.conn.Write(data); err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
|
||||
// Read response
|
||||
line, err := c.reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
var resp protocol.Response
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Ping checks if the daemon is responsive.
|
||||
func (c *Client) Ping() error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestPing, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("ping failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Status returns the daemon's status.
|
||||
func (c *Client) Status() (*protocol.StatusData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestStatus, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("status failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.StatusData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// List returns all host entries.
|
||||
func (c *Client) List() ([]protocol.HostEntry, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestList, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("list failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.ListData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Entries, nil
|
||||
}
|
||||
|
||||
// Set enables or disables a host entry by alias.
|
||||
func (c *Client) Set(alias string, enabled bool, force bool) (*protocol.SetData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{
|
||||
Alias: alias,
|
||||
Enabled: enabled,
|
||||
Force: force,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.SetData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// Enable enables a host entry by alias.
|
||||
func (c *Client) Enable(alias string) (*protocol.SetData, error) {
|
||||
return c.Set(alias, true, false)
|
||||
}
|
||||
|
||||
// Disable disables a host entry by alias.
|
||||
func (c *Client) Disable(alias string) (*protocol.SetData, error) {
|
||||
return c.Set(alias, false, false)
|
||||
}
|
||||
|
||||
// Add adds a new host entry.
|
||||
func (c *Client) Add(domain, ip, alias, group string, enabled bool) (*protocol.SetData, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: alias,
|
||||
Group: group,
|
||||
Enabled: enabled,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.SetData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
// Delete removes a host entry by alias.
|
||||
func (c *Client) Delete(alias string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{
|
||||
Alias: alias,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddGroup adds a new group.
|
||||
func (c *Client) AddGroup(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup removes a group and all its hosts.
|
||||
func (c *Client) DeleteGroup(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListGroups returns all group names.
|
||||
func (c *Client) ListGroups() ([]string, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestListGroups, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.GroupsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Groups, nil
|
||||
}
|
||||
|
||||
// Sync synchronizes the config to the hosts file.
|
||||
func (c *Client) Sync() error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestSync, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("sync failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPreset applies a named preset.
|
||||
func (c *Client) ApplyPreset(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("preset failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Rollback restores a backup by name.
|
||||
func (c *Client) Rollback(backupName string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{
|
||||
BackupName: backupName,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("rollback failed: %s", resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBackups returns available backups.
|
||||
func (c *Client) ListBackups() ([]protocol.BackupInfo, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestBackups, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("backups failed: %s", resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.BackupsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Backups, nil
|
||||
}
|
||||
|
||||
// RenameGroup renames a group.
|
||||
func (c *Client) RenameGroup(oldName, newName string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{
|
||||
OldName: oldName,
|
||||
NewName: newName,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPreset adds a new preset.
|
||||
func (c *Client) AddPreset(name string, enable, disable []string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{
|
||||
Name: name,
|
||||
Enable: enable,
|
||||
Disable: disable,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePreset removes a preset by name.
|
||||
func (c *Client) DeletePreset(name string) error {
|
||||
req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{
|
||||
Name: name,
|
||||
})
|
||||
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPresets returns all presets.
|
||||
func (c *Client) ListPresets() ([]protocol.PresetInfo, error) {
|
||||
req, _ := protocol.NewRequest(protocol.RequestListPresets, nil)
|
||||
resp, err := c.send(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !resp.IsOK() {
|
||||
return nil, fmt.Errorf("%s: %s", resp.Code, resp.Message)
|
||||
}
|
||||
|
||||
var data protocol.PresetsData
|
||||
if err := resp.ParseData(&data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data.Presets, nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the daemon is reachable.
|
||||
func IsConnected(socketPath string) bool {
|
||||
client := New(socketPath)
|
||||
if err := client.Connect(); err != nil {
|
||||
return false
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
return client.Ping() == nil
|
||||
}
|
||||
@@ -0,0 +1,516 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// mockServer creates a mock Unix socket server for testing
|
||||
type mockServer struct {
|
||||
listener net.Listener
|
||||
path string
|
||||
handler func(req *protocol.Request) *protocol.Response
|
||||
}
|
||||
|
||||
func newMockServer(t *testing.T) *mockServer {
|
||||
// Use /tmp directly to avoid long paths (Unix socket paths have ~104 char limit on macOS)
|
||||
tmpDir, err := os.MkdirTemp("/tmp", "lolcat")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { os.RemoveAll(tmpDir) })
|
||||
|
||||
socketPath := filepath.Join(tmpDir, "s.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
ms := &mockServer{
|
||||
listener: listener,
|
||||
path: socketPath,
|
||||
}
|
||||
|
||||
go ms.serve()
|
||||
|
||||
return ms
|
||||
}
|
||||
|
||||
func (ms *mockServer) serve() {
|
||||
for {
|
||||
conn, err := ms.listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go ms.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mockServer) handleConn(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var resp *protocol.Response
|
||||
if ms.handler != nil {
|
||||
resp = ms.handler(&req)
|
||||
} else {
|
||||
resp, _ = protocol.NewOKResponse(nil)
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(resp)
|
||||
conn.Write(append(data, '\n'))
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mockServer) close() {
|
||||
ms.listener.Close()
|
||||
os.Remove(ms.path)
|
||||
}
|
||||
|
||||
func TestClient_Connect(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
assert.NotNil(t, client.conn)
|
||||
assert.NotNil(t, client.reader)
|
||||
})
|
||||
|
||||
t.Run("failure - socket not found", func(t *testing.T) {
|
||||
client := New("/nonexistent/socket.sock")
|
||||
err := client.Connect()
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClient_Ping(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestPing {
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected request")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Ping()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Status(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestStatus {
|
||||
resp, _ := protocol.NewOKResponse(protocol.StatusData{
|
||||
Running: true,
|
||||
Version: "1.0.0",
|
||||
Uptime: 3600,
|
||||
ActiveCount: 5,
|
||||
RequestCount: 100,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
status, err := client.Status()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, status.Running)
|
||||
assert.Equal(t, "1.0.0", status.Version)
|
||||
assert.Equal(t, int64(3600), status.Uptime)
|
||||
assert.Equal(t, 5, status.ActiveCount)
|
||||
assert.Equal(t, int64(100), status.RequestCount)
|
||||
}
|
||||
|
||||
func TestClient_List(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestList {
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{
|
||||
Entries: []protocol.HostEntry{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true, Group: "dev"},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false, Group: "dev"},
|
||||
},
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
entries, err := client.List()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, entries, 2)
|
||||
assert.Equal(t, "a.com", entries[0].Domain)
|
||||
assert.True(t, entries[0].Enabled)
|
||||
assert.Equal(t, "b.com", entries[1].Domain)
|
||||
assert.False(t, entries[1].Enabled)
|
||||
}
|
||||
|
||||
func TestClient_Set(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: "example.com",
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
data, err := client.Set("test", true, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "example.com", data.Domain)
|
||||
assert.True(t, data.Applied)
|
||||
}
|
||||
|
||||
func TestClient_Enable(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.True(t, payload.Enabled)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Enable("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Disable(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSet {
|
||||
var payload protocol.SetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.False(t, payload.Enabled)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: "test.com", Applied: true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Disable("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Sync(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestSync {
|
||||
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Sync()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_ApplyPreset(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestPreset {
|
||||
var payload protocol.PresetPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.Equal(t, "local", payload.Name)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": "local"})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.ApplyPreset("local")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_Rollback(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestRollback {
|
||||
var payload protocol.RollbackPayload
|
||||
req.ParsePayload(&payload)
|
||||
assert.Equal(t, "hosts.backup.bak", payload.BackupName)
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
err = client.Rollback("hosts.backup.bak")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestClient_ListBackups(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
if req.Type == protocol.RequestBackups {
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{
|
||||
Backups: []protocol.BackupInfo{
|
||||
{Name: "hosts.20231201.bak", Timestamp: 1701432000, Size: 1024},
|
||||
{Name: "hosts.20231130.bak", Timestamp: 1701345600, Size: 1000},
|
||||
},
|
||||
})
|
||||
return resp
|
||||
}
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
backups, err := client.ListBackups()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, backups, 2)
|
||||
assert.Equal(t, "hosts.20231201.bak", backups[0].Name)
|
||||
}
|
||||
|
||||
func TestClient_ErrorResponse(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, "domain is blocked")
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_, err = client.Set("test", true, false)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "domain is blocked")
|
||||
}
|
||||
|
||||
func TestClient_NotConnected(t *testing.T) {
|
||||
client := New("/nonexistent/socket.sock")
|
||||
|
||||
_, err := client.Status()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "not connected")
|
||||
}
|
||||
|
||||
func TestClient_Timeout(t *testing.T) {
|
||||
client := NewWithTimeout("/nonexistent.sock", 100*time.Millisecond)
|
||||
assert.Equal(t, 100*time.Millisecond, client.timeout)
|
||||
}
|
||||
|
||||
func TestIsConnected(t *testing.T) {
|
||||
t.Run("connected", func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
return resp
|
||||
}
|
||||
|
||||
connected := IsConnected(server.path)
|
||||
assert.True(t, connected)
|
||||
})
|
||||
|
||||
t.Run("not connected", func(t *testing.T) {
|
||||
connected := IsConnected("/nonexistent/socket.sock")
|
||||
assert.False(t, connected)
|
||||
})
|
||||
}
|
||||
|
||||
// Matrix test for request types
|
||||
func TestClient_RequestTypes_Matrix(t *testing.T) {
|
||||
types := []struct {
|
||||
name string
|
||||
reqType protocol.RequestType
|
||||
call func(*Client) error
|
||||
}{
|
||||
{"ping", protocol.RequestPing, func(c *Client) error { return c.Ping() }},
|
||||
{"status", protocol.RequestStatus, func(c *Client) error { _, err := c.Status(); return err }},
|
||||
{"list", protocol.RequestList, func(c *Client) error { _, err := c.List(); return err }},
|
||||
{"sync", protocol.RequestSync, func(c *Client) error { return c.Sync() }},
|
||||
{"preset", protocol.RequestPreset, func(c *Client) error { return c.ApplyPreset("test") }},
|
||||
{"backups", protocol.RequestBackups, func(c *Client) error { _, err := c.ListBackups(); return err }},
|
||||
}
|
||||
|
||||
for _, tt := range types {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
receivedType := protocol.RequestType("")
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
receivedType = req.Type
|
||||
|
||||
switch req.Type {
|
||||
case protocol.RequestStatus:
|
||||
resp, _ := protocol.NewOKResponse(protocol.StatusData{})
|
||||
return resp
|
||||
case protocol.RequestList:
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{})
|
||||
return resp
|
||||
case protocol.RequestBackups:
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{})
|
||||
return resp
|
||||
default:
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
client := New(server.path)
|
||||
err := client.Connect()
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
_ = tt.call(client)
|
||||
assert.Equal(t, tt.reqType, receivedType)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkClient_Ping(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
socketPath := filepath.Join(tmpDir, "bench.sock")
|
||||
|
||||
listener, err := net.Listen("unix", socketPath)
|
||||
require.NoError(b, err)
|
||||
defer listener.Close()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
reader := bufio.NewReader(c)
|
||||
for {
|
||||
_, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
data, _ := json.Marshal(resp)
|
||||
c.Write(append(data, '\n'))
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
client := New(socketPath)
|
||||
err = client.Connect()
|
||||
require.NoError(b, err)
|
||||
defer client.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = client.Ping()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user