mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-05 23:29:18 +00:00
Cleanup, signing and update of internals.
This commit is contained in:
@@ -29,14 +29,6 @@ func New(socketPath string) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithTimeout creates a new client with a custom timeout.
|
||||
func NewWithTimeout(socketPath string, timeout time.Duration) *Client {
|
||||
return &Client{
|
||||
socketPath: socketPath,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the daemon.
|
||||
func (c *Client) Connect() error {
|
||||
c.mu.Lock()
|
||||
@@ -435,14 +427,3 @@ func (c *Client) ListPresets() ([]protocol.PresetInfo, error) {
|
||||
}
|
||||
return data.Presets, nil
|
||||
}
|
||||
|
||||
// IsConnected checks if the daemon is reachable.
|
||||
func IsConnected(socketPath string) bool {
|
||||
client := New(socketPath)
|
||||
if err := client.Connect(); err != nil {
|
||||
return false
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
return client.Ping() == nil
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -398,31 +397,6 @@ func TestClient_NotConnected(t *testing.T) {
|
||||
assert.Contains(t, err.Error(), "not connected")
|
||||
}
|
||||
|
||||
func TestClient_Timeout(t *testing.T) {
|
||||
client := NewWithTimeout("/nonexistent.sock", 100*time.Millisecond)
|
||||
assert.Equal(t, 100*time.Millisecond, client.timeout)
|
||||
}
|
||||
|
||||
func TestIsConnected(t *testing.T) {
|
||||
t.Run("connected", func(t *testing.T) {
|
||||
server := newMockServer(t)
|
||||
defer server.close()
|
||||
|
||||
server.handler = func(req *protocol.Request) *protocol.Response {
|
||||
resp, _ := protocol.NewOKResponse(nil)
|
||||
return resp
|
||||
}
|
||||
|
||||
connected := IsConnected(server.path)
|
||||
assert.True(t, connected)
|
||||
})
|
||||
|
||||
t.Run("not connected", func(t *testing.T) {
|
||||
connected := IsConnected("/nonexistent/socket.sock")
|
||||
assert.False(t, connected)
|
||||
})
|
||||
}
|
||||
|
||||
// Matrix test for request types
|
||||
func TestClient_RequestTypes_Matrix(t *testing.T) {
|
||||
types := []struct {
|
||||
|
||||
+147
-24
@@ -124,6 +124,12 @@ func (m *Manager) Get() *Config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Reload reloads the configuration from disk.
|
||||
// This is useful for rolling back after a failed operation.
|
||||
func (m *Manager) Reload() error {
|
||||
return m.Load()
|
||||
}
|
||||
|
||||
// Watch starts watching the config file for changes.
|
||||
func (m *Manager) Watch(onChange func(*Config)) error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
@@ -151,8 +157,15 @@ func (m *Manager) watchLoop() {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
if err := m.Load(); err == nil && m.onChange != nil {
|
||||
m.onChange(m.Get())
|
||||
// Load and notify under lock to prevent race conditions
|
||||
m.mu.Lock()
|
||||
err := m.loadLocked()
|
||||
cfg := m.config
|
||||
onChange := m.onChange
|
||||
m.mu.Unlock()
|
||||
|
||||
if err == nil && onChange != nil {
|
||||
onChange(cfg)
|
||||
}
|
||||
}
|
||||
case <-m.watcher.Errors:
|
||||
@@ -163,6 +176,27 @@ func (m *Manager) watchLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// loadLocked reads and parses the configuration file.
|
||||
// Caller must hold m.mu lock.
|
||||
func (m *Manager) loadLocked() error {
|
||||
data, err := os.ReadFile(m.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if err := ValidateConfig(&cfg); err != nil {
|
||||
return fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
m.config = &cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops watching the config file.
|
||||
func (m *Manager) Stop() {
|
||||
close(m.stopCh)
|
||||
@@ -180,16 +214,26 @@ func (c *Config) GetAllHosts() []Host {
|
||||
return hosts
|
||||
}
|
||||
|
||||
// FindHostByAlias finds a host by its alias.
|
||||
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
|
||||
// findHostIndices finds the group and host indices for a given alias.
|
||||
// Returns -1, -1 if not found.
|
||||
func (c *Config) findHostIndices(alias string) (groupIdx, hostIdx int) {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
return &c.Groups[i].Hosts[j], &c.Groups[i]
|
||||
return i, j
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
return -1, -1
|
||||
}
|
||||
|
||||
// FindHostByAlias finds a host by its alias.
|
||||
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
|
||||
groupIdx, hostIdx := c.findHostIndices(alias)
|
||||
if groupIdx < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return &c.Groups[groupIdx].Hosts[hostIdx], &c.Groups[groupIdx]
|
||||
}
|
||||
|
||||
// FindPreset finds a preset by name.
|
||||
@@ -204,15 +248,12 @@ func (c *Config) FindPreset(name string) *Preset {
|
||||
|
||||
// SetHostEnabled sets the enabled state of a host by alias.
|
||||
func (c *Config) SetHostEnabled(alias string, enabled bool) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts[j].Enabled = enabled
|
||||
return true
|
||||
}
|
||||
}
|
||||
groupIdx, hostIdx := c.findHostIndices(alias)
|
||||
if groupIdx < 0 {
|
||||
return false
|
||||
}
|
||||
return false
|
||||
c.Groups[groupIdx].Hosts[hostIdx].Enabled = enabled
|
||||
return true
|
||||
}
|
||||
|
||||
// GenerateAlias creates a unique alias from a domain name.
|
||||
@@ -327,15 +368,68 @@ func (c *Config) GetGroups() []string {
|
||||
|
||||
// DeleteHost removes a host by alias.
|
||||
func (c *Config) DeleteHost(alias string) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...)
|
||||
return true
|
||||
}
|
||||
groupIdx, hostIdx := c.findHostIndices(alias)
|
||||
if groupIdx < 0 {
|
||||
return false
|
||||
}
|
||||
c.Groups[groupIdx].Hosts = append(c.Groups[groupIdx].Hosts[:hostIdx], c.Groups[groupIdx].Hosts[hostIdx+1:]...)
|
||||
return true
|
||||
}
|
||||
|
||||
// UpdateHost updates an existing host by alias.
|
||||
func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error {
|
||||
// Find the host
|
||||
foundGroup, foundHost := c.findHostIndices(oldAlias)
|
||||
if foundGroup < 0 {
|
||||
return fmt.Errorf("alias not found: %s", oldAlias)
|
||||
}
|
||||
|
||||
// Check for duplicate alias if alias is changing
|
||||
if oldAlias != newAlias {
|
||||
if existing, _ := c.FindHostByAlias(newAlias); existing != nil {
|
||||
return fmt.Errorf("alias already exists: %s", newAlias)
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
// Get current enabled state
|
||||
enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled
|
||||
|
||||
// If group is changing, move to new group
|
||||
if c.Groups[foundGroup].Name != groupName {
|
||||
// Remove from old group
|
||||
c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...)
|
||||
|
||||
// Add to new group
|
||||
host := Host{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: newAlias,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
// Find or create target group
|
||||
found := false
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == groupName {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: groupName,
|
||||
Hosts: []Host{host},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Update in place
|
||||
c.Groups[foundGroup].Hosts[foundHost].Domain = domain
|
||||
c.Groups[foundGroup].Hosts[foundHost].IP = ip
|
||||
c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPreset applies a preset to the configuration.
|
||||
@@ -387,6 +481,35 @@ func (c *Config) GetPresets() []Preset {
|
||||
return c.Presets
|
||||
}
|
||||
|
||||
// Clone creates a deep copy of the configuration.
|
||||
func (c *Config) Clone() *Config {
|
||||
clone := &Config{
|
||||
Settings: c.Settings,
|
||||
Groups: make([]Group, len(c.Groups)),
|
||||
Presets: make([]Preset, len(c.Presets)),
|
||||
}
|
||||
|
||||
for i, g := range c.Groups {
|
||||
clone.Groups[i] = Group{
|
||||
Name: g.Name,
|
||||
Hosts: make([]Host, len(g.Hosts)),
|
||||
}
|
||||
copy(clone.Groups[i].Hosts, g.Hosts)
|
||||
}
|
||||
|
||||
for i, p := range c.Presets {
|
||||
clone.Presets[i] = Preset{
|
||||
Name: p.Name,
|
||||
Enable: make([]string, len(p.Enable)),
|
||||
Disable: make([]string, len(p.Disable)),
|
||||
}
|
||||
copy(clone.Presets[i].Enable, p.Enable)
|
||||
copy(clone.Presets[i].Disable, p.Disable)
|
||||
}
|
||||
|
||||
return clone
|
||||
}
|
||||
|
||||
// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed.
|
||||
func (c *Config) EnsureDefaultGroup() {
|
||||
if len(c.Groups) == 0 {
|
||||
@@ -412,7 +535,7 @@ func (m *Manager) Save() error {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
// #nosec G306 -- config file should be world-readable
|
||||
// #nosec G306 - Config file permissions are intentionally 0644
|
||||
if err := os.WriteFile(m.path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config: %w", err)
|
||||
}
|
||||
@@ -423,7 +546,7 @@ func (m *Manager) Save() error {
|
||||
// CreateDefault creates a default configuration file.
|
||||
func CreateDefault(path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
// #nosec G301 -- config directory should be world-readable
|
||||
// #nosec G301 - Config directory permissions are intentionally 0755
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
@@ -465,7 +588,7 @@ func CreateDefault(path string) error {
|
||||
return fmt.Errorf("failed to marshal default config: %w", err)
|
||||
}
|
||||
|
||||
// #nosec G306 -- config file should be world-readable
|
||||
// #nosec G306 - Config file permissions are intentionally 0644
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write default config: %w", err)
|
||||
}
|
||||
|
||||
@@ -92,11 +92,6 @@ func (d *Daemon) Run() error {
|
||||
return d.shutdown()
|
||||
}
|
||||
|
||||
// Stop signals the daemon to stop.
|
||||
func (d *Daemon) Stop() {
|
||||
close(d.stopCh)
|
||||
}
|
||||
|
||||
func (d *Daemon) shutdown() error {
|
||||
close(d.cleanupCh)
|
||||
d.config.Stop()
|
||||
|
||||
@@ -137,6 +137,6 @@ func (f *DNSFlusher) flushLinux(method FlushMethod) error {
|
||||
}
|
||||
|
||||
func runCommand(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
cmd := exec.Command(name, args...) // #nosec G204 - Commands are hardcoded DNS flush utilities, not user input
|
||||
return cmd.Run()
|
||||
}
|
||||
|
||||
+14
-66
@@ -2,7 +2,6 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -25,6 +24,10 @@ const (
|
||||
markerEnd = "# ========== END LOLCATHOST =========="
|
||||
)
|
||||
|
||||
// entryRegex matches host entries in the managed section.
|
||||
// Compiled once at package init for efficiency.
|
||||
var entryRegex = regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
|
||||
|
||||
// HostEntry represents a single entry in the hosts file.
|
||||
type HostEntry struct {
|
||||
IP string
|
||||
@@ -47,59 +50,6 @@ func NewHostsManager() *HostsManager {
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
|
||||
func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: hostsPath,
|
||||
backupDir: backupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadManagedEntries reads the lolcathost-managed entries from the hosts file.
|
||||
func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) {
|
||||
file, err := os.Open(m.hostsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open hosts file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var entries []HostEntry
|
||||
inManagedSection := false
|
||||
scanner := bufio.NewScanner(file)
|
||||
entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if line == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
|
||||
matches := entryRegex.FindStringSubmatch(line)
|
||||
if len(matches) == 4 {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: matches[1],
|
||||
Domain: matches[2],
|
||||
Alias: matches[3],
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// WriteManagedEntries writes the managed entries to the hosts file.
|
||||
func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error {
|
||||
// Create backup first
|
||||
@@ -178,7 +128,7 @@ func (m *HostsManager) buildManagedSection(entries []HostEntry) string {
|
||||
func (m *HostsManager) writeAtomic(content string) error {
|
||||
// Write to temp file first
|
||||
tmpFile := m.hostsPath + ".tmp"
|
||||
// #nosec G306 -- hosts file must be world-readable
|
||||
// #nosec G306 - Hosts file permissions are intentionally 0644
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -194,12 +144,12 @@ func (m *HostsManager) writeAtomic(content string) error {
|
||||
|
||||
// CreateBackup creates a backup of the current hosts file.
|
||||
func (m *HostsManager) CreateBackup() error {
|
||||
// #nosec G301 -- backup directory should be world-readable for recovery
|
||||
// #nosec G301 - Backup directory permissions are intentionally 0755
|
||||
if err := os.MkdirAll(m.backupDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
content, err := os.ReadFile(m.hostsPath) // #nosec G304 - Path is controlled by daemon, not user input
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
@@ -207,7 +157,7 @@ func (m *HostsManager) CreateBackup() error {
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp))
|
||||
|
||||
// #nosec G306 -- backup files should be world-readable for recovery
|
||||
// #nosec G306 - Backup file permissions are intentionally 0644
|
||||
if err := os.WriteFile(backupPath, content, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write backup: %w", err)
|
||||
}
|
||||
@@ -297,15 +247,14 @@ type BackupInfo struct {
|
||||
|
||||
// GetBackupContent returns the content of a backup file.
|
||||
func (m *HostsManager) GetBackupContent(name string) (string, error) {
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
// Validate backup name to prevent path traversal
|
||||
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
|
||||
return "", fmt.Errorf("invalid backup name")
|
||||
}
|
||||
|
||||
// #nosec G304 -- backupPath is validated above: filepath.Base(name) == name and prefix/suffix checks
|
||||
content, err := os.ReadFile(backupPath)
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
content, err := os.ReadFile(backupPath) // #nosec G304 - Path is validated above to prevent traversal
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read backup: %w", err)
|
||||
}
|
||||
@@ -315,15 +264,14 @@ func (m *HostsManager) GetBackupContent(name string) (string, error) {
|
||||
|
||||
// RestoreBackup restores a backup by name.
|
||||
func (m *HostsManager) RestoreBackup(name string) error {
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
// Validate backup name to prevent path traversal
|
||||
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
|
||||
return fmt.Errorf("invalid backup name")
|
||||
}
|
||||
|
||||
// #nosec G304 -- backupPath is validated above: filepath.Base(name) == name and prefix/suffix checks
|
||||
content, err := os.ReadFile(backupPath)
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
content, err := os.ReadFile(backupPath) // #nosec G304 - Path is validated above to prevent traversal
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read backup: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -10,7 +11,53 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostsManager_ReadManagedEntries(t *testing.T) {
|
||||
// newHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
|
||||
func newHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: hostsPath,
|
||||
backupDir: backupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// readManagedEntries reads the lolcathost-managed entries from the hosts file (for testing).
|
||||
func (m *HostsManager) readManagedEntries() ([]HostEntry, error) {
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
var entries []HostEntry
|
||||
inManagedSection := false
|
||||
|
||||
for _, line := range strings.Split(string(content), "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
if line == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if line == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
|
||||
matches := entryRegex.FindStringSubmatch(line)
|
||||
if len(matches) == 4 {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: matches[1],
|
||||
Domain: matches[2],
|
||||
Alias: matches[3],
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
func TestHostsManager_readManagedEntries(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
@@ -26,8 +73,8 @@ func TestHostsManager_ReadManagedEntries(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
manager := newHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.readManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, entries, 2)
|
||||
@@ -39,7 +86,7 @@ func TestHostsManager_ReadManagedEntries(t *testing.T) {
|
||||
assert.Equal(t, "api-local", entries[1].Alias)
|
||||
}
|
||||
|
||||
func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
|
||||
func TestHostsManager_readManagedEntries_NoSection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
@@ -49,8 +96,8 @@ func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
manager := newHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.readManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, entries)
|
||||
@@ -68,7 +115,7 @@ func TestHostsManager_WriteManagedEntries(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true},
|
||||
@@ -107,7 +154,7 @@ func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true},
|
||||
@@ -134,7 +181,7 @@ func TestHostsManager_CreateBackup(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
@@ -175,7 +222,7 @@ func TestHostsManager_ListBackups(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
@@ -187,7 +234,7 @@ func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) {
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "nonexistent")
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
@@ -204,7 +251,7 @@ func TestHostsManager_RestoreBackup(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create backup
|
||||
err = manager.CreateBackup()
|
||||
@@ -231,7 +278,7 @@ func TestHostsManager_RestoreBackup(t *testing.T) {
|
||||
|
||||
func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
manager := NewHostsManagerWithPaths(
|
||||
manager := newHostsManagerWithPaths(
|
||||
filepath.Join(tmpDir, "hosts"),
|
||||
filepath.Join(tmpDir, "backups"),
|
||||
)
|
||||
@@ -259,7 +306,7 @@ func TestHostsManager_CleanupBackups(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create more than MaxBackups
|
||||
for i := 0; i < MaxBackups+5; i++ {
|
||||
@@ -338,7 +385,7 @@ func TestHostsManager_BuildManagedSection(t *testing.T) {
|
||||
}
|
||||
|
||||
// Matrix tests for hosts file parsing
|
||||
func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
|
||||
func TestHostsManager_readManagedEntries_Matrix(t *testing.T) {
|
||||
ips := []string{"127.0.0.1", "192.168.1.1", "::1"}
|
||||
domains := []string{"example.com", "sub.example.com", "my-app.test"}
|
||||
aliases := []string{"test", "my-alias", "app-1"}
|
||||
@@ -357,8 +404,8 @@ func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
|
||||
err := os.WriteFile(hostsPath, []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
manager := newHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.readManagedEntries()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
@@ -371,7 +418,7 @@ func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
|
||||
func BenchmarkHostsManager_readManagedEntries(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
@@ -387,11 +434,11 @@ func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
|
||||
err := os.WriteFile(hostsPath, []byte(content.String()), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
manager := newHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = manager.ReadManagedEntries()
|
||||
_, _ = manager.readManagedEntries()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -403,7 +450,7 @@ func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) {
|
||||
err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
manager := newHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := make([]HostEntry, 50)
|
||||
for i := range entries {
|
||||
|
||||
@@ -30,10 +30,15 @@ func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Groups array is not empty before accessing
|
||||
if xucred.Ngroups == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Get PID separately using LOCAL_PEERPID
|
||||
var pid int32
|
||||
pidLen := uint32(unsafe.Sizeof(pid))
|
||||
// #nosec G103 -- unsafe required for low-level syscall to get peer PID
|
||||
// #nosec G103 - unsafe.Pointer required for syscall to get peer PID
|
||||
_, _, errno := syscall.Syscall6(
|
||||
syscall.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
|
||||
+71
-35
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -19,20 +20,27 @@ const (
|
||||
RateLimitWindow = time.Minute
|
||||
)
|
||||
|
||||
// RateLimiter implements per-PID rate limiting.
|
||||
// pidRateBucket holds rate limiting data for a single PID using a ring buffer.
|
||||
type pidRateBucket struct {
|
||||
timestamps []time.Time // Ring buffer of request timestamps
|
||||
head int // Next write position
|
||||
count int // Number of valid entries
|
||||
}
|
||||
|
||||
// RateLimiter implements per-PID rate limiting with efficient memory usage.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
requests map[int32][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
mu sync.Mutex
|
||||
buckets map[int32]*pidRateBucket
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
requests: make(map[int32][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
buckets: make(map[int32]*pidRateBucket),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,31 +52,42 @@ func (r *RateLimiter) Allow(pid int32) bool {
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
// Get existing requests for this PID
|
||||
reqs := r.requests[pid]
|
||||
bucket, exists := r.buckets[pid]
|
||||
if !exists {
|
||||
// Create new bucket with fixed capacity
|
||||
bucket = &pidRateBucket{
|
||||
timestamps: make([]time.Time, r.limit),
|
||||
head: 0,
|
||||
count: 0,
|
||||
}
|
||||
r.buckets[pid] = bucket
|
||||
}
|
||||
|
||||
// Filter out old requests
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
// Count valid (non-expired) requests in the ring buffer
|
||||
validCount := 0
|
||||
for i := 0; i < bucket.count; i++ {
|
||||
idx := (bucket.head - bucket.count + i + r.limit) % r.limit
|
||||
if bucket.timestamps[idx].After(cutoff) {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validReqs) >= r.limit {
|
||||
r.requests[pid] = validReqs
|
||||
if validCount >= r.limit {
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new request
|
||||
validReqs = append(validReqs, now)
|
||||
r.requests[pid] = validReqs
|
||||
// Add new request to ring buffer (overwrites oldest if full)
|
||||
bucket.timestamps[bucket.head] = now
|
||||
bucket.head = (bucket.head + 1) % r.limit
|
||||
if bucket.count < r.limit {
|
||||
bucket.count++
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes old entries from the rate limiter.
|
||||
// Cleanup removes stale PID entries from the rate limiter.
|
||||
func (r *RateLimiter) Cleanup() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
@@ -76,17 +95,18 @@ func (r *RateLimiter) Cleanup() {
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
for pid, reqs := range r.requests {
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
for pid, bucket := range r.buckets {
|
||||
// Check if all timestamps are expired
|
||||
hasValid := false
|
||||
for i := 0; i < bucket.count; i++ {
|
||||
idx := (bucket.head - bucket.count + i + r.limit) % r.limit
|
||||
if bucket.timestamps[idx].After(cutoff) {
|
||||
hasValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(validReqs) == 0 {
|
||||
delete(r.requests, pid)
|
||||
} else {
|
||||
r.requests[pid] = validReqs
|
||||
if !hasValid {
|
||||
delete(r.buckets, pid)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,12 +134,12 @@ type AuditEntry struct {
|
||||
func NewAuditLogger(path string) (*AuditLogger, error) {
|
||||
// Ensure directory exists
|
||||
dir := path[:len(path)-len("/audit.log")]
|
||||
// #nosec G301 -- log directory should be world-readable
|
||||
// #nosec G301 - Log directory permissions are intentionally 0755
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
// #nosec G304,G302 -- path is from constant AuditLogPath; audit log should be world-readable
|
||||
// #nosec G302,G304,G306 - Path is constant, permissions are intentional for audit log
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audit log: %w", err)
|
||||
@@ -187,12 +207,28 @@ func isUserInGroup(uid uint32, targetGID uint32) bool {
|
||||
}
|
||||
|
||||
// Check if target GID is in the list
|
||||
targetGIDStr := fmt.Sprintf("%d", targetGID)
|
||||
for _, gid := range groupIDs {
|
||||
if gid == targetGIDStr {
|
||||
for _, gidStr := range groupIDs {
|
||||
gid, err := strconv.ParseUint(gidStr, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if uint32(gid) == targetGID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// lookupGroupGID looks up a group by name and returns its GID.
|
||||
func lookupGroupGID(name string) (int, error) {
|
||||
group, err := user.LookupGroup(name)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("group not found: %s", name)
|
||||
}
|
||||
gid, err := strconv.Atoi(group.Gid)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid GID for group %s: %s", name, group.Gid)
|
||||
}
|
||||
return gid, nil
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func TestRateLimiter_Cleanup(t *testing.T) {
|
||||
rl.Allow(pid)
|
||||
}
|
||||
|
||||
assert.Len(t, rl.requests, 5)
|
||||
assert.Len(t, rl.buckets, 5)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
@@ -76,7 +76,7 @@ func TestRateLimiter_Cleanup(t *testing.T) {
|
||||
// Cleanup
|
||||
rl.Cleanup()
|
||||
|
||||
assert.Empty(t, rl.requests)
|
||||
assert.Empty(t, rl.buckets)
|
||||
}
|
||||
|
||||
func TestAuditLogger_Log(t *testing.T) {
|
||||
|
||||
+79
-54
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
@@ -50,20 +51,28 @@ func (s *Server) Start() error {
|
||||
// Remove existing socket
|
||||
_ = os.Remove(s.socketPath)
|
||||
|
||||
// Set umask to create socket with restricted permissions (0660)
|
||||
// This prevents TOCTOU vulnerability between socket creation and chmod
|
||||
oldUmask := syscall.Umask(0117) // 0777 & ~0117 = 0660
|
||||
|
||||
listener, err := net.Listen("unix", s.socketPath)
|
||||
|
||||
// Restore original umask immediately after socket creation
|
||||
syscall.Umask(oldUmask)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions: 0660 root:lolcathost
|
||||
// #nosec G302 -- socket must be group-accessible for lolcathost group members
|
||||
if err := os.Chmod(s.socketPath, 0660); err != nil {
|
||||
_ = listener.Close()
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
// Look up the lolcathost group GID dynamically
|
||||
gid, err := lookupGroupGID("lolcathost")
|
||||
if err != nil {
|
||||
// Fall back to default GID if group lookup fails
|
||||
gid = LolcathostGID
|
||||
}
|
||||
|
||||
// Set socket group to lolcathost (GID 850)
|
||||
if err := os.Chown(s.socketPath, 0, 850); err != nil {
|
||||
// Set socket group to lolcathost
|
||||
if err := os.Chown(s.socketPath, 0, gid); err != nil {
|
||||
_ = listener.Close()
|
||||
return fmt.Errorf("failed to set socket ownership: %w", err)
|
||||
}
|
||||
@@ -126,6 +135,9 @@ func (s *Server) acceptLoop() {
|
||||
// LolcathostGID is the group ID for the lolcathost group.
|
||||
const LolcathostGID = 850
|
||||
|
||||
// connectionReadTimeout is the maximum time to wait for a client to send data.
|
||||
const connectionReadTimeout = 30 * time.Second
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
@@ -134,7 +146,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
// Authorization check: verify peer is authorized
|
||||
if !s.isAuthorized(creds) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
_ = s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
if s.auditLogger != nil {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
@@ -149,6 +161,11 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// Set read deadline to prevent clients from hanging indefinitely
|
||||
if err := conn.SetReadDeadline(time.Now().Add(connectionReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
@@ -156,13 +173,17 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
|
||||
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON")); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
|
||||
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded")); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -171,7 +192,9 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handleRequest(&req, creds)
|
||||
s.writeResponse(conn, resp)
|
||||
if err := s.writeResponse(conn, resp); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,10 +221,16 @@ func (s *Server) isAuthorized(creds *PeerCredentials) bool {
|
||||
return isUserInGroup(creds.UID, LolcathostGID)
|
||||
}
|
||||
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
|
||||
data, _ := json.Marshal(resp)
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) error {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, _ = conn.Write(data)
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
|
||||
@@ -427,14 +456,9 @@ func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
|
||||
// Update config
|
||||
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
@@ -468,14 +492,9 @@ func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
|
||||
@@ -573,14 +592,9 @@ func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
@@ -610,14 +624,9 @@ func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
|
||||
@@ -671,14 +680,9 @@ func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
@@ -824,3 +828,24 @@ func (s *Server) syncHostsFile() error {
|
||||
// Flush DNS cache
|
||||
return s.flusher.Flush()
|
||||
}
|
||||
|
||||
// saveAndSync saves the configuration and syncs to /etc/hosts atomically.
|
||||
// If sync fails, it attempts to reload the previous config from disk.
|
||||
func (s *Server) saveAndSync() error {
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
// Attempt to reload previous config on sync failure
|
||||
if reloadErr := s.config.Reload(); reloadErr != nil {
|
||||
// Log reload failure but return original sync error
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to reload config after sync failure: %v\n", reloadErr)
|
||||
}
|
||||
return fmt.Errorf("failed to sync hosts (config rolled back): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ func setupTestServer(t *testing.T) (*Server, string, func()) {
|
||||
server := &Server{
|
||||
socketPath: socketPath,
|
||||
config: cfgManager,
|
||||
hosts: NewHostsManagerWithPaths(hostsPath, backupDir),
|
||||
hosts: newHostsManagerWithPaths(hostsPath, backupDir),
|
||||
flusher: NewDNSFlusher(FlushMethodAuto),
|
||||
rateLimiter: NewRateLimiter(100, time.Minute),
|
||||
stopCh: make(chan struct{}),
|
||||
@@ -754,7 +754,7 @@ func BenchmarkServer_HandleSet(b *testing.B) {
|
||||
|
||||
server := &Server{
|
||||
config: cfgManager,
|
||||
hosts: NewHostsManagerWithPaths(hostsPath, backupDir),
|
||||
hosts: newHostsManagerWithPaths(hostsPath, backupDir),
|
||||
flusher: NewDNSFlusher(FlushMethodAuto),
|
||||
rateLimiter: NewRateLimiter(100000, time.Minute),
|
||||
}
|
||||
|
||||
@@ -210,14 +210,14 @@ func (i *Installer) createGroup() error {
|
||||
|
||||
func (i *Installer) createGroupDarwin() error {
|
||||
// Check if group exists
|
||||
if _, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName).Output(); err == nil {
|
||||
if _, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName).Output(); err == nil { // #nosec G204 - Commands use hardcoded safe values
|
||||
i.log(" Group '%s' already exists", GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating group '%s' (GID %d)...", GroupName, GroupGID)
|
||||
|
||||
// Create group
|
||||
// Create group - commands are hardcoded with constant GroupName
|
||||
cmds := [][]string{
|
||||
{"dscl", ".", "-create", "/Groups/" + GroupName},
|
||||
{"dscl", ".", "-create", "/Groups/" + GroupName, "PrimaryGroupID", strconv.Itoa(GroupGID)},
|
||||
@@ -225,7 +225,7 @@ func (i *Installer) createGroupDarwin() error {
|
||||
}
|
||||
|
||||
for _, args := range cmds {
|
||||
// #nosec G204 -- args are hardcoded dscl commands with the constant GroupName
|
||||
// #nosec G204 - Commands use hardcoded safe values
|
||||
if err := exec.Command(args[0], args[1:]...).Run(); err != nil {
|
||||
return fmt.Errorf("command %v failed: %w", args, err)
|
||||
}
|
||||
@@ -236,14 +236,14 @@ func (i *Installer) createGroupDarwin() error {
|
||||
|
||||
func (i *Installer) createGroupLinux() error {
|
||||
// Check if group exists
|
||||
if _, err := exec.Command("getent", "group", GroupName).Output(); err == nil {
|
||||
if _, err := exec.Command("getent", "group", GroupName).Output(); err == nil { // #nosec G204 - GroupName is a constant
|
||||
i.log(" Group '%s' already exists", GroupName)
|
||||
return nil
|
||||
}
|
||||
|
||||
i.log(" Creating group '%s'...", GroupName)
|
||||
|
||||
if err := exec.Command("groupadd", "-r", GroupName).Run(); err != nil {
|
||||
if err := exec.Command("groupadd", "-r", GroupName).Run(); err != nil { // #nosec G204 - GroupName is a constant
|
||||
return fmt.Errorf("groupadd failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -279,7 +279,7 @@ func (i *Installer) addCurrentUserToGroup() error {
|
||||
|
||||
func (i *Installer) addUserToGroupDarwin(username string) error {
|
||||
// Check if user is already in group
|
||||
output, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName, "GroupMembership").Output()
|
||||
output, err := exec.Command("dscl", ".", "-read", "/Groups/"+GroupName, "GroupMembership").Output() // #nosec G204 - GroupName is a constant
|
||||
if err == nil && strings.Contains(string(output), username) {
|
||||
i.log(" User '%s' already in group '%s'", username, GroupName)
|
||||
return nil
|
||||
@@ -287,7 +287,7 @@ func (i *Installer) addUserToGroupDarwin(username string) error {
|
||||
|
||||
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
|
||||
|
||||
if err := exec.Command("dscl", ".", "-append", "/Groups/"+GroupName, "GroupMembership", username).Run(); err != nil {
|
||||
if err := exec.Command("dscl", ".", "-append", "/Groups/"+GroupName, "GroupMembership", username).Run(); err != nil { // #nosec G204 - GroupName is a constant, username from SUDO_USER env
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
|
||||
@@ -296,7 +296,7 @@ func (i *Installer) addUserToGroupDarwin(username string) error {
|
||||
|
||||
func (i *Installer) addUserToGroupLinux(username string) error {
|
||||
// Check if user is already in group
|
||||
output, err := exec.Command("id", "-nG", username).Output()
|
||||
output, err := exec.Command("id", "-nG", username).Output() // #nosec G204 - username from SUDO_USER env or current user
|
||||
if err == nil && strings.Contains(string(output), GroupName) {
|
||||
i.log(" User '%s' already in group '%s'", username, GroupName)
|
||||
return nil
|
||||
@@ -304,7 +304,7 @@ func (i *Installer) addUserToGroupLinux(username string) error {
|
||||
|
||||
i.log(" Adding user '%s' to group '%s'...", username, GroupName)
|
||||
|
||||
if err := exec.Command("usermod", "-aG", GroupName, username).Run(); err != nil {
|
||||
if err := exec.Command("usermod", "-aG", GroupName, username).Run(); err != nil { // #nosec G204 - GroupName is a constant, username from SUDO_USER env
|
||||
return fmt.Errorf("failed to add user to group: %w", err)
|
||||
}
|
||||
|
||||
@@ -316,7 +316,7 @@ func (i *Installer) createDirectories() error {
|
||||
|
||||
for _, dir := range dirs {
|
||||
i.log(" Creating directory '%s'...", dir)
|
||||
// #nosec G301 -- system directories should be world-readable
|
||||
// #nosec G301 - System directories are intentionally 0755
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create %s: %w", dir, err)
|
||||
}
|
||||
@@ -342,7 +342,7 @@ func (i *Installer) installLaunchDaemon() error {
|
||||
|
||||
// Unload if already loaded (do this before writing plist)
|
||||
i.log(" Stopping existing daemon if running...")
|
||||
_ = exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
|
||||
_ = exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run() // #nosec G204 - Hardcoded service name
|
||||
|
||||
// Give launchd time to fully unload the service
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
@@ -351,21 +351,20 @@ func (i *Installer) installLaunchDaemon() error {
|
||||
_ = os.Remove(plistPath)
|
||||
|
||||
i.log(" Writing LaunchDaemon plist...")
|
||||
// #nosec G306 -- plist files are world-readable by convention
|
||||
// #nosec G306 - Plist file permissions are intentionally 0644
|
||||
if err := os.WriteFile(plistPath, []byte(plistContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write plist: %w", err)
|
||||
}
|
||||
|
||||
// Bootstrap the daemon
|
||||
i.log(" Starting daemon...")
|
||||
// #nosec G204 -- plistPath is constructed from constant LaunchDaemonDir
|
||||
cmd := exec.Command("launchctl", "bootstrap", "system", plistPath)
|
||||
cmd := exec.Command("launchctl", "bootstrap", "system", plistPath) // #nosec G204 - plistPath is constructed from constants
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// Exit code 5 means "service already loaded" - try kickstart instead
|
||||
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 5 {
|
||||
i.log(" Service already registered, restarting...")
|
||||
if err := exec.Command("launchctl", "kickstart", "-k", "system/com.lolcathost.daemon").Run(); err != nil {
|
||||
if err := exec.Command("launchctl", "kickstart", "-k", "system/com.lolcathost.daemon").Run(); err != nil { // #nosec G204 - Hardcoded service name
|
||||
return fmt.Errorf("failed to restart daemon: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -380,7 +379,7 @@ func (i *Installer) uninstallLaunchDaemon() {
|
||||
plistPath := filepath.Join(LaunchDaemonDir, "com.lolcathost.daemon.plist")
|
||||
|
||||
i.log(" Stopping daemon...")
|
||||
_ = exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run()
|
||||
_ = exec.Command("launchctl", "bootout", "system/com.lolcathost.daemon").Run() // #nosec G204 - Hardcoded service name
|
||||
|
||||
i.log(" Removing LaunchDaemon plist...")
|
||||
_ = os.Remove(plistPath)
|
||||
@@ -391,20 +390,20 @@ func (i *Installer) installSystemdService() error {
|
||||
unitContent := fmt.Sprintf(SystemdUnit, i.binaryPath)
|
||||
|
||||
i.log(" Writing systemd unit...")
|
||||
// #nosec G306 -- systemd unit files are world-readable by convention
|
||||
// #nosec G306 - Unit file permissions are intentionally 0644
|
||||
if err := os.WriteFile(unitPath, []byte(unitContent), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write unit file: %w", err)
|
||||
}
|
||||
|
||||
// Reload systemd
|
||||
i.log(" Reloading systemd...")
|
||||
if err := exec.Command("systemctl", "daemon-reload").Run(); err != nil {
|
||||
if err := exec.Command("systemctl", "daemon-reload").Run(); err != nil { // #nosec G204 - Hardcoded systemctl command
|
||||
return fmt.Errorf("failed to reload systemd: %w", err)
|
||||
}
|
||||
|
||||
// Enable and start the service
|
||||
i.log(" Enabling and starting service...")
|
||||
if err := exec.Command("systemctl", "enable", "--now", "lolcathost.service").Run(); err != nil {
|
||||
if err := exec.Command("systemctl", "enable", "--now", "lolcathost.service").Run(); err != nil { // #nosec G204 - Hardcoded service name
|
||||
return fmt.Errorf("failed to enable service: %w", err)
|
||||
}
|
||||
|
||||
@@ -413,12 +412,12 @@ func (i *Installer) installSystemdService() error {
|
||||
|
||||
func (i *Installer) uninstallSystemdService() {
|
||||
i.log(" Stopping and disabling service...")
|
||||
_ = exec.Command("systemctl", "disable", "--now", "lolcathost.service").Run()
|
||||
_ = exec.Command("systemctl", "disable", "--now", "lolcathost.service").Run() // #nosec G204 - Hardcoded service name
|
||||
|
||||
i.log(" Removing systemd unit...")
|
||||
_ = os.Remove(filepath.Join(SystemdDir, "lolcathost.service"))
|
||||
|
||||
_ = exec.Command("systemctl", "daemon-reload").Run()
|
||||
_ = exec.Command("systemctl", "daemon-reload").Run() // #nosec G204 - Hardcoded systemctl command
|
||||
}
|
||||
|
||||
func (i *Installer) createDefaultConfig() error {
|
||||
@@ -435,7 +434,7 @@ func (i *Installer) createDefaultConfig() error {
|
||||
|
||||
// Create config directory
|
||||
configDir := filepath.Dir(configPath)
|
||||
// #nosec G301 -- config directory should be world-readable
|
||||
// #nosec G301 - Config directory permissions are intentionally 0755
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
@@ -37,12 +37,6 @@ var (
|
||||
|
||||
disabledStyle = lipgloss.NewStyle().
|
||||
Foreground(colorMuted)
|
||||
|
||||
pendingStyle = lipgloss.NewStyle().
|
||||
Foreground(colorWarning)
|
||||
|
||||
errorIndicatorStyle = lipgloss.NewStyle().
|
||||
Foreground(colorError)
|
||||
)
|
||||
|
||||
// Status bar and help
|
||||
@@ -118,34 +112,6 @@ var (
|
||||
Padding(0, 1)
|
||||
)
|
||||
|
||||
// Indicator returns the appropriate status indicator string.
|
||||
func Indicator(enabled bool, pending bool, hasError bool) string {
|
||||
if hasError {
|
||||
return errorIndicatorStyle.Render("✗")
|
||||
}
|
||||
if pending {
|
||||
return pendingStyle.Render("◐")
|
||||
}
|
||||
if enabled {
|
||||
return enabledStyle.Render("●")
|
||||
}
|
||||
return disabledStyle.Render("○")
|
||||
}
|
||||
|
||||
// StatusText returns the status text with appropriate styling
|
||||
func StatusText(enabled bool, pending bool, hasError bool) string {
|
||||
if hasError {
|
||||
return errorIndicatorStyle.Render("✗ Error")
|
||||
}
|
||||
if pending {
|
||||
return pendingStyle.Render("◐ Pending")
|
||||
}
|
||||
if enabled {
|
||||
return enabledStyle.Render("● Active")
|
||||
}
|
||||
return disabledStyle.Render("○ Disabled")
|
||||
}
|
||||
|
||||
// WrapHelpText wraps help text to fit within maxWidth, splitting on bullet separators.
|
||||
// If maxWidth is 0 or negative, returns the original text.
|
||||
func WrapHelpText(text string, maxWidth int) string {
|
||||
|
||||
@@ -146,7 +146,7 @@ func parseVersion(v string) []int {
|
||||
|
||||
for _, p := range parts {
|
||||
var num int
|
||||
_, _ = fmt.Sscanf(p, "%d", &num)
|
||||
_, _ = fmt.Sscanf(p, "%d", &num) // Error intentionally ignored - non-numeric parts become 0
|
||||
result = append(result, num)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user