mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-05 23:29:18 +00:00
29263dc8a2
* gosec govulncheck runs
* Fix flaky TestRateLimiter_Matrix test
The test was failing due to two issues:
1. Test name generation used invalid character conversion (string(rune('0'+limit)))
which produced non-printable characters for limits >= 10
2. Using 10ms windows with 100 requests caused race conditions - early requests
would expire before all 100 were made, allowing the 101st request
Changed to use struct-based test cases with proper fmt.Sprintf naming and
a consistent 1-second window that won't expire during rapid test execution.
827 lines
22 KiB
Go
827 lines
22 KiB
Go
// Package daemon provides the Unix socket server for the daemon.
|
|
package daemon
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lukaszraczylo/lolcathost/internal/config"
|
|
"github.com/lukaszraczylo/lolcathost/internal/protocol"
|
|
)
|
|
|
|
// Version is set by the main package at startup
|
|
var Version = "dev"
|
|
|
|
// Server is the daemon's Unix socket server.
|
|
type Server struct {
|
|
socketPath string
|
|
listener net.Listener
|
|
config *config.Manager
|
|
hosts *HostsManager
|
|
flusher *DNSFlusher
|
|
rateLimiter *RateLimiter
|
|
auditLogger *AuditLogger
|
|
mu sync.RWMutex
|
|
running bool
|
|
stopCh chan struct{}
|
|
requestCount int64
|
|
startTime int64
|
|
}
|
|
|
|
// NewServer creates a new daemon server.
|
|
func NewServer(socketPath string, cfgManager *config.Manager) *Server {
|
|
return &Server{
|
|
socketPath: socketPath,
|
|
config: cfgManager,
|
|
hosts: NewHostsManager(),
|
|
flusher: NewDNSFlusher(FlushMethodAuto),
|
|
rateLimiter: NewRateLimiter(RateLimit, RateLimitWindow),
|
|
stopCh: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Start starts the server.
|
|
func (s *Server) Start() error {
|
|
// Remove existing socket
|
|
_ = os.Remove(s.socketPath)
|
|
|
|
listener, err := net.Listen("unix", s.socketPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on socket: %w", err)
|
|
}
|
|
|
|
// Set socket permissions: 0660 root:lolcathost
|
|
// #nosec G302 -- socket must be group-accessible for lolcathost group members
|
|
if err := os.Chmod(s.socketPath, 0660); err != nil {
|
|
_ = listener.Close()
|
|
return fmt.Errorf("failed to set socket permissions: %w", err)
|
|
}
|
|
|
|
// Set socket group to lolcathost (GID 850)
|
|
if err := os.Chown(s.socketPath, 0, 850); err != nil {
|
|
_ = listener.Close()
|
|
return fmt.Errorf("failed to set socket ownership: %w", err)
|
|
}
|
|
|
|
s.listener = listener
|
|
s.running = true
|
|
s.startTime = currentTimeUnix()
|
|
|
|
// Try to create audit logger, but don't fail if it doesn't work
|
|
if logger, err := NewAuditLogger(AuditLogPath); err == nil {
|
|
s.auditLogger = logger
|
|
}
|
|
|
|
go s.acceptLoop()
|
|
|
|
return nil
|
|
}
|
|
|
|
func currentTimeUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|
|
|
|
// Stop stops the server.
|
|
func (s *Server) Stop() error {
|
|
s.mu.Lock()
|
|
s.running = false
|
|
s.mu.Unlock()
|
|
|
|
close(s.stopCh)
|
|
|
|
if s.listener != nil {
|
|
_ = s.listener.Close()
|
|
}
|
|
|
|
_ = os.Remove(s.socketPath)
|
|
|
|
if s.auditLogger != nil {
|
|
_ = s.auditLogger.Close()
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) acceptLoop() {
|
|
for {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-s.stopCh:
|
|
return
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
|
|
go s.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
// LolcathostGID is the group ID for the lolcathost group.
|
|
const LolcathostGID = 850
|
|
|
|
func (s *Server) handleConnection(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
// Get peer credentials
|
|
creds := s.getPeerCredentials(conn)
|
|
|
|
// Authorization check: verify peer is authorized
|
|
if !s.isAuthorized(creds) {
|
|
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
|
if s.auditLogger != nil {
|
|
var uid uint32
|
|
var pid int32
|
|
if creds != nil {
|
|
uid = creds.UID
|
|
pid = creds.PID
|
|
}
|
|
s.auditLogger.Log(uid, pid, "connect", nil, false, "unauthorized access attempt")
|
|
}
|
|
return
|
|
}
|
|
|
|
reader := bufio.NewReader(conn)
|
|
for {
|
|
line, err := reader.ReadBytes('\n')
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
var req protocol.Request
|
|
if err := json.Unmarshal(line, &req); err != nil {
|
|
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
|
|
continue
|
|
}
|
|
|
|
// Rate limiting
|
|
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
|
|
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
|
|
continue
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.requestCount++
|
|
s.mu.Unlock()
|
|
|
|
resp := s.handleRequest(&req, creds)
|
|
s.writeResponse(conn, resp)
|
|
}
|
|
}
|
|
|
|
// isAuthorized checks if the peer is authorized to access the daemon.
|
|
// Authorized users are: root (UID 0) or members of the lolcathost group (GID 850).
|
|
func (s *Server) isAuthorized(creds *PeerCredentials) bool {
|
|
if creds == nil {
|
|
// Can't verify credentials - deny by default
|
|
return false
|
|
}
|
|
|
|
// Root is always authorized
|
|
if creds.UID == 0 {
|
|
return true
|
|
}
|
|
|
|
// Check if user's primary GID is lolcathost
|
|
if creds.GID == LolcathostGID {
|
|
return true
|
|
}
|
|
|
|
// Check supplementary groups (user might be in lolcathost as secondary group)
|
|
// This requires looking up the user's groups from the system
|
|
return isUserInGroup(creds.UID, LolcathostGID)
|
|
}
|
|
|
|
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
|
|
data, _ := json.Marshal(resp)
|
|
data = append(data, '\n')
|
|
_, _ = conn.Write(data)
|
|
}
|
|
|
|
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
|
|
var uid uint32
|
|
var pid int32
|
|
if creds != nil {
|
|
uid = creds.UID
|
|
pid = creds.PID
|
|
}
|
|
|
|
switch req.Type {
|
|
case protocol.RequestPing:
|
|
return s.handlePing()
|
|
|
|
case protocol.RequestStatus:
|
|
return s.handleStatus()
|
|
|
|
case protocol.RequestList:
|
|
return s.handleList()
|
|
|
|
case protocol.RequestSet:
|
|
resp := s.handleSet(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.SetPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "set", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestSync:
|
|
resp := s.handleSync()
|
|
if s.auditLogger != nil {
|
|
s.auditLogger.Log(uid, pid, "sync", nil, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestPreset:
|
|
resp := s.handlePreset(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.PresetPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "preset", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestRollback:
|
|
resp := s.handleRollback(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.RollbackPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "rollback", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestBackups:
|
|
return s.handleBackups()
|
|
|
|
case protocol.RequestBackupContent:
|
|
return s.handleBackupContent(req)
|
|
|
|
case protocol.RequestAdd:
|
|
resp := s.handleAdd(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.AddPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "add", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestDelete:
|
|
resp := s.handleDelete(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.DeletePayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "delete", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestAddGroup:
|
|
resp := s.handleAddGroup(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.GroupPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "add_group", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestDeleteGroup:
|
|
resp := s.handleDeleteGroup(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.GroupPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "delete_group", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestListGroups:
|
|
return s.handleListGroups()
|
|
|
|
case protocol.RequestRenameGroup:
|
|
resp := s.handleRenameGroup(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.RenameGroupPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "rename_group", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestAddPreset:
|
|
resp := s.handleAddPreset(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.AddPresetPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "add_preset", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestDeletePreset:
|
|
resp := s.handleDeletePreset(req)
|
|
if s.auditLogger != nil {
|
|
var payload protocol.PresetPayload
|
|
_ = req.ParsePayload(&payload)
|
|
s.auditLogger.Log(uid, pid, "delete_preset", payload, resp.IsOK(), resp.Message)
|
|
}
|
|
return resp
|
|
|
|
case protocol.RequestListPresets:
|
|
return s.handleListPresets()
|
|
|
|
default:
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, fmt.Sprintf("unknown request type: %s", req.Type))
|
|
}
|
|
}
|
|
|
|
func (s *Server) handlePing() *protocol.Response {
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"pong": "ok"})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleStatus() *protocol.Response {
|
|
s.mu.RLock()
|
|
reqCount := s.requestCount
|
|
startTime := s.startTime
|
|
s.mu.RUnlock()
|
|
|
|
cfg := s.config.Get()
|
|
var activeCount int
|
|
if cfg != nil {
|
|
for _, h := range cfg.GetAllHosts() {
|
|
if h.Enabled {
|
|
activeCount++
|
|
}
|
|
}
|
|
}
|
|
|
|
data := protocol.StatusData{
|
|
Running: true,
|
|
Version: Version,
|
|
Uptime: nowUnix() - startTime,
|
|
ActiveCount: activeCount,
|
|
RequestCount: reqCount,
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(data)
|
|
return resp
|
|
}
|
|
|
|
func nowUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|
|
|
|
func (s *Server) handleList() *protocol.Response {
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
var entries []protocol.HostEntry
|
|
for _, g := range cfg.Groups {
|
|
for _, h := range g.Hosts {
|
|
entries = append(entries, protocol.HostEntry{
|
|
Domain: h.Domain,
|
|
IP: h.IP,
|
|
Alias: h.Alias,
|
|
Enabled: h.Enabled,
|
|
Group: g.Name,
|
|
})
|
|
}
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.ListData{Entries: entries})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.SetPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
host, _ := cfg.FindHostByAlias(payload.Alias)
|
|
if host == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
|
}
|
|
|
|
// Check for conflicts if enabling
|
|
if payload.Enabled && !payload.Force {
|
|
for _, g := range cfg.Groups {
|
|
for _, h := range g.Hosts {
|
|
if h.Alias != payload.Alias && h.Domain == host.Domain && h.Enabled {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeConflict,
|
|
fmt.Sprintf("domain %s already mapped by alias %s (use force to override)", host.Domain, h.Alias))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Update config
|
|
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
// Sync to hosts file
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
|
Domain: host.Domain,
|
|
Applied: true,
|
|
})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleSync() *protocol.Response {
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]bool{"synced": true})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.PresetPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.ApplyPreset(payload.Name); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
// Sync to hosts file
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleRollback(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.RollbackPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if err := s.hosts.RestoreBackup(payload.BackupName); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to restore backup: %v", err))
|
|
}
|
|
|
|
// Flush DNS after restore
|
|
_ = s.flusher.Flush()
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"restored": payload.BackupName})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleBackups() *protocol.Response {
|
|
backups, err := s.hosts.ListBackups()
|
|
if err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to list backups: %v", err))
|
|
}
|
|
|
|
var infos []protocol.BackupInfo
|
|
for _, b := range backups {
|
|
infos = append(infos, protocol.BackupInfo{
|
|
Name: b.Name,
|
|
Timestamp: b.Timestamp,
|
|
Size: b.Size,
|
|
})
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.BackupsData{Backups: infos})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleBackupContent(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.BackupContentPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.BackupName == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "backup name is required")
|
|
}
|
|
|
|
content, err := s.hosts.GetBackupContent(payload.BackupName)
|
|
if err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("failed to get backup content: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.BackupContentData{Content: content})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.AddPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
// Validate domain
|
|
if payload.Domain == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidDomain, "domain is required")
|
|
}
|
|
|
|
// Validate IP
|
|
if payload.IP == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidIP, "IP address is required")
|
|
}
|
|
|
|
// Validate group
|
|
if payload.Group == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group is required")
|
|
}
|
|
|
|
// Check blocked domains
|
|
if config.IsBlockedDomain(payload.Domain) {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeBlockedDomain, fmt.Sprintf("domain %s is blocked", payload.Domain))
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
// Add to config (alias will be auto-generated if empty)
|
|
if err := cfg.AddHost(payload.Domain, payload.IP, payload.Alias, payload.Group, payload.Enabled); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
// Sync to hosts file
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
|
Domain: payload.Domain,
|
|
Applied: true,
|
|
})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.DeletePayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.Alias == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "alias is required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
// Delete from config
|
|
if !cfg.DeleteHost(payload.Alias) {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
// Sync to hosts file
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleAddGroup(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.GroupPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.Name == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.AddGroup(payload.Name); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.GroupPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.Name == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "group name is required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.DeleteGroup(payload.Name); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
// Sync to hosts file
|
|
if err := s.syncHostsFile(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleListGroups() *protocol.Response {
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.GroupsData{Groups: cfg.GetGroups()})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleRenameGroup(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.RenameGroupPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.OldName == "" || payload.NewName == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "old_name and new_name are required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.RenameGroup(payload.OldName, payload.NewName); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"renamed": payload.NewName})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleAddPreset(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.AddPresetPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.Name == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.AddPreset(payload.Name, payload.Enable, payload.Disable); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"added": payload.Name})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleDeletePreset(req *protocol.Request) *protocol.Response {
|
|
var payload protocol.PresetPayload
|
|
if err := req.ParsePayload(&payload); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid payload")
|
|
}
|
|
|
|
if payload.Name == "" {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "preset name is required")
|
|
}
|
|
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
if err := cfg.DeletePreset(payload.Name); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
|
}
|
|
|
|
// Save config
|
|
if err := s.config.Save(); err != nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) handleListPresets() *protocol.Response {
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, "no configuration loaded")
|
|
}
|
|
|
|
presets := cfg.GetPresets()
|
|
infos := make([]protocol.PresetInfo, len(presets))
|
|
for i, p := range presets {
|
|
infos[i] = protocol.PresetInfo{
|
|
Name: p.Name,
|
|
Enable: p.Enable,
|
|
Disable: p.Disable,
|
|
}
|
|
}
|
|
|
|
resp, _ := protocol.NewOKResponse(protocol.PresetsData{Presets: infos})
|
|
return resp
|
|
}
|
|
|
|
func (s *Server) syncHostsFile() error {
|
|
cfg := s.config.Get()
|
|
if cfg == nil {
|
|
return fmt.Errorf("no configuration loaded")
|
|
}
|
|
|
|
var entries []HostEntry
|
|
for _, g := range cfg.Groups {
|
|
for _, h := range g.Hosts {
|
|
entries = append(entries, HostEntry{
|
|
IP: h.IP,
|
|
Domain: h.Domain,
|
|
Alias: h.Alias,
|
|
Enabled: h.Enabled,
|
|
})
|
|
}
|
|
}
|
|
|
|
if err := s.hosts.WriteManagedEntries(entries); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Flush DNS cache
|
|
return s.flusher.Flush()
|
|
}
|