Cleanup, signing and update of internals.

This commit is contained in:
2025-12-15 00:32:53 +00:00
parent 2d9c28657b
commit 100251b896
19 changed files with 439 additions and 313 deletions
+79 -54
View File
@@ -8,6 +8,7 @@ import (
"net"
"os"
"sync"
"syscall"
"time"
"github.com/lukaszraczylo/lolcathost/internal/config"
@@ -50,20 +51,28 @@ func (s *Server) Start() error {
// Remove existing socket
_ = os.Remove(s.socketPath)
// Set umask to create socket with restricted permissions (0660)
// This prevents TOCTOU vulnerability between socket creation and chmod
oldUmask := syscall.Umask(0117) // 0777 & ~0117 = 0660
listener, err := net.Listen("unix", s.socketPath)
// Restore original umask immediately after socket creation
syscall.Umask(oldUmask)
if err != nil {
return fmt.Errorf("failed to listen on socket: %w", err)
}
// Set socket permissions: 0660 root:lolcathost
// #nosec G302 -- socket must be group-accessible for lolcathost group members
if err := os.Chmod(s.socketPath, 0660); err != nil {
_ = listener.Close()
return fmt.Errorf("failed to set socket permissions: %w", err)
// Look up the lolcathost group GID dynamically
gid, err := lookupGroupGID("lolcathost")
if err != nil {
// Fall back to default GID if group lookup fails
gid = LolcathostGID
}
// Set socket group to lolcathost (GID 850)
if err := os.Chown(s.socketPath, 0, 850); err != nil {
// Set socket group to lolcathost
if err := os.Chown(s.socketPath, 0, gid); err != nil {
_ = listener.Close()
return fmt.Errorf("failed to set socket ownership: %w", err)
}
@@ -126,6 +135,9 @@ func (s *Server) acceptLoop() {
// LolcathostGID is the group ID for the lolcathost group.
const LolcathostGID = 850
// connectionReadTimeout is the maximum time to wait for a client to send data.
const connectionReadTimeout = 30 * time.Second
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
@@ -134,7 +146,7 @@ func (s *Server) handleConnection(conn net.Conn) {
// Authorization check: verify peer is authorized
if !s.isAuthorized(creds) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
_ = s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
if s.auditLogger != nil {
var uid uint32
var pid int32
@@ -149,6 +161,11 @@ func (s *Server) handleConnection(conn net.Conn) {
reader := bufio.NewReader(conn)
for {
// Set read deadline to prevent clients from hanging indefinitely
if err := conn.SetReadDeadline(time.Now().Add(connectionReadTimeout)); err != nil {
return
}
line, err := reader.ReadBytes('\n')
if err != nil {
return
@@ -156,13 +173,17 @@ func (s *Server) handleConnection(conn net.Conn) {
var req protocol.Request
if err := json.Unmarshal(line, &req); err != nil {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON")); err != nil {
return // Connection error, stop handling
}
continue
}
// Rate limiting
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded")); err != nil {
return // Connection error, stop handling
}
continue
}
@@ -171,7 +192,9 @@ func (s *Server) handleConnection(conn net.Conn) {
s.mu.Unlock()
resp := s.handleRequest(&req, creds)
s.writeResponse(conn, resp)
if err := s.writeResponse(conn, resp); err != nil {
return // Connection error, stop handling
}
}
}
@@ -198,10 +221,16 @@ func (s *Server) isAuthorized(creds *PeerCredentials) bool {
return isUserInGroup(creds.UID, LolcathostGID)
}
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
data, _ := json.Marshal(resp)
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) error {
data, err := json.Marshal(resp)
if err != nil {
return fmt.Errorf("failed to marshal response: %w", err)
}
data = append(data, '\n')
_, _ = conn.Write(data)
if _, err := conn.Write(data); err != nil {
return fmt.Errorf("failed to write response: %w", err)
}
return nil
}
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
@@ -427,14 +456,9 @@ func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
// Update config
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
// Save and sync with rollback on failure
if err := s.saveAndSync(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
@@ -468,14 +492,9 @@ func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
// Save and sync with rollback on failure
if err := s.saveAndSync(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
}
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
@@ -573,14 +592,9 @@ func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
// Save and sync with rollback on failure
if err := s.saveAndSync(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
}
resp, _ := protocol.NewOKResponse(protocol.SetData{
@@ -610,14 +624,9 @@ func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
// Save and sync with rollback on failure
if err := s.saveAndSync(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
@@ -671,14 +680,9 @@ func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
}
// Save config
if err := s.config.Save(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
// Save and sync with rollback on failure
if err := s.saveAndSync(); err != nil {
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
}
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
@@ -824,3 +828,24 @@ func (s *Server) syncHostsFile() error {
// Flush DNS cache
return s.flusher.Flush()
}
// saveAndSync saves the configuration and syncs to /etc/hosts atomically.
// If sync fails, it attempts to reload the previous config from disk.
func (s *Server) saveAndSync() error {
// Save config
if err := s.config.Save(); err != nil {
return fmt.Errorf("failed to save config: %w", err)
}
// Sync to hosts file
if err := s.syncHostsFile(); err != nil {
// Attempt to reload previous config on sync failure
if reloadErr := s.config.Reload(); reloadErr != nil {
// Log reload failure but return original sync error
fmt.Fprintf(os.Stderr, "warning: failed to reload config after sync failure: %v\n", reloadErr)
}
return fmt.Errorf("failed to sync hosts (config rolled back): %w", err)
}
return nil
}