Initial commit.

This commit is contained in:
2025-11-28 02:50:25 +00:00
commit 22552aec99
41 changed files with 10626 additions and 0 deletions
+541
View File
@@ -0,0 +1,541 @@
// Package config handles YAML configuration parsing and hot-reload.
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/fsnotify/fsnotify"
"gopkg.in/yaml.v3"
)
// SystemConfigDir is the system-wide config directory for the daemon.
const SystemConfigDir = "/etc/lolcathost"
// SystemConfigPath is the system-wide config file path for the daemon.
const SystemConfigPath = "/etc/lolcathost/config.yaml"
// DefaultConfigDir returns the default config directory path for users.
func DefaultConfigDir() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".config", "lolcathost")
}
// DefaultConfigPath returns the default config file path for users.
func DefaultConfigPath() string {
return filepath.Join(DefaultConfigDir(), "config.yaml")
}
// FlushMethod defines DNS cache flush methods.
type FlushMethod string
const (
FlushMethodAuto FlushMethod = "auto"
FlushMethodDscacheutil FlushMethod = "dscacheutil"
FlushMethodKillall FlushMethod = "killall"
FlushMethodBoth FlushMethod = "both"
)
// Settings holds global configuration settings.
type Settings struct {
AutoApply bool `yaml:"autoApply"`
FlushMethod FlushMethod `yaml:"flushMethod"`
}
// Host represents a single host entry in configuration.
type Host struct {
Domain string `yaml:"domain"`
IP string `yaml:"ip"`
Alias string `yaml:"alias"`
Enabled bool `yaml:"enabled"`
}
// Group represents a group of host entries.
type Group struct {
Name string `yaml:"name"`
Hosts []Host `yaml:"hosts"`
}
// Preset defines a named preset that enables/disables specific aliases.
type Preset struct {
Name string `yaml:"name"`
Enable []string `yaml:"enable,omitempty"`
Disable []string `yaml:"disable,omitempty"`
}
// Config represents the complete configuration.
type Config struct {
Settings Settings `yaml:"settings"`
Groups []Group `yaml:"groups"`
Presets []Preset `yaml:"presets"`
}
// Manager handles configuration loading and watching.
type Manager struct {
path string
config *Config
mu sync.RWMutex
watcher *fsnotify.Watcher
onChange func(*Config)
stopCh chan struct{}
}
// NewManager creates a new config manager.
func NewManager(path string) *Manager {
return &Manager{
path: path,
stopCh: make(chan struct{}),
}
}
// Load reads and parses the configuration file.
func (m *Manager) Load() error {
data, err := os.ReadFile(m.path)
if err != nil {
return fmt.Errorf("failed to read config file: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
if err := ValidateConfig(&cfg); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
m.mu.Lock()
m.config = &cfg
m.mu.Unlock()
return nil
}
// Get returns the current configuration.
func (m *Manager) Get() *Config {
m.mu.RLock()
defer m.mu.RUnlock()
return m.config
}
// Watch starts watching the config file for changes.
func (m *Manager) Watch(onChange func(*Config)) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return fmt.Errorf("failed to create watcher: %w", err)
}
m.watcher = watcher
m.onChange = onChange
go m.watchLoop()
if err := watcher.Add(m.path); err != nil {
return fmt.Errorf("failed to watch config file: %w", err)
}
return nil
}
func (m *Manager) watchLoop() {
for {
select {
case event, ok := <-m.watcher.Events:
if !ok {
return
}
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
if err := m.Load(); err == nil && m.onChange != nil {
m.onChange(m.Get())
}
}
case <-m.watcher.Errors:
// Ignore watcher errors
case <-m.stopCh:
return
}
}
}
// Stop stops watching the config file.
func (m *Manager) Stop() {
close(m.stopCh)
if m.watcher != nil {
m.watcher.Close()
}
}
// GetAllHosts returns all hosts from all groups.
func (c *Config) GetAllHosts() []Host {
var hosts []Host
for _, g := range c.Groups {
hosts = append(hosts, g.Hosts...)
}
return hosts
}
// FindHostByAlias finds a host by its alias.
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
return &c.Groups[i].Hosts[j], &c.Groups[i]
}
}
}
return nil, nil
}
// FindPreset finds a preset by name.
func (c *Config) FindPreset(name string) *Preset {
for i := range c.Presets {
if c.Presets[i].Name == name {
return &c.Presets[i]
}
}
return nil
}
// SetHostEnabled sets the enabled state of a host by alias.
func (c *Config) SetHostEnabled(alias string, enabled bool) bool {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
c.Groups[i].Hosts[j].Enabled = enabled
return true
}
}
}
return false
}
// GenerateAlias creates a unique alias from a domain name.
func (c *Config) GenerateAlias(domain string) string {
// Convert domain to alias format: example.com -> example-com
alias := strings.ReplaceAll(domain, ".", "-")
alias = strings.ReplaceAll(alias, "_", "-")
alias = strings.ToLower(alias)
// Check if alias exists, if so append a number
baseAlias := alias
counter := 1
for {
if existing, _ := c.FindHostByAlias(alias); existing == nil {
break
}
counter++
alias = fmt.Sprintf("%s-%d", baseAlias, counter)
}
return alias
}
// AddHost adds a new host to the configuration.
func (c *Config) AddHost(domain, ip, alias, groupName string, enabled bool) error {
// Auto-generate alias if empty
if alias == "" {
alias = c.GenerateAlias(domain)
} else {
// Check for duplicate alias
if existing, _ := c.FindHostByAlias(alias); existing != nil {
return fmt.Errorf("alias already exists: %s", alias)
}
}
host := Host{
Domain: domain,
IP: ip,
Alias: alias,
Enabled: enabled,
}
// Find or create group
for i := range c.Groups {
if c.Groups[i].Name == groupName {
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
return nil
}
}
// Create new group
c.Groups = append(c.Groups, Group{
Name: groupName,
Hosts: []Host{host},
})
return nil
}
// AddGroup adds a new empty group.
func (c *Config) AddGroup(name string) error {
// Check if group already exists
for _, g := range c.Groups {
if g.Name == name {
return fmt.Errorf("group already exists: %s", name)
}
}
c.Groups = append(c.Groups, Group{
Name: name,
Hosts: []Host{},
})
return nil
}
// DeleteGroup removes a group and all its hosts.
func (c *Config) DeleteGroup(name string) error {
for i, g := range c.Groups {
if g.Name == name {
c.Groups = append(c.Groups[:i], c.Groups[i+1:]...)
return nil
}
}
return fmt.Errorf("group not found: %s", name)
}
// RenameGroup renames an existing group.
func (c *Config) RenameGroup(oldName, newName string) error {
// Check if new name already exists
for _, g := range c.Groups {
if g.Name == newName {
return fmt.Errorf("group already exists: %s", newName)
}
}
for i := range c.Groups {
if c.Groups[i].Name == oldName {
c.Groups[i].Name = newName
return nil
}
}
return fmt.Errorf("group not found: %s", oldName)
}
// GetGroups returns all group names.
func (c *Config) GetGroups() []string {
names := make([]string, len(c.Groups))
for i, g := range c.Groups {
names[i] = g.Name
}
return names
}
// DeleteHost removes a host by alias.
func (c *Config) DeleteHost(alias string) bool {
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == alias {
c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...)
return true
}
}
}
return false
}
// UpdateHost updates an existing host by alias.
func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error {
// Find the host
var foundGroup int = -1
var foundHost int = -1
for i := range c.Groups {
for j := range c.Groups[i].Hosts {
if c.Groups[i].Hosts[j].Alias == oldAlias {
foundGroup = i
foundHost = j
break
}
}
if foundHost >= 0 {
break
}
}
if foundHost < 0 {
return fmt.Errorf("alias not found: %s", oldAlias)
}
// Check for duplicate alias if alias is changing
if oldAlias != newAlias {
if existing, _ := c.FindHostByAlias(newAlias); existing != nil {
return fmt.Errorf("alias already exists: %s", newAlias)
}
}
// Get current enabled state
enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled
// If group is changing, move to new group
if c.Groups[foundGroup].Name != groupName {
// Remove from old group
c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...)
// Add to new group
host := Host{
Domain: domain,
IP: ip,
Alias: newAlias,
Enabled: enabled,
}
// Find or create target group
found := false
for i := range c.Groups {
if c.Groups[i].Name == groupName {
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
found = true
break
}
}
if !found {
c.Groups = append(c.Groups, Group{
Name: groupName,
Hosts: []Host{host},
})
}
} else {
// Update in place
c.Groups[foundGroup].Hosts[foundHost].Domain = domain
c.Groups[foundGroup].Hosts[foundHost].IP = ip
c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias
}
return nil
}
// ApplyPreset applies a preset to the configuration.
func (c *Config) ApplyPreset(name string) error {
preset := c.FindPreset(name)
if preset == nil {
return fmt.Errorf("preset not found: %s", name)
}
for _, alias := range preset.Enable {
c.SetHostEnabled(alias, true)
}
for _, alias := range preset.Disable {
c.SetHostEnabled(alias, false)
}
return nil
}
// AddPreset adds a new preset.
func (c *Config) AddPreset(name string, enable, disable []string) error {
// Check if preset already exists
for _, p := range c.Presets {
if p.Name == name {
return fmt.Errorf("preset already exists: %s", name)
}
}
c.Presets = append(c.Presets, Preset{
Name: name,
Enable: enable,
Disable: disable,
})
return nil
}
// DeletePreset removes a preset by name.
func (c *Config) DeletePreset(name string) error {
for i, p := range c.Presets {
if p.Name == name {
c.Presets = append(c.Presets[:i], c.Presets[i+1:]...)
return nil
}
}
return fmt.Errorf("preset not found: %s", name)
}
// GetPresets returns all presets.
func (c *Config) GetPresets() []Preset {
return c.Presets
}
// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed.
func (c *Config) EnsureDefaultGroup() {
if len(c.Groups) == 0 {
c.Groups = append(c.Groups, Group{
Name: "default",
Hosts: []Host{},
})
}
}
// Save writes the configuration to the file.
func (m *Manager) Save() error {
m.mu.RLock()
cfg := m.config
m.mu.RUnlock()
if cfg == nil {
return fmt.Errorf("no config loaded")
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal config: %w", err)
}
if err := os.WriteFile(m.path, data, 0644); err != nil {
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
// CreateDefault creates a default configuration file.
func CreateDefault(path string) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
cfg := &Config{
Settings: Settings{
AutoApply: true,
FlushMethod: FlushMethodAuto,
},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{
Domain: "example.local",
IP: "127.0.0.1",
Alias: "example-local",
Enabled: false,
},
},
},
},
Presets: []Preset{
{
Name: "local",
Enable: []string{"example-local"},
Disable: []string{},
},
{
Name: "clear",
Enable: []string{},
Disable: []string{"example-local"},
},
},
}
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal default config: %w", err)
}
if err := os.WriteFile(path, data, 0644); err != nil {
return fmt.Errorf("failed to write default config: %w", err)
}
return nil
}
+267
View File
@@ -0,0 +1,267 @@
package config
import (
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConfig_GetAllHosts(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false},
},
},
{
Name: "staging",
Hosts: []Host{
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true},
},
},
},
}
hosts := cfg.GetAllHosts()
assert.Len(t, hosts, 3)
assert.Equal(t, "a.com", hosts[0].Domain)
assert.Equal(t, "b.com", hosts[1].Domain)
assert.Equal(t, "c.com", hosts[2].Domain)
}
func TestConfig_FindHostByAlias(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
},
},
},
}
t.Run("found", func(t *testing.T) {
host, group := cfg.FindHostByAlias("example")
require.NotNil(t, host)
require.NotNil(t, group)
assert.Equal(t, "example.com", host.Domain)
assert.Equal(t, "dev", group.Name)
})
t.Run("not found", func(t *testing.T) {
host, group := cfg.FindHostByAlias("nonexistent")
assert.Nil(t, host)
assert.Nil(t, group)
})
}
func TestConfig_FindPreset(t *testing.T) {
cfg := &Config{
Presets: []Preset{
{Name: "local", Enable: []string{"a"}, Disable: []string{"b"}},
{Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}},
},
}
t.Run("found", func(t *testing.T) {
preset := cfg.FindPreset("local")
require.NotNil(t, preset)
assert.Equal(t, "local", preset.Name)
assert.Equal(t, []string{"a"}, preset.Enable)
})
t.Run("not found", func(t *testing.T) {
preset := cfg.FindPreset("nonexistent")
assert.Nil(t, preset)
})
}
func TestConfig_SetHostEnabled(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: false},
},
},
},
}
t.Run("enable existing", func(t *testing.T) {
result := cfg.SetHostEnabled("example", true)
assert.True(t, result)
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
})
t.Run("disable existing", func(t *testing.T) {
result := cfg.SetHostEnabled("example", false)
assert.True(t, result)
assert.False(t, cfg.Groups[0].Hosts[0].Enabled)
})
t.Run("nonexistent alias", func(t *testing.T) {
result := cfg.SetHostEnabled("nonexistent", true)
assert.False(t, result)
})
}
func TestConfig_ApplyPreset(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false},
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "swap", Enable: []string{"a"}, Disable: []string{"b"}},
},
}
t.Run("valid preset", func(t *testing.T) {
err := cfg.ApplyPreset("swap")
require.NoError(t, err)
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
assert.False(t, cfg.Groups[0].Hosts[1].Enabled)
})
t.Run("nonexistent preset", func(t *testing.T) {
err := cfg.ApplyPreset("nonexistent")
assert.Error(t, err)
})
}
func TestManager_LoadAndGet(t *testing.T) {
// Create temp config file
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
configContent := `
settings:
autoApply: true
flushMethod: auto
groups:
- name: development
hosts:
- domain: example.com
ip: 127.0.0.1
alias: example-local
enabled: true
presets:
- name: local
enable: [example-local]
disable: []
`
err := os.WriteFile(configPath, []byte(configContent), 0644)
require.NoError(t, err)
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
require.NotNil(t, cfg)
assert.True(t, cfg.Settings.AutoApply)
assert.Equal(t, FlushMethodAuto, cfg.Settings.FlushMethod)
assert.Len(t, cfg.Groups, 1)
assert.Equal(t, "development", cfg.Groups[0].Name)
assert.Len(t, cfg.Groups[0].Hosts, 1)
assert.Equal(t, "example.com", cfg.Groups[0].Hosts[0].Domain)
}
func TestManager_Save(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
// Create initial config
err := CreateDefault(configPath)
require.NoError(t, err)
// Load and modify
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
cfg.Groups[0].Hosts[0].Enabled = true
// Save
err = manager.Save()
require.NoError(t, err)
// Reload and verify
manager2 := NewManager(configPath)
err = manager2.Load()
require.NoError(t, err)
cfg2 := manager2.Get()
assert.True(t, cfg2.Groups[0].Hosts[0].Enabled)
}
func TestCreateDefault(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "subdir", "config.yaml")
err := CreateDefault(configPath)
require.NoError(t, err)
// Verify file exists
_, err = os.Stat(configPath)
require.NoError(t, err)
// Verify content is valid
manager := NewManager(configPath)
err = manager.Load()
require.NoError(t, err)
cfg := manager.Get()
require.NotNil(t, cfg)
assert.True(t, cfg.Settings.AutoApply)
assert.Len(t, cfg.Groups, 1)
assert.Len(t, cfg.Presets, 2)
}
func TestManager_Load_InvalidYAML(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0644)
require.NoError(t, err)
manager := NewManager(configPath)
err = manager.Load()
assert.Error(t, err)
}
func TestManager_Load_FileNotFound(t *testing.T) {
manager := NewManager("/nonexistent/path/config.yaml")
err := manager.Load()
assert.Error(t, err)
}
func TestFlushMethod(t *testing.T) {
methods := []FlushMethod{
FlushMethodAuto,
FlushMethodDscacheutil,
FlushMethodKillall,
FlushMethodBoth,
}
for _, m := range methods {
t.Run(string(m), func(t *testing.T) {
assert.NotEmpty(t, string(m))
})
}
}
+211
View File
@@ -0,0 +1,211 @@
// Package config provides validation functions for configuration.
package config
import (
"fmt"
"net"
"regexp"
"strings"
)
// domainRegex validates domain names.
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$|^localhost$`)
// aliasRegex validates alias names.
var aliasRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`)
// blockedDomains contains domains that cannot be modified.
var blockedDomains = map[string]bool{
"apple.com": true,
"icloud.com": true,
"icloud-content.com": true,
"apple-dns.cn": true,
"apple-dns.net": true,
"mzstatic.com": true,
"itunes.apple.com": true,
"updates.apple.com": true,
}
// ValidationError represents a configuration validation error.
type ValidationError struct {
Field string
Message string
}
func (e *ValidationError) Error() string {
return fmt.Sprintf("%s: %s", e.Field, e.Message)
}
// ValidateConfig validates the entire configuration.
func ValidateConfig(cfg *Config) error {
if cfg == nil {
return &ValidationError{Field: "config", Message: "config is nil"}
}
if err := validateSettings(&cfg.Settings); err != nil {
return err
}
// Track aliases for uniqueness
aliases := make(map[string]bool)
for i, g := range cfg.Groups {
if err := validateGroup(&g, i, aliases); err != nil {
return err
}
}
for i, p := range cfg.Presets {
if err := validatePreset(&p, i, aliases); err != nil {
return err
}
}
return nil
}
func validateSettings(s *Settings) error {
switch s.FlushMethod {
case FlushMethodAuto, FlushMethodDscacheutil, FlushMethodKillall, FlushMethodBoth, "":
// Valid
default:
return &ValidationError{
Field: "settings.flushMethod",
Message: fmt.Sprintf("invalid flush method: %s", s.FlushMethod),
}
}
return nil
}
func validateGroup(g *Group, index int, aliases map[string]bool) error {
if strings.TrimSpace(g.Name) == "" {
return &ValidationError{
Field: fmt.Sprintf("groups[%d].name", index),
Message: "group name is required",
}
}
for i, h := range g.Hosts {
if err := validateHost(&h, index, i, aliases); err != nil {
return err
}
}
return nil
}
func validateHost(h *Host, groupIndex, hostIndex int, aliases map[string]bool) error {
fieldPrefix := fmt.Sprintf("groups[%d].hosts[%d]", groupIndex, hostIndex)
// Validate domain
if !ValidateDomain(h.Domain) {
return &ValidationError{
Field: fieldPrefix + ".domain",
Message: fmt.Sprintf("invalid domain: %s", h.Domain),
}
}
// Check blocked domains
if IsBlockedDomain(h.Domain) {
return &ValidationError{
Field: fieldPrefix + ".domain",
Message: fmt.Sprintf("domain is blocked: %s", h.Domain),
}
}
// Validate IP
if !ValidateIP(h.IP) {
return &ValidationError{
Field: fieldPrefix + ".ip",
Message: fmt.Sprintf("invalid IP address: %s", h.IP),
}
}
// Validate alias
if !ValidateAlias(h.Alias) {
return &ValidationError{
Field: fieldPrefix + ".alias",
Message: fmt.Sprintf("invalid alias: %s", h.Alias),
}
}
// Check alias uniqueness
if aliases[h.Alias] {
return &ValidationError{
Field: fieldPrefix + ".alias",
Message: fmt.Sprintf("duplicate alias: %s", h.Alias),
}
}
aliases[h.Alias] = true
return nil
}
func validatePreset(p *Preset, index int, aliases map[string]bool) error {
fieldPrefix := fmt.Sprintf("presets[%d]", index)
if strings.TrimSpace(p.Name) == "" {
return &ValidationError{
Field: fieldPrefix + ".name",
Message: "preset name is required",
}
}
// Note: We don't validate preset aliases strictly anymore.
// Unknown aliases in presets will simply be skipped when applying the preset.
// This allows presets to survive when hosts are removed from the config.
return nil
}
// ValidateDomain checks if a domain name is valid.
func ValidateDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}
// ValidateIP checks if an IP address is valid (IPv4 or IPv6).
func ValidateIP(ip string) bool {
if ip == "" {
return false
}
return net.ParseIP(ip) != nil
}
// ValidateAlias checks if an alias is valid.
func ValidateAlias(alias string) bool {
if alias == "" {
return false
}
return aliasRegex.MatchString(alias)
}
// IsBlockedDomain checks if a domain is in the blocklist.
func IsBlockedDomain(domain string) bool {
domain = strings.ToLower(domain)
// Check exact match
if blockedDomains[domain] {
return true
}
// Check if it's a subdomain of a blocked domain
for blocked := range blockedDomains {
if strings.HasSuffix(domain, "."+blocked) {
return true
}
}
return false
}
// GetBlockedDomains returns a copy of the blocked domains list.
func GetBlockedDomains() []string {
domains := make([]string, 0, len(blockedDomains))
for d := range blockedDomains {
domains = append(domains, d)
}
return domains
}
+436
View File
@@ -0,0 +1,436 @@
package config
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateDomain(t *testing.T) {
tests := []struct {
domain string
valid bool
}{
{"example.com", true},
{"sub.example.com", true},
{"my-app.example.com", true},
{"localhost", true},
{"a.b.c.d.example.com", true},
{"example123.com", true},
{"", false},
{"-example.com", false},
{"example-.com", false},
{"example.c", false}, // TLD too short
{"example", false}, // No TLD
{".example.com", false},
{"example..com", false},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
result := ValidateDomain(tt.domain)
assert.Equal(t, tt.valid, result, "domain: %s", tt.domain)
})
}
}
func TestValidateIP(t *testing.T) {
tests := []struct {
ip string
valid bool
}{
// Valid IPv4
{"127.0.0.1", true},
{"192.168.1.1", true},
{"0.0.0.0", true},
{"255.255.255.255", true},
// Valid IPv6
{"::1", true},
{"2001:db8::1", true},
{"fe80::1", true},
{"::ffff:192.168.1.1", true},
// Invalid
{"", false},
{"256.0.0.1", false},
{"192.168.1", false},
{"not-an-ip", false},
{"192.168.1.1.1", false},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := ValidateIP(tt.ip)
assert.Equal(t, tt.valid, result, "ip: %s", tt.ip)
})
}
}
func TestValidateAlias(t *testing.T) {
tests := []struct {
alias string
valid bool
}{
{"my-alias", true},
{"myalias", true},
{"my_alias", true},
{"alias123", true},
{"a", true},
{"a-b_c-d", true},
{"", false},
{"-startswithdash", false},
{"_startswithunderscore", false},
{"has spaces", false},
{"has.dot", false},
}
for _, tt := range tests {
t.Run(tt.alias, func(t *testing.T) {
result := ValidateAlias(tt.alias)
assert.Equal(t, tt.valid, result, "alias: %s", tt.alias)
})
}
}
func TestIsBlockedDomain(t *testing.T) {
tests := []struct {
domain string
blocked bool
}{
// Blocked domains
{"apple.com", true},
{"icloud.com", true},
{"sub.apple.com", true},
{"deep.sub.icloud.com", true},
{"APPLE.COM", true}, // Case insensitive
// Allowed domains
{"example.com", false},
{"myapp.com", false},
{"applestore.com", false}, // Not a subdomain
{"notapple.com", false},
}
for _, tt := range tests {
t.Run(tt.domain, func(t *testing.T) {
result := IsBlockedDomain(tt.domain)
assert.Equal(t, tt.blocked, result, "domain: %s", tt.domain)
})
}
}
func TestGetBlockedDomains(t *testing.T) {
domains := GetBlockedDomains()
assert.NotEmpty(t, domains)
assert.Contains(t, domains, "apple.com")
assert.Contains(t, domains, "icloud.com")
}
func TestValidateConfig(t *testing.T) {
t.Run("valid config", func(t *testing.T) {
cfg := &Config{
Settings: Settings{
AutoApply: true,
FlushMethod: FlushMethodAuto,
},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"example"}, Disable: []string{}},
},
}
err := ValidateConfig(cfg)
assert.NoError(t, err)
})
t.Run("nil config", func(t *testing.T) {
err := ValidateConfig(nil)
assert.Error(t, err)
})
t.Run("invalid flush method", func(t *testing.T) {
cfg := &Config{
Settings: Settings{FlushMethod: "invalid"},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("empty group name", func(t *testing.T) {
cfg := &Config{
Groups: []Group{{Name: "", Hosts: []Host{}}},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid domain", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "invalid", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("blocked domain", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "apple.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid IP", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "invalid", Alias: "test", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("invalid alias", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "-invalid", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("duplicate alias", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "a.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
{Domain: "b.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
},
},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("empty preset name", func(t *testing.T) {
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "", Enable: []string{}},
},
}
err := ValidateConfig(cfg)
assert.Error(t, err)
})
t.Run("preset with unknown alias is allowed", func(t *testing.T) {
// Unknown aliases in presets are now allowed (they're simply skipped when applied)
// This allows presets to survive when hosts are removed from the config
cfg := &Config{
Groups: []Group{
{
Name: "dev",
Hosts: []Host{
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"unknown"}},
},
}
err := ValidateConfig(cfg)
assert.NoError(t, err)
})
}
func TestValidationError(t *testing.T) {
err := &ValidationError{Field: "test.field", Message: "test message"}
assert.Equal(t, "test.field: test message", err.Error())
}
func TestValidateSettings(t *testing.T) {
tests := []struct {
name string
method FlushMethod
wantErr bool
}{
{"auto", FlushMethodAuto, false},
{"dscacheutil", FlushMethodDscacheutil, false},
{"killall", FlushMethodKillall, false},
{"both", FlushMethodBoth, false},
{"empty", "", false},
{"invalid", "invalid", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
settings := &Settings{FlushMethod: tt.method}
err := validateSettings(settings)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// Matrix testing for domain validation
func TestValidateDomain_Matrix(t *testing.T) {
prefixes := []string{"", "sub.", "a.b."}
domains := []string{"example", "my-app", "test123"}
tlds := []string{".com", ".io", ".co.uk", ".dev"}
for _, prefix := range prefixes {
for _, domain := range domains {
for _, tld := range tlds {
fullDomain := prefix + domain + tld
t.Run(fullDomain, func(t *testing.T) {
result := ValidateDomain(fullDomain)
assert.True(t, result, "expected %s to be valid", fullDomain)
})
}
}
}
}
// Matrix testing for IP validation
func TestValidateIP_Matrix(t *testing.T) {
octets := []string{"0", "127", "192", "255"}
for _, o1 := range octets {
for _, o2 := range octets {
for _, o3 := range octets {
for _, o4 := range octets {
ip := o1 + "." + o2 + "." + o3 + "." + o4
t.Run(ip, func(t *testing.T) {
result := ValidateIP(ip)
assert.True(t, result, "expected %s to be valid", ip)
})
}
}
}
}
}
// Benchmark tests
func BenchmarkValidateDomain(b *testing.B) {
domains := []string{
"example.com",
"sub.example.com",
"very.long.subdomain.chain.example.com",
}
for _, domain := range domains {
b.Run(domain, func(b *testing.B) {
for i := 0; i < b.N; i++ {
ValidateDomain(domain)
}
})
}
}
func BenchmarkValidateIP(b *testing.B) {
ips := []string{
"127.0.0.1",
"192.168.1.1",
"::1",
"2001:db8::1",
}
for _, ip := range ips {
b.Run(ip, func(b *testing.B) {
for i := 0; i < b.N; i++ {
ValidateIP(ip)
}
})
}
}
func BenchmarkIsBlockedDomain(b *testing.B) {
domains := []string{
"example.com", // not blocked
"apple.com", // blocked
"sub.icloud.com", // blocked subdomain
}
for _, domain := range domains {
b.Run(domain, func(b *testing.B) {
for i := 0; i < b.N; i++ {
IsBlockedDomain(domain)
}
})
}
}
func BenchmarkValidateConfig(b *testing.B) {
cfg := &Config{
Settings: Settings{AutoApply: true, FlushMethod: FlushMethodAuto},
Groups: []Group{
{
Name: "development",
Hosts: []Host{
{Domain: "a.example.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
{Domain: "b.example.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
{Domain: "c.example.com", IP: "127.0.0.1", Alias: "c", Enabled: false},
},
},
},
Presets: []Preset{
{Name: "local", Enable: []string{"a", "b"}, Disable: []string{"c"}},
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
err := ValidateConfig(cfg)
require.NoError(b, err)
}
}