mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-05 23:29:18 +00:00
Initial commit.
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
// Package daemon provides the main daemon loop and lifecycle management.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Daemon represents the lolcathost daemon.
|
||||
type Daemon struct {
|
||||
server *Server
|
||||
config *config.Manager
|
||||
stopCh chan struct{}
|
||||
cleanupCh chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new daemon instance.
|
||||
func New(configPath string) (*Daemon, error) {
|
||||
cfgManager := config.NewManager(configPath)
|
||||
|
||||
// Try to load config, create default if it doesn't exist
|
||||
if err := cfgManager.Load(); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err := config.CreateDefault(configPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to create default config: %w", err)
|
||||
}
|
||||
if err := cfgManager.Load(); err != nil {
|
||||
return nil, fmt.Errorf("failed to load default config: %w", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure at least one group exists
|
||||
cfg := cfgManager.Get()
|
||||
if cfg != nil {
|
||||
cfg.EnsureDefaultGroup()
|
||||
// Save if we added a default group
|
||||
if len(cfg.Groups) == 1 && cfg.Groups[0].Name == "default" && len(cfg.Groups[0].Hosts) == 0 {
|
||||
cfgManager.Save()
|
||||
}
|
||||
}
|
||||
|
||||
server := NewServer(protocol.SocketPath, cfgManager)
|
||||
|
||||
return &Daemon{
|
||||
server: server,
|
||||
config: cfgManager,
|
||||
stopCh: make(chan struct{}),
|
||||
cleanupCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the daemon and blocks until stopped.
|
||||
func (d *Daemon) Run() error {
|
||||
// Verify we're running as root
|
||||
if os.Geteuid() != 0 {
|
||||
return fmt.Errorf("daemon must run as root")
|
||||
}
|
||||
|
||||
// Start the server
|
||||
if err := d.server.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
// Watch config for changes
|
||||
if err := d.config.Watch(d.onConfigChange); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to watch config: %v\n", err)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go d.cleanupLoop()
|
||||
|
||||
// Wait for shutdown signal
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case <-sigCh:
|
||||
fmt.Println("Received shutdown signal")
|
||||
case <-d.stopCh:
|
||||
fmt.Println("Shutdown requested")
|
||||
}
|
||||
|
||||
return d.shutdown()
|
||||
}
|
||||
|
||||
// Stop signals the daemon to stop.
|
||||
func (d *Daemon) Stop() {
|
||||
close(d.stopCh)
|
||||
}
|
||||
|
||||
func (d *Daemon) shutdown() error {
|
||||
close(d.cleanupCh)
|
||||
d.config.Stop()
|
||||
|
||||
if err := d.server.Stop(); err != nil {
|
||||
return fmt.Errorf("failed to stop server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Daemon) onConfigChange(cfg *config.Config) {
|
||||
fmt.Println("Config changed, syncing hosts file...")
|
||||
// The server will use the updated config on next request
|
||||
// We could trigger a sync here if autoApply is enabled
|
||||
if cfg != nil && cfg.Settings.AutoApply {
|
||||
// Sync hosts file with new config
|
||||
// This is handled by the server internally
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Daemon) cleanupLoop() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
d.server.rateLimiter.Cleanup()
|
||||
case <-d.cleanupCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
// Package daemon provides DNS cache flushing functionality.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
// DNSFlusher handles DNS cache flushing.
|
||||
type DNSFlusher struct {
|
||||
method FlushMethod
|
||||
}
|
||||
|
||||
// FlushMethod defines the DNS flush method to use.
|
||||
type FlushMethod string
|
||||
|
||||
const (
|
||||
FlushMethodAuto FlushMethod = "auto"
|
||||
FlushMethodDscacheutil FlushMethod = "dscacheutil"
|
||||
FlushMethodKillall FlushMethod = "killall"
|
||||
FlushMethodBoth FlushMethod = "both"
|
||||
FlushMethodSystemd FlushMethod = "systemd"
|
||||
FlushMethodNscd FlushMethod = "nscd"
|
||||
)
|
||||
|
||||
// NewDNSFlusher creates a new DNS flusher.
|
||||
func NewDNSFlusher(method FlushMethod) *DNSFlusher {
|
||||
return &DNSFlusher{method: method}
|
||||
}
|
||||
|
||||
// Flush flushes the DNS cache using the configured method.
|
||||
func (f *DNSFlusher) Flush() error {
|
||||
method := f.method
|
||||
if method == FlushMethodAuto || method == "" {
|
||||
method = f.detectMethod()
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return f.flushDarwin(method)
|
||||
case "linux":
|
||||
return f.flushLinux(method)
|
||||
default:
|
||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) detectMethod() FlushMethod {
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
return FlushMethodBoth
|
||||
case "linux":
|
||||
// Check for systemd-resolve first
|
||||
if _, err := exec.LookPath("systemd-resolve"); err == nil {
|
||||
return FlushMethodSystemd
|
||||
}
|
||||
if _, err := exec.LookPath("resolvectl"); err == nil {
|
||||
return FlushMethodSystemd
|
||||
}
|
||||
// Fall back to nscd
|
||||
if _, err := exec.LookPath("nscd"); err == nil {
|
||||
return FlushMethodNscd
|
||||
}
|
||||
return FlushMethodAuto
|
||||
default:
|
||||
return FlushMethodAuto
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) flushDarwin(method FlushMethod) error {
|
||||
var errs []error
|
||||
|
||||
switch method {
|
||||
case FlushMethodDscacheutil:
|
||||
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
|
||||
return fmt.Errorf("dscacheutil failed: %w", err)
|
||||
}
|
||||
case FlushMethodKillall:
|
||||
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
|
||||
return fmt.Errorf("killall mDNSResponder failed: %w", err)
|
||||
}
|
||||
case FlushMethodBoth:
|
||||
if err := runCommand("dscacheutil", "-flushcache"); err != nil {
|
||||
errs = append(errs, fmt.Errorf("dscacheutil failed: %w", err))
|
||||
}
|
||||
if err := runCommand("killall", "-HUP", "mDNSResponder"); err != nil {
|
||||
errs = append(errs, fmt.Errorf("killall mDNSResponder failed: %w", err))
|
||||
}
|
||||
if len(errs) == 2 {
|
||||
return fmt.Errorf("all DNS flush methods failed: %v, %v", errs[0], errs[1])
|
||||
}
|
||||
default:
|
||||
// Auto - try both
|
||||
_ = runCommand("dscacheutil", "-flushcache")
|
||||
_ = runCommand("killall", "-HUP", "mDNSResponder")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *DNSFlusher) flushLinux(method FlushMethod) error {
|
||||
switch method {
|
||||
case FlushMethodSystemd:
|
||||
// Try resolvectl first (newer), then systemd-resolve (older)
|
||||
if err := runCommand("resolvectl", "flush-caches"); err != nil {
|
||||
if err := runCommand("systemd-resolve", "--flush-caches"); err != nil {
|
||||
return fmt.Errorf("systemd DNS flush failed: %w", err)
|
||||
}
|
||||
}
|
||||
case FlushMethodNscd:
|
||||
// Try to restart nscd
|
||||
if err := runCommand("nscd", "-i", "hosts"); err != nil {
|
||||
// Try service restart as fallback
|
||||
if err := runCommand("service", "nscd", "restart"); err != nil {
|
||||
return fmt.Errorf("nscd flush failed: %w", err)
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Auto - try all methods
|
||||
// Try systemd first
|
||||
if err := runCommand("resolvectl", "flush-caches"); err == nil {
|
||||
return nil
|
||||
}
|
||||
if err := runCommand("systemd-resolve", "--flush-caches"); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Try nscd
|
||||
if err := runCommand("nscd", "-i", "hosts"); err == nil {
|
||||
return nil
|
||||
}
|
||||
// On many Linux systems, no explicit flush is needed as /etc/hosts is read directly
|
||||
// So we return nil here
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runCommand(name string, args ...string) error {
|
||||
cmd := exec.Command(name, args...)
|
||||
return cmd.Run()
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewDNSFlusher(t *testing.T) {
|
||||
tests := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
FlushMethodSystemd,
|
||||
FlushMethodNscd,
|
||||
}
|
||||
|
||||
for _, method := range tests {
|
||||
t.Run(string(method), func(t *testing.T) {
|
||||
flusher := NewDNSFlusher(method)
|
||||
assert.NotNil(t, flusher)
|
||||
assert.Equal(t, method, flusher.method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSFlusher_DetectMethod(t *testing.T) {
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
|
||||
method := flusher.detectMethod()
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
assert.Equal(t, FlushMethodBoth, method)
|
||||
case "linux":
|
||||
// Could be systemd, nscd, or auto depending on system
|
||||
assert.Contains(t, []FlushMethod{FlushMethodSystemd, FlushMethodNscd, FlushMethodAuto}, method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlushMethod_String(t *testing.T) {
|
||||
methods := map[FlushMethod]string{
|
||||
FlushMethodAuto: "auto",
|
||||
FlushMethodDscacheutil: "dscacheutil",
|
||||
FlushMethodKillall: "killall",
|
||||
FlushMethodBoth: "both",
|
||||
FlushMethodSystemd: "systemd",
|
||||
FlushMethodNscd: "nscd",
|
||||
}
|
||||
|
||||
for method, expected := range methods {
|
||||
t.Run(expected, func(t *testing.T) {
|
||||
assert.Equal(t, expected, string(method))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Actually testing DNS flush requires root and modifies system state,
|
||||
// so we skip those tests in unit tests. They would be integration tests.
|
||||
|
||||
func TestDNSFlusher_Flush_UnsupportedOS(t *testing.T) {
|
||||
// This test only makes sense if we're not on darwin or linux
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" {
|
||||
t.Skip("Test only applicable on unsupported OS")
|
||||
}
|
||||
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
err := flusher.Flush()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported operating system")
|
||||
}
|
||||
|
||||
// Matrix test for flush methods
|
||||
func TestFlushMethod_Matrix(t *testing.T) {
|
||||
methods := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
FlushMethodSystemd,
|
||||
FlushMethodNscd,
|
||||
}
|
||||
|
||||
platforms := []string{"darwin", "linux"}
|
||||
|
||||
for _, method := range methods {
|
||||
for _, platform := range platforms {
|
||||
t.Run(string(method)+"_"+platform, func(t *testing.T) {
|
||||
flusher := NewDNSFlusher(method)
|
||||
assert.NotNil(t, flusher)
|
||||
|
||||
// Just verify no panic when checking method
|
||||
_ = flusher.method
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDNSFlusher_DetectMethod(b *testing.B) {
|
||||
flusher := NewDNSFlusher(FlushMethodAuto)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = flusher.detectMethod()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
// Package daemon implements the privileged daemon that manages /etc/hosts.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// HostsPath is the path to the system hosts file.
|
||||
HostsPath = "/etc/hosts"
|
||||
// BackupDir is the directory for hosts file backups.
|
||||
BackupDir = "/var/backups/lolcathost"
|
||||
// MaxBackups is the maximum number of backups to keep.
|
||||
MaxBackups = 10
|
||||
|
||||
// Markers for the managed section.
|
||||
markerStart = "# ========== LOLCATHOST MANAGED - DO NOT EDIT =========="
|
||||
markerEnd = "# ========== END LOLCATHOST =========="
|
||||
)
|
||||
|
||||
// HostEntry represents a single entry in the hosts file.
|
||||
type HostEntry struct {
|
||||
IP string
|
||||
Domain string
|
||||
Alias string
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// HostsManager handles reading and writing the hosts file.
|
||||
type HostsManager struct {
|
||||
hostsPath string
|
||||
backupDir string
|
||||
}
|
||||
|
||||
// NewHostsManager creates a new hosts manager.
|
||||
func NewHostsManager() *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: HostsPath,
|
||||
backupDir: BackupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// NewHostsManagerWithPaths creates a hosts manager with custom paths (for testing).
|
||||
func NewHostsManagerWithPaths(hostsPath, backupDir string) *HostsManager {
|
||||
return &HostsManager{
|
||||
hostsPath: hostsPath,
|
||||
backupDir: backupDir,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadManagedEntries reads the lolcathost-managed entries from the hosts file.
|
||||
func (m *HostsManager) ReadManagedEntries() ([]HostEntry, error) {
|
||||
file, err := os.Open(m.hostsPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open hosts file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
var entries []HostEntry
|
||||
inManagedSection := false
|
||||
scanner := bufio.NewScanner(file)
|
||||
entryRegex := regexp.MustCompile(`^(\S+)\s+(\S+)\s+#\s*lolcathost:(\S+)$`)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if line == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
|
||||
if inManagedSection && !strings.HasPrefix(line, "#") && line != "" {
|
||||
matches := entryRegex.FindStringSubmatch(line)
|
||||
if len(matches) == 4 {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: matches[1],
|
||||
Domain: matches[2],
|
||||
Alias: matches[3],
|
||||
Enabled: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// WriteManagedEntries writes the managed entries to the hosts file.
|
||||
func (m *HostsManager) WriteManagedEntries(entries []HostEntry) error {
|
||||
// Create backup first
|
||||
if err := m.CreateBackup(); err != nil {
|
||||
return fmt.Errorf("failed to create backup: %w", err)
|
||||
}
|
||||
|
||||
// Read existing content
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
// Remove existing managed section
|
||||
newContent := m.removeManagedSection(string(content))
|
||||
|
||||
// Build new managed section
|
||||
managedSection := m.buildManagedSection(entries)
|
||||
|
||||
// Append managed section
|
||||
newContent = strings.TrimRight(newContent, "\n") + "\n\n" + managedSection
|
||||
|
||||
// Write atomically
|
||||
if err := m.writeAtomic(newContent); err != nil {
|
||||
return fmt.Errorf("failed to write hosts file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *HostsManager) removeManagedSection(content string) string {
|
||||
lines := strings.Split(content, "\n")
|
||||
var result []string
|
||||
inManagedSection := false
|
||||
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if trimmed == markerStart {
|
||||
inManagedSection = true
|
||||
continue
|
||||
}
|
||||
if trimmed == markerEnd {
|
||||
inManagedSection = false
|
||||
continue
|
||||
}
|
||||
if !inManagedSection {
|
||||
result = append(result, line)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing empty lines
|
||||
for len(result) > 0 && strings.TrimSpace(result[len(result)-1]) == "" {
|
||||
result = result[:len(result)-1]
|
||||
}
|
||||
|
||||
return strings.Join(result, "\n")
|
||||
}
|
||||
|
||||
func (m *HostsManager) buildManagedSection(entries []HostEntry) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(markerStart)
|
||||
sb.WriteString("\n")
|
||||
|
||||
for _, entry := range entries {
|
||||
if entry.Enabled {
|
||||
sb.WriteString(fmt.Sprintf("%s\t%s\t# lolcathost:%s\n", entry.IP, entry.Domain, entry.Alias))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString(markerEnd)
|
||||
sb.WriteString("\n")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (m *HostsManager) writeAtomic(content string) error {
|
||||
// Write to temp file first
|
||||
tmpFile := m.hostsPath + ".tmp"
|
||||
if err := os.WriteFile(tmpFile, []byte(content), 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Rename atomically
|
||||
if err := os.Rename(tmpFile, m.hostsPath); err != nil {
|
||||
os.Remove(tmpFile)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBackup creates a backup of the current hosts file.
|
||||
func (m *HostsManager) CreateBackup() error {
|
||||
if err := os.MkdirAll(m.backupDir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create backup directory: %w", err)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(m.hostsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read hosts file: %w", err)
|
||||
}
|
||||
|
||||
timestamp := time.Now().Format("20060102-150405")
|
||||
backupPath := filepath.Join(m.backupDir, fmt.Sprintf("hosts.%s.bak", timestamp))
|
||||
|
||||
if err := os.WriteFile(backupPath, content, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write backup: %w", err)
|
||||
}
|
||||
|
||||
// Cleanup old backups
|
||||
if err := m.cleanupBackups(); err != nil {
|
||||
// Log but don't fail
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to cleanup backups: %v\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *HostsManager) cleanupBackups() error {
|
||||
entries, err := os.ReadDir(m.backupDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var backups []os.DirEntry
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasPrefix(entry.Name(), "hosts.") && strings.HasSuffix(entry.Name(), ".bak") {
|
||||
backups = append(backups, entry)
|
||||
}
|
||||
}
|
||||
|
||||
if len(backups) <= MaxBackups {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sort by name (timestamp) descending
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
return backups[i].Name() > backups[j].Name()
|
||||
})
|
||||
|
||||
// Remove oldest backups
|
||||
for i := MaxBackups; i < len(backups); i++ {
|
||||
path := filepath.Join(m.backupDir, backups[i].Name())
|
||||
os.Remove(path)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListBackups returns a list of available backups.
|
||||
func (m *HostsManager) ListBackups() ([]BackupInfo, error) {
|
||||
entries, err := os.ReadDir(m.backupDir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var backups []BackupInfo
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() || !strings.HasPrefix(entry.Name(), "hosts.") || !strings.HasSuffix(entry.Name(), ".bak") {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
backups = append(backups, BackupInfo{
|
||||
Name: entry.Name(),
|
||||
Timestamp: info.ModTime().Unix(),
|
||||
Size: info.Size(),
|
||||
})
|
||||
}
|
||||
|
||||
// Sort by timestamp descending
|
||||
sort.Slice(backups, func(i, j int) bool {
|
||||
return backups[i].Timestamp > backups[j].Timestamp
|
||||
})
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
// BackupInfo holds information about a backup file.
|
||||
type BackupInfo struct {
|
||||
Name string
|
||||
Timestamp int64
|
||||
Size int64
|
||||
}
|
||||
|
||||
// RestoreBackup restores a backup by name.
|
||||
func (m *HostsManager) RestoreBackup(name string) error {
|
||||
backupPath := filepath.Join(m.backupDir, name)
|
||||
|
||||
// Validate backup name to prevent path traversal
|
||||
if filepath.Base(name) != name || !strings.HasPrefix(name, "hosts.") || !strings.HasSuffix(name, ".bak") {
|
||||
return fmt.Errorf("invalid backup name")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(backupPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read backup: %w", err)
|
||||
}
|
||||
|
||||
// Create a backup of current state before restoring
|
||||
if err := m.CreateBackup(); err != nil {
|
||||
return fmt.Errorf("failed to create backup before restore: %w", err)
|
||||
}
|
||||
|
||||
if err := m.writeAtomic(string(content)); err != nil {
|
||||
return fmt.Errorf("failed to restore backup: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHostsManager_ReadManagedEntries(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
hostsContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
::1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 example.com # lolcathost:example-local
|
||||
192.168.1.1 api.example.com # lolcathost:api-local
|
||||
# ========== END LOLCATHOST ==========
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, entries, 2)
|
||||
assert.Equal(t, "127.0.0.1", entries[0].IP)
|
||||
assert.Equal(t, "example.com", entries[0].Domain)
|
||||
assert.Equal(t, "example-local", entries[0].Alias)
|
||||
assert.Equal(t, "192.168.1.1", entries[1].IP)
|
||||
assert.Equal(t, "api.example.com", entries[1].Domain)
|
||||
assert.Equal(t, "api-local", entries[1].Alias)
|
||||
}
|
||||
|
||||
func TestHostsManager_ReadManagedEntries_NoSection(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
hostsContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, entries)
|
||||
}
|
||||
|
||||
func TestHostsManager_WriteManagedEntries(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create initial hosts file
|
||||
initialContent := `127.0.0.1 localhost
|
||||
255.255.255.255 broadcasthost
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "myapp.com", Alias: "myapp-local", Enabled: true},
|
||||
{IP: "127.0.0.1", Domain: "api.myapp.com", Alias: "api-local", Enabled: true},
|
||||
{IP: "192.168.1.1", Domain: "staging.myapp.com", Alias: "staging", Enabled: false},
|
||||
}
|
||||
|
||||
err = manager.WriteManagedEntries(entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read back
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
|
||||
assert.Contains(t, contentStr, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tmyapp.com\t# lolcathost:myapp-local")
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tapi.myapp.com\t# lolcathost:api-local")
|
||||
assert.NotContains(t, contentStr, "staging.myapp.com") // disabled
|
||||
assert.Contains(t, contentStr, "# ========== END LOLCATHOST ==========")
|
||||
}
|
||||
|
||||
func TestHostsManager_WriteManagedEntries_UpdatesExisting(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create hosts file with existing managed section
|
||||
initialContent := `127.0.0.1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 old.com # lolcathost:old
|
||||
# ========== END LOLCATHOST ==========
|
||||
`
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "new.com", Alias: "new", Enabled: true},
|
||||
}
|
||||
|
||||
err = manager.WriteManagedEntries(entries)
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "127.0.0.1\tlocalhost")
|
||||
assert.Contains(t, contentStr, "new.com")
|
||||
assert.NotContains(t, contentStr, "old.com")
|
||||
}
|
||||
|
||||
func TestHostsManager_CreateBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
hostsContent := "127.0.0.1\tlocalhost\n"
|
||||
err := os.WriteFile(hostsPath, []byte(hostsContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify backup exists
|
||||
entries, err := os.ReadDir(backupDir)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, entries, 1)
|
||||
assert.True(t, strings.HasPrefix(entries[0].Name(), "hosts."))
|
||||
assert.True(t, strings.HasSuffix(entries[0].Name(), ".bak"))
|
||||
|
||||
// Verify backup content
|
||||
backupContent, err := os.ReadFile(filepath.Join(backupDir, entries[0].Name()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, hostsContent, string(backupContent))
|
||||
}
|
||||
|
||||
func TestHostsManager_ListBackups(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create hosts file
|
||||
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually create backup files with different timestamps
|
||||
err = os.MkdirAll(backupDir, 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
backupNames := []string{
|
||||
"hosts.20231201-120000.bak",
|
||||
"hosts.20231201-120001.bak",
|
||||
"hosts.20231201-120002.bak",
|
||||
}
|
||||
for _, name := range backupNames {
|
||||
err = os.WriteFile(filepath.Join(backupDir, name), []byte("backup"), 0644)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, backups, 3)
|
||||
}
|
||||
|
||||
func TestHostsManager_ListBackups_NoBackupDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "nonexistent")
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, backups)
|
||||
}
|
||||
|
||||
func TestHostsManager_RestoreBackup(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
// Create initial hosts file
|
||||
initialContent := "initial content"
|
||||
err := os.WriteFile(hostsPath, []byte(initialContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create backup
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Modify hosts file
|
||||
err = os.WriteFile(hostsPath, []byte("modified content"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get backup name
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, backups, 1)
|
||||
|
||||
// Restore
|
||||
err = manager.RestoreBackup(backups[0].Name)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify content restored
|
||||
content, err := os.ReadFile(hostsPath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, initialContent, string(content))
|
||||
}
|
||||
|
||||
func TestHostsManager_RestoreBackup_InvalidName(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
manager := NewHostsManagerWithPaths(
|
||||
filepath.Join(tmpDir, "hosts"),
|
||||
filepath.Join(tmpDir, "backups"),
|
||||
)
|
||||
|
||||
tests := []string{
|
||||
"../../../etc/passwd",
|
||||
"hosts.bak", // Missing timestamp
|
||||
"notahosts.backup", // Wrong format
|
||||
"",
|
||||
}
|
||||
|
||||
for _, name := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
err := manager.RestoreBackup(name)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsManager_CleanupBackups(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte("localhost"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
// Create more than MaxBackups
|
||||
for i := 0; i < MaxBackups+5; i++ {
|
||||
err = manager.CreateBackup()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Verify only MaxBackups remain
|
||||
backups, err := manager.ListBackups()
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(backups), MaxBackups)
|
||||
}
|
||||
|
||||
func TestHostsManager_RemoveManagedSection(t *testing.T) {
|
||||
manager := &HostsManager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "with managed section",
|
||||
input: `127.0.0.1 localhost
|
||||
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
127.0.0.1 example.com # lolcathost:test
|
||||
# ========== END LOLCATHOST ==========
|
||||
`,
|
||||
expected: "127.0.0.1\tlocalhost",
|
||||
},
|
||||
{
|
||||
name: "without managed section",
|
||||
input: "127.0.0.1\tlocalhost\n",
|
||||
expected: "127.0.0.1\tlocalhost",
|
||||
},
|
||||
{
|
||||
name: "multiple managed sections",
|
||||
input: `127.0.0.1 localhost
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
entry1
|
||||
# ========== END LOLCATHOST ==========
|
||||
more content
|
||||
# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========
|
||||
entry2
|
||||
# ========== END LOLCATHOST ==========
|
||||
`,
|
||||
expected: "127.0.0.1\tlocalhost\nmore content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := manager.removeManagedSection(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostsManager_BuildManagedSection(t *testing.T) {
|
||||
manager := &HostsManager{}
|
||||
|
||||
entries := []HostEntry{
|
||||
{IP: "127.0.0.1", Domain: "a.com", Alias: "a", Enabled: true},
|
||||
{IP: "192.168.1.1", Domain: "b.com", Alias: "b", Enabled: true},
|
||||
{IP: "10.0.0.1", Domain: "c.com", Alias: "c", Enabled: false},
|
||||
}
|
||||
|
||||
result := manager.buildManagedSection(entries)
|
||||
|
||||
assert.Contains(t, result, "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========")
|
||||
assert.Contains(t, result, "127.0.0.1\ta.com\t# lolcathost:a")
|
||||
assert.Contains(t, result, "192.168.1.1\tb.com\t# lolcathost:b")
|
||||
assert.NotContains(t, result, "c.com") // disabled
|
||||
assert.Contains(t, result, "# ========== END LOLCATHOST ==========")
|
||||
}
|
||||
|
||||
// Matrix tests for hosts file parsing
|
||||
func TestHostsManager_ReadManagedEntries_Matrix(t *testing.T) {
|
||||
ips := []string{"127.0.0.1", "192.168.1.1", "::1"}
|
||||
domains := []string{"example.com", "sub.example.com", "my-app.test"}
|
||||
aliases := []string{"test", "my-alias", "app-1"}
|
||||
|
||||
for _, ip := range ips {
|
||||
for _, domain := range domains {
|
||||
for _, alias := range aliases {
|
||||
t.Run(ip+"/"+domain+"/"+alias, func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
content := "# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n"
|
||||
content += ip + "\t" + domain + "\t# lolcathost:" + alias + "\n"
|
||||
content += "# ========== END LOLCATHOST ==========\n"
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte(content), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
entries, err := manager.ReadManagedEntries()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 1)
|
||||
|
||||
assert.Equal(t, ip, entries[0].IP)
|
||||
assert.Equal(t, domain, entries[0].Domain)
|
||||
assert.Equal(t, alias, entries[0].Alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHostsManager_ReadManagedEntries(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
|
||||
// Create a hosts file with many entries
|
||||
var content strings.Builder
|
||||
content.WriteString("127.0.0.1\tlocalhost\n")
|
||||
content.WriteString("# ========== LOLCATHOST MANAGED - DO NOT EDIT ==========\n")
|
||||
for i := 0; i < 100; i++ {
|
||||
content.WriteString("127.0.0.1\texample" + string(rune('a'+i%26)) + ".com\t# lolcathost:alias" + string(rune('a'+i%26)) + "\n")
|
||||
}
|
||||
content.WriteString("# ========== END LOLCATHOST ==========\n")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte(content.String()), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, filepath.Join(tmpDir, "backups"))
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = manager.ReadManagedEntries()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHostsManager_WriteManagedEntries(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
hostsPath := filepath.Join(tmpDir, "hosts")
|
||||
backupDir := filepath.Join(tmpDir, "backups")
|
||||
|
||||
err := os.WriteFile(hostsPath, []byte("127.0.0.1\tlocalhost\n"), 0644)
|
||||
require.NoError(b, err)
|
||||
|
||||
manager := NewHostsManagerWithPaths(hostsPath, backupDir)
|
||||
|
||||
entries := make([]HostEntry, 50)
|
||||
for i := range entries {
|
||||
entries[i] = HostEntry{
|
||||
IP: "127.0.0.1",
|
||||
Domain: "example" + string(rune('a'+i%26)) + ".com",
|
||||
Alias: "alias" + string(rune('a'+i%26)),
|
||||
Enabled: i%2 == 0,
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = manager.WriteManagedEntries(entries)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
//go:build darwin
|
||||
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"net"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getPeerCredentials extracts peer credentials from a Unix socket connection on macOS.
|
||||
// Note: macOS Xucred doesn't include PID, so we use LOCAL_PEERPID separately.
|
||||
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
|
||||
unixConn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawConn, err := unixConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var creds *PeerCredentials
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
xucred, err := unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get PID separately using LOCAL_PEERPID
|
||||
var pid int32
|
||||
pidLen := uint32(unsafe.Sizeof(pid))
|
||||
_, _, errno := syscall.Syscall6(
|
||||
syscall.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
unix.SOL_LOCAL,
|
||||
0x002, // LOCAL_PEERPID
|
||||
uintptr(unsafe.Pointer(&pid)),
|
||||
uintptr(unsafe.Pointer(&pidLen)),
|
||||
0,
|
||||
)
|
||||
if errno != 0 {
|
||||
pid = 0
|
||||
}
|
||||
|
||||
creds = &PeerCredentials{
|
||||
UID: xucred.Uid,
|
||||
GID: xucred.Groups[0],
|
||||
PID: pid,
|
||||
}
|
||||
})
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
//go:build linux
|
||||
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// getPeerCredentials extracts peer credentials from a Unix socket connection on Linux.
|
||||
func (s *Server) getPeerCredentials(conn net.Conn) *PeerCredentials {
|
||||
unixConn, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawConn, err := unixConn.SyscallConn()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var creds *PeerCredentials
|
||||
rawConn.Control(func(fd uintptr) {
|
||||
ucred, err := unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
creds = &PeerCredentials{
|
||||
UID: ucred.Uid,
|
||||
GID: ucred.Gid,
|
||||
PID: ucred.Pid,
|
||||
}
|
||||
})
|
||||
|
||||
return creds
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
// Package daemon provides security functions including rate limiting and audit logging.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// AuditLogPath is the path to the audit log file.
|
||||
AuditLogPath = "/var/log/lolcathost/audit.log"
|
||||
// RateLimit is the maximum requests per minute per PID.
|
||||
RateLimit = 100
|
||||
// RateLimitWindow is the time window for rate limiting.
|
||||
RateLimitWindow = time.Minute
|
||||
)
|
||||
|
||||
// RateLimiter implements per-PID rate limiting.
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
requests map[int32][]time.Time
|
||||
limit int
|
||||
window time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter.
|
||||
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
requests: make(map[int32][]time.Time),
|
||||
limit: limit,
|
||||
window: window,
|
||||
}
|
||||
}
|
||||
|
||||
// Allow checks if a request from the given PID should be allowed.
|
||||
func (r *RateLimiter) Allow(pid int32) bool {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
// Get existing requests for this PID
|
||||
reqs := r.requests[pid]
|
||||
|
||||
// Filter out old requests
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if under limit
|
||||
if len(validReqs) >= r.limit {
|
||||
r.requests[pid] = validReqs
|
||||
return false
|
||||
}
|
||||
|
||||
// Add new request
|
||||
validReqs = append(validReqs, now)
|
||||
r.requests[pid] = validReqs
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Cleanup removes old entries from the rate limiter.
|
||||
func (r *RateLimiter) Cleanup() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-r.window)
|
||||
|
||||
for pid, reqs := range r.requests {
|
||||
var validReqs []time.Time
|
||||
for _, t := range reqs {
|
||||
if t.After(cutoff) {
|
||||
validReqs = append(validReqs, t)
|
||||
}
|
||||
}
|
||||
if len(validReqs) == 0 {
|
||||
delete(r.requests, pid)
|
||||
} else {
|
||||
r.requests[pid] = validReqs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AuditLogger handles audit logging.
|
||||
type AuditLogger struct {
|
||||
mu sync.Mutex
|
||||
file *os.File
|
||||
path string
|
||||
encoder *json.Encoder
|
||||
}
|
||||
|
||||
// AuditEntry represents a single audit log entry.
|
||||
type AuditEntry struct {
|
||||
Timestamp string `json:"timestamp"`
|
||||
UID uint32 `json:"uid"`
|
||||
PID int32 `json:"pid"`
|
||||
Action string `json:"action"`
|
||||
Details any `json:"details,omitempty"`
|
||||
Success bool `json:"success"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// NewAuditLogger creates a new audit logger.
|
||||
func NewAuditLogger(path string) (*AuditLogger, error) {
|
||||
// Ensure directory exists
|
||||
dir := path[:len(path)-len("/audit.log")]
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log directory: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open audit log: %w", err)
|
||||
}
|
||||
|
||||
return &AuditLogger{
|
||||
file: file,
|
||||
path: path,
|
||||
encoder: json.NewEncoder(file),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Log writes an audit entry.
|
||||
func (a *AuditLogger) Log(uid uint32, pid int32, action string, details any, success bool, errMsg string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
entry := AuditEntry{
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
UID: uid,
|
||||
PID: pid,
|
||||
Action: action,
|
||||
Details: details,
|
||||
Success: success,
|
||||
Error: errMsg,
|
||||
}
|
||||
|
||||
// Ignore encoding errors - audit logging should not fail the operation
|
||||
_ = a.encoder.Encode(entry)
|
||||
}
|
||||
|
||||
// Close closes the audit logger.
|
||||
func (a *AuditLogger) Close() error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if a.file != nil {
|
||||
err := a.file.Close()
|
||||
a.file = nil // Prevent double close
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PeerCredentials holds the credentials of a connected peer.
|
||||
type PeerCredentials struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
PID int32
|
||||
}
|
||||
|
||||
// isUserInGroup checks if a user (by UID) is a member of a group (by GID).
|
||||
// This checks supplementary groups, not just the primary GID.
|
||||
func isUserInGroup(uid uint32, targetGID uint32) bool {
|
||||
// Look up user by UID
|
||||
u, err := user.LookupId(fmt.Sprintf("%d", uid))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get user's group IDs
|
||||
groupIDs, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if target GID is in the list
|
||||
targetGIDStr := fmt.Sprintf("%d", targetGID)
|
||||
for _, gid := range groupIDs {
|
||||
if gid == targetGIDStr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRateLimiter_Allow(t *testing.T) {
|
||||
t.Run("under limit", func(t *testing.T) {
|
||||
rl := NewRateLimiter(5, time.Minute)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, rl.Allow(123), "request %d should be allowed", i)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("over limit", func(t *testing.T) {
|
||||
rl := NewRateLimiter(3, time.Minute)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
assert.True(t, rl.Allow(123))
|
||||
}
|
||||
|
||||
// 4th request should be blocked
|
||||
assert.False(t, rl.Allow(123))
|
||||
})
|
||||
|
||||
t.Run("different PIDs", func(t *testing.T) {
|
||||
rl := NewRateLimiter(2, time.Minute)
|
||||
|
||||
// PID 1
|
||||
assert.True(t, rl.Allow(1))
|
||||
assert.True(t, rl.Allow(1))
|
||||
assert.False(t, rl.Allow(1))
|
||||
|
||||
// PID 2 should have its own limit
|
||||
assert.True(t, rl.Allow(2))
|
||||
assert.True(t, rl.Allow(2))
|
||||
assert.False(t, rl.Allow(2))
|
||||
})
|
||||
|
||||
t.Run("window expiration", func(t *testing.T) {
|
||||
rl := NewRateLimiter(2, 10*time.Millisecond)
|
||||
|
||||
assert.True(t, rl.Allow(123))
|
||||
assert.True(t, rl.Allow(123))
|
||||
assert.False(t, rl.Allow(123))
|
||||
|
||||
// Wait for window to expire
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Should be allowed again
|
||||
assert.True(t, rl.Allow(123))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRateLimiter_Cleanup(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10*time.Millisecond)
|
||||
|
||||
// Add requests from multiple PIDs
|
||||
for pid := int32(1); pid <= 5; pid++ {
|
||||
rl.Allow(pid)
|
||||
}
|
||||
|
||||
assert.Len(t, rl.requests, 5)
|
||||
|
||||
// Wait for expiration
|
||||
time.Sleep(15 * time.Millisecond)
|
||||
|
||||
// Cleanup
|
||||
rl.Cleanup()
|
||||
|
||||
assert.Empty(t, rl.requests)
|
||||
}
|
||||
|
||||
func TestAuditLogger_Log(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
|
||||
logger.Log(1000, 12345, "sync", nil, false, "sync failed")
|
||||
|
||||
// Read log file
|
||||
content, err := os.ReadFile(logPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, `"action":"set"`)
|
||||
assert.Contains(t, contentStr, `"uid":1000`)
|
||||
assert.Contains(t, contentStr, `"pid":12345`)
|
||||
assert.Contains(t, contentStr, `"success":true`)
|
||||
assert.Contains(t, contentStr, `"action":"sync"`)
|
||||
assert.Contains(t, contentStr, `"success":false`)
|
||||
assert.Contains(t, contentStr, `"error":"sync failed"`)
|
||||
}
|
||||
|
||||
func TestAuditLogger_CreatesDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "subdir", "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
// Verify directory was created
|
||||
_, err = os.Stat(filepath.Dir(logPath))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAuditLogger_Close(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = logger.Close()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Closing again should not error
|
||||
err = logger.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPeerCredentials(t *testing.T) {
|
||||
creds := &PeerCredentials{
|
||||
UID: 501,
|
||||
GID: 20,
|
||||
PID: 12345,
|
||||
}
|
||||
|
||||
assert.Equal(t, uint32(501), creds.UID)
|
||||
assert.Equal(t, uint32(20), creds.GID)
|
||||
assert.Equal(t, int32(12345), creds.PID)
|
||||
}
|
||||
|
||||
// Matrix test for rate limiting
|
||||
func TestRateLimiter_Matrix(t *testing.T) {
|
||||
limits := []int{1, 5, 10, 100}
|
||||
windows := []time.Duration{10 * time.Millisecond, 100 * time.Millisecond, time.Second}
|
||||
|
||||
for _, limit := range limits {
|
||||
for _, window := range windows {
|
||||
t.Run(
|
||||
"limit="+string(rune('0'+limit))+"_window="+window.String(),
|
||||
func(t *testing.T) {
|
||||
rl := NewRateLimiter(limit, window)
|
||||
|
||||
// Should allow exactly 'limit' requests
|
||||
for i := 0; i < limit; i++ {
|
||||
assert.True(t, rl.Allow(1))
|
||||
}
|
||||
|
||||
// Next should be blocked
|
||||
assert.False(t, rl.Allow(1))
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRateLimiter_Allow(b *testing.B) {
|
||||
rl := NewRateLimiter(RateLimit, RateLimitWindow)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Allow(int32(i % 100))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRateLimiter_Cleanup(b *testing.B) {
|
||||
rl := NewRateLimiter(RateLimit, RateLimitWindow)
|
||||
|
||||
// Pre-populate with requests
|
||||
for i := 0; i < 1000; i++ {
|
||||
rl.Allow(int32(i))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
rl.Cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAuditLogger_Log(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
logPath := filepath.Join(tmpDir, "audit.log")
|
||||
|
||||
logger, err := NewAuditLogger(logPath)
|
||||
require.NoError(b, err)
|
||||
defer logger.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
logger.Log(1000, 12345, "set", map[string]string{"alias": "test"}, true, "")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,803 @@
|
||||
// Package daemon provides the Unix socket server for the daemon.
|
||||
package daemon
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
||||
)
|
||||
|
||||
// Version is set by the main package at startup
|
||||
var Version = "dev"
|
||||
|
||||
// Server is the daemon's Unix socket server.
|
||||
type Server struct {
|
||||
socketPath string
|
||||
listener net.Listener
|
||||
config *config.Manager
|
||||
hosts *HostsManager
|
||||
flusher *DNSFlusher
|
||||
rateLimiter *RateLimiter
|
||||
auditLogger *AuditLogger
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
requestCount int64
|
||||
startTime int64
|
||||
}
|
||||
|
||||
// NewServer creates a new daemon server.
|
||||
func NewServer(socketPath string, cfgManager *config.Manager) *Server {
|
||||
return &Server{
|
||||
socketPath: socketPath,
|
||||
config: cfgManager,
|
||||
hosts: NewHostsManager(),
|
||||
flusher: NewDNSFlusher(FlushMethodAuto),
|
||||
rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the server.
|
||||
func (s *Server) Start() error {
|
||||
// Remove existing socket
|
||||
os.Remove(s.socketPath)
|
||||
|
||||
listener, err := net.Listen("unix", s.socketPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions: 0660 root:lolcathost
|
||||
if err := os.Chmod(s.socketPath, 0660); err != nil {
|
||||
listener.Close()
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
}
|
||||
|
||||
// Set socket group to lolcathost (GID 850)
|
||||
if err := os.Chown(s.socketPath, 0, 850); err != nil {
|
||||
listener.Close()
|
||||
return fmt.Errorf("failed to set socket ownership: %w", err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.running = true
|
||||
s.startTime = currentTimeUnix()
|
||||
|
||||
// Try to create audit logger, but don't fail if it doesn't work
|
||||
if logger, err := NewAuditLogger(AuditLogPath); err == nil {
|
||||
s.auditLogger = logger
|
||||
}
|
||||
|
||||
go s.acceptLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func currentTimeUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
// Stop stops the server.
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
s.running = false
|
||||
s.mu.Unlock()
|
||||
|
||||
close(s.stopCh)
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
os.Remove(s.socketPath)
|
||||
|
||||
if s.auditLogger != nil {
|
||||
s.auditLogger.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) acceptLoop() {
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
go s.handleConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// LolcathostGID is the group ID for the lolcathost group.
|
||||
const LolcathostGID = 850
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
// Get peer credentials
|
||||
creds := s.getPeerCredentials(conn)
|
||||
|
||||
// Authorization check: verify peer is authorized
|
||||
if !s.isAuthorized(creds) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
if s.auditLogger != nil {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
if creds != nil {
|
||||
uid = creds.UID
|
||||
pid = creds.PID
|
||||
}
|
||||
s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
|
||||
continue
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.requestCount++
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handleRequest(&req, creds)
|
||||
s.writeResponse(conn, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// isAuthorized checks if the peer is authorized to access the daemon.
|
||||
// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850).
|
||||
func (s *Server) isAuthorized(creds *PeerCredentials) bool {
|
||||
if creds == nil {
|
||||
// Can't verify credentials - deny by default
|
||||
return false
|
||||
}
|
||||
|
||||
// Root is always authorized
|
||||
if creds.UID == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if user's primary GID is lolcathost
|
||||
if creds.GID == LolcathostGID {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check supplementary groups (user might be in lolcathost as secondary group)
|
||||
// This requires looking up the user's groups from the system
|
||||
return isUserInGroup(creds.UID, LolcathostGID)
|
||||
}
|
||||
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
|
||||
data, _ := json.Marshal(resp)
|
||||
data = append(data, '\n')
|
||||
conn.Write(data)
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
if creds != nil {
|
||||
uid = creds.UID
|
||||
pid = creds.PID
|
||||
}
|
||||
|
||||
switch req.Type {
|
||||
case protocol.RequestPing:
|
||||
return s.handlePing()
|
||||
|
||||
case protocol.RequestStatus:
|
||||
return s.handleStatus()
|
||||
|
||||
case protocol.RequestList:
|
||||
return s.handleList()
|
||||
|
||||
case protocol.RequestSet:
|
||||
resp := s.handleSet(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.SetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestSync:
|
||||
resp := s.handleSync()
|
||||
if s.auditLogger != nil {
|
||||
s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestPreset:
|
||||
resp := s.handlePreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.PresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestRollback:
|
||||
resp := s.handleRollback(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.RollbackPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestBackups:
|
||||
return s.handleBackups()
|
||||
|
||||
case protocol.RequestAdd:
|
||||
resp := s.handleAdd(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.AddPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDelete:
|
||||
resp := s.handleDelete(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.DeletePayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestAddGroup:
|
||||
resp := s.handleAddGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.GroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDeleteGroup:
|
||||
resp := s.handleDeleteGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.GroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestListGroups:
|
||||
return s.handleListGroups()
|
||||
|
||||
case protocol.RequestRenameGroup:
|
||||
resp := s.handleRenameGroup(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.RenameGroupPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestAddPreset:
|
||||
resp := s.handleAddPreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.AddPresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestDeletePreset:
|
||||
resp := s.handleDeletePreset(req)
|
||||
if s.auditLogger != nil {
|
||||
var payload protocol.PresetPayload
|
||||
_ = req.ParsePayload(&payload)
|
||||
s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message)
|
||||
}
|
||||
return resp
|
||||
|
||||
case protocol.RequestListPresets:
|
||||
return s.handleListPresets()
|
||||
|
||||
default:
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePing() *protocol.Response {
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleStatus() *protocol.Response {
|
||||
s.mu.RLock()
|
||||
reqCount := s.requestCount
|
||||
startTime := s.startTime
|
||||
s.mu.RUnlock()
|
||||
|
||||
cfg := s.config.Get()
|
||||
var activeCount int
|
||||
if cfg != nil {
|
||||
for _, h := range cfg.GetAllHosts() {
|
||||
if h.Enabled {
|
||||
activeCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data := protocol.StatusData{
|
||||
Running: true,
|
||||
Version: Version,
|
||||
Uptime: nowUnix() - startTime,
|
||||
ActiveCount: activeCount,
|
||||
RequestCount: reqCount,
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(data)
|
||||
return resp
|
||||
}
|
||||
|
||||
func nowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
func (s *Server) handleList() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
var entries []protocol.HostEntry
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
entries = append(entries, protocol.HostEntry{
|
||||
Domain: h.Domain,
|
||||
IP: h.IP,
|
||||
Alias: h.Alias,
|
||||
Enabled: h.Enabled,
|
||||
Group: g.Name,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.SetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
host, _ := cfg.FindHostByAlias(payload.Alias)
|
||||
if host == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Check for conflicts if enabling
|
||||
if payload.Enabled && !payload.Force {
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict,
|
||||
fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update config
|
||||
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: host.Domain,
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleSync() *protocol.Response {
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.PresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.ApplyPreset(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleRollback(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.RollbackPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if err := s.hosts.RestoreBackup(payload.BackupName); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err))
|
||||
}
|
||||
|
||||
// Flush DNS after restore
|
||||
s.flusher.Flush()
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleBackups() *protocol.Response {
|
||||
backups, err := s.hosts.ListBackups()
|
||||
if err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err))
|
||||
}
|
||||
|
||||
var infos []protocol.BackupInfo
|
||||
for _, b := range backups {
|
||||
infos = append(infos, protocol.BackupInfo{
|
||||
Name: b.Name,
|
||||
Timestamp: b.Timestamp,
|
||||
Size: b.Size,
|
||||
})
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.AddPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
// Validate domain
|
||||
if payload.Domain == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required")
|
||||
}
|
||||
|
||||
// Validate IP
|
||||
if payload.IP == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required")
|
||||
}
|
||||
|
||||
// Validate group
|
||||
if payload.Group == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required")
|
||||
}
|
||||
|
||||
// Check blocked domains
|
||||
if config.IsBlockedDomain(payload.Domain) {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain))
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
// Add to config (alias will be auto-generated if empty)
|
||||
if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
Domain: payload.Domain,
|
||||
Applied: true,
|
||||
})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.DeletePayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Alias == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
// Delete from config
|
||||
if !cfg.DeleteHost(payload.Alias) {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.GroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.AddGroup(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.GroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.DeleteGroup(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleListGroups() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.RenameGroupPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.OldName == "" || payload.NewName == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.AddPresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response {
|
||||
var payload protocol.PresetPayload
|
||||
if err := req.ParsePayload(&payload); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
||||
}
|
||||
|
||||
if payload.Name == "" {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
||||
}
|
||||
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
if err := cfg.DeletePreset(payload.Name); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) handleListPresets() *protocol.Response {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
||||
}
|
||||
|
||||
presets := cfg.GetPresets()
|
||||
infos := make([]protocol.PresetInfo, len(presets))
|
||||
for i, p := range presets {
|
||||
infos[i] = protocol.PresetInfo{
|
||||
Name: p.Name,
|
||||
Enable: p.Enable,
|
||||
Disable: p.Disable,
|
||||
}
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos})
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) syncHostsFile() error {
|
||||
cfg := s.config.Get()
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("no configuration loaded")
|
||||
}
|
||||
|
||||
var entries []HostEntry
|
||||
for _, g := range cfg.Groups {
|
||||
for _, h := range g.Hosts {
|
||||
entries = append(entries, HostEntry{
|
||||
IP: h.IP,
|
||||
Domain: h.Domain,
|
||||
Alias: h.Alias,
|
||||
Enabled: h.Enabled,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.hosts.WriteManagedEntries(entries); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Flush DNS cache
|
||||
return s.flusher.Flush()
|
||||
}
|
||||
Reference in New Issue
Block a user