mirror of
https://github.com/lukaszraczylo/lolcathost.git
synced 2026-06-11 00:08:57 +00:00
Initial commit.
This commit is contained in:
@@ -0,0 +1,541 @@
|
||||
// Package config handles YAML configuration parsing and hot-reload.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// SystemConfigDir is the system-wide config directory for the daemon.
|
||||
const SystemConfigDir = "/etc/lolcathost"
|
||||
|
||||
// SystemConfigPath is the system-wide config file path for the daemon.
|
||||
const SystemConfigPath = "/etc/lolcathost/config.yaml"
|
||||
|
||||
// DefaultConfigDir returns the default config directory path for users.
|
||||
func DefaultConfigDir() string {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(home, ".config", "lolcathost")
|
||||
}
|
||||
|
||||
// DefaultConfigPath returns the default config file path for users.
|
||||
func DefaultConfigPath() string {
|
||||
return filepath.Join(DefaultConfigDir(), "config.yaml")
|
||||
}
|
||||
|
||||
// FlushMethod defines DNS cache flush methods.
|
||||
type FlushMethod string
|
||||
|
||||
const (
|
||||
FlushMethodAuto FlushMethod = "auto"
|
||||
FlushMethodDscacheutil FlushMethod = "dscacheutil"
|
||||
FlushMethodKillall FlushMethod = "killall"
|
||||
FlushMethodBoth FlushMethod = "both"
|
||||
)
|
||||
|
||||
// Settings holds global configuration settings.
|
||||
type Settings struct {
|
||||
AutoApply bool `yaml:"autoApply"`
|
||||
FlushMethod FlushMethod `yaml:"flushMethod"`
|
||||
}
|
||||
|
||||
// Host represents a single host entry in configuration.
|
||||
type Host struct {
|
||||
Domain string `yaml:"domain"`
|
||||
IP string `yaml:"ip"`
|
||||
Alias string `yaml:"alias"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// Group represents a group of host entries.
|
||||
type Group struct {
|
||||
Name string `yaml:"name"`
|
||||
Hosts []Host `yaml:"hosts"`
|
||||
}
|
||||
|
||||
// Preset defines a named preset that enables/disables specific aliases.
|
||||
type Preset struct {
|
||||
Name string `yaml:"name"`
|
||||
Enable []string `yaml:"enable,omitempty"`
|
||||
Disable []string `yaml:"disable,omitempty"`
|
||||
}
|
||||
|
||||
// Config represents the complete configuration.
|
||||
type Config struct {
|
||||
Settings Settings `yaml:"settings"`
|
||||
Groups []Group `yaml:"groups"`
|
||||
Presets []Preset `yaml:"presets"`
|
||||
}
|
||||
|
||||
// Manager handles configuration loading and watching.
|
||||
type Manager struct {
|
||||
path string
|
||||
config *Config
|
||||
mu sync.RWMutex
|
||||
watcher *fsnotify.Watcher
|
||||
onChange func(*Config)
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewManager creates a new config manager.
|
||||
func NewManager(path string) *Manager {
|
||||
return &Manager{
|
||||
path: path,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads and parses the configuration file.
|
||||
func (m *Manager) Load() error {
|
||||
data, err := os.ReadFile(m.path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
if err := ValidateConfig(&cfg); err != nil {
|
||||
return fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
m.config = &cfg
|
||||
m.mu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get returns the current configuration.
|
||||
func (m *Manager) Get() *Config {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.config
|
||||
}
|
||||
|
||||
// Watch starts watching the config file for changes.
|
||||
func (m *Manager) Watch(onChange func(*Config)) error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create watcher: %w", err)
|
||||
}
|
||||
|
||||
m.watcher = watcher
|
||||
m.onChange = onChange
|
||||
|
||||
go m.watchLoop()
|
||||
|
||||
if err := watcher.Add(m.path); err != nil {
|
||||
return fmt.Errorf("failed to watch config file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) watchLoop() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-m.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write) || event.Has(fsnotify.Create) {
|
||||
if err := m.Load(); err == nil && m.onChange != nil {
|
||||
m.onChange(m.Get())
|
||||
}
|
||||
}
|
||||
case <-m.watcher.Errors:
|
||||
// Ignore watcher errors
|
||||
case <-m.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops watching the config file.
|
||||
func (m *Manager) Stop() {
|
||||
close(m.stopCh)
|
||||
if m.watcher != nil {
|
||||
m.watcher.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllHosts returns all hosts from all groups.
|
||||
func (c *Config) GetAllHosts() []Host {
|
||||
var hosts []Host
|
||||
for _, g := range c.Groups {
|
||||
hosts = append(hosts, g.Hosts...)
|
||||
}
|
||||
return hosts
|
||||
}
|
||||
|
||||
// FindHostByAlias finds a host by its alias.
|
||||
func (c *Config) FindHostByAlias(alias string) (*Host, *Group) {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
return &c.Groups[i].Hosts[j], &c.Groups[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// FindPreset finds a preset by name.
|
||||
func (c *Config) FindPreset(name string) *Preset {
|
||||
for i := range c.Presets {
|
||||
if c.Presets[i].Name == name {
|
||||
return &c.Presets[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHostEnabled sets the enabled state of a host by alias.
|
||||
func (c *Config) SetHostEnabled(alias string, enabled bool) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts[j].Enabled = enabled
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GenerateAlias creates a unique alias from a domain name.
|
||||
func (c *Config) GenerateAlias(domain string) string {
|
||||
// Convert domain to alias format: example.com -> example-com
|
||||
alias := strings.ReplaceAll(domain, ".", "-")
|
||||
alias = strings.ReplaceAll(alias, "_", "-")
|
||||
alias = strings.ToLower(alias)
|
||||
|
||||
// Check if alias exists, if so append a number
|
||||
baseAlias := alias
|
||||
counter := 1
|
||||
for {
|
||||
if existing, _ := c.FindHostByAlias(alias); existing == nil {
|
||||
break
|
||||
}
|
||||
counter++
|
||||
alias = fmt.Sprintf("%s-%d", baseAlias, counter)
|
||||
}
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
// AddHost adds a new host to the configuration.
|
||||
func (c *Config) AddHost(domain, ip, alias, groupName string, enabled bool) error {
|
||||
// Auto-generate alias if empty
|
||||
if alias == "" {
|
||||
alias = c.GenerateAlias(domain)
|
||||
} else {
|
||||
// Check for duplicate alias
|
||||
if existing, _ := c.FindHostByAlias(alias); existing != nil {
|
||||
return fmt.Errorf("alias already exists: %s", alias)
|
||||
}
|
||||
}
|
||||
|
||||
host := Host{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: alias,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
// Find or create group
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == groupName {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create new group
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: groupName,
|
||||
Hosts: []Host{host},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddGroup adds a new empty group.
|
||||
func (c *Config) AddGroup(name string) error {
|
||||
// Check if group already exists
|
||||
for _, g := range c.Groups {
|
||||
if g.Name == name {
|
||||
return fmt.Errorf("group already exists: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: name,
|
||||
Hosts: []Host{},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup removes a group and all its hosts.
|
||||
func (c *Config) DeleteGroup(name string) error {
|
||||
for i, g := range c.Groups {
|
||||
if g.Name == name {
|
||||
c.Groups = append(c.Groups[:i], c.Groups[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("group not found: %s", name)
|
||||
}
|
||||
|
||||
// RenameGroup renames an existing group.
|
||||
func (c *Config) RenameGroup(oldName, newName string) error {
|
||||
// Check if new name already exists
|
||||
for _, g := range c.Groups {
|
||||
if g.Name == newName {
|
||||
return fmt.Errorf("group already exists: %s", newName)
|
||||
}
|
||||
}
|
||||
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == oldName {
|
||||
c.Groups[i].Name = newName
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("group not found: %s", oldName)
|
||||
}
|
||||
|
||||
// GetGroups returns all group names.
|
||||
func (c *Config) GetGroups() []string {
|
||||
names := make([]string, len(c.Groups))
|
||||
for i, g := range c.Groups {
|
||||
names[i] = g.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// DeleteHost removes a host by alias.
|
||||
func (c *Config) DeleteHost(alias string) bool {
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == alias {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts[:j], c.Groups[i].Hosts[j+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// UpdateHost updates an existing host by alias.
|
||||
func (c *Config) UpdateHost(oldAlias, domain, ip, newAlias, groupName string) error {
|
||||
// Find the host
|
||||
var foundGroup int = -1
|
||||
var foundHost int = -1
|
||||
for i := range c.Groups {
|
||||
for j := range c.Groups[i].Hosts {
|
||||
if c.Groups[i].Hosts[j].Alias == oldAlias {
|
||||
foundGroup = i
|
||||
foundHost = j
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundHost >= 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if foundHost < 0 {
|
||||
return fmt.Errorf("alias not found: %s", oldAlias)
|
||||
}
|
||||
|
||||
// Check for duplicate alias if alias is changing
|
||||
if oldAlias != newAlias {
|
||||
if existing, _ := c.FindHostByAlias(newAlias); existing != nil {
|
||||
return fmt.Errorf("alias already exists: %s", newAlias)
|
||||
}
|
||||
}
|
||||
|
||||
// Get current enabled state
|
||||
enabled := c.Groups[foundGroup].Hosts[foundHost].Enabled
|
||||
|
||||
// If group is changing, move to new group
|
||||
if c.Groups[foundGroup].Name != groupName {
|
||||
// Remove from old group
|
||||
c.Groups[foundGroup].Hosts = append(c.Groups[foundGroup].Hosts[:foundHost], c.Groups[foundGroup].Hosts[foundHost+1:]...)
|
||||
|
||||
// Add to new group
|
||||
host := Host{
|
||||
Domain: domain,
|
||||
IP: ip,
|
||||
Alias: newAlias,
|
||||
Enabled: enabled,
|
||||
}
|
||||
|
||||
// Find or create target group
|
||||
found := false
|
||||
for i := range c.Groups {
|
||||
if c.Groups[i].Name == groupName {
|
||||
c.Groups[i].Hosts = append(c.Groups[i].Hosts, host)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: groupName,
|
||||
Hosts: []Host{host},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// Update in place
|
||||
c.Groups[foundGroup].Hosts[foundHost].Domain = domain
|
||||
c.Groups[foundGroup].Hosts[foundHost].IP = ip
|
||||
c.Groups[foundGroup].Hosts[foundHost].Alias = newAlias
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPreset applies a preset to the configuration.
|
||||
func (c *Config) ApplyPreset(name string) error {
|
||||
preset := c.FindPreset(name)
|
||||
if preset == nil {
|
||||
return fmt.Errorf("preset not found: %s", name)
|
||||
}
|
||||
|
||||
for _, alias := range preset.Enable {
|
||||
c.SetHostEnabled(alias, true)
|
||||
}
|
||||
for _, alias := range preset.Disable {
|
||||
c.SetHostEnabled(alias, false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPreset adds a new preset.
|
||||
func (c *Config) AddPreset(name string, enable, disable []string) error {
|
||||
// Check if preset already exists
|
||||
for _, p := range c.Presets {
|
||||
if p.Name == name {
|
||||
return fmt.Errorf("preset already exists: %s", name)
|
||||
}
|
||||
}
|
||||
|
||||
c.Presets = append(c.Presets, Preset{
|
||||
Name: name,
|
||||
Enable: enable,
|
||||
Disable: disable,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePreset removes a preset by name.
|
||||
func (c *Config) DeletePreset(name string) error {
|
||||
for i, p := range c.Presets {
|
||||
if p.Name == name {
|
||||
c.Presets = append(c.Presets[:i], c.Presets[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("preset not found: %s", name)
|
||||
}
|
||||
|
||||
// GetPresets returns all presets.
|
||||
func (c *Config) GetPresets() []Preset {
|
||||
return c.Presets
|
||||
}
|
||||
|
||||
// EnsureDefaultGroup ensures at least one group exists, creating "default" if needed.
|
||||
func (c *Config) EnsureDefaultGroup() {
|
||||
if len(c.Groups) == 0 {
|
||||
c.Groups = append(c.Groups, Group{
|
||||
Name: "default",
|
||||
Hosts: []Host{},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Save writes the configuration to the file.
|
||||
func (m *Manager) Save() error {
|
||||
m.mu.RLock()
|
||||
cfg := m.config
|
||||
m.mu.RUnlock()
|
||||
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("no config loaded")
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(m.path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDefault creates a default configuration file.
|
||||
func CreateDefault(path string) error {
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create config directory: %w", err)
|
||||
}
|
||||
|
||||
cfg := &Config{
|
||||
Settings: Settings{
|
||||
AutoApply: true,
|
||||
FlushMethod: FlushMethodAuto,
|
||||
},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{
|
||||
Domain: "example.local",
|
||||
IP: "127.0.0.1",
|
||||
Alias: "example-local",
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{
|
||||
Name: "local",
|
||||
Enable: []string{"example-local"},
|
||||
Disable: []string{},
|
||||
},
|
||||
{
|
||||
Name: "clear",
|
||||
Enable: []string{},
|
||||
Disable: []string{"example-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal default config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write default config: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConfig_GetAllHosts(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: false},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "staging",
|
||||
Hosts: []Host{
|
||||
{Domain: "c.com", IP: "192.168.1.1", Alias: "c", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
hosts := cfg.GetAllHosts()
|
||||
assert.Len(t, hosts, 3)
|
||||
assert.Equal(t, "a.com", hosts[0].Domain)
|
||||
assert.Equal(t, "b.com", hosts[1].Domain)
|
||||
assert.Equal(t, "c.com", hosts[2].Domain)
|
||||
}
|
||||
|
||||
func TestConfig_FindHostByAlias(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
host, group := cfg.FindHostByAlias("example")
|
||||
require.NotNil(t, host)
|
||||
require.NotNil(t, group)
|
||||
assert.Equal(t, "example.com", host.Domain)
|
||||
assert.Equal(t, "dev", group.Name)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
host, group := cfg.FindHostByAlias("nonexistent")
|
||||
assert.Nil(t, host)
|
||||
assert.Nil(t, group)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_FindPreset(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"a"}, Disable: []string{"b"}},
|
||||
{Name: "staging", Enable: []string{"b"}, Disable: []string{"a"}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
preset := cfg.FindPreset("local")
|
||||
require.NotNil(t, preset)
|
||||
assert.Equal(t, "local", preset.Name)
|
||||
assert.Equal(t, []string{"a"}, preset.Enable)
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
preset := cfg.FindPreset("nonexistent")
|
||||
assert.Nil(t, preset)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_SetHostEnabled(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("enable existing", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("example", true)
|
||||
assert.True(t, result)
|
||||
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
})
|
||||
|
||||
t.Run("disable existing", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("example", false)
|
||||
assert.True(t, result)
|
||||
assert.False(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
})
|
||||
|
||||
t.Run("nonexistent alias", func(t *testing.T) {
|
||||
result := cfg.SetHostEnabled("nonexistent", true)
|
||||
assert.False(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestConfig_ApplyPreset(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "a", Enabled: false},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "swap", Enable: []string{"a"}, Disable: []string{"b"}},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("valid preset", func(t *testing.T) {
|
||||
err := cfg.ApplyPreset("swap")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, cfg.Groups[0].Hosts[0].Enabled)
|
||||
assert.False(t, cfg.Groups[0].Hosts[1].Enabled)
|
||||
})
|
||||
|
||||
t.Run("nonexistent preset", func(t *testing.T) {
|
||||
err := cfg.ApplyPreset("nonexistent")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestManager_LoadAndGet(t *testing.T) {
|
||||
// Create temp config file
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
configContent := `
|
||||
settings:
|
||||
autoApply: true
|
||||
flushMethod: auto
|
||||
groups:
|
||||
- name: development
|
||||
hosts:
|
||||
- domain: example.com
|
||||
ip: 127.0.0.1
|
||||
alias: example-local
|
||||
enabled: true
|
||||
presets:
|
||||
- name: local
|
||||
enable: [example-local]
|
||||
disable: []
|
||||
`
|
||||
err := os.WriteFile(configPath, []byte(configContent), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.True(t, cfg.Settings.AutoApply)
|
||||
assert.Equal(t, FlushMethodAuto, cfg.Settings.FlushMethod)
|
||||
assert.Len(t, cfg.Groups, 1)
|
||||
assert.Equal(t, "development", cfg.Groups[0].Name)
|
||||
assert.Len(t, cfg.Groups[0].Hosts, 1)
|
||||
assert.Equal(t, "example.com", cfg.Groups[0].Hosts[0].Domain)
|
||||
}
|
||||
|
||||
func TestManager_Save(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
// Create initial config
|
||||
err := CreateDefault(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Load and modify
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
cfg.Groups[0].Hosts[0].Enabled = true
|
||||
|
||||
// Save
|
||||
err = manager.Save()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reload and verify
|
||||
manager2 := NewManager(configPath)
|
||||
err = manager2.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg2 := manager2.Get()
|
||||
assert.True(t, cfg2.Groups[0].Hosts[0].Enabled)
|
||||
}
|
||||
|
||||
func TestCreateDefault(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "subdir", "config.yaml")
|
||||
|
||||
err := CreateDefault(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify file exists
|
||||
_, err = os.Stat(configPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify content is valid
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := manager.Get()
|
||||
require.NotNil(t, cfg)
|
||||
assert.True(t, cfg.Settings.AutoApply)
|
||||
assert.Len(t, cfg.Groups, 1)
|
||||
assert.Len(t, cfg.Presets, 2)
|
||||
}
|
||||
|
||||
func TestManager_Load_InvalidYAML(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
configPath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
err := os.WriteFile(configPath, []byte("invalid: yaml: content:"), 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager := NewManager(configPath)
|
||||
err = manager.Load()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Load_FileNotFound(t *testing.T) {
|
||||
manager := NewManager("/nonexistent/path/config.yaml")
|
||||
err := manager.Load()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFlushMethod(t *testing.T) {
|
||||
methods := []FlushMethod{
|
||||
FlushMethodAuto,
|
||||
FlushMethodDscacheutil,
|
||||
FlushMethodKillall,
|
||||
FlushMethodBoth,
|
||||
}
|
||||
|
||||
for _, m := range methods {
|
||||
t.Run(string(m), func(t *testing.T) {
|
||||
assert.NotEmpty(t, string(m))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,211 @@
|
||||
// Package config provides validation functions for configuration.
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// domainRegex validates domain names.
|
||||
var domainRegex = regexp.MustCompile(`^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$|^localhost$`)
|
||||
|
||||
// aliasRegex validates alias names.
|
||||
var aliasRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]{0,62}$`)
|
||||
|
||||
// blockedDomains contains domains that cannot be modified.
|
||||
var blockedDomains = map[string]bool{
|
||||
"apple.com": true,
|
||||
"icloud.com": true,
|
||||
"icloud-content.com": true,
|
||||
"apple-dns.cn": true,
|
||||
"apple-dns.net": true,
|
||||
"mzstatic.com": true,
|
||||
"itunes.apple.com": true,
|
||||
"updates.apple.com": true,
|
||||
}
|
||||
|
||||
// ValidationError represents a configuration validation error.
|
||||
type ValidationError struct {
|
||||
Field string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return fmt.Sprintf("%s: %s", e.Field, e.Message)
|
||||
}
|
||||
|
||||
// ValidateConfig validates the entire configuration.
|
||||
func ValidateConfig(cfg *Config) error {
|
||||
if cfg == nil {
|
||||
return &ValidationError{Field: "config", Message: "config is nil"}
|
||||
}
|
||||
|
||||
if err := validateSettings(&cfg.Settings); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Track aliases for uniqueness
|
||||
aliases := make(map[string]bool)
|
||||
|
||||
for i, g := range cfg.Groups {
|
||||
if err := validateGroup(&g, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
for i, p := range cfg.Presets {
|
||||
if err := validatePreset(&p, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSettings(s *Settings) error {
|
||||
switch s.FlushMethod {
|
||||
case FlushMethodAuto, FlushMethodDscacheutil, FlushMethodKillall, FlushMethodBoth, "":
|
||||
// Valid
|
||||
default:
|
||||
return &ValidationError{
|
||||
Field: "settings.flushMethod",
|
||||
Message: fmt.Sprintf("invalid flush method: %s", s.FlushMethod),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateGroup(g *Group, index int, aliases map[string]bool) error {
|
||||
if strings.TrimSpace(g.Name) == "" {
|
||||
return &ValidationError{
|
||||
Field: fmt.Sprintf("groups[%d].name", index),
|
||||
Message: "group name is required",
|
||||
}
|
||||
}
|
||||
|
||||
for i, h := range g.Hosts {
|
||||
if err := validateHost(&h, index, i, aliases); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateHost(h *Host, groupIndex, hostIndex int, aliases map[string]bool) error {
|
||||
fieldPrefix := fmt.Sprintf("groups[%d].hosts[%d]", groupIndex, hostIndex)
|
||||
|
||||
// Validate domain
|
||||
if !ValidateDomain(h.Domain) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".domain",
|
||||
Message: fmt.Sprintf("invalid domain: %s", h.Domain),
|
||||
}
|
||||
}
|
||||
|
||||
// Check blocked domains
|
||||
if IsBlockedDomain(h.Domain) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".domain",
|
||||
Message: fmt.Sprintf("domain is blocked: %s", h.Domain),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate IP
|
||||
if !ValidateIP(h.IP) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".ip",
|
||||
Message: fmt.Sprintf("invalid IP address: %s", h.IP),
|
||||
}
|
||||
}
|
||||
|
||||
// Validate alias
|
||||
if !ValidateAlias(h.Alias) {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".alias",
|
||||
Message: fmt.Sprintf("invalid alias: %s", h.Alias),
|
||||
}
|
||||
}
|
||||
|
||||
// Check alias uniqueness
|
||||
if aliases[h.Alias] {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".alias",
|
||||
Message: fmt.Sprintf("duplicate alias: %s", h.Alias),
|
||||
}
|
||||
}
|
||||
aliases[h.Alias] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePreset(p *Preset, index int, aliases map[string]bool) error {
|
||||
fieldPrefix := fmt.Sprintf("presets[%d]", index)
|
||||
|
||||
if strings.TrimSpace(p.Name) == "" {
|
||||
return &ValidationError{
|
||||
Field: fieldPrefix + ".name",
|
||||
Message: "preset name is required",
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We don't validate preset aliases strictly anymore.
|
||||
// Unknown aliases in presets will simply be skipped when applying the preset.
|
||||
// This allows presets to survive when hosts are removed from the config.
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateDomain checks if a domain name is valid.
|
||||
func ValidateDomain(domain string) bool {
|
||||
if domain == "" {
|
||||
return false
|
||||
}
|
||||
return domainRegex.MatchString(domain)
|
||||
}
|
||||
|
||||
// ValidateIP checks if an IP address is valid (IPv4 or IPv6).
|
||||
func ValidateIP(ip string) bool {
|
||||
if ip == "" {
|
||||
return false
|
||||
}
|
||||
return net.ParseIP(ip) != nil
|
||||
}
|
||||
|
||||
// ValidateAlias checks if an alias is valid.
|
||||
func ValidateAlias(alias string) bool {
|
||||
if alias == "" {
|
||||
return false
|
||||
}
|
||||
return aliasRegex.MatchString(alias)
|
||||
}
|
||||
|
||||
// IsBlockedDomain checks if a domain is in the blocklist.
|
||||
func IsBlockedDomain(domain string) bool {
|
||||
domain = strings.ToLower(domain)
|
||||
|
||||
// Check exact match
|
||||
if blockedDomains[domain] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if it's a subdomain of a blocked domain
|
||||
for blocked := range blockedDomains {
|
||||
if strings.HasSuffix(domain, "."+blocked) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// GetBlockedDomains returns a copy of the blocked domains list.
|
||||
func GetBlockedDomains() []string {
|
||||
domains := make([]string, 0, len(blockedDomains))
|
||||
for d := range blockedDomains {
|
||||
domains = append(domains, d)
|
||||
}
|
||||
return domains
|
||||
}
|
||||
@@ -0,0 +1,436 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
valid bool
|
||||
}{
|
||||
{"example.com", true},
|
||||
{"sub.example.com", true},
|
||||
{"my-app.example.com", true},
|
||||
{"localhost", true},
|
||||
{"a.b.c.d.example.com", true},
|
||||
{"example123.com", true},
|
||||
|
||||
{"", false},
|
||||
{"-example.com", false},
|
||||
{"example-.com", false},
|
||||
{"example.c", false}, // TLD too short
|
||||
{"example", false}, // No TLD
|
||||
{".example.com", false},
|
||||
{"example..com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
result := ValidateDomain(tt.domain)
|
||||
assert.Equal(t, tt.valid, result, "domain: %s", tt.domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
ip string
|
||||
valid bool
|
||||
}{
|
||||
// Valid IPv4
|
||||
{"127.0.0.1", true},
|
||||
{"192.168.1.1", true},
|
||||
{"0.0.0.0", true},
|
||||
{"255.255.255.255", true},
|
||||
|
||||
// Valid IPv6
|
||||
{"::1", true},
|
||||
{"2001:db8::1", true},
|
||||
{"fe80::1", true},
|
||||
{"::ffff:192.168.1.1", true},
|
||||
|
||||
// Invalid
|
||||
{"", false},
|
||||
{"256.0.0.1", false},
|
||||
{"192.168.1", false},
|
||||
{"not-an-ip", false},
|
||||
{"192.168.1.1.1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
result := ValidateIP(tt.ip)
|
||||
assert.Equal(t, tt.valid, result, "ip: %s", tt.ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
alias string
|
||||
valid bool
|
||||
}{
|
||||
{"my-alias", true},
|
||||
{"myalias", true},
|
||||
{"my_alias", true},
|
||||
{"alias123", true},
|
||||
{"a", true},
|
||||
{"a-b_c-d", true},
|
||||
|
||||
{"", false},
|
||||
{"-startswithdash", false},
|
||||
{"_startswithunderscore", false},
|
||||
{"has spaces", false},
|
||||
{"has.dot", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.alias, func(t *testing.T) {
|
||||
result := ValidateAlias(tt.alias)
|
||||
assert.Equal(t, tt.valid, result, "alias: %s", tt.alias)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
domain string
|
||||
blocked bool
|
||||
}{
|
||||
// Blocked domains
|
||||
{"apple.com", true},
|
||||
{"icloud.com", true},
|
||||
{"sub.apple.com", true},
|
||||
{"deep.sub.icloud.com", true},
|
||||
{"APPLE.COM", true}, // Case insensitive
|
||||
|
||||
// Allowed domains
|
||||
{"example.com", false},
|
||||
{"myapp.com", false},
|
||||
{"applestore.com", false}, // Not a subdomain
|
||||
{"notapple.com", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.domain, func(t *testing.T) {
|
||||
result := IsBlockedDomain(tt.domain)
|
||||
assert.Equal(t, tt.blocked, result, "domain: %s", tt.domain)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetBlockedDomains(t *testing.T) {
|
||||
domains := GetBlockedDomains()
|
||||
assert.NotEmpty(t, domains)
|
||||
assert.Contains(t, domains, "apple.com")
|
||||
assert.Contains(t, domains, "icloud.com")
|
||||
}
|
||||
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{
|
||||
AutoApply: true,
|
||||
FlushMethod: FlushMethodAuto,
|
||||
},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "example", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"example"}, Disable: []string{}},
|
||||
},
|
||||
}
|
||||
|
||||
err := ValidateConfig(cfg)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nil config", func(t *testing.T) {
|
||||
err := ValidateConfig(nil)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid flush method", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{FlushMethod: "invalid"},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty group name", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{{Name: "", Hosts: []Host{}}},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid domain", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "invalid", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("blocked domain", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "apple.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid IP", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "invalid", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("invalid alias", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "-invalid", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("duplicate alias", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
|
||||
{Domain: "b.com", IP: "127.0.0.1", Alias: "same", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty preset name", func(t *testing.T) {
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "", Enable: []string{}},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("preset with unknown alias is allowed", func(t *testing.T) {
|
||||
// Unknown aliases in presets are now allowed (they're simply skipped when applied)
|
||||
// This allows presets to survive when hosts are removed from the config
|
||||
cfg := &Config{
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "dev",
|
||||
Hosts: []Host{
|
||||
{Domain: "example.com", IP: "127.0.0.1", Alias: "test", Enabled: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"unknown"}},
|
||||
},
|
||||
}
|
||||
err := ValidateConfig(cfg)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidationError(t *testing.T) {
|
||||
err := &ValidationError{Field: "test.field", Message: "test message"}
|
||||
assert.Equal(t, "test.field: test message", err.Error())
|
||||
}
|
||||
|
||||
func TestValidateSettings(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
method FlushMethod
|
||||
wantErr bool
|
||||
}{
|
||||
{"auto", FlushMethodAuto, false},
|
||||
{"dscacheutil", FlushMethodDscacheutil, false},
|
||||
{"killall", FlushMethodKillall, false},
|
||||
{"both", FlushMethodBoth, false},
|
||||
{"empty", "", false},
|
||||
{"invalid", "invalid", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
settings := &Settings{FlushMethod: tt.method}
|
||||
err := validateSettings(settings)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix testing for domain validation
|
||||
func TestValidateDomain_Matrix(t *testing.T) {
|
||||
prefixes := []string{"", "sub.", "a.b."}
|
||||
domains := []string{"example", "my-app", "test123"}
|
||||
tlds := []string{".com", ".io", ".co.uk", ".dev"}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
for _, domain := range domains {
|
||||
for _, tld := range tlds {
|
||||
fullDomain := prefix + domain + tld
|
||||
t.Run(fullDomain, func(t *testing.T) {
|
||||
result := ValidateDomain(fullDomain)
|
||||
assert.True(t, result, "expected %s to be valid", fullDomain)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix testing for IP validation
|
||||
func TestValidateIP_Matrix(t *testing.T) {
|
||||
octets := []string{"0", "127", "192", "255"}
|
||||
|
||||
for _, o1 := range octets {
|
||||
for _, o2 := range octets {
|
||||
for _, o3 := range octets {
|
||||
for _, o4 := range octets {
|
||||
ip := o1 + "." + o2 + "." + o3 + "." + o4
|
||||
t.Run(ip, func(t *testing.T) {
|
||||
result := ValidateIP(ip)
|
||||
assert.True(t, result, "expected %s to be valid", ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkValidateDomain(b *testing.B) {
|
||||
domains := []string{
|
||||
"example.com",
|
||||
"sub.example.com",
|
||||
"very.long.subdomain.chain.example.com",
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
b.Run(domain, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateDomain(domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateIP(b *testing.B) {
|
||||
ips := []string{
|
||||
"127.0.0.1",
|
||||
"192.168.1.1",
|
||||
"::1",
|
||||
"2001:db8::1",
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
b.Run(ip, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ValidateIP(ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIsBlockedDomain(b *testing.B) {
|
||||
domains := []string{
|
||||
"example.com", // not blocked
|
||||
"apple.com", // blocked
|
||||
"sub.icloud.com", // blocked subdomain
|
||||
}
|
||||
|
||||
for _, domain := range domains {
|
||||
b.Run(domain, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
IsBlockedDomain(domain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkValidateConfig(b *testing.B) {
|
||||
cfg := &Config{
|
||||
Settings: Settings{AutoApply: true, FlushMethod: FlushMethodAuto},
|
||||
Groups: []Group{
|
||||
{
|
||||
Name: "development",
|
||||
Hosts: []Host{
|
||||
{Domain: "a.example.com", IP: "127.0.0.1", Alias: "a", Enabled: true},
|
||||
{Domain: "b.example.com", IP: "127.0.0.1", Alias: "b", Enabled: true},
|
||||
{Domain: "c.example.com", IP: "127.0.0.1", Alias: "c", Enabled: false},
|
||||
},
|
||||
},
|
||||
},
|
||||
Presets: []Preset{
|
||||
{Name: "local", Enable: []string{"a", "b"}, Disable: []string{"c"}},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
err := ValidateConfig(cfg)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user