Initial commit.

This commit is contained in:
2025-11-28 02:50:25 +00:00
commit 22552aec99
41 changed files with 10626 additions and 0 deletions
+133
View File
@@ -0,0 +1,133 @@
// Package daemon provides the main daemon loop and lifecycle management.
package daemon
import (
"fmt"
"os"
"os/signal"
"syscall"
"time"
"github.com/lukaszraczylo/lolcathost/internal/config"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// Daemon represents the lolcathost daemon.
type Daemon struct {
server *Server
config *config.Manager
stopCh chan struct{}
cleanupCh chan struct{}
}
// New creates a new daemon instance.
func New(configPath string) (*Daemon, error) {
cfgManager := config.NewManager(configPath)
// Try to load config, create default if it doesn't exist
if err := cfgManager.Load(); err != nil {
if os.IsNotExist(err) {
if err := config.CreateDefault(configPath); err != nil {
return nil, fmt.Errorf("failed to create default config: %w", err)
}
if err := cfgManager.Load(); err != nil {
return nil, fmt.Errorf("failed to load default config: %w", err)
}
} else {
return nil, fmt.Errorf("failed to load config: %w", err)
}
}
// Ensure at least one group exists
cfg := cfgManager.Get()
if cfg != nil {
cfg.EnsureDefaultGroup()
// Save if we added a default group
if len(cfg.Groups) == 1 && cfg.Groups[0].Name == "default" && len(cfg.Groups[0].Hosts) == 0 {
cfgManager.Save()
}
}
server := NewServer(protocol.SocketPath, cfgManager)
return &Daemon{
server: server,
config: cfgManager,
stopCh: make(chan struct{}),
cleanupCh: make(chan struct{}),
}, nil
}
// Run starts the daemon and blocks until stopped.
func (d *Daemon) Run() error {
// Verify we're running as root
if os.Geteuid() != 0 {
return fmt.Errorf("daemon must run as root")
}
// Start the server
if err := d.server.Start(); err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
// Watch config for changes
if err := d.config.Watch(d.onConfigChange); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to watch config: %v\n", err)
}
// Start cleanup goroutine
go d.cleanupLoop()
// Wait for shutdown signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
select {
case <-sigCh:
fmt.Println("Received shutdown signal")
case <-d.stopCh:
fmt.Println("Shutdown requested")
}
return d.shutdown()
}
// Stop signals the daemon to stop.
func (d *Daemon) Stop() {
close(d.stopCh)
}
func (d *Daemon) shutdown() error {
close(d.cleanupCh)
d.config.Stop()
if err := d.server.Stop(); err != nil {
return fmt.Errorf("failed to stop server: %w", err)
}
return nil
}
func (d *Daemon) onConfigChange(cfg *config.Config) {
fmt.Println("Config changed, syncing hosts file...")
// The server will use the updated config on next request
// We could trigger a sync here if autoApply is enabled
if cfg != nil && cfg.Settings.AutoApply {
// Sync hosts file with new config
// This is handled by the server internally
}
}
func (d *Daemon) cleanupLoop() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
d.server.rateLimiter.Cleanup()
case <-d.cleanupCh:
return
}
}
}
+142
View File
@@ -0,0 +1,142 @@
// Package daemon provides DNS cache flushing functionality.
package daemon
import (
"fmt"
"os/exec"
"runtime"
)
// DNSFlusher handles DNS cache flushing.
type DNSFlusher struct {
method FlushMethod
}
// FlushMethod defines the DNS flush method to use.
type FlushMethod string
const (
FlushMethodAuto FlushMethod = "auto"
FlushMethodDscacheutil FlushMethod = "dscacheutil"
FlushMethodKillall FlushMethod = "killall"
FlushMethodBoth FlushMethod = "both"
FlushMethodSystemd FlushMethod = "systemd"
FlushMethodNscd FlushMethod = "nscd"
)
// NewDNSFlusher creates a new DNS flusher.
func NewDNSFlusher(method FlushMethod) *DNSFlusher {
return &DNSFlusher{method: method}
}
// Flush flushes the DNS cache using the configured method.
func (f *DNSFlusher) Flush() error {
method := f.method
if method == FlushMethodAuto || method == "" {
method = f.detectMethod()
}
switch runtime.GOOS {
case "darwin":
return f.flushDarwin(method)
case "linux":
return f.flushLinux(method)
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
}
func (f *DNSFlusher) detectMethod() FlushMethod {
switch runtime.GOOS {
case "darwin":
return FlushMethodBoth
case "linux":
// Check for systemd-resolve first
if _, err := exec.LookPath("systemd-resolve"); err == nil {
return FlushMethodSystemd
}
if _, err := exec.LookPath("resolvectl"); err == nil {
return FlushMethodSystemd
}
// Fall back to nscd
if _, err := exec.LookPath("nscd"); err == nil {
return FlushMethodNscd
}
return FlushMethodAuto
default:
return FlushMethodAuto
}
}
func (f *DNSFlusher) flushDarwin(method FlushMethod) error {
var errs []error
switch method {
case FlushMethodDscacheutil:
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
return fmt.Errorf("dscacheutil failed: %w", err)
}
case FlushMethodKillall:
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
return fmt.Errorf("killall mDNSResponder failed: %w", err)
}
case FlushMethodBoth:
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
errs = append(errs, fmt.Errorf("dscacheutil failed: %w", err))
}
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
errs = append(errs, fmt.Errorf("killall mDNSResponder failed: %w", err))
}
if len(errs) == 2 {
return fmt.Errorf("all DNS flush methods failed: %v, %v", errs[0], errs[1])
}
default:
// Auto - try both
_ = runCommand("dscacheutil", "-flushcache")
_ = runCommand("killall", "-HUP", "mDNSResponder")
}
return nil
}
func (f *DNSFlusher) flushLinux(method FlushMethod) error {
switch method {
case FlushMethodSystemd:
// Try resolvectl first (newer), then systemd-resolve (older)
if err := runCommand("resolvectl", "flush-caches"); err != nil {
if err := runCommand("systemd-resolve", "--flush-caches"); err != nil {
return fmt.Errorf("systemd DNS flush failed: %w", err)
}
}
case FlushMethodNscd:
// Try to restart nscd
if err := runCommand("nscd", "-i", "hosts"); err != nil {
// Try service restart as fallback
if err := runCommand("service", "nscd", "restart"); err != nil {
return fmt.Errorf("nscd flush failed: %w", err)
}
}
default:
// Auto - try all methods
// Try systemd first
if err := runCommand("resolvectl", "flush-caches"); err == nil {
return nil
}
if err := runCommand("systemd-resolve", "--flush-caches"); err == nil {
return nil
}
// Try nscd
if err := runCommand("nscd", "-i", "hosts"); err == nil {
return nil
}
// On many Linux systems, no explicit flush is needed as /etc/hosts is read directly
// So we return nil here
}
return nil
}
func runCommand(name string, args ...string) error {
cmd := exec.Command(name, args...)
return cmd.Run()
}
+108
View File
@@ -0,0 +1,108 @@
package daemon
import (
"runtime"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewDNSFlusher(t *testing.T) {
tests := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
FlushMethodSystemd,
FlushMethodNscd,
}
for _, method := range tests {
t.Run(string(method), func(t *testing.T) {
flusher := NewDNSFlusher(method)
assert.NotNil(t, flusher)
assert.Equal(t, method, flusher.method)
})
}
}
func TestDNSFlusher_DetectMethod(t *testing.T) {
flusher := NewDNSFlusher(FlushMethodAuto)
method := flusher.detectMethod()
switch runtime.GOOS {
case "darwin":
assert.Equal(t, FlushMethodBoth, method)
case "linux":
// Could be systemd, nscd, or auto depending on system
assert.Contains(t, []FlushMethod{FlushMethodSystemd, FlushMethodNscd, FlushMethodAuto}, method)
}
}
func TestFlushMethod_String(t *testing.T) {
methods := map[FlushMethod]string{
FlushMethodAuto: "auto",
FlushMethodDscacheutil: "dscacheutil",
FlushMethodKillall: "killall",
FlushMethodBoth: "both",
FlushMethodSystemd: "systemd",
FlushMethodNscd: "nscd",
}
for method, expected := range methods {
t.Run(expected, func(t *testing.T) {
assert.Equal(t, expected, string(method))
})
}
}
// Note: Actually testing DNS flush requires root and modifies system state,
// so we skip those tests in unit tests. They would be integration tests.
func TestDNSFlusher_Flush_UnsupportedOS(t *testing.T) {
// This test only makes sense if we're not on darwin or linux
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
t.Skip("Test only applicable on unsupported OS")
}
flusher := NewDNSFlusher(FlushMethodAuto)
err := flusher.Flush()
assert.Error(t, err)
assert.Contains(t, err.Error(), "unsupported operating system")
}
// Matrix test for flush methods
func TestFlushMethod_Matrix(t *testing.T) {
methods := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
FlushMethodSystemd,
FlushMethodNscd,
}
platforms := []string{"darwin", "linux"}
for _, method := range methods {
for _, platform := range platforms {
t.Run(string(method)+"_"+platform, func(t *testing.T) {
flusher := NewDNSFlusher(method)
assert.NotNil(t, flusher)
// Just verify no panic when checking method
_ = flusher.method
})
}
}
}
func BenchmarkDNSFlusher_DetectMethod(b *testing.B) {
flusher := NewDNSFlusher(FlushMethodAuto)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = flusher.detectMethod()
}
}
+319
View File
@@ -0,0 +1,319 @@
// Package daemon implements the privileged daemon that manages /etc/hosts.
package daemon
import (
"bufio"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
)
const (
// HostsPath is the path to the system hosts file.
HostsPath = "/etc/hosts"
// BackupDir is the directory for hosts file backups.
BackupDir = "/var/backups/lolcathost"
// MaxBackups is the maximum number of backups to keep.
MaxBackups = 10
// Markers for the managed section.
markerStart = "# ========== LOLCATHOST MANAGED - DO NOT EDIT =========="
markerEnd = "# ========== END LOLCATHOST =========="
)
// HostEntry represents a single entry in the hosts file.
type HostEntry struct {
IP string
Domain string
Alias string
Enabled bool
}
// HostsManager handles reading and writing the hosts file.
type HostsManager struct {
hostsPath string
backupDir string
}
// NewHostsManager creates a new hosts manager.
func NewHostsManager() *HostsManager {
return &HostsManager{
hostsPath: HostsPath,
backupDir: BackupDir,
}
}
// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
return &HostsManager{
hostsPath: hostsPath,
backupDir: backupDir,
}
}
// ReadManagedEntries reads the lolcathost-managed entries from the hosts file.
func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) {
file, err := os.Open(m.hostsPath)
if err != nil {
return nil, fmt.Errorf("failed to open hosts file: %w", err)
}
defer file.Close()
var entries []HostEntry
inManagedSection := false
scanner := bufio.NewScanner(file)
entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == markerStart {
inManagedSection = true
continue
}
if line == markerEnd {
inManagedSection = false
continue
}
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
matches := entryRegex.FindStringSubmatch(line)
if len(matches) == 4 {
entries = append(entries, HostEntry{
IP: matches[1],
Domain: matches[2],
Alias: matches[3],
Enabled: true,
})
}
}
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("failed to read hosts file: %w", err)
}
return entries, nil
}
// WriteManagedEntries writes the managed entries to the hosts file.
func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error {
// Create backup first
if err := m.CreateBackup(); err != nil {
return fmt.Errorf("failed to create backup: %w", err)
}
// Read existing content
content, err := os.ReadFile(m.hostsPath)
if err != nil {
return fmt.Errorf("failed to read hosts file: %w", err)
}
// Remove existing managed section
newContent := m.removeManagedSection(string(content))
// Build new managed section
managedSection := m.buildManagedSection(entries)
// Append managed section
newContent = strings.TrimRight(newContent, "\n") + "\n\n" + managedSection
// Write atomically
if err := m.writeAtomic(newContent); err != nil {
return fmt.Errorf("failed to write hosts file: %w", err)
}
return nil
}
func (m *HostsManager) removeManagedSection(content string) string {
lines := strings.Split(content, "\n")
var result []string
inManagedSection := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == markerStart {
inManagedSection = true
continue
}
if trimmed == markerEnd {
inManagedSection = false
continue
}
if !inManagedSection {
result = append(result, line)
}
}
// Remove trailing empty lines
for len(result) > 0 && strings.TrimSpace(result[len(result)-1]) == "" {
result = result[:len(result)-1]
}
return strings.Join(result, "\n")
}
func (m *HostsManager) buildManagedSection(entries []HostEntry) string {
var sb strings.Builder
sb.WriteString(markerStart)
sb.WriteString("\n")
for _, entry := range entries {
if entry.Enabled {
sb.WriteString(fmt.Sprintf("%s\t%s\t# lolcathost:%s\n", entry.IP, entry.Domain, entry.Alias))
}
}
sb.WriteString(markerEnd)
sb.WriteString("\n")
return sb.String()
}
func (m *HostsManager) writeAtomic(content string) error {
// Write to temp file first
tmpFile := m.hostsPath + ".tmp"
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
return err
}
// Rename atomically
if err := os.Rename(tmpFile, m.hostsPath); err != nil {
os.Remove(tmpFile)
return err
}
return nil
}
// CreateBackup creates a backup of the current hosts file.
func (m *HostsManager) CreateBackup() error {
if err := os.MkdirAll(m.backupDir, 0755); err != nil {
return fmt.Errorf("failed to create backup directory: %w", err)
}
content, err := os.ReadFile(m.hostsPath)
if err != nil {
return fmt.Errorf("failed to read hosts file: %w", err)
}
timestamp := time.Now().Format("20060102-150405")
backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp))
if err := os.WriteFile(backupPath, content, 0644); err != nil {
return fmt.Errorf("failed to write backup: %w", err)
}
// Cleanup old backups
if err := m.cleanupBackups(); err != nil {
// Log but don't fail
fmt.Fprintf(os.Stderr, "warning: failed to cleanup backups: %v\n", err)
}
return nil
}
func (m *HostsManager) cleanupBackups() error {
entries, err := os.ReadDir(m.backupDir)
if err != nil {
return err
}
var backups []os.DirEntry
for _, entry := range entries {
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "hosts.") && strings.HasSuffix(entry.Name(), ".bak") {
backups = append(backups, entry)
}
}
if len(backups) <= MaxBackups {
return nil
}
// Sort by name (timestamp) descending
sort.Slice(backups, func(i, j int) bool {
return backups[i].Name() > backups[j].Name()
})
// Remove oldest backups
for i := MaxBackups; i < len(backups); i++ {
path := filepath.Join(m.backupDir, backups[i].Name())
os.Remove(path)
}
return nil
}
// ListBackups returns a list of available backups.
func (m *HostsManager) ListBackups() ([]BackupInfo, error) {
entries, err := os.ReadDir(m.backupDir)
if err != nil {
if os.IsNotExist(err) {
return nil, nil
}
return nil, err
}
var backups []BackupInfo
for _, entry := range entries {
if entry.IsDir() || !strings.HasPrefix(entry.Name(), "hosts.") || !strings.HasSuffix(entry.Name(), ".bak") {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
backups = append(backups, BackupInfo{
Name: entry.Name(),
Timestamp: info.ModTime().Unix(),
Size: info.Size(),
})
}
// Sort by timestamp descending
sort.Slice(backups, func(i, j int) bool {
return backups[i].Timestamp > backups[j].Timestamp
})
return backups, nil
}
// BackupInfo holds information about a backup file.
type BackupInfo struct {
Name string
Timestamp int64
Size int64
}
// RestoreBackup restores a backup by name.
func (m *HostsManager) RestoreBackup(name string) error {
backupPath := filepath.Join(m.backupDir, name)
// Validate backup name to prevent path traversal
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
return fmt.Errorf("invalid backup name")
}
content, err := os.ReadFile(backupPath)
if err != nil {
return fmt.Errorf("failed to read backup: %w", err)
}
// Create a backup of current state before restoring
if err := m.CreateBackup(); err != nil {
return fmt.Errorf("failed to create backup before restore: %w", err)
}
if err := m.writeAtomic(string(content)); err != nil {
return fmt.Errorf("failed to restore backup: %w", err)
}
return nil
}
+422
View File
@@ -0,0 +1,422 @@
package daemon
import (
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHostsManager_ReadManagedEntries(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
hostsContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
::1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 example.com # lolcathost:example-local
192.168.1.1 api.example.com # lolcathost:api-local
# ========== END LOLCATHOST ==========
`
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
assert.Len(t, entries, 2)
assert.Equal(t, "127.0.0.1", entries[0].IP)
assert.Equal(t, "example.com", entries[0].Domain)
assert.Equal(t, "example-local", entries[0].Alias)
assert.Equal(t, "192.168.1.1", entries[1].IP)
assert.Equal(t, "api.example.com", entries[1].Domain)
assert.Equal(t, "api-local", entries[1].Alias)
}
func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
hostsContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
`
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
assert.Empty(t, entries)
}
func TestHostsManager_WriteManagedEntries(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create initial hosts file
initialContent := `127.0.0.1 localhost
255.255.255.255 broadcasthost
`
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true},
{IP: "127.0.0.1", Domain: "api.myapp.com", Alias: "api-local", Enabled: true},
{IP: "192.168.1.1", Domain: "staging.myapp.com", Alias: "staging", Enabled: false},
}
err = manager.WriteManagedEntries(entries)
require.NoError(t, err)
// Read back
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
assert.Contains(t, contentStr, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
assert.Contains(t, contentStr, "127.0.0.1\tmyapp.com\t# lolcathost:myapp-local")
assert.Contains(t, contentStr, "127.0.0.1\tapi.myapp.com\t# lolcathost:api-local")
assert.NotContains(t, contentStr, "staging.myapp.com") // disabled
assert.Contains(t, contentStr, "# ========== END LOLCATHOST ==========")
}
func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create hosts file with existing managed section
initialContent := `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 old.com # lolcathost:old
# ========== END LOLCATHOST ==========
`
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true},
}
err = manager.WriteManagedEntries(entries)
require.NoError(t, err)
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
assert.Contains(t, contentStr, "new.com")
assert.NotContains(t, contentStr, "old.com")
}
func TestHostsManager_CreateBackup(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
hostsContent := "127.0.0.1\tlocalhost\n"
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
err = manager.CreateBackup()
require.NoError(t, err)
// Verify backup exists
entries, err := os.ReadDir(backupDir)
require.NoError(t, err)
assert.Len(t, entries, 1)
assert.True(t, strings.HasPrefix(entries[0].Name(), "hosts."))
assert.True(t, strings.HasSuffix(entries[0].Name(), ".bak"))
// Verify backup content
backupContent, err := os.ReadFile(filepath.Join(backupDir, entries[0].Name()))
require.NoError(t, err)
assert.Equal(t, hostsContent, string(backupContent))
}
func TestHostsManager_ListBackups(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create hosts file
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
require.NoError(t, err)
// Manually create backup files with different timestamps
err = os.MkdirAll(backupDir, 0755)
require.NoError(t, err)
backupNames := []string{
"hosts.20231201-120000.bak",
"hosts.20231201-120001.bak",
"hosts.20231201-120002.bak",
}
for _, name := range backupNames {
err = os.WriteFile(filepath.Join(backupDir, name), []byte("backup"), 0644)
require.NoError(t, err)
}
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.Len(t, backups, 3)
}
func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "nonexistent")
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.Empty(t, backups)
}
func TestHostsManager_RestoreBackup(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
// Create initial hosts file
initialContent := "initial content"
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
// Create backup
err = manager.CreateBackup()
require.NoError(t, err)
// Modify hosts file
err = os.WriteFile(hostsPath, []byte("modified content"), 0644)
require.NoError(t, err)
// Get backup name
backups, err := manager.ListBackups()
require.NoError(t, err)
require.Len(t, backups, 1)
// Restore
err = manager.RestoreBackup(backups[0].Name)
require.NoError(t, err)
// Verify content restored
content, err := os.ReadFile(hostsPath)
require.NoError(t, err)
assert.Equal(t, initialContent, string(content))
}
func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) {
tmpDir := t.TempDir()
manager := NewHostsManagerWithPaths(
filepath.Join(tmpDir, "hosts"),
filepath.Join(tmpDir, "backups"),
)
tests := []string{
"../../../etc/passwd",
"hosts.bak", // Missing timestamp
"notahosts.backup", // Wrong format
"",
}
for _, name := range tests {
t.Run(name, func(t *testing.T) {
err := manager.RestoreBackup(name)
assert.Error(t, err)
})
}
}
func TestHostsManager_CleanupBackups(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
// Create more than MaxBackups
for i := 0; i < MaxBackups+5; i++ {
err = manager.CreateBackup()
require.NoError(t, err)
}
// Verify only MaxBackups remain
backups, err := manager.ListBackups()
require.NoError(t, err)
assert.LessOrEqual(t, len(backups), MaxBackups)
}
func TestHostsManager_RemoveManagedSection(t *testing.T) {
manager := &HostsManager{}
tests := []struct {
name string
input string
expected string
}{
{
name: "with managed section",
input: `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
127.0.0.1 example.com # lolcathost:test
# ========== END LOLCATHOST ==========
`,
expected: "127.0.0.1\tlocalhost",
},
{
name: "without managed section",
input: "127.0.0.1\tlocalhost\n",
expected: "127.0.0.1\tlocalhost",
},
{
name: "multiple managed sections",
input: `127.0.0.1 localhost
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
entry1
# ========== END LOLCATHOST ==========
more content
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
entry2
# ========== END LOLCATHOST ==========
`,
expected: "127.0.0.1\tlocalhost\nmore content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.removeManagedSection(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestHostsManager_BuildManagedSection(t *testing.T) {
manager := &HostsManager{}
entries := []HostEntry{
{IP: "127.0.0.1", Domain: "a.com", Alias: "a", Enabled: true},
{IP: "192.168.1.1", Domain: "b.com", Alias: "b", Enabled: true},
{IP: "10.0.0.1", Domain: "c.com", Alias: "c", Enabled: false},
}
result := manager.buildManagedSection(entries)
assert.Contains(t, result, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
assert.Contains(t, result, "127.0.0.1\ta.com\t# lolcathost:a")
assert.Contains(t, result, "192.168.1.1\tb.com\t# lolcathost:b")
assert.NotContains(t, result, "c.com") // disabled
assert.Contains(t, result, "# ========== END LOLCATHOST ==========")
}
// Matrix tests for hosts file parsing
func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
ips := []string{"127.0.0.1", "192.168.1.1", "::1"}
domains := []string{"example.com", "sub.example.com", "my-app.test"}
aliases := []string{"test", "my-alias", "app-1"}
for _, ip := range ips {
for _, domain := range domains {
for _, alias := range aliases {
t.Run(ip+"/"+domain+"/"+alias, func(t *testing.T) {
tmpDir := t.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
content := "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n"
content += ip + "\t" + domain + "\t# lolcathost:" + alias + "\n"
content += "# ========== END LOLCATHOST ==========\n"
err := os.WriteFile(hostsPath, []byte(content), 0644)
require.NoError(t, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
entries, err := manager.ReadManagedEntries()
require.NoError(t, err)
require.Len(t, entries, 1)
assert.Equal(t, ip, entries[0].IP)
assert.Equal(t, domain, entries[0].Domain)
assert.Equal(t, alias, entries[0].Alias)
})
}
}
}
}
func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
tmpDir := b.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
// Create a hosts file with many entries
var content strings.Builder
content.WriteString("127.0.0.1\tlocalhost\n")
content.WriteString("# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n")
for i := 0; i < 100; i++ {
content.WriteString("127.0.0.1\texample" + string(rune('a'+i%26)) + ".com\t# lolcathost:alias" + string(rune('a'+i%26)) + "\n")
}
content.WriteString("# ========== END LOLCATHOST ==========\n")
err := os.WriteFile(hostsPath, []byte(content.String()), 0644)
require.NoError(b, err)
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = manager.ReadManagedEntries()
}
}
func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) {
tmpDir := b.TempDir()
hostsPath := filepath.Join(tmpDir, "hosts")
backupDir := filepath.Join(tmpDir, "backups")
err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644)
require.NoError(b, err)
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
entries := make([]HostEntry, 50)
for i := range entries {
entries[i] = HostEntry{
IP: "127.0.0.1",
Domain: "example" + string(rune('a'+i%26)) + ".com",
Alias: "alias" + string(rune('a'+i%26)),
Enabled: i%2 == 0,
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = manager.WriteManagedEntries(entries)
}
}
+57
View File
@@ -0,0 +1,57 @@
//go:build darwin
package daemon
import (
"net"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// getPeerCredentials extracts peer credentials from a Unix socket connection on macOS.
// Note: macOS Xucred doesn't include PID, so we use LOCAL_PEERPID separately.
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return nil
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return nil
}
var creds *PeerCredentials
rawConn.Control(func(fd uintptr) {
xucred, err := unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
if err != nil {
return
}
// Get PID separately using LOCAL_PEERPID
var pid int32
pidLen := uint32(unsafe.Sizeof(pid))
_, _, errno := syscall.Syscall6(
syscall.SYS_GETSOCKOPT,
fd,
unix.SOL_LOCAL,
0x002, // LOCAL_PEERPID
uintptr(unsafe.Pointer(&pid)),
uintptr(unsafe.Pointer(&pidLen)),
0,
)
if errno != 0 {
pid = 0
}
creds = &PeerCredentials{
UID: xucred.Uid,
GID: xucred.Groups[0],
PID: pid,
}
})
return creds
}
+37
View File
@@ -0,0 +1,37 @@
//go:build linux
package daemon
import (
"net"
"golang.org/x/sys/unix"
)
// getPeerCredentials extracts peer credentials from a Unix socket connection on Linux.
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
unixConn, ok := conn.(*net.UnixConn)
if !ok {
return nil
}
rawConn, err := unixConn.SyscallConn()
if err != nil {
return nil
}
var creds *PeerCredentials
rawConn.Control(func(fd uintptr) {
ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
if err != nil {
return
}
creds = &PeerCredentials{
UID: ucred.Uid,
GID: ucred.Gid,
PID: ucred.Pid,
}
})
return creds
}
+196
View File
@@ -0,0 +1,196 @@
// Package daemon provides security functions including rate limiting and audit logging.
package daemon
import (
"encoding/json"
"fmt"
"os"
"os/user"
"sync"
"time"
)
const (
// AuditLogPath is the path to the audit log file.
AuditLogPath = "/var/log/lolcathost/audit.log"
// RateLimit is the maximum requests per minute per PID.
RateLimit = 100
// RateLimitWindow is the time window for rate limiting.
RateLimitWindow = time.Minute
)
// RateLimiter implements per-PID rate limiting.
type RateLimiter struct {
mu sync.Mutex
requests map[int32][]time.Time
limit int
window time.Duration
}
// NewRateLimiter creates a new rate limiter.
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
return &RateLimiter{
requests: make(map[int32][]time.Time),
limit: limit,
window: window,
}
}
// Allow checks if a request from the given PID should be allowed.
func (r *RateLimiter) Allow(pid int32) bool {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
cutoff := now.Add(-r.window)
// Get existing requests for this PID
reqs := r.requests[pid]
// Filter out old requests
var validReqs []time.Time
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
// Check if under limit
if len(validReqs) >= r.limit {
r.requests[pid] = validReqs
return false
}
// Add new request
validReqs = append(validReqs, now)
r.requests[pid] = validReqs
return true
}
// Cleanup removes old entries from the rate limiter.
func (r *RateLimiter) Cleanup() {
r.mu.Lock()
defer r.mu.Unlock()
now := time.Now()
cutoff := now.Add(-r.window)
for pid, reqs := range r.requests {
var validReqs []time.Time
for _, t := range reqs {
if t.After(cutoff) {
validReqs = append(validReqs, t)
}
}
if len(validReqs) == 0 {
delete(r.requests, pid)
} else {
r.requests[pid] = validReqs
}
}
}
// AuditLogger handles audit logging.
type AuditLogger struct {
mu sync.Mutex
file *os.File
path string
encoder *json.Encoder
}
// AuditEntry represents a single audit log entry.
type AuditEntry struct {
Timestamp string `json:"timestamp"`
UID uint32 `json:"uid"`
PID int32 `json:"pid"`
Action string `json:"action"`
Details any `json:"details,omitempty"`
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// NewAuditLogger creates a new audit logger.
func NewAuditLogger(path string) (*AuditLogger, error) {
// Ensure directory exists
dir := path[:len(path)-len("/audit.log")]
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log directory: %w", err)
}
file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
return nil, fmt.Errorf("failed to open audit log: %w", err)
}
return &AuditLogger{
file: file,
path: path,
encoder: json.NewEncoder(file),
}, nil
}
// Log writes an audit entry.
func (a *AuditLogger) Log(uid uint32, pid int32, action string, details any, success bool, errMsg string) {
a.mu.Lock()
defer a.mu.Unlock()
entry := AuditEntry{
Timestamp: time.Now().UTC().Format(time.RFC3339),
UID: uid,
PID: pid,
Action: action,
Details: details,
Success: success,
Error: errMsg,
}
// Ignore encoding errors - audit logging should not fail the operation
_ = a.encoder.Encode(entry)
}
// Close closes the audit logger.
func (a *AuditLogger) Close() error {
a.mu.Lock()
defer a.mu.Unlock()
if a.file != nil {
err := a.file.Close()
a.file = nil // Prevent double close
return err
}
return nil
}
// PeerCredentials holds the credentials of a connected peer.
type PeerCredentials struct {
UID uint32
GID uint32
PID int32
}
// isUserInGroup checks if a user (by UID) is a member of a group (by GID).
// This checks supplementary groups, not just the primary GID.
func isUserInGroup(uid uint32, targetGID uint32) bool {
// Look up user by UID
u, err := user.LookupId(fmt.Sprintf("%d", uid))
if err != nil {
return false
}
// Get user's group IDs
groupIDs, err := u.GroupIds()
if err != nil {
return false
}
// Check if target GID is in the list
targetGIDStr := fmt.Sprintf("%d", targetGID)
for _, gid := range groupIDs {
if gid == targetGIDStr {
return true
}
}
return false
}
+206
View File
@@ -0,0 +1,206 @@
package daemon
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimiter_Allow(t *testing.T) {
t.Run("under limit", func(t *testing.T) {
rl := NewRateLimiter(5, time.Minute)
for i := 0; i < 5; i++ {
assert.True(t, rl.Allow(123), "request %d should be allowed", i)
}
})
t.Run("over limit", func(t *testing.T) {
rl := NewRateLimiter(3, time.Minute)
for i := 0; i < 3; i++ {
assert.True(t, rl.Allow(123))
}
// 4th request should be blocked
assert.False(t, rl.Allow(123))
})
t.Run("different PIDs", func(t *testing.T) {
rl := NewRateLimiter(2, time.Minute)
// PID 1
assert.True(t, rl.Allow(1))
assert.True(t, rl.Allow(1))
assert.False(t, rl.Allow(1))
// PID 2 should have its own limit
assert.True(t, rl.Allow(2))
assert.True(t, rl.Allow(2))
assert.False(t, rl.Allow(2))
})
t.Run("window expiration", func(t *testing.T) {
rl := NewRateLimiter(2, 10*time.Millisecond)
assert.True(t, rl.Allow(123))
assert.True(t, rl.Allow(123))
assert.False(t, rl.Allow(123))
// Wait for window to expire
time.Sleep(15 * time.Millisecond)
// Should be allowed again
assert.True(t, rl.Allow(123))
})
}
func TestRateLimiter_Cleanup(t *testing.T) {
rl := NewRateLimiter(10, 10*time.Millisecond)
// Add requests from multiple PIDs
for pid := int32(1); pid <= 5; pid++ {
rl.Allow(pid)
}
assert.Len(t, rl.requests, 5)
// Wait for expiration
time.Sleep(15 * time.Millisecond)
// Cleanup
rl.Cleanup()
assert.Empty(t, rl.requests)
}
func TestAuditLogger_Log(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
defer logger.Close()
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
logger.Log(1000, 12345, "sync", nil, false, "sync failed")
// Read log file
content, err := os.ReadFile(logPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, `"action":"set"`)
assert.Contains(t, contentStr, `"uid":1000`)
assert.Contains(t, contentStr, `"pid":12345`)
assert.Contains(t, contentStr, `"success":true`)
assert.Contains(t, contentStr, `"action":"sync"`)
assert.Contains(t, contentStr, `"success":false`)
assert.Contains(t, contentStr, `"error":"sync failed"`)
}
func TestAuditLogger_CreatesDirectory(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "subdir", "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
defer logger.Close()
// Verify directory was created
_, err = os.Stat(filepath.Dir(logPath))
assert.NoError(t, err)
}
func TestAuditLogger_Close(t *testing.T) {
tmpDir := t.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(t, err)
err = logger.Close()
assert.NoError(t, err)
// Closing again should not error
err = logger.Close()
assert.NoError(t, err)
}
func TestPeerCredentials(t *testing.T) {
creds := &PeerCredentials{
UID: 501,
GID: 20,
PID: 12345,
}
assert.Equal(t, uint32(501), creds.UID)
assert.Equal(t, uint32(20), creds.GID)
assert.Equal(t, int32(12345), creds.PID)
}
// Matrix test for rate limiting
func TestRateLimiter_Matrix(t *testing.T) {
limits := []int{1, 5, 10, 100}
windows := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, time.Second}
for _, limit := range limits {
for _, window := range windows {
t.Run(
"limit="+string(rune('0'+limit))+"_window="+window.String(),
func(t *testing.T) {
rl := NewRateLimiter(limit, window)
// Should allow exactly 'limit' requests
for i := 0; i < limit; i++ {
assert.True(t, rl.Allow(1))
}
// Next should be blocked
assert.False(t, rl.Allow(1))
},
)
}
}
}
func BenchmarkRateLimiter_Allow(b *testing.B) {
rl := NewRateLimiter(RateLimit, RateLimitWindow)
b.ResetTimer()
for i := 0; i < b.N; i++ {
rl.Allow(int32(i % 100))
}
}
func BenchmarkRateLimiter_Cleanup(b *testing.B) {
rl := NewRateLimiter(RateLimit, RateLimitWindow)
// Pre-populate with requests
for i := 0; i < 1000; i++ {
rl.Allow(int32(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
rl.Cleanup()
}
}
func BenchmarkAuditLogger_Log(b *testing.B) {
tmpDir := b.TempDir()
logPath := filepath.Join(tmpDir, "audit.log")
logger, err := NewAuditLogger(logPath)
require.NoError(b, err)
defer logger.Close()
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
}
}
+803
View File
@@ -0,0 +1,803 @@
// Package daemon provides the Unix socket server for the daemon.
package daemon
import (
"bufio"
"encoding/json"
"fmt"
"net"
"os"
"sync"
"time"
"github.com/lukaszraczylo/lolcathost/internal/config"
"github.com/lukaszraczylo/lolcathost/internal/protocol"
)
// Version is set by the main package at startup
var Version = "dev"
// Server is the daemon's Unix socket server.
type Server struct {
socketPath string
listener net.Listener
config *config.Manager
hosts *HostsManager
flusher *DNSFlusher
rateLimiter *RateLimiter
auditLogger *AuditLogger
mu sync.RWMutex
running bool
stopCh chan struct{}
requestCount int64
startTime int64
}
// NewServer creates a new daemon server.
func NewServer(socketPath string, cfgManager *config.Manager) *Server {
return &Server{
socketPath: socketPath,
config: cfgManager,
hosts: NewHostsManager(),
flusher: NewDNSFlusher(FlushMethodAuto),
rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow),
stopCh: make(chan struct{}),
}
}
// Start starts the server.
func (s *Server) Start() error {
// Remove existing socket
os.Remove(s.socketPath)
listener, err := net.Listen("unix", s.socketPath)
if err != nil {
return fmt.Errorf("failed to listen on socket: %w", err)
}
// Set socket permissions: 0660 root:lolcathost
if err := os.Chmod(s.socketPath, 0660); err != nil {
listener.Close()
return fmt.Errorf("failed to set socket permissions: %w", err)
}
// Set socket group to lolcathost (GID 850)
if err := os.Chown(s.socketPath, 0, 850); err != nil {
listener.Close()
return fmt.Errorf("failed to set socket ownership: %w", err)
}
s.listener = listener
s.running = true
s.startTime = currentTimeUnix()
// Try to create audit logger, but don't fail if it doesn't work
if logger, err := NewAuditLogger(AuditLogPath); err == nil {
s.auditLogger = logger
}
go s.acceptLoop()
return nil
}
func currentTimeUnix() int64 {
return time.Now().Unix()
}
// Stop stops the server.
func (s *Server) Stop() error {
s.mu.Lock()
s.running = false
s.mu.Unlock()
close(s.stopCh)
if s.listener != nil {
s.listener.Close()
}
os.Remove(s.socketPath)
if s.auditLogger != nil {
s.auditLogger.Close()
}
return nil
}
func (s *Server) acceptLoop() {
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.stopCh:
return
default:
continue
}
}
go s.handleConnection(conn)
}
}
// LolcathostGID is the group ID for the lolcathost group.
const LolcathostGID = 850
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
// Get peer credentials
creds := s.getPeerCredentials(conn)
// Authorization check: verify peer is authorized
if !s.isAuthorized(creds) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
if s.auditLogger != nil {
var uid uint32
var pid int32
if creds != nil {
uid = creds.UID
pid = creds.PID
}
s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt")
}
return
}
reader := bufio.NewReader(conn)
for {
line, err := reader.ReadBytes('\n')
if err != nil {
return
}
var req protocol.Request
if err := json.Unmarshal(line, &req); err != nil {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
continue
}
// Rate limiting
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
continue
}
s.mu.Lock()
s.requestCount++
s.mu.Unlock()
resp := s.handleRequest(&req, creds)
s.writeResponse(conn, resp)
}
}
// isAuthorized checks if the peer is authorized to access the daemon.
// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850).
func (s *Server) isAuthorized(creds *PeerCredentials) bool {
if creds == nil {
// Can't verify credentials - deny by default
return false
}
// Root is always authorized
if creds.UID == 0 {
return true
}
// Check if user's primary GID is lolcathost
if creds.GID == LolcathostGID {
return true
}
// Check supplementary groups (user might be in lolcathost as secondary group)
// This requires looking up the user's groups from the system
return isUserInGroup(creds.UID, LolcathostGID)
}
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
data, _ := json.Marshal(resp)
data = append(data, '\n')
conn.Write(data)
}
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
var uid uint32
var pid int32
if creds != nil {
uid = creds.UID
pid = creds.PID
}
switch req.Type {
case protocol.RequestPing:
return s.handlePing()
case protocol.RequestStatus:
return s.handleStatus()
case protocol.RequestList:
return s.handleList()
case protocol.RequestSet:
resp := s.handleSet(req)
if s.auditLogger != nil {
var payload protocol.SetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestSync:
resp := s.handleSync()
if s.auditLogger != nil {
s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestPreset:
resp := s.handlePreset(req)
if s.auditLogger != nil {
var payload protocol.PresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestRollback:
resp := s.handleRollback(req)
if s.auditLogger != nil {
var payload protocol.RollbackPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestBackups:
return s.handleBackups()
case protocol.RequestAdd:
resp := s.handleAdd(req)
if s.auditLogger != nil {
var payload protocol.AddPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDelete:
resp := s.handleDelete(req)
if s.auditLogger != nil {
var payload protocol.DeletePayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestAddGroup:
resp := s.handleAddGroup(req)
if s.auditLogger != nil {
var payload protocol.GroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDeleteGroup:
resp := s.handleDeleteGroup(req)
if s.auditLogger != nil {
var payload protocol.GroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestListGroups:
return s.handleListGroups()
case protocol.RequestRenameGroup:
resp := s.handleRenameGroup(req)
if s.auditLogger != nil {
var payload protocol.RenameGroupPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestAddPreset:
resp := s.handleAddPreset(req)
if s.auditLogger != nil {
var payload protocol.AddPresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestDeletePreset:
resp := s.handleDeletePreset(req)
if s.auditLogger != nil {
var payload protocol.PresetPayload
_ = req.ParsePayload(&payload)
s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message)
}
return resp
case protocol.RequestListPresets:
return s.handleListPresets()
default:
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type))
}
}
func (s *Server) handlePing() *protocol.Response {
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
return resp
}
func (s *Server) handleStatus() *protocol.Response {
s.mu.RLock()
reqCount := s.requestCount
startTime := s.startTime
s.mu.RUnlock()
cfg := s.config.Get()
var activeCount int
if cfg != nil {
for _, h := range cfg.GetAllHosts() {
if h.Enabled {
activeCount++
}
}
}
data := protocol.StatusData{
Running: true,
Version: Version,
Uptime: nowUnix() - startTime,
ActiveCount: activeCount,
RequestCount: reqCount,
}
resp, _ := protocol.NewOKResponse(data)
return resp
}
func nowUnix() int64 {
return time.Now().Unix()
}
func (s *Server) handleList() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
var entries []protocol.HostEntry
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
entries = append(entries, protocol.HostEntry{
Domain: h.Domain,
IP: h.IP,
Alias: h.Alias,
Enabled: h.Enabled,
Group: g.Name,
})
}
}
resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries})
return resp
}
func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
var payload protocol.SetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
host, _ := cfg.FindHostByAlias(payload.Alias)
if host == nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
}
// Check for conflicts if enabling
if payload.Enabled && !payload.Force {
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled {
return protocol.NewErrorResponse(protocol.ErrCodeConflict,
fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias))
}
}
}
}
// Update config
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
Domain: host.Domain,
Applied: true,
})
return resp
}
func (s *Server) handleSync() *protocol.Response {
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
return resp
}
func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
var payload protocol.PresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.ApplyPreset(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
return resp
}
func (s *Server) handleRollback(req *protocol.Request) *protocol.Response {
var payload protocol.RollbackPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if err := s.hosts.RestoreBackup(payload.BackupName); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err))
}
// Flush DNS after restore
s.flusher.Flush()
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
return resp
}
func (s *Server) handleBackups() *protocol.Response {
backups, err := s.hosts.ListBackups()
if err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err))
}
var infos []protocol.BackupInfo
for _, b := range backups {
infos = append(infos, protocol.BackupInfo{
Name: b.Name,
Timestamp: b.Timestamp,
Size: b.Size,
})
}
resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos})
return resp
}
func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
var payload protocol.AddPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
// Validate domain
if payload.Domain == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required")
}
// Validate IP
if payload.IP == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required")
}
// Validate group
if payload.Group == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required")
}
// Check blocked domains
if config.IsBlockedDomain(payload.Domain) {
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain))
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
// Add to config (alias will be auto-generated if empty)
if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
Domain: payload.Domain,
Applied: true,
})
return resp
}
func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
var payload protocol.DeletePayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Alias == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
// Delete from config
if !cfg.DeleteHost(payload.Alias) {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
return resp
}
func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response {
var payload protocol.GroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.AddGroup(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
return resp
}
func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
var payload protocol.GroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.DeleteGroup(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
return resp
}
func (s *Server) handleListGroups() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()})
return resp
}
func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response {
var payload protocol.RenameGroupPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.OldName == "" || payload.NewName == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName})
return resp
}
func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response {
var payload protocol.AddPresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
return resp
}
func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response {
var payload protocol.PresetPayload
if err := req.ParsePayload(&payload); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
}
if payload.Name == "" {
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
}
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
if err := cfg.DeletePreset(payload.Name); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
return resp
}
func (s *Server) handleListPresets() *protocol.Response {
cfg := s.config.Get()
if cfg == nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
}
presets := cfg.GetPresets()
infos := make([]protocol.PresetInfo, len(presets))
for i, p := range presets {
infos[i] = protocol.PresetInfo{
Name: p.Name,
Enable: p.Enable,
Disable: p.Disable,
}
}
resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos})
return resp
}
func (s *Server) syncHostsFile() error {
cfg := s.config.Get()
if cfg == nil {
return fmt.Errorf("no configuration loaded")
}
var entries []HostEntry
for _, g := range cfg.Groups {
for _, h := range g.Hosts {
entries = append(entries, HostEntry{
IP: h.IP,
Domain: h.Domain,
Alias: h.Alias,
Enabled: h.Enabled,
})
}
}
if err := s.hosts.WriteManagedEntries(entries); err != nil {
return err
}
// Flush DNS cache
return s.flusher.Flush()
}