mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-07-02 03:36:20 +00:00
Cleanup, signing and update of internals.
This commit is contained in:
+79
-54
@@ -8,6 +8,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lukaszraczylo/lolcathost/internal/config"
|
||||
@@ -50,20 +51,28 @@ func (s *Server) Start() error {
|
||||
// Remove existing socket
|
||||
_ = os.Remove(s.socketPath)
|
||||
|
||||
// Set umask to create socket with restricted permissions (0660)
|
||||
// This prevents TOCTOU vulnerability between socket creation and chmod
|
||||
oldUmask := syscall.Umask(0117) // 0777 & ~0117 = 0660
|
||||
|
||||
listener, err := net.Listen("unix", s.socketPath)
|
||||
|
||||
// Restore original umask immediately after socket creation
|
||||
syscall.Umask(oldUmask)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on socket: %w", err)
|
||||
}
|
||||
|
||||
// Set socket permissions: 0660 root:lolcathost
|
||||
// #nosec G302 -- socket must be group-accessible for lolcathost group members
|
||||
if err := os.Chmod(s.socketPath, 0660); err != nil {
|
||||
_ = listener.Close()
|
||||
return fmt.Errorf("failed to set socket permissions: %w", err)
|
||||
// Look up the lolcathost group GID dynamically
|
||||
gid, err := lookupGroupGID("lolcathost")
|
||||
if err != nil {
|
||||
// Fall back to default GID if group lookup fails
|
||||
gid = LolcathostGID
|
||||
}
|
||||
|
||||
// Set socket group to lolcathost (GID 850)
|
||||
if err := os.Chown(s.socketPath, 0, 850); err != nil {
|
||||
// Set socket group to lolcathost
|
||||
if err := os.Chown(s.socketPath, 0, gid); err != nil {
|
||||
_ = listener.Close()
|
||||
return fmt.Errorf("failed to set socket ownership: %w", err)
|
||||
}
|
||||
@@ -126,6 +135,9 @@ func (s *Server) acceptLoop() {
|
||||
// LolcathostGID is the group ID for the lolcathost group.
|
||||
const LolcathostGID = 850
|
||||
|
||||
// connectionReadTimeout is the maximum time to wait for a client to send data.
|
||||
const connectionReadTimeout = 30 * time.Second
|
||||
|
||||
func (s *Server) handleConnection(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
@@ -134,7 +146,7 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
// Authorization check: verify peer is authorized
|
||||
if !s.isAuthorized(creds) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
_ = s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeUnauthorized, "unauthorized: user not in lolcathost group"))
|
||||
if s.auditLogger != nil {
|
||||
var uid uint32
|
||||
var pid int32
|
||||
@@ -149,6 +161,11 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// Set read deadline to prevent clients from hanging indefinitely
|
||||
if err := conn.SetReadDeadline(time.Now().Add(connectionReadTimeout)); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
line, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return
|
||||
@@ -156,13 +173,17 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
|
||||
var req protocol.Request
|
||||
if err := json.Unmarshal(line, &req); err != nil {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON"))
|
||||
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeInvalidRequest, "invalid JSON")); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Rate limiting
|
||||
if creds != nil && !s.rateLimiter.Allow(creds.PID) {
|
||||
s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded"))
|
||||
if err := s.writeResponse(conn, protocol.NewErrorResponse(protocol.ErrCodeRateLimited, "rate limit exceeded")); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -171,7 +192,9 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
s.mu.Unlock()
|
||||
|
||||
resp := s.handleRequest(&req, creds)
|
||||
s.writeResponse(conn, resp)
|
||||
if err := s.writeResponse(conn, resp); err != nil {
|
||||
return // Connection error, stop handling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,10 +221,16 @@ func (s *Server) isAuthorized(creds *PeerCredentials) bool {
|
||||
return isUserInGroup(creds.UID, LolcathostGID)
|
||||
}
|
||||
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) {
|
||||
data, _ := json.Marshal(resp)
|
||||
func (s *Server) writeResponse(conn net.Conn, resp *protocol.Response) error {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response: %w", err)
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, _ = conn.Write(data)
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return fmt.Errorf("failed to write response: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(req *protocol.Request, creds *PeerCredentials) *protocol.Response {
|
||||
@@ -427,14 +456,9 @@ func (s *Server) handleSet(req *protocol.Request) *protocol.Response {
|
||||
// Update config
|
||||
cfg.SetHostEnabled(payload.Alias, payload.Enabled)
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
@@ -468,14 +492,9 @@ func (s *Server) handlePreset(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"preset": payload.Name, "applied": "true"})
|
||||
@@ -573,14 +592,9 @@ func (s *Server) handleAdd(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeConflict, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(protocol.SetData{
|
||||
@@ -610,14 +624,9 @@ func (s *Server) handleDelete(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, fmt.Sprintf("alias not found: %s", payload.Alias))
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Alias})
|
||||
@@ -671,14 +680,9 @@ func (s *Server) handleDeleteGroup(req *protocol.Request) *protocol.Response {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeNotFound, err.Error())
|
||||
}
|
||||
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to save config: %v", err))
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, fmt.Sprintf("failed to sync hosts: %v", err))
|
||||
// Save and sync with rollback on failure
|
||||
if err := s.saveAndSync(); err != nil {
|
||||
return protocol.NewErrorResponse(protocol.ErrCodeInternalError, err.Error())
|
||||
}
|
||||
|
||||
resp, _ := protocol.NewOKResponse(map[string]string{"deleted": payload.Name})
|
||||
@@ -824,3 +828,24 @@ func (s *Server) syncHostsFile() error {
|
||||
// Flush DNS cache
|
||||
return s.flusher.Flush()
|
||||
}
|
||||
|
||||
// saveAndSync saves the configuration and syncs to /etc/hosts atomically.
|
||||
// If sync fails, it attempts to reload the previous config from disk.
|
||||
func (s *Server) saveAndSync() error {
|
||||
// Save config
|
||||
if err := s.config.Save(); err != nil {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
// Sync to hosts file
|
||||
if err := s.syncHostsFile(); err != nil {
|
||||
// Attempt to reload previous config on sync failure
|
||||
if reloadErr := s.config.Reload(); reloadErr != nil {
|
||||
// Log reload failure but return original sync error
|
||||
fmt.Fprintf(os.Stderr, "warning: failed to reload config after sync failure: %v\n", reloadErr)
|
||||
}
|
||||
return fmt.Errorf("failed to sync hosts (config rolled back): %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user