diff --git a/internal/client/client_test.go b/internal/client/client_test.go index ac9fd15..59e0ff3 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -474,6 +474,276 @@ func TestClient_RequestTypes_Matrix(t *testing.T) { } } +func TestClient_Add(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestAdd { + var payload protocol.AddPayload + req.ParsePayload(&payload) + assert.Equal(t, "test.local", payload.Domain) + assert.Equal(t, "127.0.0.1", payload.IP) + assert.Equal(t, "test-local", payload.Alias) + assert.Equal(t, "dev", payload.Group) + assert.True(t, payload.Enabled) + + resp, _ := protocol.NewOKResponse(protocol.SetData{Domain: payload.Domain, Applied: true}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + data, err := client.Add("test.local", "127.0.0.1", "test-local", "dev", true) + assert.NoError(t, err) + assert.Equal(t, "test.local", data.Domain) + assert.True(t, data.Applied) +} + +func TestClient_Delete(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestDelete { + var payload protocol.DeletePayload + req.ParsePayload(&payload) + assert.Equal(t, "test-alias", payload.Alias) + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.Delete("test-alias") + assert.NoError(t, err) +} + +func TestClient_AddGroup(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestAddGroup { + var payload protocol.GroupPayload + req.ParsePayload(&payload) + assert.Equal(t, "newgroup", payload.Name) + + resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.AddGroup("newgroup") + assert.NoError(t, err) +} + +func TestClient_DeleteGroup(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestDeleteGroup { + var payload protocol.GroupPayload + req.ParsePayload(&payload) + assert.Equal(t, "todelete", payload.Name) + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.DeleteGroup("todelete") + assert.NoError(t, err) +} + +func TestClient_RenameGroup(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestRenameGroup { + var payload protocol.RenameGroupPayload + req.ParsePayload(&payload) + assert.Equal(t, "oldname", payload.OldName) + assert.Equal(t, "newname", payload.NewName) + + resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.RenameGroup("oldname", "newname") + assert.NoError(t, err) +} + +func TestClient_ListGroups(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestListGroups { + resp, _ := protocol.NewOKResponse(protocol.GroupsData{ + Groups: []string{"dev", "staging", "prod"}, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + groups, err := client.ListGroups() + require.NoError(t, err) + + assert.Equal(t, []string{"dev", "staging", "prod"}, groups) +} + +func TestClient_GetBackupContent(t *testing.T) { + server := newMockServer(t) + defer server.close() + + expectedContent := "127.0.0.1\tlocalhost\n" + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestBackupContent { + var payload protocol.BackupContentPayload + req.ParsePayload(&payload) + assert.Equal(t, "hosts.backup.bak", payload.BackupName) + + resp, _ := protocol.NewOKResponse(protocol.BackupContentData{ + Content: expectedContent, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + content, err := client.GetBackupContent("hosts.backup.bak") + require.NoError(t, err) + + assert.Equal(t, expectedContent, content) +} + +func TestClient_AddPreset(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestAddPreset { + var payload protocol.AddPresetPayload + req.ParsePayload(&payload) + assert.Equal(t, "newpreset", payload.Name) + assert.Equal(t, []string{"a", "b"}, payload.Enable) + assert.Equal(t, []string{"c"}, payload.Disable) + + resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.AddPreset("newpreset", []string{"a", "b"}, []string{"c"}) + assert.NoError(t, err) +} + +func TestClient_DeletePreset(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestDeletePreset { + var payload protocol.PresetPayload + req.ParsePayload(&payload) + assert.Equal(t, "todelete", payload.Name) + + resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name}) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + err = client.DeletePreset("todelete") + assert.NoError(t, err) +} + +func TestClient_ListPresets(t *testing.T) { + server := newMockServer(t) + defer server.close() + + server.handler = func(req *protocol.Request) *protocol.Response { + if req.Type == protocol.RequestListPresets { + resp, _ := protocol.NewOKResponse(protocol.PresetsData{ + Presets: []protocol.PresetInfo{ + {Name: "local", Enable: []string{"a"}, Disable: []string{"b"}}, + {Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}}, + }, + }) + return resp + } + return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "unexpected") + } + + client := New(server.path) + err := client.Connect() + require.NoError(t, err) + defer client.Close() + + presets, err := client.ListPresets() + require.NoError(t, err) + + assert.Len(t, presets, 2) + assert.Equal(t, "local", presets[0].Name) + assert.Equal(t, "staging", presets[1].Name) +} + func BenchmarkClient_Ping(b *testing.B) { tmpDir := b.TempDir() socketPath := filepath.Join(tmpDir, "bench.sock") diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ab58dfe..552ca40 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -265,3 +265,328 @@ func TestFlushMethod(t *testing.T) { }) } } + +func TestDefaultConfigDir(t *testing.T) { + dir := DefaultConfigDir() + assert.NotEmpty(t, dir) + assert.Contains(t, dir, ".config/lolcathost") +} + +func TestDefaultConfigPath(t *testing.T) { + path := DefaultConfigPath() + assert.NotEmpty(t, path) + assert.Contains(t, path, "config.yaml") +} + +func TestConfig_GenerateAlias(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + { + Name: "dev", + Hosts: []Host{ + {Domain: "existing.com", IP: "127.0.0.1", Alias: "existing-com", Enabled: true}, + }, + }, + }, + } + + t.Run("simple domain", func(t *testing.T) { + alias := cfg.GenerateAlias("newdomain.com") + assert.Equal(t, "newdomain-com", alias) + }) + + t.Run("domain with underscore", func(t *testing.T) { + alias := cfg.GenerateAlias("my_app.test") + assert.Equal(t, "my-app-test", alias) + }) + + t.Run("duplicate generates numbered alias", func(t *testing.T) { + alias := cfg.GenerateAlias("existing.com") + assert.Equal(t, "existing-com-2", alias) + }) +} + +func TestConfig_AddHost(t *testing.T) { + t.Run("add to existing group", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{}}, + }, + } + err := cfg.AddHost("test.local", "127.0.0.1", "test-local", "dev", true) + require.NoError(t, err) + assert.Len(t, cfg.Groups[0].Hosts, 1) + assert.Equal(t, "test.local", cfg.Groups[0].Hosts[0].Domain) + }) + + t.Run("add to new group", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.AddHost("test.local", "127.0.0.1", "test-local", "newgroup", true) + require.NoError(t, err) + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "newgroup", cfg.Groups[0].Name) + }) + + t.Run("auto-generate alias", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.AddHost("auto.test", "127.0.0.1", "", "dev", true) + require.NoError(t, err) + assert.Equal(t, "auto-test", cfg.Groups[0].Hosts[0].Alias) + }) + + t.Run("duplicate alias error", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{{Domain: "a.com", IP: "127.0.0.1", Alias: "existing"}}}, + }, + } + err := cfg.AddHost("b.com", "127.0.0.1", "existing", "dev", true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "alias already exists") + }) +} + +func TestConfig_AddGroup(t *testing.T) { + t.Run("add new group", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.AddGroup("newgroup") + require.NoError(t, err) + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "newgroup", cfg.Groups[0].Name) + }) + + t.Run("duplicate group error", func(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "existing"}}} + err := cfg.AddGroup("existing") + assert.Error(t, err) + assert.Contains(t, err.Error(), "group already exists") + }) +} + +func TestConfig_DeleteGroup(t *testing.T) { + t.Run("delete existing group", func(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "todelete"}, {Name: "keep"}}} + err := cfg.DeleteGroup("todelete") + require.NoError(t, err) + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "keep", cfg.Groups[0].Name) + }) + + t.Run("delete nonexistent group", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.DeleteGroup("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "group not found") + }) +} + +func TestConfig_RenameGroup(t *testing.T) { + t.Run("rename existing group", func(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "oldname"}}} + err := cfg.RenameGroup("oldname", "newname") + require.NoError(t, err) + assert.Equal(t, "newname", cfg.Groups[0].Name) + }) + + t.Run("rename to existing name error", func(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "a"}, {Name: "b"}}} + err := cfg.RenameGroup("a", "b") + assert.Error(t, err) + assert.Contains(t, err.Error(), "group already exists") + }) + + t.Run("rename nonexistent group", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.RenameGroup("nonexistent", "newname") + assert.Error(t, err) + assert.Contains(t, err.Error(), "group not found") + }) +} + +func TestConfig_GetGroups(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "a"}, {Name: "b"}, {Name: "c"}}} + groups := cfg.GetGroups() + assert.Equal(t, []string{"a", "b", "c"}, groups) +} + +func TestConfig_DeleteHost(t *testing.T) { + t.Run("delete existing host", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{ + {Domain: "a.com", Alias: "a"}, + {Domain: "b.com", Alias: "b"}, + }}, + }, + } + result := cfg.DeleteHost("a") + assert.True(t, result) + assert.Len(t, cfg.Groups[0].Hosts, 1) + assert.Equal(t, "b", cfg.Groups[0].Hosts[0].Alias) + }) + + t.Run("delete nonexistent host", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + result := cfg.DeleteHost("nonexistent") + assert.False(t, result) + }) +} + +func TestConfig_UpdateHost(t *testing.T) { + t.Run("update in same group", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{{Domain: "old.com", IP: "127.0.0.1", Alias: "test"}}}, + }, + } + err := cfg.UpdateHost("test", "new.com", "192.168.1.1", "test", "dev") + require.NoError(t, err) + assert.Equal(t, "new.com", cfg.Groups[0].Hosts[0].Domain) + assert.Equal(t, "192.168.1.1", cfg.Groups[0].Hosts[0].IP) + }) + + t.Run("move to different group", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "source", Hosts: []Host{{Domain: "a.com", IP: "127.0.0.1", Alias: "test"}}}, + {Name: "target", Hosts: []Host{}}, + }, + } + err := cfg.UpdateHost("test", "a.com", "127.0.0.1", "test", "target") + require.NoError(t, err) + assert.Len(t, cfg.Groups[0].Hosts, 0) + assert.Len(t, cfg.Groups[1].Hosts, 1) + }) + + t.Run("move to new group", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "source", Hosts: []Host{{Domain: "a.com", IP: "127.0.0.1", Alias: "test"}}}, + }, + } + err := cfg.UpdateHost("test", "a.com", "127.0.0.1", "test", "newgroup") + require.NoError(t, err) + assert.Len(t, cfg.Groups, 2) + assert.Equal(t, "newgroup", cfg.Groups[1].Name) + }) + + t.Run("change alias", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{{Domain: "a.com", IP: "127.0.0.1", Alias: "old"}}}, + }, + } + err := cfg.UpdateHost("old", "a.com", "127.0.0.1", "new", "dev") + require.NoError(t, err) + assert.Equal(t, "new", cfg.Groups[0].Hosts[0].Alias) + }) + + t.Run("alias conflict error", func(t *testing.T) { + cfg := &Config{ + Groups: []Group{ + {Name: "dev", Hosts: []Host{ + {Domain: "a.com", Alias: "a"}, + {Domain: "b.com", Alias: "b"}, + }}, + }, + } + err := cfg.UpdateHost("a", "a.com", "127.0.0.1", "b", "dev") + assert.Error(t, err) + assert.Contains(t, err.Error(), "alias already exists") + }) + + t.Run("host not found error", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + err := cfg.UpdateHost("nonexistent", "a.com", "127.0.0.1", "new", "dev") + assert.Error(t, err) + assert.Contains(t, err.Error(), "alias not found") + }) +} + +func TestConfig_AddPreset(t *testing.T) { + t.Run("add new preset", func(t *testing.T) { + cfg := &Config{Presets: []Preset{}} + err := cfg.AddPreset("newpreset", []string{"a"}, []string{"b"}) + require.NoError(t, err) + assert.Len(t, cfg.Presets, 1) + assert.Equal(t, "newpreset", cfg.Presets[0].Name) + }) + + t.Run("duplicate preset error", func(t *testing.T) { + cfg := &Config{Presets: []Preset{{Name: "existing"}}} + err := cfg.AddPreset("existing", nil, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "preset already exists") + }) +} + +func TestConfig_DeletePreset(t *testing.T) { + t.Run("delete existing preset", func(t *testing.T) { + cfg := &Config{Presets: []Preset{{Name: "todelete"}, {Name: "keep"}}} + err := cfg.DeletePreset("todelete") + require.NoError(t, err) + assert.Len(t, cfg.Presets, 1) + assert.Equal(t, "keep", cfg.Presets[0].Name) + }) + + t.Run("delete nonexistent preset", func(t *testing.T) { + cfg := &Config{Presets: []Preset{}} + err := cfg.DeletePreset("nonexistent") + assert.Error(t, err) + assert.Contains(t, err.Error(), "preset not found") + }) +} + +func TestConfig_GetPresets(t *testing.T) { + cfg := &Config{Presets: []Preset{{Name: "a"}, {Name: "b"}}} + presets := cfg.GetPresets() + assert.Len(t, presets, 2) +} + +func TestConfig_EnsureDefaultGroup(t *testing.T) { + t.Run("creates default when empty", func(t *testing.T) { + cfg := &Config{Groups: []Group{}} + cfg.EnsureDefaultGroup() + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "default", cfg.Groups[0].Name) + }) + + t.Run("does nothing when groups exist", func(t *testing.T) { + cfg := &Config{Groups: []Group{{Name: "existing"}}} + cfg.EnsureDefaultGroup() + assert.Len(t, cfg.Groups, 1) + assert.Equal(t, "existing", cfg.Groups[0].Name) + }) +} + +func TestManager_Watch(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + err := CreateDefault(configPath) + require.NoError(t, err) + + manager := NewManager(configPath) + err = manager.Load() + require.NoError(t, err) + + changeCh := make(chan *Config, 1) + err = manager.Watch(func(cfg *Config) { + changeCh <- cfg + }) + require.NoError(t, err) + + // Stop the watcher + manager.Stop() +} + +func TestManager_Save_NoConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + + manager := NewManager(configPath) + // Don't load, so config is nil + err := manager.Save() + assert.Error(t, err) + assert.Contains(t, err.Error(), "no config loaded") +} diff --git a/internal/daemon/server_test.go b/internal/daemon/server_test.go new file mode 100644 index 0000000..e223404 --- /dev/null +++ b/internal/daemon/server_test.go @@ -0,0 +1,756 @@ +package daemon + +import ( + "encoding/json" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/lukaszraczylo/lolcathost/internal/config" + "github.com/lukaszraczylo/lolcathost/internal/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTestServer(t *testing.T) (*Server, string, func()) { + tmpDir := t.TempDir() + socketPath := filepath.Join(tmpDir, "test.sock") + configPath := filepath.Join(tmpDir, "config.yaml") + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + // Create initial hosts file + err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644) + require.NoError(t, err) + + // Create config + err = config.CreateDefault(configPath) + require.NoError(t, err) + + cfgManager := config.NewManager(configPath) + err = cfgManager.Load() + require.NoError(t, err) + + server := &Server{ + socketPath: socketPath, + config: cfgManager, + hosts: NewHostsManagerWithPaths(hostsPath, backupDir), + flusher: NewDNSFlusher(FlushMethodAuto), + rateLimiter: NewRateLimiter(100, time.Minute), + stopCh: make(chan struct{}), + } + + cleanup := func() { + server.Stop() + } + + return server, tmpDir, cleanup +} + +func TestServer_HandlePing(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handlePing() + assert.Equal(t, "ok", resp.Status) +} + +func TestServer_HandleStatus(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handleStatus() + assert.Equal(t, "ok", resp.Status) + + var data protocol.StatusData + err := resp.ParseData(&data) + require.NoError(t, err) + + assert.True(t, data.Running) +} + +func TestServer_HandleList(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handleList() + assert.Equal(t, "ok", resp.Status) + + var data protocol.ListData + err := resp.ParseData(&data) + require.NoError(t, err) + + assert.NotNil(t, data.Entries) +} + +func TestServer_HandleSet(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // First add a host to set + cfg := server.config.Get() + cfg.AddHost("test.local", "127.0.0.1", "test-local", "default", false) + server.config.Save() + + t.Run("enable host", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{ + Alias: "test-local", + Enabled: true, + }) + resp := server.handleSet(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("disable host", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{ + Alias: "test-local", + Enabled: false, + }) + resp := server.handleSet(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("nonexistent host", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{ + Alias: "nonexistent", + Enabled: true, + }) + resp := server.handleSet(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeNotFound, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestSet, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleSet(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleAdd(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + t.Run("valid host", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: "newhost.local", + IP: "127.0.0.1", + Group: "default", + }) + resp := server.handleAdd(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("duplicate alias", func(t *testing.T) { + // When alias is explicitly provided, duplicates are rejected + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: "another.local", + IP: "127.0.0.1", + Alias: "newhost-local", // Same alias as auto-generated for newhost.local + Group: "default", + }) + resp := server.handleAdd(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeConflict, resp.Code) + }) + + t.Run("blocked domain", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: "apple.com", + IP: "127.0.0.1", + Group: "default", + }) + resp := server.handleAdd(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeBlockedDomain, resp.Code) + }) + + t.Run("invalid domain", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: "", + IP: "127.0.0.1", + Group: "default", + }) + resp := server.handleAdd(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("empty IP", func(t *testing.T) { + // Only empty IP is rejected, format is not validated + req, _ := protocol.NewRequest(protocol.RequestAdd, protocol.AddPayload{ + Domain: "valid.local", + IP: "", + Group: "default", + }) + resp := server.handleAdd(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeInvalidIP, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestAdd, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleAdd(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleDelete(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Add a host first + cfg := server.config.Get() + cfg.AddHost("todelete.local", "127.0.0.1", "todelete", "default", false) + server.config.Save() + + t.Run("delete existing", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{ + Alias: "todelete", + }) + resp := server.handleDelete(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("delete nonexistent", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDelete, protocol.DeletePayload{ + Alias: "nonexistent", + }) + resp := server.handleDelete(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeNotFound, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestDelete, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleDelete(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleSync(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handleSync() + assert.Equal(t, "ok", resp.Status) +} + +func TestServer_HandleBackups(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Create a backup first + server.hosts.CreateBackup() + + resp := server.handleBackups() + assert.Equal(t, "ok", resp.Status) + + var data protocol.BackupsData + err := resp.ParseData(&data) + require.NoError(t, err) + assert.NotNil(t, data.Backups) +} + +func TestServer_HandleAddGroup(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + t.Run("add new group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{ + Name: "newgroup", + }) + resp := server.handleAddGroup(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("add duplicate group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{ + Name: "newgroup", + }) + resp := server.handleAddGroup(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeConflict, resp.Code) + }) + + t.Run("empty name", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddGroup, protocol.GroupPayload{ + Name: "", + }) + resp := server.handleAddGroup(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestAddGroup, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleAddGroup(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleDeleteGroup(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Add a group first + cfg := server.config.Get() + cfg.AddGroup("todeletegroup") + server.config.Save() + + t.Run("delete existing group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{ + Name: "todeletegroup", + }) + resp := server.handleDeleteGroup(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("delete nonexistent group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDeleteGroup, protocol.GroupPayload{ + Name: "nonexistent", + }) + resp := server.handleDeleteGroup(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeNotFound, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestDeleteGroup, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleDeleteGroup(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleListGroups(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handleListGroups() + assert.Equal(t, "ok", resp.Status) + + var data protocol.GroupsData + err := resp.ParseData(&data) + require.NoError(t, err) + assert.NotNil(t, data.Groups) +} + +func TestServer_HandleRenameGroup(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Add a group to rename + cfg := server.config.Get() + cfg.AddGroup("oldname") + server.config.Save() + + t.Run("rename existing group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{ + OldName: "oldname", + NewName: "newname", + }) + resp := server.handleRenameGroup(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("rename nonexistent group", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestRenameGroup, protocol.RenameGroupPayload{ + OldName: "nonexistent", + NewName: "newname2", + }) + resp := server.handleRenameGroup(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestRenameGroup, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleRenameGroup(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleAddPreset(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + t.Run("add new preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{ + Name: "newpreset", + Enable: []string{"alias1"}, + Disable: []string{"alias2"}, + }) + resp := server.handleAddPreset(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("add duplicate preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{ + Name: "newpreset", + Enable: []string{"alias1"}, + Disable: []string{"alias2"}, + }) + resp := server.handleAddPreset(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeConflict, resp.Code) + }) + + t.Run("empty name", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestAddPreset, protocol.AddPresetPayload{ + Name: "", + }) + resp := server.handleAddPreset(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestAddPreset, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleAddPreset(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleDeletePreset(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Add a preset first + cfg := server.config.Get() + cfg.AddPreset("todeletepreset", []string{"a"}, []string{"b"}) + server.config.Save() + + t.Run("delete existing preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{ + Name: "todeletepreset", + }) + resp := server.handleDeletePreset(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("delete nonexistent preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestDeletePreset, protocol.PresetPayload{ + Name: "nonexistent", + }) + resp := server.handleDeletePreset(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeNotFound, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestDeletePreset, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleDeletePreset(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleListPresets(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + resp := server.handleListPresets() + assert.Equal(t, "ok", resp.Status) + + var data protocol.PresetsData + err := resp.ParseData(&data) + require.NoError(t, err) + assert.NotNil(t, data.Presets) +} + +func TestServer_HandlePreset(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Add hosts and preset + cfg := server.config.Get() + cfg.AddHost("host1.local", "127.0.0.1", "host1", "default", false) + cfg.AddHost("host2.local", "127.0.0.1", "host2", "default", false) + cfg.AddPreset("testpreset", []string{"host1"}, []string{"host2"}) + server.config.Save() + + t.Run("apply existing preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{ + Name: "testpreset", + }) + resp := server.handlePreset(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("apply nonexistent preset", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestPreset, protocol.PresetPayload{ + Name: "nonexistent", + }) + resp := server.handlePreset(req) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeNotFound, resp.Code) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestPreset, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handlePreset(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleRollback(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Create a backup first + server.hosts.CreateBackup() + backups, _ := server.hosts.ListBackups() + require.NotEmpty(t, backups) + + t.Run("rollback to existing backup", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{ + BackupName: backups[0].Name, + }) + resp := server.handleRollback(req) + assert.Equal(t, "ok", resp.Status) + }) + + t.Run("rollback to nonexistent backup", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestRollback, protocol.RollbackPayload{ + BackupName: "nonexistent.bak", + }) + resp := server.handleRollback(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestRollback, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleRollback(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleBackupContent(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + // Create a backup first + server.hosts.CreateBackup() + backups, _ := server.hosts.ListBackups() + require.NotEmpty(t, backups) + + t.Run("get existing backup content", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestBackupContent, protocol.BackupContentPayload{ + BackupName: backups[0].Name, + }) + resp := server.handleBackupContent(req) + assert.Equal(t, "ok", resp.Status) + + var data protocol.BackupContentData + err := resp.ParseData(&data) + require.NoError(t, err) + assert.NotEmpty(t, data.Content) + }) + + t.Run("get nonexistent backup content", func(t *testing.T) { + req, _ := protocol.NewRequest(protocol.RequestBackupContent, protocol.BackupContentPayload{ + BackupName: "nonexistent.bak", + }) + resp := server.handleBackupContent(req) + assert.Equal(t, "error", resp.Status) + }) + + t.Run("invalid payload", func(t *testing.T) { + req := &protocol.Request{ + Type: protocol.RequestBackupContent, + Payload: json.RawMessage(`{invalid`), + } + resp := server.handleBackupContent(req) + assert.Equal(t, "error", resp.Status) + }) +} + +func TestServer_HandleRequest_UnknownType(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + req := &protocol.Request{ + Type: "unknown_type", + } + creds := &PeerCredentials{UID: 0, GID: 0, PID: 1} + resp := server.handleRequest(req, creds) + assert.Equal(t, "error", resp.Status) + assert.Equal(t, protocol.ErrCodeInvalidRequest, resp.Code) +} + +func TestServer_IsAuthorized(t *testing.T) { + server, _, cleanup := setupTestServer(t) + defer cleanup() + + t.Run("root user", func(t *testing.T) { + creds := &PeerCredentials{UID: 0, GID: 0, PID: 1} + assert.True(t, server.isAuthorized(creds)) + }) + + t.Run("nil credentials", func(t *testing.T) { + assert.False(t, server.isAuthorized(nil)) + }) +} + +func TestServer_StartStop(t *testing.T) { + // Skip test if not running as root (server.Start requires root to chown socket) + if os.Getuid() != 0 { + t.Skip("Test requires root privileges to create socket with proper ownership") + } + + server, _, _ := setupTestServer(t) + // Don't use cleanup since we manually call Stop + + // Start server in goroutine + errCh := make(chan error, 1) + go func() { + errCh <- server.Start() + }() + + // Give it time to start + time.Sleep(100 * time.Millisecond) + + // Verify socket exists + _, err := os.Stat(server.socketPath) + assert.NoError(t, err) + + // Stop server + err = server.Stop() + assert.NoError(t, err) + + // Verify socket is removed + _, err = os.Stat(server.socketPath) + assert.True(t, os.IsNotExist(err)) +} + +func TestServer_AcceptConnection(t *testing.T) { + // Skip test if not running as root (server.Start requires root to chown socket) + if os.Getuid() != 0 { + t.Skip("Test requires root privileges to create socket with proper ownership") + } + + server, _, _ := setupTestServer(t) + // Don't use cleanup - manually stop + + // Start server + go server.Start() + time.Sleep(100 * time.Millisecond) + defer server.Stop() + + // Connect to server + conn, err := net.Dial("unix", server.socketPath) + require.NoError(t, err) + defer conn.Close() + + // Send ping request + req, _ := protocol.NewRequest(protocol.RequestPing, nil) + encoder := json.NewEncoder(conn) + err = encoder.Encode(req) + require.NoError(t, err) + + // Read response + decoder := json.NewDecoder(conn) + var resp protocol.Response + err = decoder.Decode(&resp) + require.NoError(t, err) + + assert.Equal(t, "ok", resp.Status) +} + +// Benchmarks + +func BenchmarkServer_HandlePing(b *testing.B) { + tmpDir := b.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + config.CreateDefault(configPath) + cfgManager := config.NewManager(configPath) + cfgManager.Load() + + server := &Server{ + config: cfgManager, + rateLimiter: NewRateLimiter(100000, time.Minute), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + server.handlePing() + } +} + +func BenchmarkServer_HandleList(b *testing.B) { + tmpDir := b.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + config.CreateDefault(configPath) + cfgManager := config.NewManager(configPath) + cfgManager.Load() + + server := &Server{ + config: cfgManager, + rateLimiter: NewRateLimiter(100000, time.Minute), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + server.handleList() + } +} + +func BenchmarkServer_HandleSet(b *testing.B) { + tmpDir := b.TempDir() + configPath := filepath.Join(tmpDir, "config.yaml") + hostsPath := filepath.Join(tmpDir, "hosts") + backupDir := filepath.Join(tmpDir, "backups") + + os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644) + config.CreateDefault(configPath) + cfgManager := config.NewManager(configPath) + cfgManager.Load() + + // Add a test host + cfg := cfgManager.Get() + cfg.AddHost("bench.local", "127.0.0.1", "bench-local", "default", false) + + server := &Server{ + config: cfgManager, + hosts: NewHostsManagerWithPaths(hostsPath, backupDir), + flusher: NewDNSFlusher(FlushMethodAuto), + rateLimiter: NewRateLimiter(100000, time.Minute), + } + + req, _ := protocol.NewRequest(protocol.RequestSet, protocol.SetPayload{ + Alias: "bench-local", + Enabled: true, + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + server.handleSet(req) + } +}